Spaces:
Sleeping
Sleeping
File size: 8,836 Bytes
f1263de 3a82207 f1263de da60d06 f1263de 08c1bd3 f1263de 04701d7 f1263de 63b82b4 f1263de 08c1bd3 da60d06 f1263de 3a82207 da60d06 f1263de da60d06 f1263de 7ec13e1 f1263de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 | """
bart_server.py β The Language Bridge
Run with: uvicorn bart_server:app --host 0.0.0.0 --port 7860
Responsibilities:
POST /encode : text β serialized hidden-state tensor (A)
POST /decode : serialized hidden-state tensor β text
POST /inject_generate : raw latent vector (C) β project β decode β text
BART is an Encoder-Decoder, so both directions are possible.
The Engine server calls /encode to get A, then calls /inject_generate
to speak the settled truth back out through BART's decoder.
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import torch.nn as nn
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
import json
import time
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1. STARTUP: Load BART
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BART_DIM = 1024 # bart-large hidden dimension
LATENT_DIM = 2 # Perspective Engine output dimension (supply, demand)
MAX_SEQ = 32 # decoder max tokens for generated explanation
print(f"[BART SERVER] Loading facebook/bart-large on {DEVICE}...")
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
bart = BartForConditionalGeneration.from_pretrained(
"facebook/bart-large",
torch_dtype=torch.float32
).to(DEVICE).eval()
print("[BART SERVER] Model ready.")
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2. PROJECTION LAYER (C β BART hidden space)
# A small Linear that is saved alongside the engine weights.
# Both servers share the same projection checkpoint.
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class LatentProjector(nn.Module):
"""Projects Engine output C [batch, LATENT_DIM]
β BART hidden state [batch, 1, BART_DIM].
Trained in the engine script; loaded here for decoding."""
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(LATENT_DIM, 256),
nn.GELU(),
nn.Linear(256, BART_DIM)
)
def forward(self, c):
# c: [batch, LATENT_DIM] β [batch, 1, BART_DIM]
return self.net(c).unsqueeze(1)
projector = LatentProjector().to(DEVICE).eval()
# Try to load pre-trained projection weights (saved by engine_server after training)
try:
projector.load_state_dict(torch.load("projector.pt", map_location=DEVICE))
print("[BART SERVER] Projector weights loaded from projector.pt")
except FileNotFoundError:
print("[BART SERVER] WARNING: projector.pt not found. /inject_generate will use untrained projection.")
print(" Run engine_server.py training first, then projector.pt will be auto-saved.")
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3. HELPER FUNCTIONS
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def tensor_to_list(t: torch.Tensor) -> list:
return t.detach().cpu().tolist()
def list_to_tensor(data: list) -> torch.Tensor:
return torch.tensor(data, dtype=torch.float32, device=DEVICE)
def mean_pool(hidden_state: torch.Tensor) -> torch.Tensor:
"""Collapses [batch, seq_len, 1024] β [batch, 1024]."""
return hidden_state.mean(dim=1)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4. API MODELS
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = FastAPI(title="BART Language Bridge", version="1.0")
class TextRequest(BaseModel):
text: str
class TensorRequest(BaseModel):
# Serialized as nested list: shape [batch, seq_len, hidden] or [batch, hidden]
tensor: list
shape: list[int]
class LatentRequest(BaseModel):
# The raw output from the Perspective Engine before decode
latent_vector: list[float] # shape [LATENT_DIM] e.g. [supply_norm, demand_norm]
context_prompt: str = "The market analysis reveals:" # Primes the decoder
class DecodeResponse(BaseModel):
text: str
decode_time_ms: float
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 5. ENDPOINTS
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@app.post("/encode")
def encode_text(req: TextRequest):
"""
Text β BART encoder hidden state (the 'A' constraint vector).
Returns:
- full_hidden: [1, seq_len, 1024] (full sequence, for engine conditioning)
- pooled: [1, 1024] (mean-pooled single vector)
"""
inputs = tokenizer(
req.text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=64
).to(DEVICE)
with torch.no_grad():
enc_out = bart.model.encoder(**inputs)
hidden = enc_out.last_hidden_state # [1, seq_len, 1024]
pooled = mean_pool(hidden) # [1, 1024]
return {
"full_hidden": tensor_to_list(hidden),
"full_shape": list(hidden.shape),
"pooled": tensor_to_list(pooled),
"pooled_shape": list(pooled.shape),
}
@app.post("/decode", response_model=DecodeResponse)
def decode_hidden(req: TensorRequest):
"""
Arbitrary hidden-state tensor β text via BART decoder.
Pass in a tensor of shape [1, seq_len, 1024].
Used to verify that a round-trip encodeβdecode works.
"""
t0 = time.time()
hidden = list_to_tensor(req.tensor).reshape(req.shape).to(DEVICE)
enc_wrapped = BaseModelOutput(last_hidden_state=hidden)
with torch.no_grad():
gen_ids = bart.generate(
encoder_outputs=enc_wrapped,
max_new_tokens=MAX_SEQ,
num_beams=4,
early_stopping=True,
)
text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
return {"text": text, "decode_time_ms": (time.time() - t0) * 1000}
@app.post("/inject_generate", response_model=DecodeResponse)
def inject_and_generate(req: LatentRequest):
"""
THE BYPASS ENDPOINT.
Takes raw Engine output C (a small latent vector),
projects it into BART's 1024-dim hidden space,
and lets BART's decoder "speak" the logical truth.
This is the Latent-to-Language bridge.
"""
t0 = time.time()
c_tensor = torch.tensor(
[req.latent_vector], dtype=torch.float32, device=DEVICE
) # [1, LATENT_DIM]
with torch.no_grad():
# Project C β [1, 1, 1024]
projected = projector(c_tensor) # [1, 1, 1024]
# Optional: prime the decoder with a context token sequence
# so BART "knows" what kind of sentence to generate
prime_ids = tokenizer(
req.context_prompt,
return_tensors="pt"
).input_ids.to(DEVICE)
enc_wrapped = BaseModelOutput(last_hidden_state=projected)
gen_ids = bart.generate(
encoder_outputs=enc_wrapped,
decoder_input_ids=prime_ids,
max_new_tokens=MAX_SEQ,
num_beams=4,
early_stopping=True,
)
text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
return {"text": text, "decode_time_ms": (time.time() - t0) * 1000}
@app.get("/health")
def health():
return {
"status": "online",
"device": str(DEVICE),
"projector_loaded": all(
p.sum().item() != 0 for p in projector.parameters()
)
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |