| """ |
| sanskrit_model.py — Fixed |
| =========================== |
| Added inference_mode parameter to forward() so reverse_process.py can |
| pass inference_mode=True without a TypeError. |
| |
| The wrapper introspects each inner model's signature and only passes |
| kwargs that model actually accepts — safe across all four architectures. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import inspect |
|
|
|
|
| class SanskritModel(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| model_type = cfg['model_type'] |
|
|
| if model_type == 'd3pm_cross_attention': |
| from model.d3pm_model_cross_attention import D3PMCrossAttention |
| self.model = D3PMCrossAttention(cfg) |
|
|
| elif model_type == 'd3pm_encoder_decoder': |
| from model.d3pm_model_encoder_decoder import D3PMEncoderDecoder |
| self.model = D3PMEncoderDecoder(cfg) |
|
|
| elif model_type == 'baseline_cross_attention': |
| from model.d3pm_model_cross_attention import BaselineCrossAttention |
| self.model = BaselineCrossAttention(cfg) |
|
|
| elif model_type == 'baseline_encoder_decoder': |
| from model.d3pm_model_encoder_decoder import BaselineEncoderDecoder |
| self.model = BaselineEncoderDecoder(cfg) |
|
|
| else: |
| raise ValueError(f"Unknown model_type: {model_type}") |
|
|
| def forward(self, input_ids, target_ids, t, x0_hint=None, inference_mode=False): |
| """ |
| Forward pass. Introspects the inner model's signature so only |
| supported kwargs are passed — works with all four architectures. |
| """ |
| sig = inspect.signature(self.model.forward).parameters |
| kwargs = {} |
| if 'x0_hint' in sig: |
| kwargs['x0_hint'] = x0_hint |
| if 'inference_mode' in sig: |
| kwargs['inference_mode'] = inference_mode |
|
|
| if 't' in sig: |
| return self.model(input_ids, target_ids, t, **kwargs) |
| else: |
| return self.model(input_ids, target_ids, **kwargs) |
|
|
| @torch.no_grad() |
| def generate(self, src, **kwargs): |
| sig = inspect.signature(self.model.generate).parameters |
| filtered = {k: v for k, v in kwargs.items() if k in sig} |
| return self.model.generate(src, **filtered) |