MetaCortex-Dynamics commited on
Commit
9ba436b
Β·
verified Β·
1 Parent(s): 175b7be

Create pipeline/mdlm/decoder.py

Browse files
Files changed (1) hide show
  1. pipeline/mdlm/decoder.py +293 -0
pipeline/mdlm/decoder.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phase 5: Constrained Decoder β€” EXECUTE phase of the GGP.
3
+
4
+ Takes a committed governed structure (from PROMOTE) and generates
5
+ natural language within the validity envelope.
6
+
7
+ Architecture: Small transformer decoder conditioned on governed operator
8
+ tokens. The governed structure is the prompt; the output is prose that
9
+ expresses the structure in natural language.
10
+
11
+ This is NOT a general-purpose LLM. It generates governed prose β€”
12
+ text whose semantic content is constrained to what the governed permits.
13
+ The decoder cannot introduce implicit authority structures because
14
+ the governed frame doesn't encode them.
15
+
16
+ Training data: (structure tokens, source text) pairs extracted from
17
+ the decomposition pipeline.
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ import os
24
+ from dataclasses import dataclass
25
+ from pathlib import Path
26
+ from typing import Iterator
27
+
28
+ try:
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ HAS_TORCH = True
33
+ except ImportError:
34
+ HAS_TORCH = False
35
+
36
+ from pipeline.mdlm.tokenizer import (
37
+ encode as encode_gov, VOCAB_SIZE as STRUCT_VOCAB_SIZE,
38
+ BOS, EOS, PAD, TOKEN_NAMES,
39
+ )
40
+
41
+
42
+ # ═══════════════════════════════════════════════════════════════════════════════
43
+ # PAIRED DATA EXTRACTION
44
+ # ═══════════════════════════════════════════════════════════════════════════════
45
+
46
+ @dataclass
47
+ class FrameProsePair:
48
+ """A (governed structure, source prose) pair for decoder training."""
49
+ gov_tokens: list[int] # Encoded governed structure
50
+ prose: str # Original source text
51
+ source_id: str
52
+
53
+
54
+ def extract_pairs_from_pipeline(
55
+ corpus_dir: str | Path,
56
+ theory_dir: str | Path | None = None,
57
+ ) -> list[FrameProsePair]:
58
+ """Extract structure-prose pairs by re-running the pipeline with text capture.
59
+
60
+ Since the emitted JSONL doesn't store the original text, we re-run
61
+ the decomposition and capture both the governed and the source segment.
62
+ """
63
+ import sys
64
+ sys.path.insert(0, ".")
65
+
66
+ from pipeline.ingest.chat_archive import ingest_conversation_file
67
+ from pipeline.stages.s2_classify import classify, Classification
68
+ from pipeline.stages.s3_decompose import decompose
69
+ from pipeline.stages.s4_validate import validate_and_score, TigStatus, Verdict
70
+
71
+ pairs = []
72
+
73
+ if theory_dir:
74
+ theory_path = Path(theory_dir)
75
+ for conv_file in sorted(theory_path.glob("conv_*.json")):
76
+ try:
77
+ for seg in ingest_conversation_file(conv_file):
78
+ c = classify(seg)
79
+ if c.classification != Classification.TECHNICAL:
80
+ continue
81
+ ex = decompose(c)
82
+ if ex is None:
83
+ continue
84
+ r = validate_and_score(ex)
85
+ if r.tig_status != TigStatus.TRUE:
86
+ continue
87
+
88
+ # Build pair
89
+ struct_dict = {
90
+ "channel_a": {"operators": [
91
+ {"operator": e.operator.canonical_name, "evidence": e.evidence}
92
+ for e in ex.channel_a.operators.expressions
93
+ ]},
94
+ "channel_b": {"operators": [
95
+ {"operator": e.operator.canonical_name, "evidence": e.evidence}
96
+ for e in ex.channel_b.operators.expressions
97
+ ]},
98
+ "channel_c": {"operators": [
99
+ {"operator": e.operator.canonical_name, "evidence": e.evidence}
100
+ for e in ex.channel_c.operators.expressions
101
+ ]},
102
+ "witnesses": {
103
+ w.canonical_name: {"attested": a.attested, "evidence": a.evidence}
104
+ for w, a in ex.witnesses.attestations.items()
105
+ },
106
+ }
107
+ gov_tokens = encode_gov(struct_dict)
108
+
109
+ pairs.append(FrameProsePair(
110
+ gov_tokens=gov_tokens,
111
+ prose=seg.text[:512], # Cap at 512 chars for training
112
+ source_id=ex.provenance.source_id,
113
+ ))
114
+ except Exception:
115
+ continue
116
+
117
+ return pairs
118
+
119
+
120
+ # ═══════════════════════════════════════════════════════════════════════════════
121
+ # PROSE TOKENIZER (character-level for simplicity)
122
+ # ══════════════════��════════════════════════════════════════════════════════════
123
+
124
+ PROSE_PAD = 0
125
+ PROSE_BOS = 1
126
+ PROSE_EOS = 2
127
+ PROSE_UNK = 3
128
+ PROSE_VOCAB_OFFSET = 4
129
+
130
+ # Build vocab from printable ASCII + common unicode
131
+ PROSE_CHARS = (
132
+ " !\"#$%&'()*+,-./0123456789:;<=>?@"
133
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`"
134
+ "abcdefghijklmnopqrstuvwxyz{|}~"
135
+ )
136
+ PROSE_VOCAB_SIZE = PROSE_VOCAB_OFFSET + len(PROSE_CHARS)
137
+ _CHAR_TO_ID = {c: i + PROSE_VOCAB_OFFSET for i, c in enumerate(PROSE_CHARS)}
138
+
139
+
140
+ def encode_prose(text: str, max_len: int = 256) -> list[int]:
141
+ """Encode prose as character-level token IDs."""
142
+ tokens = [PROSE_BOS]
143
+ for ch in text[:max_len - 2]:
144
+ tokens.append(_CHAR_TO_ID.get(ch, PROSE_UNK))
145
+ tokens.append(PROSE_EOS)
146
+ return tokens
147
+
148
+
149
+ def decode_prose(token_ids: list[int]) -> str:
150
+ """Decode character-level token IDs back to text."""
151
+ id_to_char = {v: k for k, v in _CHAR_TO_ID.items()}
152
+ chars = []
153
+ for tid in token_ids:
154
+ if tid in (PROSE_PAD, PROSE_BOS, PROSE_EOS):
155
+ continue
156
+ if tid == PROSE_UNK:
157
+ chars.append("?")
158
+ else:
159
+ chars.append(id_to_char.get(tid, "?"))
160
+ return "".join(chars)
161
+
162
+
163
+ def pad_prose(tokens: list[int], max_len: int) -> list[int]:
164
+ """Pad or truncate prose tokens to fixed length."""
165
+ if len(tokens) >= max_len:
166
+ return tokens[:max_len]
167
+ return tokens + [PROSE_PAD] * (max_len - len(tokens))
168
+
169
+
170
+ # ═══════════════════════════════════════════════════════════════════════════════
171
+ # CONSTRAINED DECODER MODEL
172
+ # ═══════════════════════════════════════════════════════════════════════════════
173
+
174
+ if HAS_TORCH:
175
+
176
+ class ConstrainedDecoder(nn.Module):
177
+ """Transformer decoder conditioned on governed structure.
178
+
179
+ Encoder: processes governed token sequence (the committed structure)
180
+ Decoder: generates prose character-by-character within the envelope
181
+
182
+ The structure tokens serve as cross-attention keys β€” the decoder
183
+ can only attend to the committed structure, not to arbitrary context.
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ gov_vocab: int = STRUCT_VOCAB_SIZE,
189
+ prose_vocab: int = PROSE_VOCAB_SIZE,
190
+ d_model: int = 128,
191
+ nhead: int = 4,
192
+ num_encoder_layers: int = 2,
193
+ num_decoder_layers: int = 4,
194
+ max_struct_len: int = 40,
195
+ max_prose_len: int = 256,
196
+ dropout: float = 0.1,
197
+ ):
198
+ super().__init__()
199
+ self.d_model = d_model
200
+ self.max_prose_len = max_prose_len
201
+
202
+ # Encoder (governed structure)
203
+ self.struct_embedding = nn.Embedding(gov_vocab, d_model, padding_idx=PAD)
204
+ self.struct_pos = nn.Embedding(max_struct_len, d_model)
205
+ encoder_layer = nn.TransformerEncoderLayer(
206
+ d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
207
+ dropout=dropout, batch_first=True,
208
+ )
209
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
210
+
211
+ # Decoder (prose generation)
212
+ self.prose_embedding = nn.Embedding(prose_vocab, d_model, padding_idx=PROSE_PAD)
213
+ self.prose_pos = nn.Embedding(max_prose_len, d_model)
214
+ decoder_layer = nn.TransformerDecoderLayer(
215
+ d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
216
+ dropout=dropout, batch_first=True,
217
+ )
218
+ self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)
219
+ self.output_proj = nn.Linear(d_model, prose_vocab)
220
+
221
+ def forward(
222
+ self,
223
+ gov_tokens: torch.Tensor, # (B, struct_len)
224
+ prose_tokens: torch.Tensor, # (B, prose_len)
225
+ ) -> torch.Tensor:
226
+ """Forward pass. Returns logits (B, prose_len, prose_vocab)."""
227
+ B = gov_tokens.size(0)
228
+
229
+ # Encode governed structure
230
+ struct_len = gov_tokens.size(1)
231
+ struct_pos = torch.arange(struct_len, device=gov_tokens.device).unsqueeze(0).expand(B, -1)
232
+ struct_h = self.struct_embedding(gov_tokens) + self.struct_pos(struct_pos)
233
+ struct_pad_mask = (gov_tokens == PAD)
234
+ memory = self.encoder(struct_h, src_key_padding_mask=struct_pad_mask)
235
+
236
+ # Decode prose
237
+ prose_len = prose_tokens.size(1)
238
+ prose_pos = torch.arange(prose_len, device=prose_tokens.device).unsqueeze(0).expand(B, -1)
239
+ prose_h = self.prose_embedding(prose_tokens) + self.prose_pos(prose_pos)
240
+
241
+ # Causal mask for autoregressive generation
242
+ causal_mask = nn.Transformer.generate_square_subsequent_mask(prose_len, device=prose_tokens.device)
243
+ prose_pad_mask = (prose_tokens == PROSE_PAD)
244
+
245
+ decoded = self.decoder(
246
+ prose_h, memory,
247
+ tgt_mask=causal_mask,
248
+ tgt_key_padding_mask=prose_pad_mask,
249
+ memory_key_padding_mask=struct_pad_mask,
250
+ )
251
+ return self.output_proj(decoded)
252
+
253
+ def generate(
254
+ self,
255
+ gov_tokens: torch.Tensor, # (B, struct_len)
256
+ max_len: int = 200,
257
+ temperature: float = 0.8,
258
+ ) -> list[str]:
259
+ """Generate prose from governed structure."""
260
+ self.eval()
261
+ B = gov_tokens.size(0)
262
+ device = gov_tokens.device
263
+
264
+ # Encode governed
265
+ struct_len = gov_tokens.size(1)
266
+ struct_pos = torch.arange(struct_len, device=device).unsqueeze(0).expand(B, -1)
267
+ struct_h = self.struct_embedding(gov_tokens) + self.struct_pos(struct_pos)
268
+ struct_pad_mask = (gov_tokens == PAD)
269
+ memory = self.encoder(struct_h, src_key_padding_mask=struct_pad_mask)
270
+
271
+ # Autoregressive generation
272
+ generated = torch.full((B, 1), PROSE_BOS, dtype=torch.long, device=device)
273
+
274
+ with torch.no_grad():
275
+ for _ in range(max_len):
276
+ prose_len = generated.size(1)
277
+ prose_pos = torch.arange(prose_len, device=device).unsqueeze(0).expand(B, -1)
278
+ prose_h = self.prose_embedding(generated) + self.prose_pos(prose_pos)
279
+ causal_mask = nn.Transformer.generate_square_subsequent_mask(prose_len, device=device)
280
+
281
+ decoded = self.decoder(prose_h, memory, tgt_mask=causal_mask, memory_key_padding_mask=struct_pad_mask)
282
+ logits = self.output_proj(decoded[:, -1, :]) / temperature
283
+ probs = F.softmax(logits, dim=-1)
284
+ next_token = torch.multinomial(probs, 1)
285
+ generated = torch.cat([generated, next_token], dim=1)
286
+
287
+ if (next_token == PROSE_EOS).all():
288
+ break
289
+
290
+ results = []
291
+ for b in range(B):
292
+ results.append(decode_prose(generated[b].tolist()))
293
+ return results