| import copy |
| from typing import Dict, Any |
|
|
| import torch |
|
|
| from config import CONFIG |
| from inference import _build_tokenizers, _resolve_device, load_model, run_inference |
|
|
|
|
| _STATE = { |
| "loaded": False, |
| "model": None, |
| "cfg": None, |
| "device": None, |
| "src_tok": None, |
| "tgt_tok": None, |
| } |
|
|
|
|
| def _load_once() -> None: |
| if _STATE["loaded"]: |
| return |
|
|
| cfg = copy.deepcopy(CONFIG) |
| cfg["model_type"] = "d3pm_cross_attention" |
| cfg["data"]["include_negative_examples"] = True |
| device = _resolve_device(cfg) |
|
|
| model, cfg = load_model("best_model.pt", cfg, device) |
| src_tok, tgt_tok = _build_tokenizers(cfg) |
|
|
| _STATE["model"] = model |
| _STATE["cfg"] = cfg |
| _STATE["device"] = device |
| _STATE["src_tok"] = src_tok |
| _STATE["tgt_tok"] = tgt_tok |
| _STATE["loaded"] = True |
|
|
|
|
| def _clean_text(text: str) -> str: |
| text = " ".join(text.split()) |
| if not text: |
| return text |
| toks = text.split() |
| out = [] |
| prev = None |
| run = 0 |
| for tok in toks: |
| if tok == prev: |
| run += 1 |
| else: |
| prev = tok |
| run = 1 |
| if run <= 2: |
| out.append(tok) |
| s = " ".join(out) |
| s = s.replace(" ।", "।").replace(" ॥", "॥") |
| return " ".join(s.split()) |
|
|
|
|
| def predict( |
| text: str, |
| temperature: float = 0.7, |
| top_k: int = 40, |
| repetition_penalty: float = 1.2, |
| diversity_penalty: float = 0.0, |
| num_steps: int = 64, |
| clean_output: bool = True, |
| ) -> Dict[str, Any]: |
| _load_once() |
| if not text or not text.strip(): |
| return {"error": "empty input", "output": ""} |
|
|
| cfg = copy.deepcopy(_STATE["cfg"]) |
| cfg["inference"]["temperature"] = float(temperature) |
| cfg["inference"]["top_k"] = int(top_k) |
| cfg["inference"]["repetition_penalty"] = float(repetition_penalty) |
| cfg["inference"]["diversity_penalty"] = float(diversity_penalty) |
| cfg["inference"]["num_steps"] = int(num_steps) |
|
|
| src_tok = _STATE["src_tok"] |
| tgt_tok = _STATE["tgt_tok"] |
| device = _STATE["device"] |
|
|
| input_ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device) |
| out = run_inference(_STATE["model"], input_ids, cfg) |
| decoded_ids = [x for x in out[0].tolist() if x > 4] |
| raw = tgt_tok.decode(decoded_ids).strip() |
| output = _clean_text(raw) if clean_output else raw |
|
|
| return { |
| "input": text, |
| "output": output, |
| "raw_output": raw, |
| "config": { |
| "temperature": float(temperature), |
| "top_k": int(top_k), |
| "repetition_penalty": float(repetition_penalty), |
| "diversity_penalty": float(diversity_penalty), |
| "num_steps": int(num_steps), |
| "clean_output": bool(clean_output), |
| }, |
| } |
|
|