File size: 9,094 Bytes
01a8278
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import torch
import torch.nn as nn
from typing import Dict, Tuple
from dataclasses import dataclass
import math


# =========================
# Config
# =========================

@dataclass
class ModelConfig:
    # problem sizes
    n_conditions: int = 17  # true inverse-design sidebar parameter vector
    n_materials: int = 4
    n_vf_categories: int = 5  # Volume fraction categories: 0.1000, 0.2000, 0.3000, 0.4000, 0.5000
    n_max_layer: int = 5  # Quarter layers (max 5 for quarter-angle dataset)

    # model architecture
    d_model: int = 256
    n_heads: int = 4
    n_layers: int = 6
    dropout: float = 0.0


# =========================
# Model
# =========================

def timestep_embedding(t: torch.Tensor, dim: int) -> torch.Tensor:
    """
    Sinusoidal timestep embedding. t: (B,)
    """
    half = dim // 2
    freqs = torch.exp(-math.log(10000) * torch.arange(0, half, device=t.device) / half)
    args = t.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.cos(args), torch.sin(args)], dim=1)
    if dim % 2 == 1:
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
    return emb  # (B, dim)

class SelfCrossAttnBlock(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.0):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout, batch_first=True
        )
        self.cross_attn = nn.MultiheadAttention(
            d_model, n_heads, dropout=dropout, batch_first=True
        )
        self.ff = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.SiLU(),
            nn.Linear(4 * d_model, d_model),
        )

        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ln3 = nn.LayerNorm(d_model)

    def forward(self, x, cond_tokens, key_padding_mask=None):
        """
        x:          (B, N, d)   <- material + vf_category + angle tokens
        cond_tokens:(B, M, d)   <- condition tokens (M = n_conditions)
        key_padding_mask: (B, N) optional padding mask (True = mask out, False = keep)
        """
        # self-attention (within tokens)
        x = self.ln1(x + self.self_attn(x, x, x, key_padding_mask=key_padding_mask)[0])

        # cross-attention (tokens attend to conditions)
        x = self.ln2(x + self.cross_attn(x, cond_tokens, cond_tokens)[0])

        # feed-forward
        x = self.ln3(x + self.ff(x))
        return x

class MaterialHybridDenoiser(nn.Module):
    """
    Inputs:
      material_t: (B,)  in [0..n_materials-1] or MASK
      vf_category_t: (B,)  in [0..4] volume fraction category or MASK
      layer_t:    (B,L) in {0,1} or MASK
      angle_t:    (B,L) discrete category indices [0..n_angle_categories-1] or MASK (if use_discrete_angles)
                  OR (B,L,1) continuous (if not use_discrete_angles)
                  When discrete: category n_angle_categories = dead layer, n_angle_categories+1 = MASK
      cond:       (B,C) continuous, C = n_conditions
      t:          (B,) timestep

    Outputs:
      material logits: (B, n_materials)
      vf_category_logits: (B, 5)
      angle_logits:    (B,L,n_angle_categories+1)  # discrete angle categories + dead (if use_discrete_angles)
                  OR angle: (B,L,1)                # angle in radians (if not use_discrete_angles)
    """
    def __init__(self, cfg: ModelConfig, mask_ids: Dict[str, int], use_discrete_angles: bool = True, n_angle_categories: int = 7):
        super().__init__()
        self.cfg = cfg
        self.L = cfg.n_max_layer
        d = cfg.d_model
        self.mask_ids = mask_ids
        self.use_discrete_angles = use_discrete_angles
        self.n_angle_categories = n_angle_categories

        # +1 to include mask token for material
        self.material_emb = nn.Embedding(cfg.n_materials + 1, d)
        # vf_category: 5 categories (0-4) plus mask; we allocate 6
        self.vf_category_emb = nn.Embedding(cfg.n_vf_categories + 1, d)

        if use_discrete_angles:
            # Category n_angle_categories = dead layer, n_angle_categories+1 = mask token
            self.angle_emb = nn.Embedding(n_angle_categories + 2, d)
            self.layer_emb = None
        else:
            # layer token: {MASK, 0, 1} => 3 (only needed for continuous angles)
            self.layer_emb = nn.Embedding(3, d)
            self.angle_in = nn.Linear(1, d)

        # Condition projection: each scalar condition coefficient gets its own Linear(1, d)
        # n_conditions = 7 * degree
        self.cond_proj = nn.ModuleList([
            nn.Linear(1, d) for _ in range(cfg.n_conditions)
        ])

        self.blocks = nn.ModuleList([
            SelfCrossAttnBlock(d, cfg.n_heads, cfg.dropout)
            for _ in range(cfg.n_layers)
        ])

        # Positional embeddings: pos 0 = material, pos 1 = vf_category, pos 2..2+L-1 = layers
        self.pos_emb = nn.Embedding(2 + cfg.n_max_layer, d)

        self.t_proj = nn.Linear(d, d)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d,
            nhead=cfg.n_heads,
            dropout=cfg.dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=cfg.n_layers)
        self.ln = nn.LayerNorm(d)

        self.material_head = nn.Linear(d, cfg.n_materials)
        self.vf_category_head = nn.Linear(d, cfg.n_vf_categories)
        if use_discrete_angles:
            # n_angle_categories for angles + 1 for dead layer
            self.angle_head = nn.Linear(d, n_angle_categories + 1)
            self.layer_head = None
        else:
            self.layer_head = nn.Linear(d, 2)  # alive/dead
            self.angle_head = nn.Linear(d, 1)

    def forward(self, material_t, vf_category_t, layer_t, angle_t, cond, t):
        B, L = layer_t.shape
        assert L == self.L

        # Project conditions if provided as raw scalars (B, C)
        if cond.dim() == 2:
            cond_list = []
            for i in range(cond.shape[1]):
                cond_list.append(self.cond_proj[i](cond[:, i:i+1].unsqueeze(-1)))  # (B, 1, d)
            cond = torch.cat(cond_list, dim=1)  # (B, C, d)

        # global tokens as a 2-token "prefix"
        g_mat = self.material_emb(material_t).unsqueeze(1)  # (B,1,d)
        g_vf  = self.vf_category_emb(vf_category_t).unsqueeze(1)  # (B,1,d)

        # per-layer tokens
        if self.use_discrete_angles:
            layer_h = self.angle_emb(angle_t)  # (B, L, d)
        else:
            layer_h = self.layer_emb(layer_t) + self.angle_in(angle_t)  # (B,L,d)

        h = torch.cat([g_mat, g_vf, layer_h], dim=1)  # (B, 2+L, d)

        # Add positional embeddings to entire sequence
        pos_indices = torch.arange(2 + self.L, device=h.device)  # (2+L,)
        h = h + self.pos_emb(pos_indices).unsqueeze(0)  # (B, 2+L, d)

        # add timestep
        t_emb = timestep_embedding(t, h.size(-1))  # (B,d)
        h = h + self.t_proj(t_emb).unsqueeze(1)

        # Create key padding mask to enforce dead tokens are at the end
        key_padding_mask = None
        if self.use_discrete_angles:
            dead_category = self.n_angle_categories
            is_dead = (angle_t == dead_category)  # (B, L)

            first_dead_pos = torch.zeros(B, dtype=torch.long, device=angle_t.device)
            for b in range(B):
                dead_positions = torch.where(is_dead[b])[0]
                if len(dead_positions) > 0:
                    first_dead_pos[b] = dead_positions[0].item() + 2  # +2 for global tokens offset
                else:
                    first_dead_pos[b] = 2 + L  # No dead tokens

            N = 2 + L
            key_padding_mask = torch.zeros(B, N, dtype=torch.bool, device=h.device)
            for b in range(B):
                first_invalid = first_dead_pos[b].item()
                if first_invalid < 2 + L:
                    key_padding_mask[b, first_invalid:] = True
                    key_padding_mask[b, :2] = False

        for block in self.blocks:
            h = block(h, cond, key_padding_mask=key_padding_mask)

        h = self.ln(h)

        if self.use_discrete_angles:
            angle_logits = self.angle_head(h[:, 2:])  # (B, L, n_angle_categories + 1)
            out = {
                "material_logits": self.material_head(h[:, 0]),     # (B, n_materials)
                "vf_category_logits": self.vf_category_head(h[:, 1]),  # (B, 5)
                "angle_logits":    angle_logits,                     # (B,L,n_angle_categories+1)
            }
        else:
            angle_raw = self.angle_head(h[:, 2:])  # (B,L,1)
            angle = torch.sigmoid(angle_raw) * (math.pi / 2)  # (B,L,1) in radians
            out = {
                "material_logits": self.material_head(h[:, 0]),     # (B, n_materials)
                "vf_category_logits": self.vf_category_head(h[:, 1]),  # (B, 5)
                "layer_logits":    self.layer_head(h[:, 2:]),       # (B,L,2)
                "angle":           angle,                            # (B,L,1) in radians
            }
        return out