File size: 2,245 Bytes
7d6a683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
sanskrit_model.py  — Fixed
===========================
Added inference_mode parameter to forward() so reverse_process.py can
pass inference_mode=True without a TypeError.

The wrapper introspects each inner model's signature and only passes
kwargs that model actually accepts — safe across all four architectures.
"""

import torch
import torch.nn as nn
import inspect


class SanskritModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        model_type = cfg['model_type']

        if model_type == 'd3pm_cross_attention':
            from model.d3pm_model_cross_attention import D3PMCrossAttention
            self.model = D3PMCrossAttention(cfg)

        elif model_type == 'd3pm_encoder_decoder':
            from model.d3pm_model_encoder_decoder import D3PMEncoderDecoder
            self.model = D3PMEncoderDecoder(cfg)

        elif model_type == 'baseline_cross_attention':
            from model.d3pm_model_cross_attention import BaselineCrossAttention
            self.model = BaselineCrossAttention(cfg)

        elif model_type == 'baseline_encoder_decoder':
            from model.d3pm_model_encoder_decoder import BaselineEncoderDecoder
            self.model = BaselineEncoderDecoder(cfg)

        else:
            raise ValueError(f"Unknown model_type: {model_type}")

    def forward(self, input_ids, target_ids, t, x0_hint=None, inference_mode=False):
        """
        Forward pass.  Introspects the inner model's signature so only
        supported kwargs are passed — works with all four architectures.
        """
        sig    = inspect.signature(self.model.forward).parameters
        kwargs = {}
        if 'x0_hint'        in sig:
            kwargs['x0_hint']        = x0_hint
        if 'inference_mode' in sig:
            kwargs['inference_mode'] = inference_mode

        if 't' in sig:
            return self.model(input_ids, target_ids, t, **kwargs)
        else:
            return self.model(input_ids, target_ids, **kwargs)

    @torch.no_grad()
    def generate(self, src, **kwargs):
        sig      = inspect.signature(self.model.generate).parameters
        filtered = {k: v for k, v in kwargs.items() if k in sig}
        return self.model.generate(src, **filtered)