import copy from typing import Dict, Any import torch from config import CONFIG from inference import _build_tokenizers, _resolve_device, load_model, run_inference _STATE = { "loaded": False, "model": None, "cfg": None, "device": None, "src_tok": None, "tgt_tok": None, } def _load_once() -> None: if _STATE["loaded"]: return cfg = copy.deepcopy(CONFIG) cfg["model_type"] = "d3pm_cross_attention" cfg["data"]["include_negative_examples"] = True device = _resolve_device(cfg) model, cfg = load_model("best_model.pt", cfg, device) src_tok, tgt_tok = _build_tokenizers(cfg) _STATE["model"] = model _STATE["cfg"] = cfg _STATE["device"] = device _STATE["src_tok"] = src_tok _STATE["tgt_tok"] = tgt_tok _STATE["loaded"] = True def _clean_text(text: str) -> str: text = " ".join(text.split()) if not text: return text toks = text.split() out = [] prev = None run = 0 for tok in toks: if tok == prev: run += 1 else: prev = tok run = 1 if run <= 2: out.append(tok) s = " ".join(out) s = s.replace(" ।", "।").replace(" ॥", "॥") return " ".join(s.split()) def predict( text: str, temperature: float = 0.7, top_k: int = 40, repetition_penalty: float = 1.2, diversity_penalty: float = 0.0, num_steps: int = 64, clean_output: bool = True, ) -> Dict[str, Any]: _load_once() if not text or not text.strip(): return {"error": "empty input", "output": ""} cfg = copy.deepcopy(_STATE["cfg"]) cfg["inference"]["temperature"] = float(temperature) cfg["inference"]["top_k"] = int(top_k) cfg["inference"]["repetition_penalty"] = float(repetition_penalty) cfg["inference"]["diversity_penalty"] = float(diversity_penalty) cfg["inference"]["num_steps"] = int(num_steps) src_tok = _STATE["src_tok"] tgt_tok = _STATE["tgt_tok"] device = _STATE["device"] input_ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device) out = run_inference(_STATE["model"], input_ids, cfg) decoded_ids = [x for x in out[0].tolist() if x > 4] raw = tgt_tok.decode(decoded_ids).strip() output = _clean_text(raw) if clean_output else raw return { "input": text, "output": output, "raw_output": raw, "config": { "temperature": float(temperature), "top_k": int(top_k), "repetition_penalty": float(repetition_penalty), "diversity_penalty": float(diversity_penalty), "num_steps": int(num_steps), "clean_output": bool(clean_output), }, }