artigen / model.py
krystv's picture
Upload model.py
7aed37b verified
"""
ArtiGen V1.0 — Main Model
CARTEL backbone with PHI-SCAN, AdaLN conditioning, ASDL heads.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from .cartel_block import CARTELBlock
from .asdl_head import StyleHead, ContentHead, ConceptHead, MoodHead, CompositionHead
from .phi_scan import build_scan_permutations, apply_scan, unscan, get_scan_pattern
except ImportError:
from cartel_block import CARTELBlock
from asdl_head import StyleHead, ContentHead, ConceptHead, MoodHead, CompositionHead
from phi_scan import build_scan_permutations, apply_scan, unscan, get_scan_pattern
class PatchEmbed(nn.Module):
def __init__(self, in_ch, embed_dim, patch_size=2):
super().__init__()
self.proj = nn.Conv2d(in_ch, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.proj(x)
B, C, H, W = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
return self.norm(x), H, W
class AdaLN(nn.Module):
def __init__(self, dim, cond_dim=512):
super().__init__()
self.modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(cond_dim, dim * 2),
)
def forward(self, x, cond):
scale, shift = self.modulation(cond).chunk(2, dim=-1)
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class ArtiGen(nn.Module):
def __init__(
self,
latent_ch=4,
latent_h=32,
latent_w=32,
embed_dim=256,
num_layers=12,
d_state=16,
expand=2,
text_dim=768,
style_classes=128,
content_objects=1024,
mood_classes=64,
):
super().__init__()
self.embed_dim = embed_dim
self.num_layers = num_layers
self.latent_h = latent_h
self.latent_w = latent_w
self.patch_embed = PatchEmbed(latent_ch, embed_dim, patch_size=1)
self.t_embed = nn.Sequential(
nn.Linear(1, text_dim),
nn.SiLU(),
nn.Linear(text_dim, text_dim),
)
self.cond_proj = nn.Linear(text_dim, text_dim)
self.cond_transform = nn.Sequential(
nn.SiLU(),
nn.Linear(text_dim, text_dim),
)
self.token_pos = nn.Parameter(torch.randn(1, latent_h * latent_w, embed_dim) * 0.02)
self.scans = build_scan_permutations(latent_h, latent_w)
self.blocks = nn.ModuleList([
CARTELBlock(embed_dim, d_state=d_state, expand=expand)
for _ in range(num_layers)
])
self.adalns = nn.ModuleList([
AdaLN(embed_dim, cond_dim=text_dim)
for _ in range(num_layers)
])
self.skip_connect = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim),
)
self.final_proj = nn.Sequential(
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, embed_dim * 4),
nn.SiLU(),
nn.Linear(embed_dim * 4, embed_dim),
nn.Linear(embed_dim, latent_ch),
)
self.style_head = StyleHead(embed_dim, num_style_classes=style_classes)
self.content_head = ContentHead(embed_dim, num_objects=content_objects)
self.concept_head = ConceptHead(embed_dim)
self.mood_head = MoodHead(embed_dim, num_moods=mood_classes)
self.comp_head = CompositionHead(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, z_t, t, text_embed, return_asdl=False):
B = z_t.shape[0]
x, H, W = self.patch_embed(z_t)
x = x + self.token_pos[:, :x.shape[1], :]
t_emb = self.t_embed(t.view(B, 1).float())
cond = self.cond_proj(text_embed) + t_emb
cond = self.cond_transform(cond)
x_shallow = x
for i, (block, adaln) in enumerate(zip(self.blocks, self.adalns)):
x = adaln(x, cond)
scan_name = get_scan_pattern(i)
perm, inv = self.scans[scan_name]
x_scanned = apply_scan(x, perm)
x_scanned = block(x_scanned)
x = unscan(x_scanned, inv)
if i == self.num_layers // 4:
x_shallow = x
x = x + self.skip_connect(x_shallow)
v = self.final_proj(x).transpose(1, 2).reshape(B, -1, H, W)
asdl = {}
s, s_logits = self.style_head(x)
c, c_logits = self.content_head(x)
n = self.concept_head(x)
m, m_logits = self.mood_head(x)
p = self.comp_head(x)
asdl = {
"style_vec": s, "style_logits": s_logits,
"content_vec": c, "content_logits": c_logits,
"concept_vec": n,
"mood_vec": m, "mood_logits": m_logits,
"comp_vec": p,
}
if return_asdl:
return v, asdl
return v, None