File size: 3,643 Bytes
6d34b0d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 | import copy
import json
import os
from typing import Dict, Any
import torch
from config import CONFIG
from inference import _build_tokenizers, _resolve_device, load_model, run_inference
_STATE = {
"loaded": False,
"model": None,
"cfg": None,
"device": None,
"src_tok": None,
"tgt_tok": None,
}
def _read_model_settings() -> Dict[str, Any]:
if not os.path.exists("model_settings.json"):
return {}
try:
with open("model_settings.json", "r", encoding="utf-8") as f:
data = json.load(f)
return data if isinstance(data, dict) else {}
except Exception:
return {}
def _load_once() -> None:
if _STATE["loaded"]:
return
settings = _read_model_settings()
cfg = copy.deepcopy(CONFIG)
cfg["model_type"] = os.environ.get(
"HF_MODEL_TYPE",
settings.get("model_type", "d3pm_cross_attention"),
)
cfg["data"]["include_negative_examples"] = (
os.environ.get(
"HF_INCLUDE_NEG",
str(settings.get("include_negative_examples", True)).lower(),
).lower()
== "true"
)
num_steps_raw = os.environ.get("HF_NUM_STEPS", settings.get("num_steps"))
if num_steps_raw is not None:
num_steps = int(num_steps_raw)
cfg["model"]["diffusion_steps"] = num_steps
cfg["inference"]["num_steps"] = num_steps
device = _resolve_device(cfg)
model, cfg = load_model("best_model.pt", cfg, device)
src_tok, tgt_tok = _build_tokenizers(cfg)
_STATE["model"] = model
_STATE["cfg"] = cfg
_STATE["device"] = device
_STATE["src_tok"] = src_tok
_STATE["tgt_tok"] = tgt_tok
_STATE["loaded"] = True
def _clean_text(text: str) -> str:
text = " ".join(text.split())
if not text:
return text
toks = text.split()
out = []
prev = None
run = 0
for tok in toks:
if tok == prev:
run += 1
else:
prev = tok
run = 1
if run <= 2:
out.append(tok)
s = " ".join(out)
s = s.replace(" ।", "।").replace(" ॥", "॥")
return " ".join(s.split())
def predict(
text: str,
temperature: float = 0.7,
top_k: int = 40,
repetition_penalty: float = 1.2,
diversity_penalty: float = 0.0,
num_steps: int = 64,
clean_output: bool = True,
) -> Dict[str, Any]:
_load_once()
if not text or not text.strip():
return {"error": "empty input", "output": ""}
cfg = copy.deepcopy(_STATE["cfg"])
cfg["inference"]["temperature"] = float(temperature)
cfg["inference"]["top_k"] = int(top_k)
cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
cfg["inference"]["num_steps"] = int(num_steps)
src_tok = _STATE["src_tok"]
tgt_tok = _STATE["tgt_tok"]
device = _STATE["device"]
input_ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device)
out = run_inference(_STATE["model"], input_ids, cfg)
decoded_ids = [x for x in out[0].tolist() if x > 4]
raw = tgt_tok.decode(decoded_ids).strip()
output = _clean_text(raw) if clean_output else raw
return {
"input": text,
"output": output,
"raw_output": raw,
"config": {
"temperature": float(temperature),
"top_k": int(top_k),
"repetition_penalty": float(repetition_penalty),
"diversity_penalty": float(diversity_penalty),
"num_steps": int(num_steps),
"clean_output": bool(clean_output),
},
}
|