FerrellSyntheticIntelligence commited on
Commit ·
6c284cd
1
Parent(s): 3746aeb
feat: implement sovereign teacher-forcing training loop
Browse files
brain.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import Iterable, Tuple
|
| 5 |
+
from fluid_transformer import FluidTransformer
|
| 6 |
+
from abstract_reasoner import AbstractReasoner
|
| 7 |
+
from science_reasoner import ScienceReasoner
|
| 8 |
+
from concept_graph import ConceptGraph
|
| 9 |
+
from free_energy import FreeEnergyEngine
|
| 10 |
+
from ledger import Ledger
|
| 11 |
+
from CyberCore import CyberCore
|
| 12 |
+
|
| 13 |
+
# ----------------------------------------------------------------------
|
| 14 |
+
# Tokenizer Helpers
|
| 15 |
+
# ----------------------------------------------------------------------
|
| 16 |
+
def _tokenize(text: str) -> torch.Tensor:
|
| 17 |
+
tokenizer = _init_transformer().tokenizer
|
| 18 |
+
ids = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt")
|
| 19 |
+
return ids.squeeze(0)
|
| 20 |
+
|
| 21 |
+
def _decode(ids: torch.Tensor) -> str:
|
| 22 |
+
tokenizer = _init_transformer().tokenizer
|
| 23 |
+
return tokenizer.decode(ids.tolist(), skip_special_tokens=True)
|
| 24 |
+
|
| 25 |
+
# ----------------------------------------------------------------------
|
| 26 |
+
# Teacher Forcing Loop
|
| 27 |
+
# ----------------------------------------------------------------------
|
| 28 |
+
def execute_teacher_forcing(batch: Iterable[Tuple[str, str]], max_steps: int = 12, learning_rate: float = 1e-4) -> float:
|
| 29 |
+
transformer = _init_transformer()
|
| 30 |
+
transformer.train()
|
| 31 |
+
optimizer = torch.optim.SGD(transformer.parameters(), lr=learning_rate)
|
| 32 |
+
total_loss, n_examples = 0.0, 0
|
| 33 |
+
|
| 34 |
+
for inp, tgt in batch:
|
| 35 |
+
try:
|
| 36 |
+
final_node, _ = process_input(inp, max_depth=max_steps)
|
| 37 |
+
except Exception as e:
|
| 38 |
+
_init_ledger().record(action="train_reject", payload={"input": inp, "error": str(e)})
|
| 39 |
+
continue
|
| 40 |
+
|
| 41 |
+
pred_text = final_node.label
|
| 42 |
+
pred_ids = _tokenize(pred_text).unsqueeze(0)
|
| 43 |
+
tgt_ids = _tokenize(tgt).unsqueeze(0)
|
| 44 |
+
|
| 45 |
+
pad_id = transformer.tokenizer.pad_token_id
|
| 46 |
+
max_len = max(pred_ids.size(1), tgt_ids.size(1))
|
| 47 |
+
pred_ids = F.pad(pred_ids, (0, max_len - pred_ids.size(1)), value=pad_id)
|
| 48 |
+
tgt_ids = F.pad(tgt_ids, (0, max_len - tgt_ids.size(1)), value=pad_id)
|
| 49 |
+
|
| 50 |
+
logits = transformer(pred_ids)
|
| 51 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), tgt_ids.view(-1), ignore_index=pad_id)
|
| 52 |
+
|
| 53 |
+
optimizer.zero_grad()
|
| 54 |
+
loss.backward()
|
| 55 |
+
optimizer.step()
|
| 56 |
+
|
| 57 |
+
_init_ledger().record(action="train_step", payload={"input": inp, "target": tgt, "prediction": pred_text, "loss": float(loss.item())})
|
| 58 |
+
total_loss += loss.item()
|
| 59 |
+
n_examples += 1
|
| 60 |
+
|
| 61 |
+
transformer.eval()
|
| 62 |
+
return total_loss / max(1, n_examples)
|
| 63 |
+
|
| 64 |
+
# ----------------------------------------------------------------------
|
| 65 |
+
# Singletons
|
| 66 |
+
# ----------------------------------------------------------------------
|
| 67 |
+
_transformer = _abstract_reasoner = _science_reasoner = _concept_graph = _free_energy = _ledger = _cybercore = None
|
| 68 |
+
|
| 69 |
+
def _init_transformer():
|
| 70 |
+
global _transformer
|
| 71 |
+
if _transformer is None:
|
| 72 |
+
_transformer = FluidTransformer.from_pretrained("models/fluid_transformer")
|
| 73 |
+
_transformer.eval()
|
| 74 |
+
return _transformer
|
| 75 |
+
|
| 76 |
+
def _init_abstract():
|
| 77 |
+
global _abstract_reasoner
|
| 78 |
+
if _abstract_reasoner is None: _abstract_reasoner = AbstractReasoner(chunk_size=8, top_k=6)
|
| 79 |
+
return _abstract_reasoner
|
| 80 |
+
|
| 81 |
+
def _init_science():
|
| 82 |
+
global _science_reasoner, _concept_graph
|
| 83 |
+
if _science_reasoner is None: _science_reasoner = ScienceReasoner(_init_graph())
|
| 84 |
+
return _science_reasoner
|
| 85 |
+
|
| 86 |
+
def _init_graph():
|
| 87 |
+
global _concept_graph
|
| 88 |
+
if _concept_graph is None: _concept_graph = ConceptGraph(dim=768, persist_dir="data/concept_graph")
|
| 89 |
+
return _concept_graph
|
| 90 |
+
|
| 91 |
+
def _init_free_energy():
|
| 92 |
+
global _free_energy
|
| 93 |
+
if _free_energy is None: _free_energy = FreeEnergyEngine()
|
| 94 |
+
return _free_energy
|
| 95 |
+
|
| 96 |
+
def _init_ledger():
|
| 97 |
+
global _ledger
|
| 98 |
+
if _ledger is None: _ledger = Ledger(chain_name="FSI_Sovereign")
|
| 99 |
+
return _ledger
|
| 100 |
+
|
| 101 |
+
def _init_cybercore():
|
| 102 |
+
global _cybercore
|
| 103 |
+
if _cybercore is None: _cybercore = CyberCore()
|
| 104 |
+
return _cybercore
|
| 105 |
+
|
| 106 |
+
def process_input(text: str, max_depth: int = 12):
|
| 107 |
+
if _init_cybercore().is_malicious(text): raise ValueError("Input rejected")
|
| 108 |
+
hidden = _init_transformer().encode(text).squeeze(0).cpu().numpy()
|
| 109 |
+
hidden /= np.linalg.norm(hidden)
|
| 110 |
+
propositions, raw_steps = _init_abstract().reason(hidden)
|
| 111 |
+
final_node = _init_science().infer(propositions, raw_steps, max_depth=max_depth)
|
| 112 |
+
|
| 113 |
+
involved_cids = {final_node.cid}
|
| 114 |
+
for src, tgt, _w in final_node.edges: involved_cids.add(tgt)
|
| 115 |
+
embeddings = [final_node.embedding] + [_init_graph().get_node(cid).embedding for cid in involved_cids if cid != final_node.cid]
|
| 116 |
+
|
| 117 |
+
loss = _init_free_energy().compute_loss(embeddings)
|
| 118 |
+
avg = sum(_init_ledger().get_recent_losses(100)) / (len(_init_ledger().get_recent_losses(100)) + 1e-9)
|
| 119 |
+
|
| 120 |
+
if loss <= avg * 1.10:
|
| 121 |
+
_init_ledger().record(action="commit", payload={"text": text, "final_node": final_node.cid, "label": final_node.label, "confidence": final_node.confidence, "free_energy": loss})
|
| 122 |
+
_init_graph().persist()
|
| 123 |
+
else:
|
| 124 |
+
_init_ledger().record(action="reject", payload={"text": text, "reason": "free_energy_too_high"})
|
| 125 |
+
raise RuntimeError("Knowledge rejected")
|
| 126 |
+
return final_node, raw_steps
|