StruCTA / structa /model.py
YOUSSEF88's picture
Upload structa/model.py
3238dca verified
"""
StruCTA: Full model composition.
Abstraction → Structured Encoder → Privacy Verification → Reasoning Decoder → De-Abstraction
"""
import torch
import torch.nn as nn
from typing import Optional, Dict, Any, Tuple
from .config import StruCTAConfig
from .encoder import PrivacyGraphTransformer
from .decoder import StructuredReasoningDecoder
from .privacy import PrivacyVerificationModule
from .abstraction import AbstractionLayer, AbstractDocument
from .deabstraction import DeAbstractionLayer
class StruCTA(nn.Module):
"""
Full StruCTA model for end-to-end privacy-preserving reasoning.
"""
def __init__(self, config: StruCTAConfig):
super().__init__()
self.config = config
self.encoder = PrivacyGraphTransformer(config)
self.decoder = StructuredReasoningDecoder(config)
if config.use_privacy_verification:
self.privacy_module = PrivacyVerificationModule(config)
else:
self.privacy_module = None
self.abstration = AbstractionLayer(use_ner_model=False)
self.deabstraction = DeAbstractionLayer()
def forward(self, node_features, node_types=None, degree=None, spd=None,
edge_index=None, edge_types=None, decoder_input_ids=None,
graph_positions=None, attention_mask=None, encoder_mask=None):
encoder_hidden = self.encoder(
node_features, node_types=node_types, degree=degree, spd=spd,
edge_index=edge_index, edge_types=edge_types,
attention_mask=attention_mask, key_padding_mask=encoder_mask,
)
if self.privacy_module is not None:
report = self.privacy_module.verify(encoder_hidden)
if not report.struct_ok or not report.tokens_ok:
return {
"blocked": True, "report": report,
"logits": None, "encoder_hidden": encoder_hidden,
}
if decoder_input_ids is not None:
logits = self.decoder(
decoder_input_ids, encoder_hidden,
graph_positions=graph_positions, encoder_mask=encoder_mask,
)
else:
logits = None
return {
"encoder_hidden": encoder_hidden, "logits": logits,
"blocked": False, "report": report if self.privacy_module else None,
}
@torch.no_grad()
def generate_from_text(self, raw_text, max_length=128, temperature=0.8,
top_k=50, top_p=0.9):
abstract_doc = self.abstration.abstract(raw_text)
self.deabstraction.register_vault(
abstract_doc.vault_id,
self.abstration.retrieve_vault(abstract_doc.vault_id)
)
amr = abstract_doc.amr_graph
num_nodes = len(amr["nodes"])
device = next(self.parameters()).device
node_features = torch.randn(1, num_nodes, 10, device=device)
node_types = torch.zeros(1, num_nodes, dtype=torch.long, device=device)
degree = torch.ones(1, num_nodes, dtype=torch.long, device=device) * 2
spd = torch.zeros(1, num_nodes, num_nodes, dtype=torch.long, device=device)
encoder_hidden = self.encoder(node_features, node_types=node_types,
degree=degree, spd=spd)
if self.privacy_module is not None:
report = self.privacy_module.verify(encoder_hidden)
if not report.struct_ok or not report.tokens_ok:
return {
"abstract_answer": "[BLOCKED: Privacy violation detected]",
"concrete_answer": "[BLOCKED]",
"report": report,
}
start_id = 0
abstract_ids = self.decoder.generate(
encoder_hidden, start_token_id=start_id,
max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p,
)
abstract_answer = " ".join([f"<{token.item()}>" for token in abstract_ids[0]])
try:
concrete_answer = self.deabstraction.deabstract(
abstract_answer, abstract_doc.vault_id)
except ValueError:
concrete_answer = "[DE-ABSTRUCTION FAILED: Vault not found]"
return {
"abstract_answer": abstract_answer,
"concrete_answer": concrete_answer,
"report": report,
"vault_id": abstract_doc.vault_id,
}
def load_pretrained_encoder(self, state_dict):
self.encoder.load_state_dict(state_dict, strict=False)
def freeze_encoder(self):
for param in self.encoder.parameters():
param.requires_grad = False
def enable_dp_training(self):
self.config.use_dp_training = True
def get_stats(self) -> Dict[str, int]:
total = sum(p.numel() for p in self.parameters())
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
return {
"total_params": total,
"trainable_params": trainable,
"encoder_layers": self.config.num_encoder_layers,
"decoder_layers": self.config.num_decoder_layers,
"hidden_dim": self.config.hidden_dim,
}