""" bart_server.py — The Language Bridge Run with: uvicorn bart_server:app --host 0.0.0.0 --port 7860 Responsibilities: POST /encode : text → serialized hidden-state tensor (A) POST /decode : serialized hidden-state tensor → text POST /inject_generate : raw latent vector (C) → project → decode → text BART is an Encoder-Decoder, so both directions are possible. The Engine server calls /encode to get A, then calls /inject_generate to speak the settled truth back out through BART's decoder. """ from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch import torch.nn as nn from transformers import BartTokenizer, BartForConditionalGeneration from transformers.modeling_outputs import BaseModelOutput import json import time # ────────────────────────────────────────────────────────────── # 1. STARTUP: Load BART # ────────────────────────────────────────────────────────────── DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") BART_DIM = 1024 # bart-large hidden dimension LATENT_DIM = 2 # Perspective Engine output dimension (supply, demand) MAX_SEQ = 32 # decoder max tokens for generated explanation print(f"[BART SERVER] Loading facebook/bart-large on {DEVICE}...") tokenizer = BartTokenizer.from_pretrained("facebook/bart-large") bart = BartForConditionalGeneration.from_pretrained( "facebook/bart-large", torch_dtype=torch.float32 ).to(DEVICE).eval() print("[BART SERVER] Model ready.") # ────────────────────────────────────────────────────────────── # 2. PROJECTION LAYER (C → BART hidden space) # A small Linear that is saved alongside the engine weights. # Both servers share the same projection checkpoint. # ────────────────────────────────────────────────────────────── class LatentProjector(nn.Module): """Projects Engine output C [batch, LATENT_DIM] → BART hidden state [batch, 1, BART_DIM]. Trained in the engine script; loaded here for decoding.""" def __init__(self): super().__init__() self.net = nn.Sequential( nn.Linear(LATENT_DIM, 256), nn.GELU(), nn.Linear(256, BART_DIM) ) def forward(self, c): # c: [batch, LATENT_DIM] → [batch, 1, BART_DIM] return self.net(c).unsqueeze(1) projector = LatentProjector().to(DEVICE).eval() # Try to load pre-trained projection weights (saved by engine_server after training) try: projector.load_state_dict(torch.load("projector.pt", map_location=DEVICE)) print("[BART SERVER] Projector weights loaded from projector.pt") except FileNotFoundError: print("[BART SERVER] WARNING: projector.pt not found. /inject_generate will use untrained projection.") print(" Run engine_server.py training first, then projector.pt will be auto-saved.") # ────────────────────────────────────────────────────────────── # 3. HELPER FUNCTIONS # ────────────────────────────────────────────────────────────── def tensor_to_list(t: torch.Tensor) -> list: return t.detach().cpu().tolist() def list_to_tensor(data: list) -> torch.Tensor: return torch.tensor(data, dtype=torch.float32, device=DEVICE) def mean_pool(hidden_state: torch.Tensor) -> torch.Tensor: """Collapses [batch, seq_len, 1024] → [batch, 1024].""" return hidden_state.mean(dim=1) # ────────────────────────────────────────────────────────────── # 4. API MODELS # ────────────────────────────────────────────────────────────── app = FastAPI(title="BART Language Bridge", version="1.0") class TextRequest(BaseModel): text: str class TensorRequest(BaseModel): # Serialized as nested list: shape [batch, seq_len, hidden] or [batch, hidden] tensor: list shape: list[int] class LatentRequest(BaseModel): # The raw output from the Perspective Engine before decode latent_vector: list[float] # shape [LATENT_DIM] e.g. [supply_norm, demand_norm] context_prompt: str = "The market analysis reveals:" # Primes the decoder class DecodeResponse(BaseModel): text: str decode_time_ms: float # ────────────────────────────────────────────────────────────── # 5. ENDPOINTS # ────────────────────────────────────────────────────────────── @app.post("/encode") def encode_text(req: TextRequest): """ Text → BART encoder hidden state (the 'A' constraint vector). Returns: - full_hidden: [1, seq_len, 1024] (full sequence, for engine conditioning) - pooled: [1, 1024] (mean-pooled single vector) """ inputs = tokenizer( req.text, return_tensors="pt", padding=True, truncation=True, max_length=64 ).to(DEVICE) with torch.no_grad(): enc_out = bart.model.encoder(**inputs) hidden = enc_out.last_hidden_state # [1, seq_len, 1024] pooled = mean_pool(hidden) # [1, 1024] return { "full_hidden": tensor_to_list(hidden), "full_shape": list(hidden.shape), "pooled": tensor_to_list(pooled), "pooled_shape": list(pooled.shape), } @app.post("/decode", response_model=DecodeResponse) def decode_hidden(req: TensorRequest): """ Arbitrary hidden-state tensor → text via BART decoder. Pass in a tensor of shape [1, seq_len, 1024]. Used to verify that a round-trip encode→decode works. """ t0 = time.time() hidden = list_to_tensor(req.tensor).reshape(req.shape).to(DEVICE) enc_wrapped = BaseModelOutput(last_hidden_state=hidden) with torch.no_grad(): gen_ids = bart.generate( encoder_outputs=enc_wrapped, max_new_tokens=MAX_SEQ, num_beams=4, early_stopping=True, ) text = tokenizer.decode(gen_ids[0], skip_special_tokens=True) return {"text": text, "decode_time_ms": (time.time() - t0) * 1000} @app.post("/inject_generate", response_model=DecodeResponse) def inject_and_generate(req: LatentRequest): """ THE BYPASS ENDPOINT. Takes raw Engine output C (a small latent vector), projects it into BART's 1024-dim hidden space, and lets BART's decoder "speak" the logical truth. This is the Latent-to-Language bridge. """ t0 = time.time() c_tensor = torch.tensor( [req.latent_vector], dtype=torch.float32, device=DEVICE ) # [1, LATENT_DIM] with torch.no_grad(): # Project C → [1, 1, 1024] projected = projector(c_tensor) # [1, 1, 1024] # Optional: prime the decoder with a context token sequence # so BART "knows" what kind of sentence to generate prime_ids = tokenizer( req.context_prompt, return_tensors="pt" ).input_ids.to(DEVICE) enc_wrapped = BaseModelOutput(last_hidden_state=projected) gen_ids = bart.generate( encoder_outputs=enc_wrapped, decoder_input_ids=prime_ids, max_new_tokens=MAX_SEQ, num_beams=4, early_stopping=True, ) text = tokenizer.decode(gen_ids[0], skip_special_tokens=True) return {"text": text, "decode_time_ms": (time.time() - t0) * 1000} @app.get("/health") def health(): return { "status": "online", "device": str(DEVICE), "projector_loaded": all( p.sum().item() != 0 for p in projector.parameters() ) } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)