""" StruCTA: Full model composition. Abstraction → Structured Encoder → Privacy Verification → Reasoning Decoder → De-Abstraction """ import torch import torch.nn as nn from typing import Optional, Dict, Any, Tuple from .config import StruCTAConfig from .encoder import PrivacyGraphTransformer from .decoder import StructuredReasoningDecoder from .privacy import PrivacyVerificationModule from .abstraction import AbstractionLayer, AbstractDocument from .deabstraction import DeAbstractionLayer class StruCTA(nn.Module): """ Full StruCTA model for end-to-end privacy-preserving reasoning. """ def __init__(self, config: StruCTAConfig): super().__init__() self.config = config self.encoder = PrivacyGraphTransformer(config) self.decoder = StructuredReasoningDecoder(config) if config.use_privacy_verification: self.privacy_module = PrivacyVerificationModule(config) else: self.privacy_module = None self.abstration = AbstractionLayer(use_ner_model=False) self.deabstraction = DeAbstractionLayer() def forward(self, node_features, node_types=None, degree=None, spd=None, edge_index=None, edge_types=None, decoder_input_ids=None, graph_positions=None, attention_mask=None, encoder_mask=None): encoder_hidden = self.encoder( node_features, node_types=node_types, degree=degree, spd=spd, edge_index=edge_index, edge_types=edge_types, attention_mask=attention_mask, key_padding_mask=encoder_mask, ) if self.privacy_module is not None: report = self.privacy_module.verify(encoder_hidden) if not report.struct_ok or not report.tokens_ok: return { "blocked": True, "report": report, "logits": None, "encoder_hidden": encoder_hidden, } if decoder_input_ids is not None: logits = self.decoder( decoder_input_ids, encoder_hidden, graph_positions=graph_positions, encoder_mask=encoder_mask, ) else: logits = None return { "encoder_hidden": encoder_hidden, "logits": logits, "blocked": False, "report": report if self.privacy_module else None, } @torch.no_grad() def generate_from_text(self, raw_text, max_length=128, temperature=0.8, top_k=50, top_p=0.9): abstract_doc = self.abstration.abstract(raw_text) self.deabstraction.register_vault( abstract_doc.vault_id, self.abstration.retrieve_vault(abstract_doc.vault_id) ) amr = abstract_doc.amr_graph num_nodes = len(amr["nodes"]) device = next(self.parameters()).device node_features = torch.randn(1, num_nodes, 10, device=device) node_types = torch.zeros(1, num_nodes, dtype=torch.long, device=device) degree = torch.ones(1, num_nodes, dtype=torch.long, device=device) * 2 spd = torch.zeros(1, num_nodes, num_nodes, dtype=torch.long, device=device) encoder_hidden = self.encoder(node_features, node_types=node_types, degree=degree, spd=spd) if self.privacy_module is not None: report = self.privacy_module.verify(encoder_hidden) if not report.struct_ok or not report.tokens_ok: return { "abstract_answer": "[BLOCKED: Privacy violation detected]", "concrete_answer": "[BLOCKED]", "report": report, } start_id = 0 abstract_ids = self.decoder.generate( encoder_hidden, start_token_id=start_id, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, ) abstract_answer = " ".join([f"<{token.item()}>" for token in abstract_ids[0]]) try: concrete_answer = self.deabstraction.deabstract( abstract_answer, abstract_doc.vault_id) except ValueError: concrete_answer = "[DE-ABSTRUCTION FAILED: Vault not found]" return { "abstract_answer": abstract_answer, "concrete_answer": concrete_answer, "report": report, "vault_id": abstract_doc.vault_id, } def load_pretrained_encoder(self, state_dict): self.encoder.load_state_dict(state_dict, strict=False) def freeze_encoder(self): for param in self.encoder.parameters(): param.requires_grad = False def enable_dp_training(self): self.config.use_dp_training = True def get_stats(self) -> Dict[str, int]: total = sum(p.numel() for p in self.parameters()) trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) return { "total_params": total, "trainable_params": trainable, "encoder_layers": self.config.num_encoder_layers, "decoder_layers": self.config.num_decoder_layers, "hidden_dim": self.config.hidden_dim, }