FerrellSyntheticIntelligence commited on
Commit
6c284cd
·
1 Parent(s): 3746aeb

feat: implement sovereign teacher-forcing training loop

Browse files
Files changed (1) hide show
  1. brain.py +126 -0
brain.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import numpy as np
4
+ from typing import Iterable, Tuple
5
+ from fluid_transformer import FluidTransformer
6
+ from abstract_reasoner import AbstractReasoner
7
+ from science_reasoner import ScienceReasoner
8
+ from concept_graph import ConceptGraph
9
+ from free_energy import FreeEnergyEngine
10
+ from ledger import Ledger
11
+ from CyberCore import CyberCore
12
+
13
+ # ----------------------------------------------------------------------
14
+ # Tokenizer Helpers
15
+ # ----------------------------------------------------------------------
16
+ def _tokenize(text: str) -> torch.Tensor:
17
+ tokenizer = _init_transformer().tokenizer
18
+ ids = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")
19
+ return ids.squeeze(0)
20
+
21
+ def _decode(ids: torch.Tensor) -> str:
22
+ tokenizer = _init_transformer().tokenizer
23
+ return tokenizer.decode(ids.tolist(), skip_special_tokens=True)
24
+
25
+ # ----------------------------------------------------------------------
26
+ # Teacher Forcing Loop
27
+ # ----------------------------------------------------------------------
28
+ def execute_teacher_forcing(batch: Iterable[Tuple[str, str]], max_steps: int = 12, learning_rate: float = 1e-4) -> float:
29
+ transformer = _init_transformer()
30
+ transformer.train()
31
+ optimizer = torch.optim.SGD(transformer.parameters(), lr=learning_rate)
32
+ total_loss, n_examples = 0.0, 0
33
+
34
+ for inp, tgt in batch:
35
+ try:
36
+ final_node, _ = process_input(inp, max_depth=max_steps)
37
+ except Exception as e:
38
+ _init_ledger().record(action="train_reject", payload={"input": inp, "error": str(e)})
39
+ continue
40
+
41
+ pred_text = final_node.label
42
+ pred_ids = _tokenize(pred_text).unsqueeze(0)
43
+ tgt_ids = _tokenize(tgt).unsqueeze(0)
44
+
45
+ pad_id = transformer.tokenizer.pad_token_id
46
+ max_len = max(pred_ids.size(1), tgt_ids.size(1))
47
+ pred_ids = F.pad(pred_ids, (0, max_len - pred_ids.size(1)), value=pad_id)
48
+ tgt_ids = F.pad(tgt_ids, (0, max_len - tgt_ids.size(1)), value=pad_id)
49
+
50
+ logits = transformer(pred_ids)
51
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt_ids.view(-1), ignore_index=pad_id)
52
+
53
+ optimizer.zero_grad()
54
+ loss.backward()
55
+ optimizer.step()
56
+
57
+ _init_ledger().record(action="train_step", payload={"input": inp, "target": tgt, "prediction": pred_text, "loss": float(loss.item())})
58
+ total_loss += loss.item()
59
+ n_examples += 1
60
+
61
+ transformer.eval()
62
+ return total_loss / max(1, n_examples)
63
+
64
+ # ----------------------------------------------------------------------
65
+ # Singletons
66
+ # ----------------------------------------------------------------------
67
+ _transformer = _abstract_reasoner = _science_reasoner = _concept_graph = _free_energy = _ledger = _cybercore = None
68
+
69
+ def _init_transformer():
70
+ global _transformer
71
+ if _transformer is None:
72
+ _transformer = FluidTransformer.from_pretrained("models/fluid_transformer")
73
+ _transformer.eval()
74
+ return _transformer
75
+
76
+ def _init_abstract():
77
+ global _abstract_reasoner
78
+ if _abstract_reasoner is None: _abstract_reasoner = AbstractReasoner(chunk_size=8, top_k=6)
79
+ return _abstract_reasoner
80
+
81
+ def _init_science():
82
+ global _science_reasoner, _concept_graph
83
+ if _science_reasoner is None: _science_reasoner = ScienceReasoner(_init_graph())
84
+ return _science_reasoner
85
+
86
+ def _init_graph():
87
+ global _concept_graph
88
+ if _concept_graph is None: _concept_graph = ConceptGraph(dim=768, persist_dir="data/concept_graph")
89
+ return _concept_graph
90
+
91
+ def _init_free_energy():
92
+ global _free_energy
93
+ if _free_energy is None: _free_energy = FreeEnergyEngine()
94
+ return _free_energy
95
+
96
+ def _init_ledger():
97
+ global _ledger
98
+ if _ledger is None: _ledger = Ledger(chain_name="FSI_Sovereign")
99
+ return _ledger
100
+
101
+ def _init_cybercore():
102
+ global _cybercore
103
+ if _cybercore is None: _cybercore = CyberCore()
104
+ return _cybercore
105
+
106
+ def process_input(text: str, max_depth: int = 12):
107
+ if _init_cybercore().is_malicious(text): raise ValueError("Input rejected")
108
+ hidden = _init_transformer().encode(text).squeeze(0).cpu().numpy()
109
+ hidden /= np.linalg.norm(hidden)
110
+ propositions, raw_steps = _init_abstract().reason(hidden)
111
+ final_node = _init_science().infer(propositions, raw_steps, max_depth=max_depth)
112
+
113
+ involved_cids = {final_node.cid}
114
+ for src, tgt, _w in final_node.edges: involved_cids.add(tgt)
115
+ embeddings = [final_node.embedding] + [_init_graph().get_node(cid).embedding for cid in involved_cids if cid != final_node.cid]
116
+
117
+ loss = _init_free_energy().compute_loss(embeddings)
118
+ avg = sum(_init_ledger().get_recent_losses(100)) / (len(_init_ledger().get_recent_losses(100)) + 1e-9)
119
+
120
+ if loss <= avg * 1.10:
121
+ _init_ledger().record(action="commit", payload={"text": text, "final_node": final_node.cid, "label": final_node.label, "confidence": final_node.confidence, "free_energy": loss})
122
+ _init_graph().persist()
123
+ else:
124
+ _init_ledger().record(action="reject", payload={"text": text, "reason": "free_energy_too_high"})
125
+ raise RuntimeError("Knowledge rejected")
126
+ return final_node, raw_steps