Embeddings_e / app.py
everydaytok's picture
Update app.py
f1263de verified
"""
bart_server.py β€” The Language Bridge
Run with: uvicorn bart_server:app --host 0.0.0.0 --port 7860
Responsibilities:
POST /encode : text β†’ serialized hidden-state tensor (A)
POST /decode : serialized hidden-state tensor β†’ text
POST /inject_generate : raw latent vector (C) β†’ project β†’ decode β†’ text
BART is an Encoder-Decoder, so both directions are possible.
The Engine server calls /encode to get A, then calls /inject_generate
to speak the settled truth back out through BART's decoder.
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import torch.nn as nn
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
import json
import time
# ──────────────────────────────────────────────────────────────
# 1. STARTUP: Load BART
# ──────────────────────────────────────────────────────────────
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BART_DIM = 1024 # bart-large hidden dimension
LATENT_DIM = 2 # Perspective Engine output dimension (supply, demand)
MAX_SEQ = 32 # decoder max tokens for generated explanation
print(f"[BART SERVER] Loading facebook/bart-large on {DEVICE}...")
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
bart = BartForConditionalGeneration.from_pretrained(
"facebook/bart-large",
torch_dtype=torch.float32
).to(DEVICE).eval()
print("[BART SERVER] Model ready.")
# ──────────────────────────────────────────────────────────────
# 2. PROJECTION LAYER (C β†’ BART hidden space)
# A small Linear that is saved alongside the engine weights.
# Both servers share the same projection checkpoint.
# ──────────────────────────────────────────────────────────────
class LatentProjector(nn.Module):
"""Projects Engine output C [batch, LATENT_DIM]
β†’ BART hidden state [batch, 1, BART_DIM].
Trained in the engine script; loaded here for decoding."""
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(LATENT_DIM, 256),
nn.GELU(),
nn.Linear(256, BART_DIM)
)
def forward(self, c):
# c: [batch, LATENT_DIM] β†’ [batch, 1, BART_DIM]
return self.net(c).unsqueeze(1)
projector = LatentProjector().to(DEVICE).eval()
# Try to load pre-trained projection weights (saved by engine_server after training)
try:
projector.load_state_dict(torch.load("projector.pt", map_location=DEVICE))
print("[BART SERVER] Projector weights loaded from projector.pt")
except FileNotFoundError:
print("[BART SERVER] WARNING: projector.pt not found. /inject_generate will use untrained projection.")
print(" Run engine_server.py training first, then projector.pt will be auto-saved.")
# ──────────────────────────────────────────────────────────────
# 3. HELPER FUNCTIONS
# ──────────────────────────────────────────────────────────────
def tensor_to_list(t: torch.Tensor) -> list:
return t.detach().cpu().tolist()
def list_to_tensor(data: list) -> torch.Tensor:
return torch.tensor(data, dtype=torch.float32, device=DEVICE)
def mean_pool(hidden_state: torch.Tensor) -> torch.Tensor:
"""Collapses [batch, seq_len, 1024] β†’ [batch, 1024]."""
return hidden_state.mean(dim=1)
# ──────────────────────────────────────────────────────────────
# 4. API MODELS
# ──────────────────────────────────────────────────────────────
app = FastAPI(title="BART Language Bridge", version="1.0")
class TextRequest(BaseModel):
text: str
class TensorRequest(BaseModel):
# Serialized as nested list: shape [batch, seq_len, hidden] or [batch, hidden]
tensor: list
shape: list[int]
class LatentRequest(BaseModel):
# The raw output from the Perspective Engine before decode
latent_vector: list[float] # shape [LATENT_DIM] e.g. [supply_norm, demand_norm]
context_prompt: str = "The market analysis reveals:" # Primes the decoder
class DecodeResponse(BaseModel):
text: str
decode_time_ms: float
# ──────────────────────────────────────────────────────────────
# 5. ENDPOINTS
# ──────────────────────────────────────────────────────────────
@app.post("/encode")
def encode_text(req: TextRequest):
"""
Text β†’ BART encoder hidden state (the 'A' constraint vector).
Returns:
- full_hidden: [1, seq_len, 1024] (full sequence, for engine conditioning)
- pooled: [1, 1024] (mean-pooled single vector)
"""
inputs = tokenizer(
req.text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=64
).to(DEVICE)
with torch.no_grad():
enc_out = bart.model.encoder(**inputs)
hidden = enc_out.last_hidden_state # [1, seq_len, 1024]
pooled = mean_pool(hidden) # [1, 1024]
return {
"full_hidden": tensor_to_list(hidden),
"full_shape": list(hidden.shape),
"pooled": tensor_to_list(pooled),
"pooled_shape": list(pooled.shape),
}
@app.post("/decode", response_model=DecodeResponse)
def decode_hidden(req: TensorRequest):
"""
Arbitrary hidden-state tensor β†’ text via BART decoder.
Pass in a tensor of shape [1, seq_len, 1024].
Used to verify that a round-trip encode→decode works.
"""
t0 = time.time()
hidden = list_to_tensor(req.tensor).reshape(req.shape).to(DEVICE)
enc_wrapped = BaseModelOutput(last_hidden_state=hidden)
with torch.no_grad():
gen_ids = bart.generate(
encoder_outputs=enc_wrapped,
max_new_tokens=MAX_SEQ,
num_beams=4,
early_stopping=True,
)
text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
return {"text": text, "decode_time_ms": (time.time() - t0) * 1000}
@app.post("/inject_generate", response_model=DecodeResponse)
def inject_and_generate(req: LatentRequest):
"""
THE BYPASS ENDPOINT.
Takes raw Engine output C (a small latent vector),
projects it into BART's 1024-dim hidden space,
and lets BART's decoder "speak" the logical truth.
This is the Latent-to-Language bridge.
"""
t0 = time.time()
c_tensor = torch.tensor(
[req.latent_vector], dtype=torch.float32, device=DEVICE
) # [1, LATENT_DIM]
with torch.no_grad():
# Project C β†’ [1, 1, 1024]
projected = projector(c_tensor) # [1, 1, 1024]
# Optional: prime the decoder with a context token sequence
# so BART "knows" what kind of sentence to generate
prime_ids = tokenizer(
req.context_prompt,
return_tensors="pt"
).input_ids.to(DEVICE)
enc_wrapped = BaseModelOutput(last_hidden_state=projected)
gen_ids = bart.generate(
encoder_outputs=enc_wrapped,
decoder_input_ids=prime_ids,
max_new_tokens=MAX_SEQ,
num_beams=4,
early_stopping=True,
)
text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
return {"text": text, "decode_time_ms": (time.time() - t0) * 1000}
@app.get("/health")
def health():
return {
"status": "online",
"device": str(DEVICE),
"projector_loaded": all(
p.sum().item() != 0 for p in projector.parameters()
)
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)