File size: 5,186 Bytes
3238dca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
"""
StruCTA: Full model composition.
Abstraction → Structured Encoder → Privacy Verification → Reasoning Decoder → De-Abstraction
"""

import torch
import torch.nn as nn
from typing import Optional, Dict, Any, Tuple

from .config import StruCTAConfig
from .encoder import PrivacyGraphTransformer
from .decoder import StructuredReasoningDecoder
from .privacy import PrivacyVerificationModule
from .abstraction import AbstractionLayer, AbstractDocument
from .deabstraction import DeAbstractionLayer


class StruCTA(nn.Module):
    """
    Full StruCTA model for end-to-end privacy-preserving reasoning.
    """

    def __init__(self, config: StruCTAConfig):
        super().__init__()
        self.config = config
        self.encoder = PrivacyGraphTransformer(config)
        self.decoder = StructuredReasoningDecoder(config)
        if config.use_privacy_verification:
            self.privacy_module = PrivacyVerificationModule(config)
        else:
            self.privacy_module = None
        self.abstration = AbstractionLayer(use_ner_model=False)
        self.deabstraction = DeAbstractionLayer()

    def forward(self, node_features, node_types=None, degree=None, spd=None,
                edge_index=None, edge_types=None, decoder_input_ids=None,
                graph_positions=None, attention_mask=None, encoder_mask=None):
        encoder_hidden = self.encoder(
            node_features, node_types=node_types, degree=degree, spd=spd,
            edge_index=edge_index, edge_types=edge_types,
            attention_mask=attention_mask, key_padding_mask=encoder_mask,
        )
        if self.privacy_module is not None:
            report = self.privacy_module.verify(encoder_hidden)
            if not report.struct_ok or not report.tokens_ok:
                return {
                    "blocked": True, "report": report,
                    "logits": None, "encoder_hidden": encoder_hidden,
                }
        if decoder_input_ids is not None:
            logits = self.decoder(
                decoder_input_ids, encoder_hidden,
                graph_positions=graph_positions, encoder_mask=encoder_mask,
            )
        else:
            logits = None
        return {
            "encoder_hidden": encoder_hidden, "logits": logits,
            "blocked": False, "report": report if self.privacy_module else None,
        }

    @torch.no_grad()
    def generate_from_text(self, raw_text, max_length=128, temperature=0.8,
                          top_k=50, top_p=0.9):
        abstract_doc = self.abstration.abstract(raw_text)
        self.deabstraction.register_vault(
            abstract_doc.vault_id,
            self.abstration.retrieve_vault(abstract_doc.vault_id)
        )
        amr = abstract_doc.amr_graph
        num_nodes = len(amr["nodes"])
        device = next(self.parameters()).device
        node_features = torch.randn(1, num_nodes, 10, device=device)
        node_types = torch.zeros(1, num_nodes, dtype=torch.long, device=device)
        degree = torch.ones(1, num_nodes, dtype=torch.long, device=device) * 2
        spd = torch.zeros(1, num_nodes, num_nodes, dtype=torch.long, device=device)
        encoder_hidden = self.encoder(node_features, node_types=node_types,
                                      degree=degree, spd=spd)
        if self.privacy_module is not None:
            report = self.privacy_module.verify(encoder_hidden)
            if not report.struct_ok or not report.tokens_ok:
                return {
                    "abstract_answer": "[BLOCKED: Privacy violation detected]",
                    "concrete_answer": "[BLOCKED]",
                    "report": report,
                }
        start_id = 0
        abstract_ids = self.decoder.generate(
            encoder_hidden, start_token_id=start_id,
            max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p,
        )
        abstract_answer = " ".join([f"<{token.item()}>" for token in abstract_ids[0]])
        try:
            concrete_answer = self.deabstraction.deabstract(
                abstract_answer, abstract_doc.vault_id)
        except ValueError:
            concrete_answer = "[DE-ABSTRUCTION FAILED: Vault not found]"
        return {
            "abstract_answer": abstract_answer,
            "concrete_answer": concrete_answer,
            "report": report,
            "vault_id": abstract_doc.vault_id,
        }

    def load_pretrained_encoder(self, state_dict):
        self.encoder.load_state_dict(state_dict, strict=False)

    def freeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = False

    def enable_dp_training(self):
        self.config.use_dp_training = True

    def get_stats(self) -> Dict[str, int]:
        total = sum(p.numel() for p in self.parameters())
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return {
            "total_params": total,
            "trainable_params": trainable,
            "encoder_layers": self.config.num_encoder_layers,
            "decoder_layers": self.config.num_decoder_layers,
            "hidden_dim": self.config.hidden_dim,
        }