Spaces:
Running on Zero
Running on Zero
| import torch | |
| import torch.nn as nn | |
| from typing import Dict, Tuple | |
| from dataclasses import dataclass | |
| import math | |
| # ========================= | |
| # Config | |
| # ========================= | |
| class ModelConfig: | |
| # problem sizes | |
| n_conditions: int = 17 # true inverse-design sidebar parameter vector | |
| n_materials: int = 4 | |
| n_vf_categories: int = 5 # Volume fraction categories: 0.1000, 0.2000, 0.3000, 0.4000, 0.5000 | |
| n_max_layer: int = 5 # Quarter layers (max 5 for quarter-angle dataset) | |
| # model architecture | |
| d_model: int = 256 | |
| n_heads: int = 4 | |
| n_layers: int = 6 | |
| dropout: float = 0.0 | |
| # ========================= | |
| # Model | |
| # ========================= | |
| def timestep_embedding(t: torch.Tensor, dim: int) -> torch.Tensor: | |
| """ | |
| Sinusoidal timestep embedding. t: (B,) | |
| """ | |
| half = dim // 2 | |
| freqs = torch.exp(-math.log(10000) * torch.arange(0, half, device=t.device) / half) | |
| args = t.float().unsqueeze(1) * freqs.unsqueeze(0) | |
| emb = torch.cat([torch.cos(args), torch.sin(args)], dim=1) | |
| if dim % 2 == 1: | |
| emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1) | |
| return emb # (B, dim) | |
| class SelfCrossAttnBlock(nn.Module): | |
| def __init__(self, d_model, n_heads, dropout=0.0): | |
| super().__init__() | |
| self.self_attn = nn.MultiheadAttention( | |
| d_model, n_heads, dropout=dropout, batch_first=True | |
| ) | |
| self.cross_attn = nn.MultiheadAttention( | |
| d_model, n_heads, dropout=dropout, batch_first=True | |
| ) | |
| self.ff = nn.Sequential( | |
| nn.Linear(d_model, 4 * d_model), | |
| nn.SiLU(), | |
| nn.Linear(4 * d_model, d_model), | |
| ) | |
| self.ln1 = nn.LayerNorm(d_model) | |
| self.ln2 = nn.LayerNorm(d_model) | |
| self.ln3 = nn.LayerNorm(d_model) | |
| def forward(self, x, cond_tokens, key_padding_mask=None): | |
| """ | |
| x: (B, N, d) <- material + vf_category + angle tokens | |
| cond_tokens:(B, M, d) <- condition tokens (M = n_conditions) | |
| key_padding_mask: (B, N) optional padding mask (True = mask out, False = keep) | |
| """ | |
| # self-attention (within tokens) | |
| x = self.ln1(x + self.self_attn(x, x, x, key_padding_mask=key_padding_mask)[0]) | |
| # cross-attention (tokens attend to conditions) | |
| x = self.ln2(x + self.cross_attn(x, cond_tokens, cond_tokens)[0]) | |
| # feed-forward | |
| x = self.ln3(x + self.ff(x)) | |
| return x | |
| class MaterialHybridDenoiser(nn.Module): | |
| """ | |
| Inputs: | |
| material_t: (B,) in [0..n_materials-1] or MASK | |
| vf_category_t: (B,) in [0..4] volume fraction category or MASK | |
| layer_t: (B,L) in {0,1} or MASK | |
| angle_t: (B,L) discrete category indices [0..n_angle_categories-1] or MASK (if use_discrete_angles) | |
| OR (B,L,1) continuous (if not use_discrete_angles) | |
| When discrete: category n_angle_categories = dead layer, n_angle_categories+1 = MASK | |
| cond: (B,C) continuous, C = n_conditions | |
| t: (B,) timestep | |
| Outputs: | |
| material logits: (B, n_materials) | |
| vf_category_logits: (B, 5) | |
| angle_logits: (B,L,n_angle_categories+1) # discrete angle categories + dead (if use_discrete_angles) | |
| OR angle: (B,L,1) # angle in radians (if not use_discrete_angles) | |
| """ | |
| def __init__(self, cfg: ModelConfig, mask_ids: Dict[str, int], use_discrete_angles: bool = True, n_angle_categories: int = 7): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.L = cfg.n_max_layer | |
| d = cfg.d_model | |
| self.mask_ids = mask_ids | |
| self.use_discrete_angles = use_discrete_angles | |
| self.n_angle_categories = n_angle_categories | |
| # +1 to include mask token for material | |
| self.material_emb = nn.Embedding(cfg.n_materials + 1, d) | |
| # vf_category: 5 categories (0-4) plus mask; we allocate 6 | |
| self.vf_category_emb = nn.Embedding(cfg.n_vf_categories + 1, d) | |
| if use_discrete_angles: | |
| # Category n_angle_categories = dead layer, n_angle_categories+1 = mask token | |
| self.angle_emb = nn.Embedding(n_angle_categories + 2, d) | |
| self.layer_emb = None | |
| else: | |
| # layer token: {MASK, 0, 1} => 3 (only needed for continuous angles) | |
| self.layer_emb = nn.Embedding(3, d) | |
| self.angle_in = nn.Linear(1, d) | |
| # Condition projection: each scalar condition coefficient gets its own Linear(1, d) | |
| # n_conditions = 7 * degree | |
| self.cond_proj = nn.ModuleList([ | |
| nn.Linear(1, d) for _ in range(cfg.n_conditions) | |
| ]) | |
| self.blocks = nn.ModuleList([ | |
| SelfCrossAttnBlock(d, cfg.n_heads, cfg.dropout) | |
| for _ in range(cfg.n_layers) | |
| ]) | |
| # Positional embeddings: pos 0 = material, pos 1 = vf_category, pos 2..2+L-1 = layers | |
| self.pos_emb = nn.Embedding(2 + cfg.n_max_layer, d) | |
| self.t_proj = nn.Linear(d, d) | |
| enc_layer = nn.TransformerEncoderLayer( | |
| d_model=d, | |
| nhead=cfg.n_heads, | |
| dropout=cfg.dropout, | |
| batch_first=True, | |
| ) | |
| self.encoder = nn.TransformerEncoder(enc_layer, num_layers=cfg.n_layers) | |
| self.ln = nn.LayerNorm(d) | |
| self.material_head = nn.Linear(d, cfg.n_materials) | |
| self.vf_category_head = nn.Linear(d, cfg.n_vf_categories) | |
| if use_discrete_angles: | |
| # n_angle_categories for angles + 1 for dead layer | |
| self.angle_head = nn.Linear(d, n_angle_categories + 1) | |
| self.layer_head = None | |
| else: | |
| self.layer_head = nn.Linear(d, 2) # alive/dead | |
| self.angle_head = nn.Linear(d, 1) | |
| def forward(self, material_t, vf_category_t, layer_t, angle_t, cond, t): | |
| B, L = layer_t.shape | |
| assert L == self.L | |
| # Project conditions if provided as raw scalars (B, C) | |
| if cond.dim() == 2: | |
| cond_list = [] | |
| for i in range(cond.shape[1]): | |
| cond_list.append(self.cond_proj[i](cond[:, i:i+1].unsqueeze(-1))) # (B, 1, d) | |
| cond = torch.cat(cond_list, dim=1) # (B, C, d) | |
| # global tokens as a 2-token "prefix" | |
| g_mat = self.material_emb(material_t).unsqueeze(1) # (B,1,d) | |
| g_vf = self.vf_category_emb(vf_category_t).unsqueeze(1) # (B,1,d) | |
| # per-layer tokens | |
| if self.use_discrete_angles: | |
| layer_h = self.angle_emb(angle_t) # (B, L, d) | |
| else: | |
| layer_h = self.layer_emb(layer_t) + self.angle_in(angle_t) # (B,L,d) | |
| h = torch.cat([g_mat, g_vf, layer_h], dim=1) # (B, 2+L, d) | |
| # Add positional embeddings to entire sequence | |
| pos_indices = torch.arange(2 + self.L, device=h.device) # (2+L,) | |
| h = h + self.pos_emb(pos_indices).unsqueeze(0) # (B, 2+L, d) | |
| # add timestep | |
| t_emb = timestep_embedding(t, h.size(-1)) # (B,d) | |
| h = h + self.t_proj(t_emb).unsqueeze(1) | |
| # Create key padding mask to enforce dead tokens are at the end | |
| key_padding_mask = None | |
| if self.use_discrete_angles: | |
| dead_category = self.n_angle_categories | |
| is_dead = (angle_t == dead_category) # (B, L) | |
| first_dead_pos = torch.zeros(B, dtype=torch.long, device=angle_t.device) | |
| for b in range(B): | |
| dead_positions = torch.where(is_dead[b])[0] | |
| if len(dead_positions) > 0: | |
| first_dead_pos[b] = dead_positions[0].item() + 2 # +2 for global tokens offset | |
| else: | |
| first_dead_pos[b] = 2 + L # No dead tokens | |
| N = 2 + L | |
| key_padding_mask = torch.zeros(B, N, dtype=torch.bool, device=h.device) | |
| for b in range(B): | |
| first_invalid = first_dead_pos[b].item() | |
| if first_invalid < 2 + L: | |
| key_padding_mask[b, first_invalid:] = True | |
| key_padding_mask[b, :2] = False | |
| for block in self.blocks: | |
| h = block(h, cond, key_padding_mask=key_padding_mask) | |
| h = self.ln(h) | |
| if self.use_discrete_angles: | |
| angle_logits = self.angle_head(h[:, 2:]) # (B, L, n_angle_categories + 1) | |
| out = { | |
| "material_logits": self.material_head(h[:, 0]), # (B, n_materials) | |
| "vf_category_logits": self.vf_category_head(h[:, 1]), # (B, 5) | |
| "angle_logits": angle_logits, # (B,L,n_angle_categories+1) | |
| } | |
| else: | |
| angle_raw = self.angle_head(h[:, 2:]) # (B,L,1) | |
| angle = torch.sigmoid(angle_raw) * (math.pi / 2) # (B,L,1) in radians | |
| out = { | |
| "material_logits": self.material_head(h[:, 0]), # (B, n_materials) | |
| "vf_category_logits": self.vf_category_head(h[:, 1]), # (B, 5) | |
| "layer_logits": self.layer_head(h[:, 2:]), # (B,L,2) | |
| "angle": angle, # (B,L,1) in radians | |
| } | |
| return out | |