""" ArtiGen V1.0 — Main Model CARTEL backbone with PHI-SCAN, AdaLN conditioning, ASDL heads. """ import math import torch import torch.nn as nn import torch.nn.functional as F try: from .cartel_block import CARTELBlock from .asdl_head import StyleHead, ContentHead, ConceptHead, MoodHead, CompositionHead from .phi_scan import build_scan_permutations, apply_scan, unscan, get_scan_pattern except ImportError: from cartel_block import CARTELBlock from asdl_head import StyleHead, ContentHead, ConceptHead, MoodHead, CompositionHead from phi_scan import build_scan_permutations, apply_scan, unscan, get_scan_pattern class PatchEmbed(nn.Module): def __init__(self, in_ch, embed_dim, patch_size=2): super().__init__() self.proj = nn.Conv2d(in_ch, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(embed_dim) def forward(self, x): x = self.proj(x) B, C, H, W = x.shape x = x.permute(0, 2, 3, 1).reshape(B, H * W, C) return self.norm(x), H, W class AdaLN(nn.Module): def __init__(self, dim, cond_dim=512): super().__init__() self.modulation = nn.Sequential( nn.SiLU(), nn.Linear(cond_dim, dim * 2), ) def forward(self, x, cond): scale, shift = self.modulation(cond).chunk(2, dim=-1) return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) class ArtiGen(nn.Module): def __init__( self, latent_ch=4, latent_h=32, latent_w=32, embed_dim=256, num_layers=12, d_state=16, expand=2, text_dim=768, style_classes=128, content_objects=1024, mood_classes=64, ): super().__init__() self.embed_dim = embed_dim self.num_layers = num_layers self.latent_h = latent_h self.latent_w = latent_w self.patch_embed = PatchEmbed(latent_ch, embed_dim, patch_size=1) self.t_embed = nn.Sequential( nn.Linear(1, text_dim), nn.SiLU(), nn.Linear(text_dim, text_dim), ) self.cond_proj = nn.Linear(text_dim, text_dim) self.cond_transform = nn.Sequential( nn.SiLU(), nn.Linear(text_dim, text_dim), ) self.token_pos = nn.Parameter(torch.randn(1, latent_h * latent_w, embed_dim) * 0.02) self.scans = build_scan_permutations(latent_h, latent_w) self.blocks = nn.ModuleList([ CARTELBlock(embed_dim, d_state=d_state, expand=expand) for _ in range(num_layers) ]) self.adalns = nn.ModuleList([ AdaLN(embed_dim, cond_dim=text_dim) for _ in range(num_layers) ]) self.skip_connect = nn.Sequential( nn.Linear(embed_dim, embed_dim), nn.SiLU(), nn.Linear(embed_dim, embed_dim), ) self.final_proj = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, embed_dim * 4), nn.SiLU(), nn.Linear(embed_dim * 4, embed_dim), nn.Linear(embed_dim, latent_ch), ) self.style_head = StyleHead(embed_dim, num_style_classes=style_classes) self.content_head = ContentHead(embed_dim, num_objects=content_objects) self.concept_head = ConceptHead(embed_dim) self.mood_head = MoodHead(embed_dim, num_moods=mood_classes) self.comp_head = CompositionHead(embed_dim) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def forward(self, z_t, t, text_embed, return_asdl=False): B = z_t.shape[0] x, H, W = self.patch_embed(z_t) x = x + self.token_pos[:, :x.shape[1], :] t_emb = self.t_embed(t.view(B, 1).float()) cond = self.cond_proj(text_embed) + t_emb cond = self.cond_transform(cond) x_shallow = x for i, (block, adaln) in enumerate(zip(self.blocks, self.adalns)): x = adaln(x, cond) scan_name = get_scan_pattern(i) perm, inv = self.scans[scan_name] x_scanned = apply_scan(x, perm) x_scanned = block(x_scanned) x = unscan(x_scanned, inv) if i == self.num_layers // 4: x_shallow = x x = x + self.skip_connect(x_shallow) v = self.final_proj(x).transpose(1, 2).reshape(B, -1, H, W) asdl = {} s, s_logits = self.style_head(x) c, c_logits = self.content_head(x) n = self.concept_head(x) m, m_logits = self.mood_head(x) p = self.comp_head(x) asdl = { "style_vec": s, "style_logits": s_logits, "content_vec": c, "content_logits": c_logits, "concept_vec": n, "mood_vec": m, "mood_logits": m_logits, "comp_vec": p, } if return_asdl: return v, asdl return v, None