import copy import json import os from typing import Dict, Any import torch from config import CONFIG from inference import _build_tokenizers, _resolve_device, load_model, run_inference _STATE = { "loaded": False, "model": None, "cfg": None, "device": None, "src_tok": None, "tgt_tok": None, } def _read_model_settings() -> Dict[str, Any]: if not os.path.exists("model_settings.json"): return {} try: with open("model_settings.json", "r", encoding="utf-8") as f: data = json.load(f) return data if isinstance(data, dict) else {} except Exception: return {} def _load_once() -> None: if _STATE["loaded"]: return settings = _read_model_settings() cfg = copy.deepcopy(CONFIG) cfg["model_type"] = os.environ.get( "HF_MODEL_TYPE", settings.get("model_type", "d3pm_cross_attention"), ) cfg["data"]["include_negative_examples"] = ( os.environ.get( "HF_INCLUDE_NEG", str(settings.get("include_negative_examples", True)).lower(), ).lower() == "true" ) num_steps_raw = os.environ.get("HF_NUM_STEPS", settings.get("num_steps")) if num_steps_raw is not None: num_steps = int(num_steps_raw) cfg["model"]["diffusion_steps"] = num_steps cfg["inference"]["num_steps"] = num_steps device = _resolve_device(cfg) model, cfg = load_model("best_model.pt", cfg, device) src_tok, tgt_tok = _build_tokenizers(cfg) _STATE["model"] = model _STATE["cfg"] = cfg _STATE["device"] = device _STATE["src_tok"] = src_tok _STATE["tgt_tok"] = tgt_tok _STATE["loaded"] = True def _clean_text(text: str) -> str: text = " ".join(text.split()) if not text: return text toks = text.split() out = [] prev = None run = 0 for tok in toks: if tok == prev: run += 1 else: prev = tok run = 1 if run <= 2: out.append(tok) s = " ".join(out) s = s.replace(" ।", "।").replace(" ॥", "॥") return " ".join(s.split()) def predict( text: str, temperature: float = 0.7, top_k: int = 40, repetition_penalty: float = 1.2, diversity_penalty: float = 0.0, num_steps: int = 64, clean_output: bool = True, ) -> Dict[str, Any]: _load_once() if not text or not text.strip(): return {"error": "empty input", "output": ""} cfg = copy.deepcopy(_STATE["cfg"]) cfg["inference"]["temperature"] = float(temperature) cfg["inference"]["top_k"] = int(top_k) cfg["inference"]["repetition_penalty"] = float(repetition_penalty) cfg["inference"]["diversity_penalty"] = float(diversity_penalty) cfg["inference"]["num_steps"] = int(num_steps) src_tok = _STATE["src_tok"] tgt_tok = _STATE["tgt_tok"] device = _STATE["device"] input_ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device) out = run_inference(_STATE["model"], input_ids, cfg) decoded_ids = [x for x in out[0].tolist() if x > 4] raw = tgt_tok.decode(decoded_ids).strip() output = _clean_text(raw) if clean_output else raw return { "input": text, "output": output, "raw_output": raw, "config": { "temperature": float(temperature), "top_k": int(top_k), "repetition_penalty": float(repetition_penalty), "diversity_penalty": float(diversity_penalty), "num_steps": int(num_steps), "clean_output": bool(clean_output), }, }