File size: 2,758 Bytes
7d6a683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import copy
from typing import Dict, Any

import torch

from config import CONFIG
from inference import _build_tokenizers, _resolve_device, load_model, run_inference


_STATE = {
    "loaded": False,
    "model": None,
    "cfg": None,
    "device": None,
    "src_tok": None,
    "tgt_tok": None,
}


def _load_once() -> None:
    if _STATE["loaded"]:
        return

    cfg = copy.deepcopy(CONFIG)
    cfg["model_type"] = "d3pm_cross_attention"
    cfg["data"]["include_negative_examples"] = True
    device = _resolve_device(cfg)

    model, cfg = load_model("best_model.pt", cfg, device)
    src_tok, tgt_tok = _build_tokenizers(cfg)

    _STATE["model"] = model
    _STATE["cfg"] = cfg
    _STATE["device"] = device
    _STATE["src_tok"] = src_tok
    _STATE["tgt_tok"] = tgt_tok
    _STATE["loaded"] = True


def _clean_text(text: str) -> str:
    text = " ".join(text.split())
    if not text:
        return text
    toks = text.split()
    out = []
    prev = None
    run = 0
    for tok in toks:
        if tok == prev:
            run += 1
        else:
            prev = tok
            run = 1
        if run <= 2:
            out.append(tok)
    s = " ".join(out)
    s = s.replace(" ।", "।").replace(" ॥", "॥")
    return " ".join(s.split())


def predict(
    text: str,
    temperature: float = 0.7,
    top_k: int = 40,
    repetition_penalty: float = 1.2,
    diversity_penalty: float = 0.0,
    num_steps: int = 64,
    clean_output: bool = True,
) -> Dict[str, Any]:
    _load_once()
    if not text or not text.strip():
        return {"error": "empty input", "output": ""}

    cfg = copy.deepcopy(_STATE["cfg"])
    cfg["inference"]["temperature"] = float(temperature)
    cfg["inference"]["top_k"] = int(top_k)
    cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
    cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
    cfg["inference"]["num_steps"] = int(num_steps)

    src_tok = _STATE["src_tok"]
    tgt_tok = _STATE["tgt_tok"]
    device = _STATE["device"]

    input_ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device)
    out = run_inference(_STATE["model"], input_ids, cfg)
    decoded_ids = [x for x in out[0].tolist() if x > 4]
    raw = tgt_tok.decode(decoded_ids).strip()
    output = _clean_text(raw) if clean_output else raw

    return {
        "input": text,
        "output": output,
        "raw_output": raw,
        "config": {
            "temperature": float(temperature),
            "top_k": int(top_k),
            "repetition_penalty": float(repetition_penalty),
            "diversity_penalty": float(diversity_penalty),
            "num_steps": int(num_steps),
            "clean_output": bool(clean_output),
        },
    }