File size: 3,643 Bytes
6d34b0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import copy
import json
import os
from typing import Dict, Any

import torch

from config import CONFIG
from inference import _build_tokenizers, _resolve_device, load_model, run_inference


_STATE = {
    "loaded": False,
    "model": None,
    "cfg": None,
    "device": None,
    "src_tok": None,
    "tgt_tok": None,
}


def _read_model_settings() -> Dict[str, Any]:
    if not os.path.exists("model_settings.json"):
        return {}
    try:
        with open("model_settings.json", "r", encoding="utf-8") as f:
            data = json.load(f)
        return data if isinstance(data, dict) else {}
    except Exception:
        return {}


def _load_once() -> None:
    if _STATE["loaded"]:
        return

    settings = _read_model_settings()
    cfg = copy.deepcopy(CONFIG)
    cfg["model_type"] = os.environ.get(
        "HF_MODEL_TYPE",
        settings.get("model_type", "d3pm_cross_attention"),
    )
    cfg["data"]["include_negative_examples"] = (
        os.environ.get(
            "HF_INCLUDE_NEG",
            str(settings.get("include_negative_examples", True)).lower(),
        ).lower()
        == "true"
    )
    num_steps_raw = os.environ.get("HF_NUM_STEPS", settings.get("num_steps"))
    if num_steps_raw is not None:
        num_steps = int(num_steps_raw)
        cfg["model"]["diffusion_steps"] = num_steps
        cfg["inference"]["num_steps"] = num_steps
    device = _resolve_device(cfg)

    model, cfg = load_model("best_model.pt", cfg, device)
    src_tok, tgt_tok = _build_tokenizers(cfg)

    _STATE["model"] = model
    _STATE["cfg"] = cfg
    _STATE["device"] = device
    _STATE["src_tok"] = src_tok
    _STATE["tgt_tok"] = tgt_tok
    _STATE["loaded"] = True


def _clean_text(text: str) -> str:
    text = " ".join(text.split())
    if not text:
        return text
    toks = text.split()
    out = []
    prev = None
    run = 0
    for tok in toks:
        if tok == prev:
            run += 1
        else:
            prev = tok
            run = 1
        if run <= 2:
            out.append(tok)
    s = " ".join(out)
    s = s.replace(" ।", "।").replace(" ॥", "॥")
    return " ".join(s.split())


def predict(
    text: str,
    temperature: float = 0.7,
    top_k: int = 40,
    repetition_penalty: float = 1.2,
    diversity_penalty: float = 0.0,
    num_steps: int = 64,
    clean_output: bool = True,
) -> Dict[str, Any]:
    _load_once()
    if not text or not text.strip():
        return {"error": "empty input", "output": ""}

    cfg = copy.deepcopy(_STATE["cfg"])
    cfg["inference"]["temperature"] = float(temperature)
    cfg["inference"]["top_k"] = int(top_k)
    cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
    cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
    cfg["inference"]["num_steps"] = int(num_steps)

    src_tok = _STATE["src_tok"]
    tgt_tok = _STATE["tgt_tok"]
    device = _STATE["device"]

    input_ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device)
    out = run_inference(_STATE["model"], input_ids, cfg)
    decoded_ids = [x for x in out[0].tolist() if x > 4]
    raw = tgt_tok.decode(decoded_ids).strip()
    output = _clean_text(raw) if clean_output else raw

    return {
        "input": text,
        "output": output,
        "raw_output": raw,
        "config": {
            "temperature": float(temperature),
            "top_k": int(top_k),
            "repetition_penalty": float(repetition_penalty),
            "diversity_penalty": float(diversity_penalty),
            "num_steps": int(num_steps),
            "clean_output": bool(clean_output),
        },
    }