| """ |
| ArtiGen V1.0 — Main Model |
| CARTEL backbone with PHI-SCAN, AdaLN conditioning, ASDL heads. |
| """ |
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| try: |
| from .cartel_block import CARTELBlock |
| from .asdl_head import StyleHead, ContentHead, ConceptHead, MoodHead, CompositionHead |
| from .phi_scan import build_scan_permutations, apply_scan, unscan, get_scan_pattern |
| except ImportError: |
| from cartel_block import CARTELBlock |
| from asdl_head import StyleHead, ContentHead, ConceptHead, MoodHead, CompositionHead |
| from phi_scan import build_scan_permutations, apply_scan, unscan, get_scan_pattern |
|
|
| class PatchEmbed(nn.Module): |
| def __init__(self, in_ch, embed_dim, patch_size=2): |
| super().__init__() |
| self.proj = nn.Conv2d(in_ch, embed_dim, kernel_size=patch_size, stride=patch_size) |
| self.norm = nn.LayerNorm(embed_dim) |
| def forward(self, x): |
| x = self.proj(x) |
| B, C, H, W = x.shape |
| x = x.permute(0, 2, 3, 1).reshape(B, H * W, C) |
| return self.norm(x), H, W |
|
|
| class AdaLN(nn.Module): |
| def __init__(self, dim, cond_dim=512): |
| super().__init__() |
| self.modulation = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(cond_dim, dim * 2), |
| ) |
| def forward(self, x, cond): |
| scale, shift = self.modulation(cond).chunk(2, dim=-1) |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
| class ArtiGen(nn.Module): |
| def __init__( |
| self, |
| latent_ch=4, |
| latent_h=32, |
| latent_w=32, |
| embed_dim=256, |
| num_layers=12, |
| d_state=16, |
| expand=2, |
| text_dim=768, |
| style_classes=128, |
| content_objects=1024, |
| mood_classes=64, |
| ): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.num_layers = num_layers |
| self.latent_h = latent_h |
| self.latent_w = latent_w |
| self.patch_embed = PatchEmbed(latent_ch, embed_dim, patch_size=1) |
| self.t_embed = nn.Sequential( |
| nn.Linear(1, text_dim), |
| nn.SiLU(), |
| nn.Linear(text_dim, text_dim), |
| ) |
| self.cond_proj = nn.Linear(text_dim, text_dim) |
| self.cond_transform = nn.Sequential( |
| nn.SiLU(), |
| nn.Linear(text_dim, text_dim), |
| ) |
| self.token_pos = nn.Parameter(torch.randn(1, latent_h * latent_w, embed_dim) * 0.02) |
| self.scans = build_scan_permutations(latent_h, latent_w) |
| self.blocks = nn.ModuleList([ |
| CARTELBlock(embed_dim, d_state=d_state, expand=expand) |
| for _ in range(num_layers) |
| ]) |
| self.adalns = nn.ModuleList([ |
| AdaLN(embed_dim, cond_dim=text_dim) |
| for _ in range(num_layers) |
| ]) |
| self.skip_connect = nn.Sequential( |
| nn.Linear(embed_dim, embed_dim), |
| nn.SiLU(), |
| nn.Linear(embed_dim, embed_dim), |
| ) |
| self.final_proj = nn.Sequential( |
| nn.LayerNorm(embed_dim), |
| nn.Linear(embed_dim, embed_dim * 4), |
| nn.SiLU(), |
| nn.Linear(embed_dim * 4, embed_dim), |
| nn.Linear(embed_dim, latent_ch), |
| ) |
| self.style_head = StyleHead(embed_dim, num_style_classes=style_classes) |
| self.content_head = ContentHead(embed_dim, num_objects=content_objects) |
| self.concept_head = ConceptHead(embed_dim) |
| self.mood_head = MoodHead(embed_dim, num_moods=mood_classes) |
| self.comp_head = CompositionHead(embed_dim) |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| def forward(self, z_t, t, text_embed, return_asdl=False): |
| B = z_t.shape[0] |
| x, H, W = self.patch_embed(z_t) |
| x = x + self.token_pos[:, :x.shape[1], :] |
| t_emb = self.t_embed(t.view(B, 1).float()) |
| cond = self.cond_proj(text_embed) + t_emb |
| cond = self.cond_transform(cond) |
| x_shallow = x |
| for i, (block, adaln) in enumerate(zip(self.blocks, self.adalns)): |
| x = adaln(x, cond) |
| scan_name = get_scan_pattern(i) |
| perm, inv = self.scans[scan_name] |
| x_scanned = apply_scan(x, perm) |
| x_scanned = block(x_scanned) |
| x = unscan(x_scanned, inv) |
| if i == self.num_layers // 4: |
| x_shallow = x |
| x = x + self.skip_connect(x_shallow) |
| v = self.final_proj(x).transpose(1, 2).reshape(B, -1, H, W) |
| asdl = {} |
| s, s_logits = self.style_head(x) |
| c, c_logits = self.content_head(x) |
| n = self.concept_head(x) |
| m, m_logits = self.mood_head(x) |
| p = self.comp_head(x) |
| asdl = { |
| "style_vec": s, "style_logits": s_logits, |
| "content_vec": c, "content_logits": c_logits, |
| "concept_vec": n, |
| "mood_vec": m, "mood_logits": m_logits, |
| "comp_vec": p, |
| } |
| if return_asdl: |
| return v, asdl |
| return v, None |
|
|