feiyang-cai's picture
Initial ZeroGPU Gradio Space
01a8278
Raw
History Blame Contribute Delete
9.09 kB
import torch
import torch.nn as nn
from typing import Dict, Tuple
from dataclasses import dataclass
import math
# =========================
# Config
# =========================
@dataclass
class ModelConfig:
# problem sizes
n_conditions: int = 17 # true inverse-design sidebar parameter vector
n_materials: int = 4
n_vf_categories: int = 5 # Volume fraction categories: 0.1000, 0.2000, 0.3000, 0.4000, 0.5000
n_max_layer: int = 5 # Quarter layers (max 5 for quarter-angle dataset)
# model architecture
d_model: int = 256
n_heads: int = 4
n_layers: int = 6
dropout: float = 0.0
# =========================
# Model
# =========================
def timestep_embedding(t: torch.Tensor, dim: int) -> torch.Tensor:
"""
Sinusoidal timestep embedding. t: (B,)
"""
half = dim // 2
freqs = torch.exp(-math.log(10000) * torch.arange(0, half, device=t.device) / half)
args = t.float().unsqueeze(1) * freqs.unsqueeze(0)
emb = torch.cat([torch.cos(args), torch.sin(args)], dim=1)
if dim % 2 == 1:
emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
return emb # (B, dim)
class SelfCrossAttnBlock(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.0):
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout, batch_first=True
)
self.cross_attn = nn.MultiheadAttention(
d_model, n_heads, dropout=dropout, batch_first=True
)
self.ff = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.SiLU(),
nn.Linear(4 * d_model, d_model),
)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.ln3 = nn.LayerNorm(d_model)
def forward(self, x, cond_tokens, key_padding_mask=None):
"""
x: (B, N, d) <- material + vf_category + angle tokens
cond_tokens:(B, M, d) <- condition tokens (M = n_conditions)
key_padding_mask: (B, N) optional padding mask (True = mask out, False = keep)
"""
# self-attention (within tokens)
x = self.ln1(x + self.self_attn(x, x, x, key_padding_mask=key_padding_mask)[0])
# cross-attention (tokens attend to conditions)
x = self.ln2(x + self.cross_attn(x, cond_tokens, cond_tokens)[0])
# feed-forward
x = self.ln3(x + self.ff(x))
return x
class MaterialHybridDenoiser(nn.Module):
"""
Inputs:
material_t: (B,) in [0..n_materials-1] or MASK
vf_category_t: (B,) in [0..4] volume fraction category or MASK
layer_t: (B,L) in {0,1} or MASK
angle_t: (B,L) discrete category indices [0..n_angle_categories-1] or MASK (if use_discrete_angles)
OR (B,L,1) continuous (if not use_discrete_angles)
When discrete: category n_angle_categories = dead layer, n_angle_categories+1 = MASK
cond: (B,C) continuous, C = n_conditions
t: (B,) timestep
Outputs:
material logits: (B, n_materials)
vf_category_logits: (B, 5)
angle_logits: (B,L,n_angle_categories+1) # discrete angle categories + dead (if use_discrete_angles)
OR angle: (B,L,1) # angle in radians (if not use_discrete_angles)
"""
def __init__(self, cfg: ModelConfig, mask_ids: Dict[str, int], use_discrete_angles: bool = True, n_angle_categories: int = 7):
super().__init__()
self.cfg = cfg
self.L = cfg.n_max_layer
d = cfg.d_model
self.mask_ids = mask_ids
self.use_discrete_angles = use_discrete_angles
self.n_angle_categories = n_angle_categories
# +1 to include mask token for material
self.material_emb = nn.Embedding(cfg.n_materials + 1, d)
# vf_category: 5 categories (0-4) plus mask; we allocate 6
self.vf_category_emb = nn.Embedding(cfg.n_vf_categories + 1, d)
if use_discrete_angles:
# Category n_angle_categories = dead layer, n_angle_categories+1 = mask token
self.angle_emb = nn.Embedding(n_angle_categories + 2, d)
self.layer_emb = None
else:
# layer token: {MASK, 0, 1} => 3 (only needed for continuous angles)
self.layer_emb = nn.Embedding(3, d)
self.angle_in = nn.Linear(1, d)
# Condition projection: each scalar condition coefficient gets its own Linear(1, d)
# n_conditions = 7 * degree
self.cond_proj = nn.ModuleList([
nn.Linear(1, d) for _ in range(cfg.n_conditions)
])
self.blocks = nn.ModuleList([
SelfCrossAttnBlock(d, cfg.n_heads, cfg.dropout)
for _ in range(cfg.n_layers)
])
# Positional embeddings: pos 0 = material, pos 1 = vf_category, pos 2..2+L-1 = layers
self.pos_emb = nn.Embedding(2 + cfg.n_max_layer, d)
self.t_proj = nn.Linear(d, d)
enc_layer = nn.TransformerEncoderLayer(
d_model=d,
nhead=cfg.n_heads,
dropout=cfg.dropout,
batch_first=True,
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=cfg.n_layers)
self.ln = nn.LayerNorm(d)
self.material_head = nn.Linear(d, cfg.n_materials)
self.vf_category_head = nn.Linear(d, cfg.n_vf_categories)
if use_discrete_angles:
# n_angle_categories for angles + 1 for dead layer
self.angle_head = nn.Linear(d, n_angle_categories + 1)
self.layer_head = None
else:
self.layer_head = nn.Linear(d, 2) # alive/dead
self.angle_head = nn.Linear(d, 1)
def forward(self, material_t, vf_category_t, layer_t, angle_t, cond, t):
B, L = layer_t.shape
assert L == self.L
# Project conditions if provided as raw scalars (B, C)
if cond.dim() == 2:
cond_list = []
for i in range(cond.shape[1]):
cond_list.append(self.cond_proj[i](cond[:, i:i+1].unsqueeze(-1))) # (B, 1, d)
cond = torch.cat(cond_list, dim=1) # (B, C, d)
# global tokens as a 2-token "prefix"
g_mat = self.material_emb(material_t).unsqueeze(1) # (B,1,d)
g_vf = self.vf_category_emb(vf_category_t).unsqueeze(1) # (B,1,d)
# per-layer tokens
if self.use_discrete_angles:
layer_h = self.angle_emb(angle_t) # (B, L, d)
else:
layer_h = self.layer_emb(layer_t) + self.angle_in(angle_t) # (B,L,d)
h = torch.cat([g_mat, g_vf, layer_h], dim=1) # (B, 2+L, d)
# Add positional embeddings to entire sequence
pos_indices = torch.arange(2 + self.L, device=h.device) # (2+L,)
h = h + self.pos_emb(pos_indices).unsqueeze(0) # (B, 2+L, d)
# add timestep
t_emb = timestep_embedding(t, h.size(-1)) # (B,d)
h = h + self.t_proj(t_emb).unsqueeze(1)
# Create key padding mask to enforce dead tokens are at the end
key_padding_mask = None
if self.use_discrete_angles:
dead_category = self.n_angle_categories
is_dead = (angle_t == dead_category) # (B, L)
first_dead_pos = torch.zeros(B, dtype=torch.long, device=angle_t.device)
for b in range(B):
dead_positions = torch.where(is_dead[b])[0]
if len(dead_positions) > 0:
first_dead_pos[b] = dead_positions[0].item() + 2 # +2 for global tokens offset
else:
first_dead_pos[b] = 2 + L # No dead tokens
N = 2 + L
key_padding_mask = torch.zeros(B, N, dtype=torch.bool, device=h.device)
for b in range(B):
first_invalid = first_dead_pos[b].item()
if first_invalid < 2 + L:
key_padding_mask[b, first_invalid:] = True
key_padding_mask[b, :2] = False
for block in self.blocks:
h = block(h, cond, key_padding_mask=key_padding_mask)
h = self.ln(h)
if self.use_discrete_angles:
angle_logits = self.angle_head(h[:, 2:]) # (B, L, n_angle_categories + 1)
out = {
"material_logits": self.material_head(h[:, 0]), # (B, n_materials)
"vf_category_logits": self.vf_category_head(h[:, 1]), # (B, 5)
"angle_logits": angle_logits, # (B,L,n_angle_categories+1)
}
else:
angle_raw = self.angle_head(h[:, 2:]) # (B,L,1)
angle = torch.sigmoid(angle_raw) * (math.pi / 2) # (B,L,1) in radians
out = {
"material_logits": self.material_head(h[:, 0]), # (B, n_materials)
"vf_category_logits": self.vf_category_head(h[:, 1]), # (B, 5)
"layer_logits": self.layer_head(h[:, 2:]), # (B,L,2)
"angle": angle, # (B,L,1) in radians
}
return out