| import torch |
| import torch.nn as nn |
| from typing import Dict, Tuple |
| from dataclasses import dataclass |
| import math |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ModelConfig: |
| |
| n_conditions: int = 17 |
| n_materials: int = 4 |
| n_vf_categories: int = 5 |
| n_max_layer: int = 5 |
|
|
| |
| d_model: int = 256 |
| n_heads: int = 4 |
| n_layers: int = 6 |
| dropout: float = 0.0 |
|
|
|
|
| |
| |
| |
|
|
| def timestep_embedding(t: torch.Tensor, dim: int) -> torch.Tensor: |
| """ |
| Sinusoidal timestep embedding. t: (B,) |
| """ |
| half = dim // 2 |
| freqs = torch.exp(-math.log(10000) * torch.arange(0, half, device=t.device) / half) |
| args = t.float().unsqueeze(1) * freqs.unsqueeze(0) |
| emb = torch.cat([torch.cos(args), torch.sin(args)], dim=1) |
| if dim % 2 == 1: |
| emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1) |
| return emb |
|
|
| class SelfCrossAttnBlock(nn.Module): |
| def __init__(self, d_model, n_heads, dropout=0.0): |
| super().__init__() |
| self.self_attn = nn.MultiheadAttention( |
| d_model, n_heads, dropout=dropout, batch_first=True |
| ) |
| self.cross_attn = nn.MultiheadAttention( |
| d_model, n_heads, dropout=dropout, batch_first=True |
| ) |
| self.ff = nn.Sequential( |
| nn.Linear(d_model, 4 * d_model), |
| nn.SiLU(), |
| nn.Linear(4 * d_model, d_model), |
| ) |
|
|
| self.ln1 = nn.LayerNorm(d_model) |
| self.ln2 = nn.LayerNorm(d_model) |
| self.ln3 = nn.LayerNorm(d_model) |
|
|
| def forward(self, x, cond_tokens, key_padding_mask=None): |
| """ |
| x: (B, N, d) <- material + vf_category + angle tokens |
| cond_tokens:(B, M, d) <- condition tokens (M = n_conditions) |
| key_padding_mask: (B, N) optional padding mask (True = mask out, False = keep) |
| """ |
| |
| x = self.ln1(x + self.self_attn(x, x, x, key_padding_mask=key_padding_mask)[0]) |
|
|
| |
| x = self.ln2(x + self.cross_attn(x, cond_tokens, cond_tokens)[0]) |
|
|
| |
| x = self.ln3(x + self.ff(x)) |
| return x |
|
|
| class MaterialHybridDenoiser(nn.Module): |
| """ |
| Inputs: |
| material_t: (B,) in [0..n_materials-1] or MASK |
| vf_category_t: (B,) in [0..4] volume fraction category or MASK |
| layer_t: (B,L) in {0,1} or MASK |
| angle_t: (B,L) discrete category indices [0..n_angle_categories-1] or MASK (if use_discrete_angles) |
| OR (B,L,1) continuous (if not use_discrete_angles) |
| When discrete: category n_angle_categories = dead layer, n_angle_categories+1 = MASK |
| cond: (B,C) continuous, C = n_conditions |
| t: (B,) timestep |
| |
| Outputs: |
| material logits: (B, n_materials) |
| vf_category_logits: (B, 5) |
| angle_logits: (B,L,n_angle_categories+1) # discrete angle categories + dead (if use_discrete_angles) |
| OR angle: (B,L,1) # angle in radians (if not use_discrete_angles) |
| """ |
| def __init__(self, cfg: ModelConfig, mask_ids: Dict[str, int], use_discrete_angles: bool = True, n_angle_categories: int = 7): |
| super().__init__() |
| self.cfg = cfg |
| self.L = cfg.n_max_layer |
| d = cfg.d_model |
| self.mask_ids = mask_ids |
| self.use_discrete_angles = use_discrete_angles |
| self.n_angle_categories = n_angle_categories |
|
|
| |
| self.material_emb = nn.Embedding(cfg.n_materials + 1, d) |
| |
| self.vf_category_emb = nn.Embedding(cfg.n_vf_categories + 1, d) |
|
|
| if use_discrete_angles: |
| |
| self.angle_emb = nn.Embedding(n_angle_categories + 2, d) |
| self.layer_emb = None |
| else: |
| |
| self.layer_emb = nn.Embedding(3, d) |
| self.angle_in = nn.Linear(1, d) |
|
|
| |
| |
| self.cond_proj = nn.ModuleList([ |
| nn.Linear(1, d) for _ in range(cfg.n_conditions) |
| ]) |
|
|
| self.blocks = nn.ModuleList([ |
| SelfCrossAttnBlock(d, cfg.n_heads, cfg.dropout) |
| for _ in range(cfg.n_layers) |
| ]) |
|
|
| |
| self.pos_emb = nn.Embedding(2 + cfg.n_max_layer, d) |
|
|
| self.t_proj = nn.Linear(d, d) |
|
|
| enc_layer = nn.TransformerEncoderLayer( |
| d_model=d, |
| nhead=cfg.n_heads, |
| dropout=cfg.dropout, |
| batch_first=True, |
| ) |
| self.encoder = nn.TransformerEncoder(enc_layer, num_layers=cfg.n_layers) |
| self.ln = nn.LayerNorm(d) |
|
|
| self.material_head = nn.Linear(d, cfg.n_materials) |
| self.vf_category_head = nn.Linear(d, cfg.n_vf_categories) |
| if use_discrete_angles: |
| |
| self.angle_head = nn.Linear(d, n_angle_categories + 1) |
| self.layer_head = None |
| else: |
| self.layer_head = nn.Linear(d, 2) |
| self.angle_head = nn.Linear(d, 1) |
|
|
| def forward(self, material_t, vf_category_t, layer_t, angle_t, cond, t): |
| B, L = layer_t.shape |
| assert L == self.L |
|
|
| |
| if cond.dim() == 2: |
| cond_list = [] |
| for i in range(cond.shape[1]): |
| cond_list.append(self.cond_proj[i](cond[:, i:i+1].unsqueeze(-1))) |
| cond = torch.cat(cond_list, dim=1) |
|
|
| |
| g_mat = self.material_emb(material_t).unsqueeze(1) |
| g_vf = self.vf_category_emb(vf_category_t).unsqueeze(1) |
|
|
| |
| if self.use_discrete_angles: |
| layer_h = self.angle_emb(angle_t) |
| else: |
| layer_h = self.layer_emb(layer_t) + self.angle_in(angle_t) |
|
|
| h = torch.cat([g_mat, g_vf, layer_h], dim=1) |
|
|
| |
| pos_indices = torch.arange(2 + self.L, device=h.device) |
| h = h + self.pos_emb(pos_indices).unsqueeze(0) |
|
|
| |
| t_emb = timestep_embedding(t, h.size(-1)) |
| h = h + self.t_proj(t_emb).unsqueeze(1) |
|
|
| |
| key_padding_mask = None |
| if self.use_discrete_angles: |
| dead_category = self.n_angle_categories |
| is_dead = (angle_t == dead_category) |
|
|
| first_dead_pos = torch.zeros(B, dtype=torch.long, device=angle_t.device) |
| for b in range(B): |
| dead_positions = torch.where(is_dead[b])[0] |
| if len(dead_positions) > 0: |
| first_dead_pos[b] = dead_positions[0].item() + 2 |
| else: |
| first_dead_pos[b] = 2 + L |
|
|
| N = 2 + L |
| key_padding_mask = torch.zeros(B, N, dtype=torch.bool, device=h.device) |
| for b in range(B): |
| first_invalid = first_dead_pos[b].item() |
| if first_invalid < 2 + L: |
| key_padding_mask[b, first_invalid:] = True |
| key_padding_mask[b, :2] = False |
|
|
| for block in self.blocks: |
| h = block(h, cond, key_padding_mask=key_padding_mask) |
|
|
| h = self.ln(h) |
|
|
| if self.use_discrete_angles: |
| angle_logits = self.angle_head(h[:, 2:]) |
| out = { |
| "material_logits": self.material_head(h[:, 0]), |
| "vf_category_logits": self.vf_category_head(h[:, 1]), |
| "angle_logits": angle_logits, |
| } |
| else: |
| angle_raw = self.angle_head(h[:, 2:]) |
| angle = torch.sigmoid(angle_raw) * (math.pi / 2) |
| out = { |
| "material_logits": self.material_head(h[:, 0]), |
| "vf_category_logits": self.vf_category_head(h[:, 1]), |
| "layer_logits": self.layer_head(h[:, 2:]), |
| "angle": angle, |
| } |
| return out |
|
|