Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import sentencepiece as spm | |
| from fastapi import FastAPI, Query | |
| from pydantic import BaseModel | |
| from fastapi.responses import JSONResponse | |
| # ========================================================= | |
| # CONFIG | |
| # ========================================================= | |
| ARTIFACTS_DIR = "artifacts" | |
| MODEL_PATH = os.path.join(ARTIFACTS_DIR, "responder.pt") | |
| SPM_MODEL_PATH = os.path.join(ARTIFACTS_DIR, "spm.model") | |
| TEXT_PATH = os.path.join(ARTIFACTS_DIR, "all_text.txt") | |
| MAX_LEN = 64 | |
| DEVICE = "cpu" | |
| app = FastAPI(title="Responder API", version="1.1") | |
| # ========================================================= | |
| # HYBRID TOKENIZER | |
| # ========================================================= | |
| class HybridTokenizer: | |
| def __init__(self, sp_model_path): | |
| self.sp = spm.SentencePieceProcessor() | |
| self.sp.Load(sp_model_path) | |
| self.vocab_size = self.sp.get_piece_size() | |
| self.pad_id = self.sp.pad_id() if self.sp.pad_id() >= 0 else 0 | |
| self.unk_id = self.sp.unk_id() if self.sp.unk_id() >= 0 else 0 | |
| self.bos_id = self.sp.bos_id() if self.sp.bos_id() >= 0 else None | |
| self.eos_id = self.sp.eos_id() if self.sp.eos_id() >= 0 else None | |
| def encode(self, text): | |
| ids = self.sp.encode(text, out_type=int) | |
| if self.bos_id is not None: | |
| ids = [self.bos_id] + ids | |
| if self.eos_id is not None: | |
| ids = ids + [self.eos_id] | |
| ids = self._sanitize_ids(ids) | |
| return ids | |
| def decode(self, ids): | |
| cleaned = [] | |
| for i in ids: | |
| if i in {self.pad_id, self.bos_id}: | |
| continue | |
| if i == self.eos_id: | |
| break | |
| cleaned.append(i) | |
| return self.sp.decode(cleaned) | |
| def pad(self, ids, max_len): | |
| if len(ids) < max_len: | |
| return ids + [self.pad_id] * (max_len - len(ids)) | |
| return ids[:max_len] | |
| def _sanitize_ids(self, ids): | |
| return [i if 0 <= i < self.vocab_size else self.unk_id for i in ids] | |
| # ========================================================= | |
| # LOAD MODEL | |
| # ========================================================= | |
| def load_model(path): | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"Model not found: {path}") | |
| model = torch.load(path, map_location=DEVICE) | |
| if isinstance(model, dict): | |
| raise ValueError("responder.pt is a state_dict. Define architecture to load it.") | |
| model.eval() | |
| return model | |
| # ========================================================= | |
| # GLOBALS | |
| # ========================================================= | |
| tokenizer = HybridTokenizer(SPM_MODEL_PATH) | |
| model = load_model(MODEL_PATH) | |
| # ========================================================= | |
| # HEALTH CHECK | |
| # ========================================================= | |
| def health(): | |
| return {"status": "ok"} | |
| # ========================================================= | |
| # SINGLE PREDICTION ENDPOINT | |
| # ========================================================= | |
| class PredictRequest(BaseModel): | |
| text: str | |
| def predict_get(text: str = Query(..., description="Input text to generate response")): | |
| try: | |
| return {"response": predict(text)} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| def predict_post(request: PredictRequest): | |
| try: | |
| return {"response": predict(request.text)} | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| # ========================================================= | |
| # BATCH PREDICTIONS FROM FILE | |
| # ========================================================= | |
| def batch_predict(): | |
| if not os.path.exists(TEXT_PATH): | |
| return JSONResponse(status_code=404, content={"error": "all_text.txt not found"}) | |
| with open(TEXT_PATH, "r", encoding="utf-8") as f: | |
| lines = [line.strip() for line in f if line.strip()] | |
| results = [] | |
| for line in lines: | |
| try: | |
| resp = predict(line) | |
| results.append({"input": line, "response": resp}) | |
| except Exception as e: | |
| results.append({"input": line, "response": None, "error": str(e)}) | |
| return {"count": len(results), "results": results} | |
| # ========================================================= | |
| # INFERENCE HELPER | |
| # ========================================================= | |
| def predict(text: str): | |
| ids = tokenizer.encode(text) | |
| ids = tokenizer.pad(ids, MAX_LEN) | |
| input_tensor = torch.tensor([ids], dtype=torch.long).to(DEVICE) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| if isinstance(output, tuple): | |
| output = output[0] | |
| if output.dim() == 3: | |
| pred_ids = torch.argmax(output, dim=-1)[0].tolist() | |
| else: | |
| pred_ids = torch.argmax(output, dim=-1).tolist() | |
| return tokenizer.decode(pred_ids) | |
| # ========================================================= | |
| # RUN WITH: | |
| # uvicorn app:app --host 0.0.0.0 --port 8000 | |
| # ========================================================= |