| """ |
| StruCTA: Full model composition. |
| Abstraction → Structured Encoder → Privacy Verification → Reasoning Decoder → De-Abstraction |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Optional, Dict, Any, Tuple |
|
|
| from .config import StruCTAConfig |
| from .encoder import PrivacyGraphTransformer |
| from .decoder import StructuredReasoningDecoder |
| from .privacy import PrivacyVerificationModule |
| from .abstraction import AbstractionLayer, AbstractDocument |
| from .deabstraction import DeAbstractionLayer |
|
|
|
|
| class StruCTA(nn.Module): |
| """ |
| Full StruCTA model for end-to-end privacy-preserving reasoning. |
| """ |
|
|
| def __init__(self, config: StruCTAConfig): |
| super().__init__() |
| self.config = config |
| self.encoder = PrivacyGraphTransformer(config) |
| self.decoder = StructuredReasoningDecoder(config) |
| if config.use_privacy_verification: |
| self.privacy_module = PrivacyVerificationModule(config) |
| else: |
| self.privacy_module = None |
| self.abstration = AbstractionLayer(use_ner_model=False) |
| self.deabstraction = DeAbstractionLayer() |
|
|
| def forward(self, node_features, node_types=None, degree=None, spd=None, |
| edge_index=None, edge_types=None, decoder_input_ids=None, |
| graph_positions=None, attention_mask=None, encoder_mask=None): |
| encoder_hidden = self.encoder( |
| node_features, node_types=node_types, degree=degree, spd=spd, |
| edge_index=edge_index, edge_types=edge_types, |
| attention_mask=attention_mask, key_padding_mask=encoder_mask, |
| ) |
| if self.privacy_module is not None: |
| report = self.privacy_module.verify(encoder_hidden) |
| if not report.struct_ok or not report.tokens_ok: |
| return { |
| "blocked": True, "report": report, |
| "logits": None, "encoder_hidden": encoder_hidden, |
| } |
| if decoder_input_ids is not None: |
| logits = self.decoder( |
| decoder_input_ids, encoder_hidden, |
| graph_positions=graph_positions, encoder_mask=encoder_mask, |
| ) |
| else: |
| logits = None |
| return { |
| "encoder_hidden": encoder_hidden, "logits": logits, |
| "blocked": False, "report": report if self.privacy_module else None, |
| } |
|
|
| @torch.no_grad() |
| def generate_from_text(self, raw_text, max_length=128, temperature=0.8, |
| top_k=50, top_p=0.9): |
| abstract_doc = self.abstration.abstract(raw_text) |
| self.deabstraction.register_vault( |
| abstract_doc.vault_id, |
| self.abstration.retrieve_vault(abstract_doc.vault_id) |
| ) |
| amr = abstract_doc.amr_graph |
| num_nodes = len(amr["nodes"]) |
| device = next(self.parameters()).device |
| node_features = torch.randn(1, num_nodes, 10, device=device) |
| node_types = torch.zeros(1, num_nodes, dtype=torch.long, device=device) |
| degree = torch.ones(1, num_nodes, dtype=torch.long, device=device) * 2 |
| spd = torch.zeros(1, num_nodes, num_nodes, dtype=torch.long, device=device) |
| encoder_hidden = self.encoder(node_features, node_types=node_types, |
| degree=degree, spd=spd) |
| if self.privacy_module is not None: |
| report = self.privacy_module.verify(encoder_hidden) |
| if not report.struct_ok or not report.tokens_ok: |
| return { |
| "abstract_answer": "[BLOCKED: Privacy violation detected]", |
| "concrete_answer": "[BLOCKED]", |
| "report": report, |
| } |
| start_id = 0 |
| abstract_ids = self.decoder.generate( |
| encoder_hidden, start_token_id=start_id, |
| max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, |
| ) |
| abstract_answer = " ".join([f"<{token.item()}>" for token in abstract_ids[0]]) |
| try: |
| concrete_answer = self.deabstraction.deabstract( |
| abstract_answer, abstract_doc.vault_id) |
| except ValueError: |
| concrete_answer = "[DE-ABSTRUCTION FAILED: Vault not found]" |
| return { |
| "abstract_answer": abstract_answer, |
| "concrete_answer": concrete_answer, |
| "report": report, |
| "vault_id": abstract_doc.vault_id, |
| } |
|
|
| def load_pretrained_encoder(self, state_dict): |
| self.encoder.load_state_dict(state_dict, strict=False) |
|
|
| def freeze_encoder(self): |
| for param in self.encoder.parameters(): |
| param.requires_grad = False |
|
|
| def enable_dp_training(self): |
| self.config.use_dp_training = True |
|
|
| def get_stats(self) -> Dict[str, int]: |
| total = sum(p.numel() for p in self.parameters()) |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| return { |
| "total_params": total, |
| "trainable_params": trainable, |
| "encoder_layers": self.config.num_encoder_layers, |
| "decoder_layers": self.config.num_decoder_layers, |
| "hidden_dim": self.config.hidden_dim, |
| } |
|
|