import torch import torch.nn as nn from typing import Dict, Tuple from dataclasses import dataclass import math # ========================= # Config # ========================= @dataclass class ModelConfig: # problem sizes n_conditions: int = 17 # true inverse-design sidebar parameter vector n_materials: int = 4 n_vf_categories: int = 5 # Volume fraction categories: 0.1000, 0.2000, 0.3000, 0.4000, 0.5000 n_max_layer: int = 5 # Quarter layers (max 5 for quarter-angle dataset) # model architecture d_model: int = 256 n_heads: int = 4 n_layers: int = 6 dropout: float = 0.0 # ========================= # Model # ========================= def timestep_embedding(t: torch.Tensor, dim: int) -> torch.Tensor: """ Sinusoidal timestep embedding. t: (B,) """ half = dim // 2 freqs = torch.exp(-math.log(10000) * torch.arange(0, half, device=t.device) / half) args = t.float().unsqueeze(1) * freqs.unsqueeze(0) emb = torch.cat([torch.cos(args), torch.sin(args)], dim=1) if dim % 2 == 1: emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1) return emb # (B, dim) class SelfCrossAttnBlock(nn.Module): def __init__(self, d_model, n_heads, dropout=0.0): super().__init__() self.self_attn = nn.MultiheadAttention( d_model, n_heads, dropout=dropout, batch_first=True ) self.cross_attn = nn.MultiheadAttention( d_model, n_heads, dropout=dropout, batch_first=True ) self.ff = nn.Sequential( nn.Linear(d_model, 4 * d_model), nn.SiLU(), nn.Linear(4 * d_model, d_model), ) self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) self.ln3 = nn.LayerNorm(d_model) def forward(self, x, cond_tokens, key_padding_mask=None): """ x: (B, N, d) <- material + vf_category + angle tokens cond_tokens:(B, M, d) <- condition tokens (M = n_conditions) key_padding_mask: (B, N) optional padding mask (True = mask out, False = keep) """ # self-attention (within tokens) x = self.ln1(x + self.self_attn(x, x, x, key_padding_mask=key_padding_mask)[0]) # cross-attention (tokens attend to conditions) x = self.ln2(x + self.cross_attn(x, cond_tokens, cond_tokens)[0]) # feed-forward x = self.ln3(x + self.ff(x)) return x class MaterialHybridDenoiser(nn.Module): """ Inputs: material_t: (B,) in [0..n_materials-1] or MASK vf_category_t: (B,) in [0..4] volume fraction category or MASK layer_t: (B,L) in {0,1} or MASK angle_t: (B,L) discrete category indices [0..n_angle_categories-1] or MASK (if use_discrete_angles) OR (B,L,1) continuous (if not use_discrete_angles) When discrete: category n_angle_categories = dead layer, n_angle_categories+1 = MASK cond: (B,C) continuous, C = n_conditions t: (B,) timestep Outputs: material logits: (B, n_materials) vf_category_logits: (B, 5) angle_logits: (B,L,n_angle_categories+1) # discrete angle categories + dead (if use_discrete_angles) OR angle: (B,L,1) # angle in radians (if not use_discrete_angles) """ def __init__(self, cfg: ModelConfig, mask_ids: Dict[str, int], use_discrete_angles: bool = True, n_angle_categories: int = 7): super().__init__() self.cfg = cfg self.L = cfg.n_max_layer d = cfg.d_model self.mask_ids = mask_ids self.use_discrete_angles = use_discrete_angles self.n_angle_categories = n_angle_categories # +1 to include mask token for material self.material_emb = nn.Embedding(cfg.n_materials + 1, d) # vf_category: 5 categories (0-4) plus mask; we allocate 6 self.vf_category_emb = nn.Embedding(cfg.n_vf_categories + 1, d) if use_discrete_angles: # Category n_angle_categories = dead layer, n_angle_categories+1 = mask token self.angle_emb = nn.Embedding(n_angle_categories + 2, d) self.layer_emb = None else: # layer token: {MASK, 0, 1} => 3 (only needed for continuous angles) self.layer_emb = nn.Embedding(3, d) self.angle_in = nn.Linear(1, d) # Condition projection: each scalar condition coefficient gets its own Linear(1, d) # n_conditions = 7 * degree self.cond_proj = nn.ModuleList([ nn.Linear(1, d) for _ in range(cfg.n_conditions) ]) self.blocks = nn.ModuleList([ SelfCrossAttnBlock(d, cfg.n_heads, cfg.dropout) for _ in range(cfg.n_layers) ]) # Positional embeddings: pos 0 = material, pos 1 = vf_category, pos 2..2+L-1 = layers self.pos_emb = nn.Embedding(2 + cfg.n_max_layer, d) self.t_proj = nn.Linear(d, d) enc_layer = nn.TransformerEncoderLayer( d_model=d, nhead=cfg.n_heads, dropout=cfg.dropout, batch_first=True, ) self.encoder = nn.TransformerEncoder(enc_layer, num_layers=cfg.n_layers) self.ln = nn.LayerNorm(d) self.material_head = nn.Linear(d, cfg.n_materials) self.vf_category_head = nn.Linear(d, cfg.n_vf_categories) if use_discrete_angles: # n_angle_categories for angles + 1 for dead layer self.angle_head = nn.Linear(d, n_angle_categories + 1) self.layer_head = None else: self.layer_head = nn.Linear(d, 2) # alive/dead self.angle_head = nn.Linear(d, 1) def forward(self, material_t, vf_category_t, layer_t, angle_t, cond, t): B, L = layer_t.shape assert L == self.L # Project conditions if provided as raw scalars (B, C) if cond.dim() == 2: cond_list = [] for i in range(cond.shape[1]): cond_list.append(self.cond_proj[i](cond[:, i:i+1].unsqueeze(-1))) # (B, 1, d) cond = torch.cat(cond_list, dim=1) # (B, C, d) # global tokens as a 2-token "prefix" g_mat = self.material_emb(material_t).unsqueeze(1) # (B,1,d) g_vf = self.vf_category_emb(vf_category_t).unsqueeze(1) # (B,1,d) # per-layer tokens if self.use_discrete_angles: layer_h = self.angle_emb(angle_t) # (B, L, d) else: layer_h = self.layer_emb(layer_t) + self.angle_in(angle_t) # (B,L,d) h = torch.cat([g_mat, g_vf, layer_h], dim=1) # (B, 2+L, d) # Add positional embeddings to entire sequence pos_indices = torch.arange(2 + self.L, device=h.device) # (2+L,) h = h + self.pos_emb(pos_indices).unsqueeze(0) # (B, 2+L, d) # add timestep t_emb = timestep_embedding(t, h.size(-1)) # (B,d) h = h + self.t_proj(t_emb).unsqueeze(1) # Create key padding mask to enforce dead tokens are at the end key_padding_mask = None if self.use_discrete_angles: dead_category = self.n_angle_categories is_dead = (angle_t == dead_category) # (B, L) first_dead_pos = torch.zeros(B, dtype=torch.long, device=angle_t.device) for b in range(B): dead_positions = torch.where(is_dead[b])[0] if len(dead_positions) > 0: first_dead_pos[b] = dead_positions[0].item() + 2 # +2 for global tokens offset else: first_dead_pos[b] = 2 + L # No dead tokens N = 2 + L key_padding_mask = torch.zeros(B, N, dtype=torch.bool, device=h.device) for b in range(B): first_invalid = first_dead_pos[b].item() if first_invalid < 2 + L: key_padding_mask[b, first_invalid:] = True key_padding_mask[b, :2] = False for block in self.blocks: h = block(h, cond, key_padding_mask=key_padding_mask) h = self.ln(h) if self.use_discrete_angles: angle_logits = self.angle_head(h[:, 2:]) # (B, L, n_angle_categories + 1) out = { "material_logits": self.material_head(h[:, 0]), # (B, n_materials) "vf_category_logits": self.vf_category_head(h[:, 1]), # (B, 5) "angle_logits": angle_logits, # (B,L,n_angle_categories+1) } else: angle_raw = self.angle_head(h[:, 2:]) # (B,L,1) angle = torch.sigmoid(angle_raw) * (math.pi / 2) # (B,L,1) in radians out = { "material_logits": self.material_head(h[:, 0]), # (B, n_materials) "vf_category_logits": self.vf_category_head(h[:, 1]), # (B, 5) "layer_logits": self.layer_head(h[:, 2:]), # (B,L,2) "angle": angle, # (B,L,1) in radians } return out