File size: 8,836 Bytes
f1263de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a82207
f1263de
da60d06
 
f1263de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08c1bd3
f1263de
 
 
04701d7
f1263de
 
63b82b4
f1263de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08c1bd3
da60d06
f1263de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a82207
da60d06
f1263de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da60d06
f1263de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ec13e1
 
f1263de
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""
bart_server.py  β€”  The Language Bridge
Run with: uvicorn bart_server:app --host 0.0.0.0 --port 7860

Responsibilities:
  POST /encode            : text  β†’ serialized hidden-state tensor (A)
  POST /decode            : serialized hidden-state tensor β†’ text
  POST /inject_generate   : raw latent vector (C) β†’ project β†’ decode β†’ text

BART is an Encoder-Decoder, so both directions are possible.
The Engine server calls /encode to get A, then calls /inject_generate
to speak the settled truth back out through BART's decoder.
"""

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import torch.nn as nn
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
import json
import time

# ──────────────────────────────────────────────────────────────
# 1. STARTUP: Load BART
# ──────────────────────────────────────────────────────────────

DEVICE     = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BART_DIM   = 1024          # bart-large hidden dimension
LATENT_DIM = 2             # Perspective Engine output dimension (supply, demand)
MAX_SEQ    = 32            # decoder max tokens for generated explanation

print(f"[BART SERVER] Loading facebook/bart-large on {DEVICE}...")
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large")
bart      = BartForConditionalGeneration.from_pretrained(
                "facebook/bart-large",
                torch_dtype=torch.float32
            ).to(DEVICE).eval()
print("[BART SERVER] Model ready.")

# ──────────────────────────────────────────────────────────────
# 2. PROJECTION LAYER  (C β†’ BART hidden space)
#    A small Linear that is saved alongside the engine weights.
#    Both servers share the same projection checkpoint.
# ──────────────────────────────────────────────────────────────

class LatentProjector(nn.Module):
    """Projects Engine output C [batch, LATENT_DIM]
       β†’ BART hidden state  [batch, 1, BART_DIM].
       Trained in the engine script; loaded here for decoding."""
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(LATENT_DIM, 256),
            nn.GELU(),
            nn.Linear(256, BART_DIM)
        )
    def forward(self, c):
        # c: [batch, LATENT_DIM]  β†’  [batch, 1, BART_DIM]
        return self.net(c).unsqueeze(1)

projector = LatentProjector().to(DEVICE).eval()

# Try to load pre-trained projection weights (saved by engine_server after training)
try:
    projector.load_state_dict(torch.load("projector.pt", map_location=DEVICE))
    print("[BART SERVER] Projector weights loaded from projector.pt")
except FileNotFoundError:
    print("[BART SERVER] WARNING: projector.pt not found. /inject_generate will use untrained projection.")
    print("              Run engine_server.py training first, then projector.pt will be auto-saved.")

# ──────────────────────────────────────────────────────────────
# 3. HELPER FUNCTIONS
# ──────────────────────────────────────────────────────────────

def tensor_to_list(t: torch.Tensor) -> list:
    return t.detach().cpu().tolist()

def list_to_tensor(data: list) -> torch.Tensor:
    return torch.tensor(data, dtype=torch.float32, device=DEVICE)

def mean_pool(hidden_state: torch.Tensor) -> torch.Tensor:
    """Collapses [batch, seq_len, 1024] β†’ [batch, 1024]."""
    return hidden_state.mean(dim=1)

# ──────────────────────────────────────────────────────────────
# 4. API MODELS
# ──────────────────────────────────────────────────────────────

app = FastAPI(title="BART Language Bridge", version="1.0")

class TextRequest(BaseModel):
    text: str

class TensorRequest(BaseModel):
    # Serialized as nested list: shape [batch, seq_len, hidden] or [batch, hidden]
    tensor: list
    shape:  list[int]

class LatentRequest(BaseModel):
    # The raw output from the Perspective Engine before decode
    latent_vector: list[float]   # shape [LATENT_DIM]  e.g. [supply_norm, demand_norm]
    context_prompt: str = "The market analysis reveals:"  # Primes the decoder

class DecodeResponse(BaseModel):
    text: str
    decode_time_ms: float

# ──────────────────────────────────────────────────────────────
# 5. ENDPOINTS
# ──────────────────────────────────────────────────────────────

@app.post("/encode")
def encode_text(req: TextRequest):
    """
    Text β†’ BART encoder hidden state (the 'A' constraint vector).
    Returns:
      - full_hidden: [1, seq_len, 1024]  (full sequence, for engine conditioning)
      - pooled:      [1, 1024]           (mean-pooled single vector)
    """
    inputs = tokenizer(
        req.text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=64
    ).to(DEVICE)

    with torch.no_grad():
        enc_out = bart.model.encoder(**inputs)

    hidden = enc_out.last_hidden_state          # [1, seq_len, 1024]
    pooled = mean_pool(hidden)                  # [1, 1024]

    return {
        "full_hidden": tensor_to_list(hidden),
        "full_shape":  list(hidden.shape),
        "pooled":      tensor_to_list(pooled),
        "pooled_shape": list(pooled.shape),
    }


@app.post("/decode", response_model=DecodeResponse)
def decode_hidden(req: TensorRequest):
    """
    Arbitrary hidden-state tensor β†’ text via BART decoder.
    Pass in a tensor of shape [1, seq_len, 1024].
    Used to verify that a round-trip encode→decode works.
    """
    t0 = time.time()
    hidden = list_to_tensor(req.tensor).reshape(req.shape).to(DEVICE)

    enc_wrapped = BaseModelOutput(last_hidden_state=hidden)

    with torch.no_grad():
        gen_ids = bart.generate(
            encoder_outputs=enc_wrapped,
            max_new_tokens=MAX_SEQ,
            num_beams=4,
            early_stopping=True,
        )

    text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
    return {"text": text, "decode_time_ms": (time.time() - t0) * 1000}


@app.post("/inject_generate", response_model=DecodeResponse)
def inject_and_generate(req: LatentRequest):
    """
    THE BYPASS ENDPOINT.
    Takes raw Engine output C (a small latent vector),
    projects it into BART's 1024-dim hidden space,
    and lets BART's decoder "speak" the logical truth.

    This is the Latent-to-Language bridge.
    """
    t0 = time.time()

    c_tensor = torch.tensor(
        [req.latent_vector], dtype=torch.float32, device=DEVICE
    )  # [1, LATENT_DIM]

    with torch.no_grad():
        # Project C β†’ [1, 1, 1024]
        projected = projector(c_tensor)        # [1, 1, 1024]

        # Optional: prime the decoder with a context token sequence
        # so BART "knows" what kind of sentence to generate
        prime_ids = tokenizer(
            req.context_prompt,
            return_tensors="pt"
        ).input_ids.to(DEVICE)

        enc_wrapped = BaseModelOutput(last_hidden_state=projected)

        gen_ids = bart.generate(
            encoder_outputs=enc_wrapped,
            decoder_input_ids=prime_ids,
            max_new_tokens=MAX_SEQ,
            num_beams=4,
            early_stopping=True,
        )

    text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
    return {"text": text, "decode_time_ms": (time.time() - t0) * 1000}


@app.get("/health")
def health():
    return {
        "status": "online",
        "device": str(DEVICE),
        "projector_loaded": all(
            p.sum().item() != 0 for p in projector.parameters()
        )
    }


if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)