File size: 2,764 Bytes
cd99730
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# modeling_transolver.py

from typing import Optional, Tuple

import torch
from physicsnemo.models.transolver import Transolver as TransolverBase
from transformers import PretrainedConfig, PreTrainedModel


class TransolverConfig(PretrainedConfig):
    model_type = "transolver"

    def __init__(
        self,
        functional_dim: int = 5,
        out_dim: int = 1,
        embedding_dim: Optional[int] = 3,
        n_layers: int = 4,
        n_hidden: int = 128,
        dropout: float = 0.0,
        n_head: int = 8,
        act: str = "gelu",
        mlp_ratio: int = 4,
        slice_num: int = 32,
        unified_pos: bool = False,
        ref: int = 8,
        structured_shape: Optional[Tuple[int, ...]] = None,
        use_te: bool = False,
        time_input: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.functional_dim = functional_dim
        self.out_dim = out_dim
        self.embedding_dim = embedding_dim
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.dropout = dropout
        self.n_head = n_head
        self.act = act
        self.mlp_ratio = mlp_ratio
        self.slice_num = slice_num
        self.unified_pos = unified_pos
        self.ref = ref
        self.structured_shape = structured_shape
        self.use_te = use_te
        self.time_input = time_input


class TransolverModel(PreTrainedModel):
    config_class = TransolverConfig

    def __init__(self, config: TransolverConfig):
        super().__init__(config)

        self.transolver = TransolverBase(
            functional_dim=config.functional_dim,
            out_dim=config.out_dim,
            embedding_dim=config.embedding_dim,
            n_layers=config.n_layers,
            n_hidden=config.n_hidden,
            dropout=config.dropout,
            n_head=config.n_head,
            act=config.act,
            mlp_ratio=config.mlp_ratio,
            slice_num=config.slice_num,
            unified_pos=config.unified_pos,
            ref=config.ref,
            structured_shape=config.structured_shape,
            use_te=config.use_te,
            time_input=config.time_input,
        )

        # Transformers expects the model to register its weights for saving/loading
        self.post_init()

    def forward(
        self,
        fx: torch.Tensor,
        embedding: Optional[torch.Tensor] = None,
        time: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        """Thin wrapper around TransolverBase.forward.

        Args:
            fx: [B, N, functional_dim] or [B, *structure, functional_dim]
            embedding: position / embeddings
            time: optional time tensor
        """
        return self.transolver(fx, embedding=embedding, time=time)