File size: 3,998 Bytes
1cc095d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import copy

import torch
import torch.nn.functional as F

from config import CONFIG


def _resolve_device(cfg: dict) -> torch.device:
    requested = cfg["training"]["device"]
    if requested == "cuda" and not torch.cuda.is_available():
        requested = "cpu"
    if requested == "mps" and not torch.backends.mps.is_available():
        requested = "cpu"
    cfg["training"]["device"] = requested
    return torch.device(requested)


def _build_tokenizers(cfg):
    from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer

    src_tok = SanskritSourceTokenizer(
        vocab_size=cfg["model"].get("src_vocab_size", 16000),
        max_len=cfg["model"]["max_seq_len"],
    )
    tgt_tok = SanskritTargetTokenizer(
        vocab_size=cfg["model"].get("tgt_vocab_size", 16000),
        max_len=cfg["model"]["max_seq_len"],
    )
    return src_tok, tgt_tok


def load_model(ckpt_path: str, base_cfg: dict, device: torch.device):
    from model.sanskrit_model import SanskritModel

    cfg = copy.deepcopy(base_cfg)
    state = torch.load(ckpt_path, map_location="cpu")

    emb_key = "model.src_embed.token_emb.weight"
    if emb_key in state:
        vocab, d_model = state[emb_key].shape
        cfg["model"]["src_vocab_size"] = vocab
        cfg["model"]["d_model"] = d_model
        cfg["model"]["d_ff"] = d_model * 4

    layer_ids = {int(k.split(".")[2]) for k in state if k.startswith("model.encoder_blocks.")}
    if layer_ids:
        cfg["model"]["n_layers"] = max(layer_ids) + 1

    pos_key = "model.src_embed.pos_enc.pe"
    if pos_key in state:
        cfg["model"]["max_seq_len"] = state[pos_key].shape[1]

    d_model = cfg["model"]["d_model"]
    n_heads = cfg["model"].get("n_heads", 8)
    if d_model % n_heads != 0:
        n_heads = next(h for h in [8, 6, 4, 2, 1] if d_model % h == 0)
    cfg["model"]["n_heads"] = n_heads

    model = SanskritModel(cfg).to(device)
    model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=False)
    model.eval()
    return model, cfg


def run_inference(model, input_ids, cfg):
    inf = cfg["inference"]
    device = input_ids.device
    bsz, seqlen = input_ids.shape
    inner = model.model

    total_steps = inner.scheduler.num_timesteps
    steps = int(inf["num_steps"])
    step_size = max(1, total_steps // max(steps, 1))
    timesteps = list(range(total_steps - 1, -1, -step_size))
    if timesteps[-1] != 0:
        timesteps.append(0)

    x0_est = torch.full((bsz, seqlen), inner.mask_token_id, dtype=torch.long, device=device)
    hint = None

    with torch.no_grad():
        for i, t_val in enumerate(timesteps):
            is_last = i == len(timesteps) - 1
            t = torch.full((bsz,), t_val, dtype=torch.long, device=device)

            logits, _ = model(input_ids, x0_est, t, x0_hint=hint, inference_mode=True)

            if inf["repetition_penalty"] != 1.0:
                from model.d3pm_model_cross_attention import _apply_repetition_penalty

                logits = _apply_repetition_penalty(logits, x0_est, float(inf["repetition_penalty"]))
            if inf["diversity_penalty"] > 0.0:
                from model.d3pm_model_cross_attention import _apply_diversity_penalty_fixed

                logits = _apply_diversity_penalty_fixed(logits, float(inf["diversity_penalty"]))

            logits = logits / max(float(inf["temperature"]), 1e-5)
            if int(inf["top_k"]) > 0:
                from model.d3pm_model_cross_attention import _top_k_filter

                logits = _top_k_filter(logits, int(inf["top_k"]))

            probs = F.softmax(logits, dim=-1)
            if is_last:
                x0_est = torch.argmax(probs, dim=-1)
            else:
                from model.d3pm_model_cross_attention import _batch_multinomial

                x0_est = _batch_multinomial(probs)
            hint = x0_est

    return x0_est


__all__ = [
    "CONFIG",
    "_resolve_device",
    "_build_tokenizers",
    "load_model",
    "run_inference",
]