File size: 5,113 Bytes
7aed37b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
ArtiGen V1.0 — Main Model
CARTEL backbone with PHI-SCAN, AdaLN conditioning, ASDL heads.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
    from .cartel_block import CARTELBlock
    from .asdl_head import StyleHead, ContentHead, ConceptHead, MoodHead, CompositionHead
    from .phi_scan import build_scan_permutations, apply_scan, unscan, get_scan_pattern
except ImportError:
    from cartel_block import CARTELBlock
    from asdl_head import StyleHead, ContentHead, ConceptHead, MoodHead, CompositionHead
    from phi_scan import build_scan_permutations, apply_scan, unscan, get_scan_pattern

class PatchEmbed(nn.Module):
    def __init__(self, in_ch, embed_dim, patch_size=2):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)
    def forward(self, x):
        x = self.proj(x)
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
        return self.norm(x), H, W

class AdaLN(nn.Module):
    def __init__(self, dim, cond_dim=512):
        super().__init__()
        self.modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(cond_dim, dim * 2),
        )
    def forward(self, x, cond):
        scale, shift = self.modulation(cond).chunk(2, dim=-1)
        return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class ArtiGen(nn.Module):
    def __init__(
        self,
        latent_ch=4,
        latent_h=32,
        latent_w=32,
        embed_dim=256,
        num_layers=12,
        d_state=16,
        expand=2,
        text_dim=768,
        style_classes=128,
        content_objects=1024,
        mood_classes=64,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_layers = num_layers
        self.latent_h = latent_h
        self.latent_w = latent_w
        self.patch_embed = PatchEmbed(latent_ch, embed_dim, patch_size=1)
        self.t_embed = nn.Sequential(
            nn.Linear(1, text_dim),
            nn.SiLU(),
            nn.Linear(text_dim, text_dim),
        )
        self.cond_proj = nn.Linear(text_dim, text_dim)
        self.cond_transform = nn.Sequential(
            nn.SiLU(),
            nn.Linear(text_dim, text_dim),
        )
        self.token_pos = nn.Parameter(torch.randn(1, latent_h * latent_w, embed_dim) * 0.02)
        self.scans = build_scan_permutations(latent_h, latent_w)
        self.blocks = nn.ModuleList([
            CARTELBlock(embed_dim, d_state=d_state, expand=expand)
            for _ in range(num_layers)
        ])
        self.adalns = nn.ModuleList([
            AdaLN(embed_dim, cond_dim=text_dim)
            for _ in range(num_layers)
        ])
        self.skip_connect = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.SiLU(),
            nn.Linear(embed_dim, embed_dim),
        )
        self.final_proj = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim * 4),
            nn.SiLU(),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Linear(embed_dim, latent_ch),
        )
        self.style_head = StyleHead(embed_dim, num_style_classes=style_classes)
        self.content_head = ContentHead(embed_dim, num_objects=content_objects)
        self.concept_head = ConceptHead(embed_dim)
        self.mood_head = MoodHead(embed_dim, num_moods=mood_classes)
        self.comp_head = CompositionHead(embed_dim)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, z_t, t, text_embed, return_asdl=False):
        B = z_t.shape[0]
        x, H, W = self.patch_embed(z_t)
        x = x + self.token_pos[:, :x.shape[1], :]
        t_emb = self.t_embed(t.view(B, 1).float())
        cond = self.cond_proj(text_embed) + t_emb
        cond = self.cond_transform(cond)
        x_shallow = x
        for i, (block, adaln) in enumerate(zip(self.blocks, self.adalns)):
            x = adaln(x, cond)
            scan_name = get_scan_pattern(i)
            perm, inv = self.scans[scan_name]
            x_scanned = apply_scan(x, perm)
            x_scanned = block(x_scanned)
            x = unscan(x_scanned, inv)
            if i == self.num_layers // 4:
                x_shallow = x
        x = x + self.skip_connect(x_shallow)
        v = self.final_proj(x).transpose(1, 2).reshape(B, -1, H, W)
        asdl = {}
        s, s_logits = self.style_head(x)
        c, c_logits = self.content_head(x)
        n = self.concept_head(x)
        m, m_logits = self.mood_head(x)
        p = self.comp_head(x)
        asdl = {
            "style_vec": s, "style_logits": s_logits,
            "content_vec": c, "content_logits": c_logits,
            "concept_vec": n,
            "mood_vec": m, "mood_logits": m_logits,
            "comp_vec": p,
        }
        if return_asdl:
            return v, asdl
        return v, None