import os import torch import sentencepiece as spm from fastapi import FastAPI, Query from pydantic import BaseModel from fastapi.responses import JSONResponse # ========================================================= # CONFIG # ========================================================= ARTIFACTS_DIR = "artifacts" MODEL_PATH = os.path.join(ARTIFACTS_DIR, "responder.pt") SPM_MODEL_PATH = os.path.join(ARTIFACTS_DIR, "spm.model") TEXT_PATH = os.path.join(ARTIFACTS_DIR, "all_text.txt") MAX_LEN = 64 DEVICE = "cpu" app = FastAPI(title="Responder API", version="1.1") # ========================================================= # HYBRID TOKENIZER # ========================================================= class HybridTokenizer: def __init__(self, sp_model_path): self.sp = spm.SentencePieceProcessor() self.sp.Load(sp_model_path) self.vocab_size = self.sp.get_piece_size() self.pad_id = self.sp.pad_id() if self.sp.pad_id() >= 0 else 0 self.unk_id = self.sp.unk_id() if self.sp.unk_id() >= 0 else 0 self.bos_id = self.sp.bos_id() if self.sp.bos_id() >= 0 else None self.eos_id = self.sp.eos_id() if self.sp.eos_id() >= 0 else None def encode(self, text): ids = self.sp.encode(text, out_type=int) if self.bos_id is not None: ids = [self.bos_id] + ids if self.eos_id is not None: ids = ids + [self.eos_id] ids = self._sanitize_ids(ids) return ids def decode(self, ids): cleaned = [] for i in ids: if i in {self.pad_id, self.bos_id}: continue if i == self.eos_id: break cleaned.append(i) return self.sp.decode(cleaned) def pad(self, ids, max_len): if len(ids) < max_len: return ids + [self.pad_id] * (max_len - len(ids)) return ids[:max_len] def _sanitize_ids(self, ids): return [i if 0 <= i < self.vocab_size else self.unk_id for i in ids] # ========================================================= # LOAD MODEL # ========================================================= def load_model(path): if not os.path.exists(path): raise FileNotFoundError(f"Model not found: {path}") model = torch.load(path, map_location=DEVICE) if isinstance(model, dict): raise ValueError("responder.pt is a state_dict. Define architecture to load it.") model.eval() return model # ========================================================= # GLOBALS # ========================================================= tokenizer = HybridTokenizer(SPM_MODEL_PATH) model = load_model(MODEL_PATH) # ========================================================= # HEALTH CHECK # ========================================================= @app.get("/health") def health(): return {"status": "ok"} # ========================================================= # SINGLE PREDICTION ENDPOINT # ========================================================= class PredictRequest(BaseModel): text: str @app.get("/predict") def predict_get(text: str = Query(..., description="Input text to generate response")): try: return {"response": predict(text)} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) @app.post("/predict") def predict_post(request: PredictRequest): try: return {"response": predict(request.text)} except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) # ========================================================= # BATCH PREDICTIONS FROM FILE # ========================================================= @app.get("/batch_predict") def batch_predict(): if not os.path.exists(TEXT_PATH): return JSONResponse(status_code=404, content={"error": "all_text.txt not found"}) with open(TEXT_PATH, "r", encoding="utf-8") as f: lines = [line.strip() for line in f if line.strip()] results = [] for line in lines: try: resp = predict(line) results.append({"input": line, "response": resp}) except Exception as e: results.append({"input": line, "response": None, "error": str(e)}) return {"count": len(results), "results": results} # ========================================================= # INFERENCE HELPER # ========================================================= def predict(text: str): ids = tokenizer.encode(text) ids = tokenizer.pad(ids, MAX_LEN) input_tensor = torch.tensor([ids], dtype=torch.long).to(DEVICE) with torch.no_grad(): output = model(input_tensor) if isinstance(output, tuple): output = output[0] if output.dim() == 3: pred_ids = torch.argmax(output, dim=-1)[0].tolist() else: pred_ids = torch.argmax(output, dim=-1).tolist() return tokenizer.decode(pred_ids) # ========================================================= # RUN WITH: # uvicorn app:app --host 0.0.0.0 --port 8000 # =========================================================