DevaFlow-space / inference.py
bhsinghgrid's picture
Upload Gradio Space app
1cc095d verified
import copy
import torch
import torch.nn.functional as F
from config import CONFIG
def _resolve_device(cfg: dict) -> torch.device:
requested = cfg["training"]["device"]
if requested == "cuda" and not torch.cuda.is_available():
requested = "cpu"
if requested == "mps" and not torch.backends.mps.is_available():
requested = "cpu"
cfg["training"]["device"] = requested
return torch.device(requested)
def _build_tokenizers(cfg):
from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
src_tok = SanskritSourceTokenizer(
vocab_size=cfg["model"].get("src_vocab_size", 16000),
max_len=cfg["model"]["max_seq_len"],
)
tgt_tok = SanskritTargetTokenizer(
vocab_size=cfg["model"].get("tgt_vocab_size", 16000),
max_len=cfg["model"]["max_seq_len"],
)
return src_tok, tgt_tok
def load_model(ckpt_path: str, base_cfg: dict, device: torch.device):
from model.sanskrit_model import SanskritModel
cfg = copy.deepcopy(base_cfg)
state = torch.load(ckpt_path, map_location="cpu")
emb_key = "model.src_embed.token_emb.weight"
if emb_key in state:
vocab, d_model = state[emb_key].shape
cfg["model"]["src_vocab_size"] = vocab
cfg["model"]["d_model"] = d_model
cfg["model"]["d_ff"] = d_model * 4
layer_ids = {int(k.split(".")[2]) for k in state if k.startswith("model.encoder_blocks.")}
if layer_ids:
cfg["model"]["n_layers"] = max(layer_ids) + 1
pos_key = "model.src_embed.pos_enc.pe"
if pos_key in state:
cfg["model"]["max_seq_len"] = state[pos_key].shape[1]
d_model = cfg["model"]["d_model"]
n_heads = cfg["model"].get("n_heads", 8)
if d_model % n_heads != 0:
n_heads = next(h for h in [8, 6, 4, 2, 1] if d_model % h == 0)
cfg["model"]["n_heads"] = n_heads
model = SanskritModel(cfg).to(device)
model.load_state_dict(torch.load(ckpt_path, map_location=device), strict=False)
model.eval()
return model, cfg
def run_inference(model, input_ids, cfg):
inf = cfg["inference"]
device = input_ids.device
bsz, seqlen = input_ids.shape
inner = model.model
total_steps = inner.scheduler.num_timesteps
steps = int(inf["num_steps"])
step_size = max(1, total_steps // max(steps, 1))
timesteps = list(range(total_steps - 1, -1, -step_size))
if timesteps[-1] != 0:
timesteps.append(0)
x0_est = torch.full((bsz, seqlen), inner.mask_token_id, dtype=torch.long, device=device)
hint = None
with torch.no_grad():
for i, t_val in enumerate(timesteps):
is_last = i == len(timesteps) - 1
t = torch.full((bsz,), t_val, dtype=torch.long, device=device)
logits, _ = model(input_ids, x0_est, t, x0_hint=hint, inference_mode=True)
if inf["repetition_penalty"] != 1.0:
from model.d3pm_model_cross_attention import _apply_repetition_penalty
logits = _apply_repetition_penalty(logits, x0_est, float(inf["repetition_penalty"]))
if inf["diversity_penalty"] > 0.0:
from model.d3pm_model_cross_attention import _apply_diversity_penalty_fixed
logits = _apply_diversity_penalty_fixed(logits, float(inf["diversity_penalty"]))
logits = logits / max(float(inf["temperature"]), 1e-5)
if int(inf["top_k"]) > 0:
from model.d3pm_model_cross_attention import _top_k_filter
logits = _top_k_filter(logits, int(inf["top_k"]))
probs = F.softmax(logits, dim=-1)
if is_last:
x0_est = torch.argmax(probs, dim=-1)
else:
from model.d3pm_model_cross_attention import _batch_multinomial
x0_est = _batch_multinomial(probs)
hint = x0_est
return x0_est
__all__ = [
"CONFIG",
"_resolve_device",
"_build_tokenizers",
"load_model",
"run_inference",
]