Spaces:
Sleeping
Sleeping
| import copy | |
| import torch | |
| import torch.nn.functional as F | |
| from config import CONFIG | |
| def _resolve_device(cfg: dict) -> torch.device: | |
| requested = cfg["training"]["device"] | |
| if requested == "cuda" and not torch.cuda.is_available(): | |
| requested = "cpu" | |
| if requested == "mps" and not torch.backends.mps.is_available(): | |
| requested = "cpu" | |
| cfg["training"]["device"] = requested | |
| return torch.device(requested) | |
| def _build_tokenizers(cfg): | |
| from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer | |
| src_tok = SanskritSourceTokenizer( | |
| vocab_size=cfg["model"].get("src_vocab_size", 16000), | |
| max_len=cfg["model"]["max_seq_len"], | |
| ) | |
| tgt_tok = SanskritTargetTokenizer( | |
| vocab_size=cfg["model"].get("tgt_vocab_size", 16000), | |
| max_len=cfg["model"]["max_seq_len"], | |
| ) | |
| return src_tok, tgt_tok | |
| def load_model(ckpt_path: str, base_cfg: dict, device: torch.device): | |
| from model.sanskrit_model import SanskritModel | |
| cfg = copy.deepcopy(base_cfg) | |
| state = torch.load(ckpt_path, map_location="cpu") | |
| emb_key = "model.src_embed.token_emb.weight" | |
| if emb_key in state: | |
| vocab, d_model = state[emb_key].shape | |
| cfg["model"]["src_vocab_size"] = vocab | |
| cfg["model"]["d_model"] = d_model | |
| cfg["model"]["d_ff"] = d_model * 4 | |
| layer_ids = {int(k.split(".")[2]) for k in state if k.startswith("model.encoder_blocks.")} | |
| if layer_ids: | |
| cfg["model"]["n_layers"] = max(layer_ids) + 1 | |
| pos_key = "model.src_embed.pos_enc.pe" | |
| if pos_key in state: | |
| cfg["model"]["max_seq_len"] = state[pos_key].shape[1] | |
| d_model = cfg["model"]["d_model"] | |
| n_heads = cfg["model"].get("n_heads", 8) | |
| if d_model % n_heads != 0: | |
| n_heads = next(h for h in [8, 6, 4, 2, 1] if d_model % h == 0) | |
| cfg["model"]["n_heads"] = n_heads | |
| model = SanskritModel(cfg).to(device) | |
| model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=False) | |
| model.eval() | |
| return model, cfg | |
| def run_inference(model, input_ids, cfg): | |
| inf = cfg["inference"] | |
| device = input_ids.device | |
| bsz, seqlen = input_ids.shape | |
| inner = model.model | |
| total_steps = inner.scheduler.num_timesteps | |
| steps = int(inf["num_steps"]) | |
| step_size = max(1, total_steps // max(steps, 1)) | |
| timesteps = list(range(total_steps - 1, -1, -step_size)) | |
| if timesteps[-1] != 0: | |
| timesteps.append(0) | |
| x0_est = torch.full((bsz, seqlen), inner.mask_token_id, dtype=torch.long, device=device) | |
| hint = None | |
| with torch.no_grad(): | |
| for i, t_val in enumerate(timesteps): | |
| is_last = i == len(timesteps) - 1 | |
| t = torch.full((bsz,), t_val, dtype=torch.long, device=device) | |
| logits, _ = model(input_ids, x0_est, t, x0_hint=hint, inference_mode=True) | |
| if inf["repetition_penalty"] != 1.0: | |
| from model.d3pm_model_cross_attention import _apply_repetition_penalty | |
| logits = _apply_repetition_penalty(logits, x0_est, float(inf["repetition_penalty"])) | |
| if inf["diversity_penalty"] > 0.0: | |
| from model.d3pm_model_cross_attention import _apply_diversity_penalty_fixed | |
| logits = _apply_diversity_penalty_fixed(logits, float(inf["diversity_penalty"])) | |
| logits = logits / max(float(inf["temperature"]), 1e-5) | |
| if int(inf["top_k"]) > 0: | |
| from model.d3pm_model_cross_attention import _top_k_filter | |
| logits = _top_k_filter(logits, int(inf["top_k"])) | |
| probs = F.softmax(logits, dim=-1) | |
| if is_last: | |
| x0_est = torch.argmax(probs, dim=-1) | |
| else: | |
| from model.d3pm_model_cross_attention import _batch_multinomial | |
| x0_est = _batch_multinomial(probs) | |
| hint = x0_est | |
| return x0_est | |
| __all__ = [ | |
| "CONFIG", | |
| "_resolve_device", | |
| "_build_tokenizers", | |
| "load_model", | |
| "run_inference", | |
| ] | |