Spaces:
Sleeping
Sleeping
| """ | |
| bart_server.py β The Language Bridge | |
| Run with: uvicorn bart_server:app --host 0.0.0.0 --port 7860 | |
| Responsibilities: | |
| POST /encode : text β serialized hidden-state tensor (A) | |
| POST /decode : serialized hidden-state tensor β text | |
| POST /inject_generate : raw latent vector (C) β project β decode β text | |
| BART is an Encoder-Decoder, so both directions are possible. | |
| The Engine server calls /encode to get A, then calls /inject_generate | |
| to speak the settled truth back out through BART's decoder. | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import torch | |
| import torch.nn as nn | |
| from transformers import BartTokenizer, BartForConditionalGeneration | |
| from transformers.modeling_outputs import BaseModelOutput | |
| import json | |
| import time | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. STARTUP: Load BART | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| BART_DIM = 1024 # bart-large hidden dimension | |
| LATENT_DIM = 2 # Perspective Engine output dimension (supply, demand) | |
| MAX_SEQ = 32 # decoder max tokens for generated explanation | |
| print(f"[BART SERVER] Loading facebook/bart-large on {DEVICE}...") | |
| tokenizer = BartTokenizer.from_pretrained("facebook/bart-large") | |
| bart = BartForConditionalGeneration.from_pretrained( | |
| "facebook/bart-large", | |
| torch_dtype=torch.float32 | |
| ).to(DEVICE).eval() | |
| print("[BART SERVER] Model ready.") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. PROJECTION LAYER (C β BART hidden space) | |
| # A small Linear that is saved alongside the engine weights. | |
| # Both servers share the same projection checkpoint. | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class LatentProjector(nn.Module): | |
| """Projects Engine output C [batch, LATENT_DIM] | |
| β BART hidden state [batch, 1, BART_DIM]. | |
| Trained in the engine script; loaded here for decoding.""" | |
| def __init__(self): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(LATENT_DIM, 256), | |
| nn.GELU(), | |
| nn.Linear(256, BART_DIM) | |
| ) | |
| def forward(self, c): | |
| # c: [batch, LATENT_DIM] β [batch, 1, BART_DIM] | |
| return self.net(c).unsqueeze(1) | |
| projector = LatentProjector().to(DEVICE).eval() | |
| # Try to load pre-trained projection weights (saved by engine_server after training) | |
| try: | |
| projector.load_state_dict(torch.load("projector.pt", map_location=DEVICE)) | |
| print("[BART SERVER] Projector weights loaded from projector.pt") | |
| except FileNotFoundError: | |
| print("[BART SERVER] WARNING: projector.pt not found. /inject_generate will use untrained projection.") | |
| print(" Run engine_server.py training first, then projector.pt will be auto-saved.") | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. HELPER FUNCTIONS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def tensor_to_list(t: torch.Tensor) -> list: | |
| return t.detach().cpu().tolist() | |
| def list_to_tensor(data: list) -> torch.Tensor: | |
| return torch.tensor(data, dtype=torch.float32, device=DEVICE) | |
| def mean_pool(hidden_state: torch.Tensor) -> torch.Tensor: | |
| """Collapses [batch, seq_len, 1024] β [batch, 1024].""" | |
| return hidden_state.mean(dim=1) | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. API MODELS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI(title="BART Language Bridge", version="1.0") | |
| class TextRequest(BaseModel): | |
| text: str | |
| class TensorRequest(BaseModel): | |
| # Serialized as nested list: shape [batch, seq_len, hidden] or [batch, hidden] | |
| tensor: list | |
| shape: list[int] | |
| class LatentRequest(BaseModel): | |
| # The raw output from the Perspective Engine before decode | |
| latent_vector: list[float] # shape [LATENT_DIM] e.g. [supply_norm, demand_norm] | |
| context_prompt: str = "The market analysis reveals:" # Primes the decoder | |
| class DecodeResponse(BaseModel): | |
| text: str | |
| decode_time_ms: float | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. ENDPOINTS | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def encode_text(req: TextRequest): | |
| """ | |
| Text β BART encoder hidden state (the 'A' constraint vector). | |
| Returns: | |
| - full_hidden: [1, seq_len, 1024] (full sequence, for engine conditioning) | |
| - pooled: [1, 1024] (mean-pooled single vector) | |
| """ | |
| inputs = tokenizer( | |
| req.text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=64 | |
| ).to(DEVICE) | |
| with torch.no_grad(): | |
| enc_out = bart.model.encoder(**inputs) | |
| hidden = enc_out.last_hidden_state # [1, seq_len, 1024] | |
| pooled = mean_pool(hidden) # [1, 1024] | |
| return { | |
| "full_hidden": tensor_to_list(hidden), | |
| "full_shape": list(hidden.shape), | |
| "pooled": tensor_to_list(pooled), | |
| "pooled_shape": list(pooled.shape), | |
| } | |
| def decode_hidden(req: TensorRequest): | |
| """ | |
| Arbitrary hidden-state tensor β text via BART decoder. | |
| Pass in a tensor of shape [1, seq_len, 1024]. | |
| Used to verify that a round-trip encodeβdecode works. | |
| """ | |
| t0 = time.time() | |
| hidden = list_to_tensor(req.tensor).reshape(req.shape).to(DEVICE) | |
| enc_wrapped = BaseModelOutput(last_hidden_state=hidden) | |
| with torch.no_grad(): | |
| gen_ids = bart.generate( | |
| encoder_outputs=enc_wrapped, | |
| max_new_tokens=MAX_SEQ, | |
| num_beams=4, | |
| early_stopping=True, | |
| ) | |
| text = tokenizer.decode(gen_ids[0], skip_special_tokens=True) | |
| return {"text": text, "decode_time_ms": (time.time() - t0) * 1000} | |
| def inject_and_generate(req: LatentRequest): | |
| """ | |
| THE BYPASS ENDPOINT. | |
| Takes raw Engine output C (a small latent vector), | |
| projects it into BART's 1024-dim hidden space, | |
| and lets BART's decoder "speak" the logical truth. | |
| This is the Latent-to-Language bridge. | |
| """ | |
| t0 = time.time() | |
| c_tensor = torch.tensor( | |
| [req.latent_vector], dtype=torch.float32, device=DEVICE | |
| ) # [1, LATENT_DIM] | |
| with torch.no_grad(): | |
| # Project C β [1, 1, 1024] | |
| projected = projector(c_tensor) # [1, 1, 1024] | |
| # Optional: prime the decoder with a context token sequence | |
| # so BART "knows" what kind of sentence to generate | |
| prime_ids = tokenizer( | |
| req.context_prompt, | |
| return_tensors="pt" | |
| ).input_ids.to(DEVICE) | |
| enc_wrapped = BaseModelOutput(last_hidden_state=projected) | |
| gen_ids = bart.generate( | |
| encoder_outputs=enc_wrapped, | |
| decoder_input_ids=prime_ids, | |
| max_new_tokens=MAX_SEQ, | |
| num_beams=4, | |
| early_stopping=True, | |
| ) | |
| text = tokenizer.decode(gen_ids[0], skip_special_tokens=True) | |
| return {"text": text, "decode_time_ms": (time.time() - t0) * 1000} | |
| def health(): | |
| return { | |
| "status": "online", | |
| "device": str(DEVICE), | |
| "projector_loaded": all( | |
| p.sum().item() != 0 for p in projector.parameters() | |
| ) | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |