krystv commited on
Commit
7aed37b
·
verified ·
1 Parent(s): a89ce99

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +140 -0
model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ArtiGen V1.0 — Main Model
3
+ CARTEL backbone with PHI-SCAN, AdaLN conditioning, ASDL heads.
4
+ """
5
+ import math
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ try:
10
+ from .cartel_block import CARTELBlock
11
+ from .asdl_head import StyleHead, ContentHead, ConceptHead, MoodHead, CompositionHead
12
+ from .phi_scan import build_scan_permutations, apply_scan, unscan, get_scan_pattern
13
+ except ImportError:
14
+ from cartel_block import CARTELBlock
15
+ from asdl_head import StyleHead, ContentHead, ConceptHead, MoodHead, CompositionHead
16
+ from phi_scan import build_scan_permutations, apply_scan, unscan, get_scan_pattern
17
+
18
+ class PatchEmbed(nn.Module):
19
+ def __init__(self, in_ch, embed_dim, patch_size=2):
20
+ super().__init__()
21
+ self.proj = nn.Conv2d(in_ch, embed_dim, kernel_size=patch_size, stride=patch_size)
22
+ self.norm = nn.LayerNorm(embed_dim)
23
+ def forward(self, x):
24
+ x = self.proj(x)
25
+ B, C, H, W = x.shape
26
+ x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
27
+ return self.norm(x), H, W
28
+
29
+ class AdaLN(nn.Module):
30
+ def __init__(self, dim, cond_dim=512):
31
+ super().__init__()
32
+ self.modulation = nn.Sequential(
33
+ nn.SiLU(),
34
+ nn.Linear(cond_dim, dim * 2),
35
+ )
36
+ def forward(self, x, cond):
37
+ scale, shift = self.modulation(cond).chunk(2, dim=-1)
38
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
39
+
40
+ class ArtiGen(nn.Module):
41
+ def __init__(
42
+ self,
43
+ latent_ch=4,
44
+ latent_h=32,
45
+ latent_w=32,
46
+ embed_dim=256,
47
+ num_layers=12,
48
+ d_state=16,
49
+ expand=2,
50
+ text_dim=768,
51
+ style_classes=128,
52
+ content_objects=1024,
53
+ mood_classes=64,
54
+ ):
55
+ super().__init__()
56
+ self.embed_dim = embed_dim
57
+ self.num_layers = num_layers
58
+ self.latent_h = latent_h
59
+ self.latent_w = latent_w
60
+ self.patch_embed = PatchEmbed(latent_ch, embed_dim, patch_size=1)
61
+ self.t_embed = nn.Sequential(
62
+ nn.Linear(1, text_dim),
63
+ nn.SiLU(),
64
+ nn.Linear(text_dim, text_dim),
65
+ )
66
+ self.cond_proj = nn.Linear(text_dim, text_dim)
67
+ self.cond_transform = nn.Sequential(
68
+ nn.SiLU(),
69
+ nn.Linear(text_dim, text_dim),
70
+ )
71
+ self.token_pos = nn.Parameter(torch.randn(1, latent_h * latent_w, embed_dim) * 0.02)
72
+ self.scans = build_scan_permutations(latent_h, latent_w)
73
+ self.blocks = nn.ModuleList([
74
+ CARTELBlock(embed_dim, d_state=d_state, expand=expand)
75
+ for _ in range(num_layers)
76
+ ])
77
+ self.adalns = nn.ModuleList([
78
+ AdaLN(embed_dim, cond_dim=text_dim)
79
+ for _ in range(num_layers)
80
+ ])
81
+ self.skip_connect = nn.Sequential(
82
+ nn.Linear(embed_dim, embed_dim),
83
+ nn.SiLU(),
84
+ nn.Linear(embed_dim, embed_dim),
85
+ )
86
+ self.final_proj = nn.Sequential(
87
+ nn.LayerNorm(embed_dim),
88
+ nn.Linear(embed_dim, embed_dim * 4),
89
+ nn.SiLU(),
90
+ nn.Linear(embed_dim * 4, embed_dim),
91
+ nn.Linear(embed_dim, latent_ch),
92
+ )
93
+ self.style_head = StyleHead(embed_dim, num_style_classes=style_classes)
94
+ self.content_head = ContentHead(embed_dim, num_objects=content_objects)
95
+ self.concept_head = ConceptHead(embed_dim)
96
+ self.mood_head = MoodHead(embed_dim, num_moods=mood_classes)
97
+ self.comp_head = CompositionHead(embed_dim)
98
+ self.apply(self._init_weights)
99
+
100
+ def _init_weights(self, m):
101
+ if isinstance(m, nn.Linear):
102
+ nn.init.xavier_uniform_(m.weight)
103
+ if m.bias is not None:
104
+ nn.init.zeros_(m.bias)
105
+
106
+ def forward(self, z_t, t, text_embed, return_asdl=False):
107
+ B = z_t.shape[0]
108
+ x, H, W = self.patch_embed(z_t)
109
+ x = x + self.token_pos[:, :x.shape[1], :]
110
+ t_emb = self.t_embed(t.view(B, 1).float())
111
+ cond = self.cond_proj(text_embed) + t_emb
112
+ cond = self.cond_transform(cond)
113
+ x_shallow = x
114
+ for i, (block, adaln) in enumerate(zip(self.blocks, self.adalns)):
115
+ x = adaln(x, cond)
116
+ scan_name = get_scan_pattern(i)
117
+ perm, inv = self.scans[scan_name]
118
+ x_scanned = apply_scan(x, perm)
119
+ x_scanned = block(x_scanned)
120
+ x = unscan(x_scanned, inv)
121
+ if i == self.num_layers // 4:
122
+ x_shallow = x
123
+ x = x + self.skip_connect(x_shallow)
124
+ v = self.final_proj(x).transpose(1, 2).reshape(B, -1, H, W)
125
+ asdl = {}
126
+ s, s_logits = self.style_head(x)
127
+ c, c_logits = self.content_head(x)
128
+ n = self.concept_head(x)
129
+ m, m_logits = self.mood_head(x)
130
+ p = self.comp_head(x)
131
+ asdl = {
132
+ "style_vec": s, "style_logits": s_logits,
133
+ "content_vec": c, "content_logits": c_logits,
134
+ "concept_vec": n,
135
+ "mood_vec": m, "mood_logits": m_logits,
136
+ "comp_vec": p,
137
+ }
138
+ if return_asdl:
139
+ return v, asdl
140
+ return v, None