DevaFlow / inference_api.py
bhsinghgrid's picture
Add files using upload-large-folder tool
7d6a683 verified
import copy
from typing import Dict, Any
import torch
from config import CONFIG
from inference import _build_tokenizers, _resolve_device, load_model, run_inference
_STATE = {
"loaded": False,
"model": None,
"cfg": None,
"device": None,
"src_tok": None,
"tgt_tok": None,
}
def _load_once() -> None:
if _STATE["loaded"]:
return
cfg = copy.deepcopy(CONFIG)
cfg["model_type"] = "d3pm_cross_attention"
cfg["data"]["include_negative_examples"] = True
device = _resolve_device(cfg)
model, cfg = load_model("best_model.pt", cfg, device)
src_tok, tgt_tok = _build_tokenizers(cfg)
_STATE["model"] = model
_STATE["cfg"] = cfg
_STATE["device"] = device
_STATE["src_tok"] = src_tok
_STATE["tgt_tok"] = tgt_tok
_STATE["loaded"] = True
def _clean_text(text: str) -> str:
text = " ".join(text.split())
if not text:
return text
toks = text.split()
out = []
prev = None
run = 0
for tok in toks:
if tok == prev:
run += 1
else:
prev = tok
run = 1
if run <= 2:
out.append(tok)
s = " ".join(out)
s = s.replace(" ।", "।").replace(" ॥", "॥")
return " ".join(s.split())
def predict(
text: str,
temperature: float = 0.7,
top_k: int = 40,
repetition_penalty: float = 1.2,
diversity_penalty: float = 0.0,
num_steps: int = 64,
clean_output: bool = True,
) -> Dict[str, Any]:
_load_once()
if not text or not text.strip():
return {"error": "empty input", "output": ""}
cfg = copy.deepcopy(_STATE["cfg"])
cfg["inference"]["temperature"] = float(temperature)
cfg["inference"]["top_k"] = int(top_k)
cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
cfg["inference"]["num_steps"] = int(num_steps)
src_tok = _STATE["src_tok"]
tgt_tok = _STATE["tgt_tok"]
device = _STATE["device"]
input_ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device)
out = run_inference(_STATE["model"], input_ids, cfg)
decoded_ids = [x for x in out[0].tolist() if x > 4]
raw = tgt_tok.decode(decoded_ids).strip()
output = _clean_text(raw) if clean_output else raw
return {
"input": text,
"output": output,
"raw_output": raw,
"config": {
"temperature": float(temperature),
"top_k": int(top_k),
"repetition_penalty": float(repetition_penalty),
"diversity_penalty": float(diversity_penalty),
"num_steps": int(num_steps),
"clean_output": bool(clean_output),
},
}