mvi-ai-engine / app.py
Musombi's picture
Update app.py
5706906 verified
import os
import torch
import sentencepiece as spm
from fastapi import FastAPI, Query
from pydantic import BaseModel
from fastapi.responses import JSONResponse
# =========================================================
# CONFIG
# =========================================================
ARTIFACTS_DIR = "artifacts"
MODEL_PATH = os.path.join(ARTIFACTS_DIR, "responder.pt")
SPM_MODEL_PATH = os.path.join(ARTIFACTS_DIR, "spm.model")
TEXT_PATH = os.path.join(ARTIFACTS_DIR, "all_text.txt")
MAX_LEN = 64
DEVICE = "cpu"
app = FastAPI(title="Responder API", version="1.1")
# =========================================================
# HYBRID TOKENIZER
# =========================================================
class HybridTokenizer:
def __init__(self, sp_model_path):
self.sp = spm.SentencePieceProcessor()
self.sp.Load(sp_model_path)
self.vocab_size = self.sp.get_piece_size()
self.pad_id = self.sp.pad_id() if self.sp.pad_id() >= 0 else 0
self.unk_id = self.sp.unk_id() if self.sp.unk_id() >= 0 else 0
self.bos_id = self.sp.bos_id() if self.sp.bos_id() >= 0 else None
self.eos_id = self.sp.eos_id() if self.sp.eos_id() >= 0 else None
def encode(self, text):
ids = self.sp.encode(text, out_type=int)
if self.bos_id is not None:
ids = [self.bos_id] + ids
if self.eos_id is not None:
ids = ids + [self.eos_id]
ids = self._sanitize_ids(ids)
return ids
def decode(self, ids):
cleaned = []
for i in ids:
if i in {self.pad_id, self.bos_id}:
continue
if i == self.eos_id:
break
cleaned.append(i)
return self.sp.decode(cleaned)
def pad(self, ids, max_len):
if len(ids) < max_len:
return ids + [self.pad_id] * (max_len - len(ids))
return ids[:max_len]
def _sanitize_ids(self, ids):
return [i if 0 <= i < self.vocab_size else self.unk_id for i in ids]
# =========================================================
# LOAD MODEL
# =========================================================
def load_model(path):
if not os.path.exists(path):
raise FileNotFoundError(f"Model not found: {path}")
model = torch.load(path, map_location=DEVICE)
if isinstance(model, dict):
raise ValueError("responder.pt is a state_dict. Define architecture to load it.")
model.eval()
return model
# =========================================================
# GLOBALS
# =========================================================
tokenizer = HybridTokenizer(SPM_MODEL_PATH)
model = load_model(MODEL_PATH)
# =========================================================
# HEALTH CHECK
# =========================================================
@app.get("/health")
def health():
return {"status": "ok"}
# =========================================================
# SINGLE PREDICTION ENDPOINT
# =========================================================
class PredictRequest(BaseModel):
text: str
@app.get("/predict")
def predict_get(text: str = Query(..., description="Input text to generate response")):
try:
return {"response": predict(text)}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
@app.post("/predict")
def predict_post(request: PredictRequest):
try:
return {"response": predict(request.text)}
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
# =========================================================
# BATCH PREDICTIONS FROM FILE
# =========================================================
@app.get("/batch_predict")
def batch_predict():
if not os.path.exists(TEXT_PATH):
return JSONResponse(status_code=404, content={"error": "all_text.txt not found"})
with open(TEXT_PATH, "r", encoding="utf-8") as f:
lines = [line.strip() for line in f if line.strip()]
results = []
for line in lines:
try:
resp = predict(line)
results.append({"input": line, "response": resp})
except Exception as e:
results.append({"input": line, "response": None, "error": str(e)})
return {"count": len(results), "results": results}
# =========================================================
# INFERENCE HELPER
# =========================================================
def predict(text: str):
ids = tokenizer.encode(text)
ids = tokenizer.pad(ids, MAX_LEN)
input_tensor = torch.tensor([ids], dtype=torch.long).to(DEVICE)
with torch.no_grad():
output = model(input_tensor)
if isinstance(output, tuple):
output = output[0]
if output.dim() == 3:
pred_ids = torch.argmax(output, dim=-1)[0].tolist()
else:
pred_ids = torch.argmax(output, dim=-1).tolist()
return tokenizer.decode(pred_ids)
# =========================================================
# RUN WITH:
# uvicorn app:app --host 0.0.0.0 --port 8000
# =========================================================