devflow / model /sanskrit_model.py
bhsinghgrid's picture
Update app/inference + ablation task outputs
27f26fd verified
"""
sanskrit_model.py — Fixed
===========================
Added inference_mode parameter to forward() so reverse_process.py can
pass inference_mode=True without a TypeError.
The wrapper introspects each inner model's signature and only passes
kwargs that model actually accepts — safe across all four architectures.
"""
import torch
import torch.nn as nn
import inspect
class SanskritModel(nn.Module):
def __init__(self, cfg):
super().__init__()
model_type = cfg['model_type']
if model_type == 'd3pm_cross_attention':
from model.d3pm_model_cross_attention import D3PMCrossAttention
self.model = D3PMCrossAttention(cfg)
elif model_type == 'd3pm_encoder_decoder':
from model.d3pm_model_encoder_decoder import D3PMEncoderDecoder
self.model = D3PMEncoderDecoder(cfg)
elif model_type == 'baseline_cross_attention':
from model.d3pm_model_cross_attention import BaselineCrossAttention
self.model = BaselineCrossAttention(cfg)
elif model_type == 'baseline_encoder_decoder':
from model.d3pm_model_encoder_decoder import BaselineEncoderDecoder
self.model = BaselineEncoderDecoder(cfg)
else:
raise ValueError(f"Unknown model_type: {model_type}")
def forward(self, input_ids, target_ids, t, x0_hint=None, inference_mode=False):
"""
Forward pass. Introspects the inner model's signature so only
supported kwargs are passed — works with all four architectures.
"""
sig = inspect.signature(self.model.forward).parameters
kwargs = {}
if 'x0_hint' in sig:
kwargs['x0_hint'] = x0_hint
if 'inference_mode' in sig:
kwargs['inference_mode'] = inference_mode
if 't' in sig:
return self.model(input_ids, target_ids, t, **kwargs)
else:
return self.model(input_ids, target_ids, **kwargs)
@torch.no_grad()
def generate(self, src, **kwargs):
sig = inspect.signature(self.model.generate).parameters
filtered = {k: v for k, v in kwargs.items() if k in sig}
return self.model.generate(src, **filtered)