File size: 5,186 Bytes
3238dca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | """
StruCTA: Full model composition.
Abstraction → Structured Encoder → Privacy Verification → Reasoning Decoder → De-Abstraction
"""
import torch
import torch.nn as nn
from typing import Optional, Dict, Any, Tuple
from .config import StruCTAConfig
from .encoder import PrivacyGraphTransformer
from .decoder import StructuredReasoningDecoder
from .privacy import PrivacyVerificationModule
from .abstraction import AbstractionLayer, AbstractDocument
from .deabstraction import DeAbstractionLayer
class StruCTA(nn.Module):
"""
Full StruCTA model for end-to-end privacy-preserving reasoning.
"""
def __init__(self, config: StruCTAConfig):
super().__init__()
self.config = config
self.encoder = PrivacyGraphTransformer(config)
self.decoder = StructuredReasoningDecoder(config)
if config.use_privacy_verification:
self.privacy_module = PrivacyVerificationModule(config)
else:
self.privacy_module = None
self.abstration = AbstractionLayer(use_ner_model=False)
self.deabstraction = DeAbstractionLayer()
def forward(self, node_features, node_types=None, degree=None, spd=None,
edge_index=None, edge_types=None, decoder_input_ids=None,
graph_positions=None, attention_mask=None, encoder_mask=None):
encoder_hidden = self.encoder(
node_features, node_types=node_types, degree=degree, spd=spd,
edge_index=edge_index, edge_types=edge_types,
attention_mask=attention_mask, key_padding_mask=encoder_mask,
)
if self.privacy_module is not None:
report = self.privacy_module.verify(encoder_hidden)
if not report.struct_ok or not report.tokens_ok:
return {
"blocked": True, "report": report,
"logits": None, "encoder_hidden": encoder_hidden,
}
if decoder_input_ids is not None:
logits = self.decoder(
decoder_input_ids, encoder_hidden,
graph_positions=graph_positions, encoder_mask=encoder_mask,
)
else:
logits = None
return {
"encoder_hidden": encoder_hidden, "logits": logits,
"blocked": False, "report": report if self.privacy_module else None,
}
@torch.no_grad()
def generate_from_text(self, raw_text, max_length=128, temperature=0.8,
top_k=50, top_p=0.9):
abstract_doc = self.abstration.abstract(raw_text)
self.deabstraction.register_vault(
abstract_doc.vault_id,
self.abstration.retrieve_vault(abstract_doc.vault_id)
)
amr = abstract_doc.amr_graph
num_nodes = len(amr["nodes"])
device = next(self.parameters()).device
node_features = torch.randn(1, num_nodes, 10, device=device)
node_types = torch.zeros(1, num_nodes, dtype=torch.long, device=device)
degree = torch.ones(1, num_nodes, dtype=torch.long, device=device) * 2
spd = torch.zeros(1, num_nodes, num_nodes, dtype=torch.long, device=device)
encoder_hidden = self.encoder(node_features, node_types=node_types,
degree=degree, spd=spd)
if self.privacy_module is not None:
report = self.privacy_module.verify(encoder_hidden)
if not report.struct_ok or not report.tokens_ok:
return {
"abstract_answer": "[BLOCKED: Privacy violation detected]",
"concrete_answer": "[BLOCKED]",
"report": report,
}
start_id = 0
abstract_ids = self.decoder.generate(
encoder_hidden, start_token_id=start_id,
max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p,
)
abstract_answer = " ".join([f"<{token.item()}>" for token in abstract_ids[0]])
try:
concrete_answer = self.deabstraction.deabstract(
abstract_answer, abstract_doc.vault_id)
except ValueError:
concrete_answer = "[DE-ABSTRUCTION FAILED: Vault not found]"
return {
"abstract_answer": abstract_answer,
"concrete_answer": concrete_answer,
"report": report,
"vault_id": abstract_doc.vault_id,
}
def load_pretrained_encoder(self, state_dict):
self.encoder.load_state_dict(state_dict, strict=False)
def freeze_encoder(self):
for param in self.encoder.parameters():
param.requires_grad = False
def enable_dp_training(self):
self.config.use_dp_training = True
def get_stats(self) -> Dict[str, int]:
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {
"total_params": total,
"trainable_params": trainable,
"encoder_layers": self.config.num_encoder_layers,
"decoder_layers": self.config.num_decoder_layers,
"hidden_dim": self.config.hidden_dim,
}
|