import copy import torch import torch.nn.functional as F from config import CONFIG def _resolve_device(cfg: dict) -> torch.device: requested = cfg["training"]["device"] if requested == "cuda" and not torch.cuda.is_available(): requested = "cpu" if requested == "mps" and not torch.backends.mps.is_available(): requested = "cpu" cfg["training"]["device"] = requested return torch.device(requested) def _build_tokenizers(cfg): from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer src_tok = SanskritSourceTokenizer( vocab_size=cfg["model"].get("src_vocab_size", 16000), max_len=cfg["model"]["max_seq_len"], ) tgt_tok = SanskritTargetTokenizer( vocab_size=cfg["model"].get("tgt_vocab_size", 16000), max_len=cfg["model"]["max_seq_len"], ) return src_tok, tgt_tok def load_model(ckpt_path: str, base_cfg: dict, device: torch.device): from model.sanskrit_model import SanskritModel cfg = copy.deepcopy(base_cfg) state = torch.load(ckpt_path, map_location="cpu") emb_key = "model.src_embed.token_emb.weight" if emb_key in state: vocab, d_model = state[emb_key].shape cfg["model"]["src_vocab_size"] = vocab cfg["model"]["d_model"] = d_model cfg["model"]["d_ff"] = d_model * 4 layer_ids = {int(k.split(".")[2]) for k in state if k.startswith("model.encoder_blocks.")} if layer_ids: cfg["model"]["n_layers"] = max(layer_ids) + 1 pos_key = "model.src_embed.pos_enc.pe" if pos_key in state: cfg["model"]["max_seq_len"] = state[pos_key].shape[1] d_model = cfg["model"]["d_model"] n_heads = cfg["model"].get("n_heads", 8) if d_model % n_heads != 0: n_heads = next(h for h in [8, 6, 4, 2, 1] if d_model % h == 0) cfg["model"]["n_heads"] = n_heads model = SanskritModel(cfg).to(device) model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=False) model.eval() return model, cfg def run_inference(model, input_ids, cfg): inf = cfg["inference"] device = input_ids.device bsz, seqlen = input_ids.shape inner = model.model total_steps = inner.scheduler.num_timesteps steps = int(inf["num_steps"]) step_size = max(1, total_steps // max(steps, 1)) timesteps = list(range(total_steps - 1, -1, -step_size)) if timesteps[-1] != 0: timesteps.append(0) x0_est = torch.full((bsz, seqlen), inner.mask_token_id, dtype=torch.long, device=device) hint = None with torch.no_grad(): for i, t_val in enumerate(timesteps): is_last = i == len(timesteps) - 1 t = torch.full((bsz,), t_val, dtype=torch.long, device=device) logits, _ = model(input_ids, x0_est, t, x0_hint=hint, inference_mode=True) if inf["repetition_penalty"] != 1.0: from model.d3pm_model_cross_attention import _apply_repetition_penalty logits = _apply_repetition_penalty(logits, x0_est, float(inf["repetition_penalty"])) if inf["diversity_penalty"] > 0.0: from model.d3pm_model_cross_attention import _apply_diversity_penalty_fixed logits = _apply_diversity_penalty_fixed(logits, float(inf["diversity_penalty"])) logits = logits / max(float(inf["temperature"]), 1e-5) if int(inf["top_k"]) > 0: from model.d3pm_model_cross_attention import _top_k_filter logits = _top_k_filter(logits, int(inf["top_k"])) probs = F.softmax(logits, dim=-1) if is_last: x0_est = torch.argmax(probs, dim=-1) else: from model.d3pm_model_cross_attention import _batch_multinomial x0_est = _batch_multinomial(probs) hint = x0_est return x0_est __all__ = [ "CONFIG", "_resolve_device", "_build_tokenizers", "load_model", "run_inference", ]