zirobtc commited on
Commit
3c212d2
·
verified ·
1 Parent(s): 0b85347

Upload 2 files

Browse files
Files changed (2) hide show
  1. models/diffloss.py +308 -0
  2. models/llama_model.py +1894 -0
models/diffloss.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # models/diffloss.py
2
+
3
+ import math
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.checkpoint import checkpoint
7
+ from models.diffusion import create_diffusion
8
+
9
+
10
+ # ---------------- utils ----------------
11
+ def modulate(x, shift, scale):
12
+ return x * (1 + scale) + shift
13
+
14
+
15
+ class TimestepEmbedder(nn.Module):
16
+ def __init__(self, hidden_size, frequency_embedding_size=256):
17
+ super().__init__()
18
+ self.mlp = nn.Sequential(
19
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
20
+ nn.SiLU(),
21
+ nn.Linear(hidden_size, hidden_size, bias=True),
22
+ )
23
+ self.frequency_embedding_size = frequency_embedding_size
24
+
25
+ @staticmethod
26
+ def timestep_embedding(t, dim, max_period=10000):
27
+ half = dim // 2
28
+ freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32) / half).to(t.device)
29
+ args = t[:, None].float() * freqs[None]
30
+ emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
31
+ if dim % 2:
32
+ emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1)
33
+ return emb
34
+
35
+ def forward(self, t):
36
+ return self.mlp(self.timestep_embedding(t, self.frequency_embedding_size))
37
+
38
+
39
+ class SinPos1D(nn.Module):
40
+ def __init__(self, dim):
41
+ super().__init__()
42
+ self.dim = dim
43
+ def forward(self, L, device, dtype):
44
+ pe = torch.zeros(L, self.dim, device=device, dtype=torch.float32)
45
+ pos = torch.arange(0, L, device=device, dtype=torch.float32).unsqueeze(1)
46
+ div = torch.exp(torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) * (-math.log(10000.0)/self.dim))
47
+ pe[:, 0::2] = torch.sin(pos * div)
48
+ pe[:, 1::2] = torch.cos(pos * div)
49
+ return pe.to(dtype)
50
+
51
+
52
+ # --------------- DiT block (causal) ---------------
53
+ class TemporalDiTBlock(nn.Module):
54
+ """
55
+ Transformer block with AdaLN (DiT-style), **causal** self-attention over time.
56
+ """
57
+ def __init__(self, dim, n_heads, mlp_ratio=4.0, dropout=0.0):
58
+ super().__init__()
59
+ self.dim = dim
60
+ self.n_heads = n_heads
61
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
62
+ self.attn = nn.MultiheadAttention(dim, n_heads, dropout=dropout, batch_first=True)
63
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
64
+ hidden = int(dim * mlp_ratio)
65
+ self.ffn = nn.Sequential(
66
+ nn.Linear(dim, 2 * hidden, bias=True),
67
+ nn.SiLU(),
68
+ nn.Linear(2 * hidden, dim, bias=True),
69
+ )
70
+ # AdaLN params: shift/scale/gate for attn and ffn
71
+ self.adaLN = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
72
+ nn.init.constant_(self.adaLN[-1].weight, 0)
73
+ nn.init.constant_(self.adaLN[-1].bias, 0)
74
+
75
+ def forward(self, x, y, causal_mask):
76
+ """
77
+ x: [B, L, D], y: [B, D], causal_mask: [L, L] bool, True = mask (disallow)
78
+ """
79
+ s1, sc1, g1, s2, sc2, g2 = self.adaLN(y).chunk(6, dim=-1) # [B, D] each
80
+
81
+ # attn (causal)
82
+ h = modulate(self.norm1(x), s1.unsqueeze(1), sc1.unsqueeze(1))
83
+ # torch's attn expects attn_mask shape [L, L] or [B*nH, L, L]; True means -inf
84
+ h, _ = self.attn(h, h, h, attn_mask=causal_mask, need_weights=False)
85
+ x = x + g1.unsqueeze(1) * h
86
+
87
+ # ffn
88
+ h2 = modulate(self.norm2(x), s2.unsqueeze(1), sc2.unsqueeze(1))
89
+ h2 = self.ffn(h2)
90
+ x = x + g2.unsqueeze(1) * h2
91
+ return x
92
+
93
+
94
+ class FinalLayer(nn.Module):
95
+ def __init__(self, dim, out_channels):
96
+ super().__init__()
97
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
98
+ self.linear = nn.Linear(dim, out_channels, bias=True)
99
+ self.adaLN = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
100
+ nn.init.constant_(self.adaLN[-1].weight, 0)
101
+ nn.init.constant_(self.adaLN[-1].bias, 0)
102
+ nn.init.constant_(self.linear.weight, 0)
103
+ nn.init.constant_(self.linear.bias, 0)
104
+
105
+ def forward(self, x, c):
106
+ shift, scale = self.adaLN(c).chunk(2, dim=-1)
107
+ x = modulate(self.norm(x), shift.unsqueeze(1), scale.unsqueeze(1))
108
+ return self.linear(x)
109
+
110
+
111
+ # --------------- Temporal DiT (sequence-aware, causal) ---------------
112
+ class TemporalDiTAdaLN(nn.Module):
113
+ """
114
+ DiT-like denoiser that:
115
+ - operates on [B, L, C]
116
+ - uses **causal** attention (each position sees only <= t)
117
+ - accepts (B, L) via set_sequence_layout for flatten↔sequence reshaping
118
+ - returns all positions but we usually **read only the last token** for streaming
119
+ """
120
+ def __init__(self, in_channels, model_channels, out_channels, z_channels, depth, n_heads=8,
121
+ mlp_ratio=4.0, grad_checkpointing=False):
122
+ super().__init__()
123
+ self.in_channels = in_channels
124
+ self.model_channels = model_channels
125
+ self.out_channels = out_channels
126
+ self.z_channels = z_channels
127
+ self.depth = depth
128
+ self.n_heads = n_heads
129
+ self.grad_checkpointing = grad_checkpointing
130
+
131
+ self.time_embed = TimestepEmbedder(model_channels)
132
+ self.cond_embed = nn.Linear(z_channels, model_channels)
133
+ self.input_proj = nn.Linear(in_channels, model_channels)
134
+ self.pos = SinPos1D(model_channels)
135
+
136
+ self.blocks = nn.ModuleList([
137
+ TemporalDiTBlock(model_channels, n_heads=n_heads, mlp_ratio=mlp_ratio)
138
+ for _ in range(depth)
139
+ ])
140
+ self.final = FinalLayer(model_channels, out_channels)
141
+
142
+ self._seq_B = None
143
+ self._seq_L = None
144
+
145
+ self._init_weights()
146
+
147
+ def _init_weights(self):
148
+ def _xav(m):
149
+ if isinstance(m, nn.Linear):
150
+ nn.init.xavier_uniform_(m.weight)
151
+ if m.bias is not None: nn.init.constant_(m.bias, 0)
152
+ self.apply(_xav)
153
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
154
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
155
+
156
+ def set_sequence_layout(self, B, L):
157
+ self._seq_B = int(B)
158
+ self._seq_L = int(L)
159
+
160
+ def _flatten_to_seq(self, x_flat, c_flat):
161
+ if self._seq_B is None or self._seq_L is None:
162
+ B, L = x_flat.shape[0], 1
163
+ else:
164
+ B, L = self._seq_B, self._seq_L
165
+ assert B * L == x_flat.shape[0], f"set_sequence_layout({B},{L}) mismatch"
166
+ x = x_flat.view(B, L, -1)
167
+ c = c_flat.view(B, L, -1)
168
+ return x, c
169
+
170
+ @staticmethod
171
+ def _causal_mask(L, device):
172
+ # True where masked (disallowed)
173
+ m = torch.ones(L, L, device=device, dtype=torch.bool).triu(1)
174
+ # MultiheadAttention expects float mask with -inf where we mask.
175
+ # But newer PyTorch also supports bool with True=mask. We'll pass bool here.
176
+ return m
177
+
178
+ def forward(self, x_flat, t, c_flat, cfg_scale: float = 1.0):
179
+ x, c = self._flatten_to_seq(x_flat, c_flat) # [B, L, C], [B, L, Cz]
180
+ B, L, _ = x.shape
181
+
182
+ x = self.input_proj(x)
183
+ pos = self.pos(L, x.device, x.dtype)
184
+ x = x + pos.unsqueeze(0)
185
+
186
+ # pool cond to a single AdaLN vector per batch (like DiT)
187
+ t_emb = self.time_embed(t).view(B, L, -1).mean(dim=1) # [B, D]
188
+ c_emb = self.cond_embed(c).mean(dim=1) # [B, D]
189
+ y = t_emb + c_emb
190
+
191
+ causal_mask = self._causal_mask(L, x.device)
192
+
193
+ if self.grad_checkpointing and not torch.jit.is_scripting():
194
+ for blk in self.blocks:
195
+ x = checkpoint(blk, x, y, causal_mask)
196
+ else:
197
+ for blk in self.blocks:
198
+ x = blk(x, y, causal_mask)
199
+
200
+ out = self.final(x, y) # [B, L, out_channels]
201
+ return out.view(B * L, -1)
202
+
203
+ def forward_with_cfg(self, x, t, c, cfg_scale):
204
+ half = x[: len(x) // 2]
205
+ combined = torch.cat([half, half], dim=0)
206
+ model_out = self.forward(combined, t, c, cfg_scale=cfg_scale)
207
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
208
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
209
+ guided = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
210
+ eps = torch.cat([guided, guided], dim=0)
211
+ return torch.cat([eps, rest], dim=1)
212
+
213
+
214
+ # --------------- Wrapper (same training API) + streaming helpers ---------------
215
+ class DiffLoss(nn.Module):
216
+ """
217
+ Diffusion loss with **causal, streamable** temporal DiT denoiser.
218
+ Training API unchanged; plus:
219
+ - set_sequence_layout(B, L)
220
+ - sample_next_token(z_seq, temperature=1.0, cfg=1.0) -> [B, C] (last token)
221
+ """
222
+ def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps,
223
+ grad_checkpointing=False, learn_sigma=False, n_heads=8, mlp_ratio=4.0):
224
+ super().__init__()
225
+ self.in_channels = target_channels
226
+ self.learn_sigma = learn_sigma
227
+
228
+ self.net = TemporalDiTAdaLN(
229
+ in_channels=target_channels,
230
+ model_channels=width,
231
+ out_channels=target_channels * 2 if learn_sigma else target_channels,
232
+ z_channels=z_channels,
233
+ depth=depth,
234
+ n_heads=n_heads,
235
+ mlp_ratio=mlp_ratio,
236
+ grad_checkpointing=grad_checkpointing
237
+ )
238
+
239
+ self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine")
240
+ self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine")
241
+
242
+ # cached (B,L) for flatten↔sequence
243
+ self._B = None
244
+ self._L = None
245
+
246
+ # --- layout for flatten<->sequence ---
247
+ def set_sequence_layout(self, B, L):
248
+ self._B, self._L = int(B), int(L)
249
+ self.net.set_sequence_layout(B, L)
250
+
251
+ # --- training ---
252
+ def forward(self, target, z, mask=None):
253
+ t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
254
+ loss_dict = self.train_diffusion.training_losses(self.net, target, t, dict(c=z))
255
+ loss, pred_xstart = loss_dict["loss"], loss_dict["pred_xstart"]
256
+ if mask is not None:
257
+ loss = (loss * mask).sum() / mask.sum()
258
+ return loss.mean(), pred_xstart
259
+
260
+ # --- full sequence sampling (kept for compatibility) ---
261
+ def sample(self, z, temperature=1.0, cfg=1.0):
262
+ if cfg != 1.0:
263
+ noise = torch.randn(z.shape[0] // 2, self.in_channels, device=z.device)
264
+ noise = torch.cat([noise, noise], dim=0)
265
+ sample_fn = self.net.forward_with_cfg
266
+ kwargs = dict(c=z, cfg_scale=cfg)
267
+ else:
268
+ noise = torch.randn(z.shape[0], self.in_channels, device=z.device)
269
+ sample_fn = self.net.forward
270
+ kwargs = dict(c=z)
271
+
272
+ return self.gen_diffusion.p_sample_loop(
273
+ sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=kwargs,
274
+ progress=False, temperature=temperature
275
+ )
276
+
277
+ # --- STREAMING: sample only the **last token** of current window ---
278
+ @torch.no_grad()
279
+ def sample_next_token(self, z_seq, temperature=1.0, cfg=1.0):
280
+ """
281
+ z_seq: [B, L, Cz] AR conditions for the current streaming window (history + 1 step).
282
+ Call set_sequence_layout(B, L) first.
283
+ Returns: next_token: [B, C] (the last position’s denoised sample).
284
+ Mechanism: denoise **entire window** with causal attention and read the last index only.
285
+ """
286
+ assert self._B is not None and self._L is not None, "Call set_sequence_layout(B, L) first."
287
+ B, L, Cz = z_seq.shape
288
+ assert B == self._B and L == self._L, "z_seq shape must match set_sequence_layout."
289
+
290
+ z_flat = z_seq.reshape(B * L, Cz)
291
+
292
+ if cfg != 1.0:
293
+ noise = torch.randn((B * L) // 2, self.in_channels, device=z_seq.device)
294
+ noise = torch.cat([noise, noise], dim=0)
295
+ sample_fn = self.net.forward_with_cfg
296
+ kwargs = dict(c=z_flat, cfg_scale=cfg)
297
+ else:
298
+ noise = torch.randn(B * L, self.in_channels, device=z_seq.device)
299
+ sample_fn = self.net.forward
300
+ kwargs = dict(c=z_flat)
301
+
302
+ x = self.gen_diffusion.p_sample_loop(
303
+ sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=kwargs,
304
+ progress=False, temperature=temperature
305
+ ) # [B*L, C]
306
+
307
+ x_seq = x.view(B, L, self.in_channels)
308
+ return x_seq[:, -1, :] # last token only
models/llama_model.py ADDED
@@ -0,0 +1,1894 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from dataclasses import dataclass
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.nn import functional as F
7
+ from typing_extensions import Self
8
+ from typing import Optional
9
+ from transformers.modeling_utils import PreTrainedModel
10
+ from torch.distributions import Categorical
11
+
12
+
13
+ @dataclass
14
+ class LLaMAHFConfig:
15
+ block_size: int = 156
16
+ n_layer: int = 32
17
+ n_head: int = 32
18
+ n_kv_head: Optional[int] = None
19
+ n_embd: int = 4096
20
+ rope_base: int = 500000
21
+ T5_xxl_dim: int = 768
22
+
23
+ @classmethod
24
+ def from_name(cls, name: str) -> Self:
25
+ return cls(**llama_configs[name])
26
+
27
+
28
+ llama_configs = {
29
+ "Normal_size": dict(n_layer=12, n_head=12, n_embd=768)
30
+ }
31
+
32
+
33
+ class LLaMAHF(nn.Module):
34
+ def __init__(self, config: LLaMAHFConfig, num_diffusion_head_layers=6, n_diffusion_heads=4, input_token_dim=16, device=torch.device('cuda'), width=512) -> None:
35
+ super().__init__()
36
+ assert config.block_size is not None
37
+ self.config = config
38
+
39
+ cond_dim = config.T5_xxl_dim
40
+
41
+ self.transformer = nn.ModuleDict(
42
+ dict(
43
+ wte=nn.Linear(input_token_dim, config.n_embd), # vector tokens -> embeddings
44
+ cond_embed=nn.Linear(cond_dim, config.n_embd), # text feature -> context emb
45
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
46
+ ln_f=RMSNorm(config.n_embd),
47
+ )
48
+ )
49
+
50
+ target_channels = input_token_dim
51
+ from models.diffloss import DiffLoss
52
+ self.diff_loss = DiffLoss(
53
+ target_channels=target_channels,
54
+ z_channels=config.n_embd,
55
+ width=width,
56
+ depth=num_diffusion_head_layers,
57
+ num_sampling_steps='50',
58
+ grad_checkpointing=False,
59
+ n_heads=n_diffusion_heads,
60
+ mlp_ratio=2.0
61
+ ).to(device)
62
+
63
+ self.out_proj = nn.Linear(config.n_embd, config.n_embd)
64
+ self.use_out_proj = True
65
+
66
+ # --- Persistent prompt cache & BOS token ---
67
+ self._prompt_cached = False
68
+ self._prompt_bsz = None
69
+ self.bos = nn.Parameter(torch.zeros(1, 1, config.n_embd))
70
+
71
+ # === Needed by several sampling/forward paths ===
72
+ # projects raw text features when they are concatenated as tokens
73
+ self.llama_proj = nn.Linear(config.T5_xxl_dim, config.n_embd)
74
+ # special boundary-of-motion token used in forward_babel
75
+ self.BOM_tag = nn.Parameter(torch.zeros(1, 1, config.n_embd))
76
+
77
+ # (Optional) only if sample_for_eval_classification() is used:
78
+ # self.classify_head = nn.Linear(config.n_embd, num_classes)
79
+
80
+
81
+
82
+ @torch.no_grad()
83
+ def set_prompt(self, feature: torch.Tensor):
84
+ """
85
+ Precompute and cache cross-attention K/V for the current prompt (feature).
86
+ Call this ONCE when you switch prompt (e.g., 'walk' -> 'crawl').
87
+ """
88
+ context = self._prepare_context(feature)
89
+ if context is None:
90
+ raise ValueError("set_prompt: feature cannot be None")
91
+
92
+ self._prompt_bsz = context.size(0)
93
+ for blk in self.transformer.h:
94
+ blk.set_context_cache(context)
95
+ self._prompt_cached = True
96
+
97
+ @torch.no_grad()
98
+ def clear_prompt(self):
99
+ for blk in self.transformer.h:
100
+ blk.clear_context_cache()
101
+ self._prompt_cached = False
102
+ self._prompt_bsz = None
103
+
104
+ def _prepare_context(self, feature: Optional[torch.Tensor], batch_size: Optional[int] = None) -> Optional[torch.Tensor]:
105
+ if feature is None:
106
+ return None
107
+ if not torch.is_tensor(feature):
108
+ feature = torch.as_tensor(
109
+ feature,
110
+ dtype=self.transformer.cond_embed.weight.dtype,
111
+ device=self.transformer.cond_embed.weight.device,
112
+ )
113
+ else:
114
+ feature = feature.to(
115
+ dtype=self.transformer.cond_embed.weight.dtype,
116
+ device=self.transformer.cond_embed.weight.device,
117
+ )
118
+
119
+ if feature.dim() == 1:
120
+ feature = feature.unsqueeze(0)
121
+
122
+ context = self.transformer.cond_embed(feature)
123
+ if context.dim() == 2:
124
+ context = context.unsqueeze(1)
125
+
126
+ if batch_size is not None and context.size(0) != batch_size:
127
+ if context.size(0) == 1:
128
+ context = context.expand(batch_size, -1, -1)
129
+ else:
130
+ raise ValueError(
131
+ f"Condition batch ({context.size(0)}) does not match token batch ({batch_size})."
132
+ )
133
+ return context
134
+
135
+ def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
136
+ """Tie or clone module weights depending of whether we are using TorchScript or not"""
137
+ output_embeddings.weight = input_embeddings.weight
138
+
139
+ if getattr(output_embeddings, "bias", None) is not None:
140
+ output_embeddings.bias.data = nn.functional.pad(
141
+ output_embeddings.bias.data,
142
+ (
143
+ 0,
144
+ output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
145
+ ),
146
+ "constant",
147
+ 0,
148
+ )
149
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
150
+ output_embeddings.out_features = input_embeddings.num_embeddings
151
+
152
+ def get_input_embeddings(self):
153
+ return self.transformer.wte
154
+
155
+ def set_input_embeddings(self, value):
156
+ self.transformer.wte = value
157
+
158
+ def get_output_embeddings(self):
159
+ return self.out_proj
160
+
161
+ def set_output_embeddings(self, new_embeddings):
162
+ self.out_proj = new_embeddings
163
+
164
+ def _init_weights(self, module: nn.Module) -> None:
165
+ if isinstance(module, nn.Linear):
166
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
167
+ elif isinstance(module, nn.Embedding):
168
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
169
+
170
+
171
+
172
+ def forward_sample(self, idx: torch.Tensor, clip_feature: torch.Tensor, y_mask) -> torch.Tensor:
173
+
174
+ text_length = clip_feature.shape[1]
175
+ context = self._prepare_context(clip_feature)
176
+ if len(idx) == 0:
177
+ x = self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :]
178
+ else:
179
+ _, t = idx.size()
180
+ assert (
181
+ t <= self.config.block_size
182
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
183
+ # forward the LLaMA model itself
184
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
185
+ x = torch.cat((self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :],x), dim=1)
186
+
187
+ if context is not None and context.size(0) != x.size(0):
188
+ if context.size(0) == 1:
189
+ context = context.expand(x.size(0), -1, -1)
190
+ else:
191
+ raise ValueError("Conditioning batch size does not match token batch size.")
192
+
193
+ for block in self.transformer.h:
194
+ x = block(x, context=context)
195
+ x = self.transformer.ln_f(x)
196
+ logits = x
197
+ return logits
198
+
199
+
200
+
201
+ def sample_for_eval_CFG(self, text, length=196, tokenize_model=None, device=torch.device('cuda'), unit_length=4, cfg=4.0):
202
+ max_token_len = length // unit_length
203
+
204
+ # Prepare conditioned prompt once and cache it
205
+ feat_text = torch.from_numpy(tokenize_model.encode(text)).float().to(device)
206
+ self.set_prompt(feat_text) # <-- persist until you change it
207
+
208
+ # Prepare empty/uncond prompt once and cache it too
209
+ empty_feat_text = torch.from_numpy(tokenize_model.encode('')).float().unsqueeze(0).to(device)
210
+
211
+ # We'll flip between two caches: cond and uncond
212
+ def _use_cond_cache():
213
+ self.set_prompt(feat_text)
214
+
215
+ def _use_uncond_cache():
216
+ self.set_prompt(empty_feat_text)
217
+
218
+ xs = None
219
+ for k in range(max_token_len):
220
+ x = [] if k == 0 else xs
221
+
222
+ # conditioned next-step
223
+ _use_cond_cache()
224
+ conditions = self.forward(x, feature=None)[:, -1, :]
225
+
226
+ # unconditioned next-step
227
+ _use_uncond_cache()
228
+ empty_conditions = self.forward(x, feature=None)[:, -1, :]
229
+
230
+ temperature = 1.0
231
+ if cfg != 1:
232
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
233
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
234
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
235
+ else:
236
+ scaled_logits = self.diff_loss.sample(conditions, temperature=temperature, cfg=1)
237
+
238
+ scaled_logits = scaled_logits.unsqueeze(0)
239
+ xs = scaled_logits if k == 0 else torch.cat((xs, scaled_logits), dim=1)
240
+
241
+ # re-enable the conditioned prompt cache for whatever comes next
242
+ self.set_prompt(feat_text)
243
+ return xs
244
+
245
+
246
+
247
+ # For inference, can stop sampling when the distance between the current token and the reference end token is less than the threshold.
248
+ def sample_for_eval_CFG_inference(self, text, length=312, tokenizer=None, device=torch.device('cuda'),
249
+ unit_length=4, reference_end_latent=None, threshold=0.1, cfg=4.0, temperature=1.0):
250
+ max_token_len = length // unit_length
251
+ feat_text = torch.from_numpy(tokenizer.encode(text)).float().to(device)
252
+ empty_feat_text = torch.from_numpy(tokenizer.encode('')).float().unsqueeze(0).to(device)
253
+
254
+ def _use_cond(): self.set_prompt(feat_text)
255
+ def _use_uncond(): self.set_prompt(empty_feat_text)
256
+
257
+ xs = None
258
+ for k in range(max_token_len):
259
+ x = [] if k == 0 else xs
260
+
261
+ _use_cond()
262
+ conditions = self.forward_inference(x, feature=None)[:, -1, :]
263
+
264
+ _use_uncond()
265
+ empty_conditions = self.forward(x, feature=None)[:, -1, :]
266
+
267
+ mix = torch.cat([conditions, empty_conditions], dim=0)
268
+ sampled = self.diff_loss.sample(mix, temperature=temperature, cfg=cfg)
269
+ scaled_logits, _ = sampled.chunk(2, dim=0) if cfg != 1 else (sampled, None)
270
+ scaled_logits = scaled_logits.unsqueeze(0)
271
+
272
+ if reference_end_latent is not None:
273
+ dist = torch.sqrt(torch.sum((scaled_logits - reference_end_latent)**2))
274
+ if dist < threshold: break
275
+
276
+ xs = scaled_logits if k == 0 else torch.cat((xs, scaled_logits), dim=1)
277
+
278
+ # leave the cond cache active
279
+ self.set_prompt(feat_text)
280
+ return xs
281
+
282
+
283
+
284
+ def sample_for_eval_CFG_inference2(self, feat_clip_text, empty_feat_clip_text, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0):
285
+
286
+ import clip
287
+ max_token_len = length // unit_length
288
+
289
+ for k in range(max_token_len):
290
+ if k == 0:
291
+ x = []
292
+ else:
293
+ x = xs
294
+
295
+ try:
296
+ conditions = self.forward(x, feat_clip_text)
297
+ except:
298
+ conditions = self.forward(x, feat_clip_text.unsqueeze(0))
299
+
300
+
301
+ conditions = conditions[:, -1, :]
302
+
303
+
304
+
305
+ empty_conditions = self.forward(x, empty_feat_clip_text)
306
+ empty_conditions = empty_conditions[:, -1, :]
307
+
308
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
309
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
310
+
311
+ # chunk
312
+ if cfg != 1:
313
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
314
+ else:
315
+ scaled_logits = sampled_token_latent
316
+
317
+ scaled_logits = scaled_logits.unsqueeze(0)
318
+
319
+ if reference_end_token is not None:
320
+ distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
321
+ print(distance_l2)
322
+ if distance_l2 < threshold:
323
+ break
324
+
325
+ if k == 0:
326
+ xs = scaled_logits
327
+ else:
328
+ xs = torch.cat((xs, scaled_logits), dim=1)
329
+
330
+ return xs
331
+
332
+ def sample_for_eval_CFG_inference_next_one(self, current_token=[], feat_clip_text=None, empty_feat_clip_text=None, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0):
333
+
334
+ import clip
335
+ max_token_len = length // unit_length
336
+
337
+
338
+ for k in range(1):
339
+
340
+ if current_token == []:
341
+ x = []
342
+ else:
343
+ x = torch.cat(current_token, dim=1)
344
+
345
+
346
+ try:
347
+ conditions = self.forward(x, feat_clip_text)
348
+ except:
349
+ conditions = self.forward(x, feat_clip_text.unsqueeze(0))
350
+
351
+
352
+ conditions = conditions[:, -1, :]
353
+
354
+
355
+ empty_conditions = self.forward(x, empty_feat_clip_text)
356
+ empty_conditions = empty_conditions[:, -1, :]
357
+
358
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
359
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
360
+
361
+ # chunk
362
+ if cfg != 1:
363
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
364
+ else:
365
+ scaled_logits = sampled_token_latent
366
+
367
+
368
+ scaled_logits = scaled_logits.unsqueeze(0)
369
+
370
+
371
+ if k == 0:
372
+ xs = scaled_logits
373
+ else:
374
+ xs = torch.cat((xs, scaled_logits), dim=1)
375
+
376
+ return xs
377
+
378
+
379
+ def sample_for_eval_CFG_babel(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3):
380
+
381
+ import clip
382
+ B_token_length = length // unit_length - A_motion.shape[0]
383
+
384
+ if tokenizer == 'clip':
385
+ A_text = clip.tokenize(A_text, truncate=True).to(device)
386
+ A_feat_clip_text = clip_model.encode_text(A_text).float()
387
+ B_text = clip.tokenize(B_text, truncate=True).to(device)
388
+ B_feat_clip_text = clip_model.encode_text(B_text).float()
389
+ elif tokenizer == 't5-xxl':
390
+ A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float()
391
+ A_feat_clip_text = A_feat_clip_text.to(device)
392
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
393
+ B_feat_clip_text = B_feat_clip_text.to(device)
394
+
395
+ A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0)
396
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
397
+
398
+ A_motion = A_motion.unsqueeze(0)
399
+ A_motion_embeddings = self.transformer.wte(A_motion)
400
+ B_motion = torch.tensor([]).to(device)
401
+
402
+ for k in range(B_token_length):
403
+ if k == 0:
404
+ x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1)
405
+ else:
406
+ x = xs
407
+
408
+
409
+ conditions = self.forward_babel_eval(x)
410
+ conditions = conditions[:, -1, :]
411
+
412
+ empty_clip_text = ''
413
+ if tokenizer == 'clip':
414
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
415
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
416
+ elif tokenizer == 't5-xxl':
417
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
418
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
419
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
420
+
421
+ empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
422
+
423
+ if k == 0:
424
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1)
425
+ empty_conditions = self.forward_babel_eval(empty_input)
426
+ else:
427
+ B_motion_embeddings = self.transformer.wte(B_motion)
428
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1)
429
+ empty_conditions = self.forward_babel_eval(empty_input)
430
+
431
+ empty_conditions = empty_conditions[:, -1, :]
432
+ temperature = 1.0
433
+
434
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
435
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
436
+
437
+ # chunk
438
+ if cfg != 1:
439
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
440
+ else:
441
+ scaled_logits = sampled_token_latent
442
+
443
+
444
+ scaled_logits = scaled_logits.unsqueeze(0)
445
+
446
+
447
+ B_motion = torch.cat((B_motion, scaled_logits), dim=1)
448
+
449
+ scaled_logits_embedding = self.transformer.wte(scaled_logits)
450
+ xs = torch.cat((x, scaled_logits_embedding), dim=1)
451
+
452
+
453
+ return xs, B_motion
454
+
455
+ def sample_for_eval_CFG_babel_inference(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3):
456
+
457
+ import clip
458
+ B_token_length = length // unit_length - A_motion.shape[0]
459
+
460
+ if tokenizer == 'clip':
461
+ A_text = clip.tokenize(A_text, truncate=True).to(device)
462
+ A_feat_clip_text = clip_model.encode_text(A_text).float()
463
+ B_text = clip.tokenize(B_text, truncate=True).to(device)
464
+ B_feat_clip_text = clip_model.encode_text(B_text).float()
465
+ elif tokenizer == 't5-xxl':
466
+ A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float()
467
+ A_feat_clip_text = A_feat_clip_text.to(device)
468
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
469
+ B_feat_clip_text = B_feat_clip_text.to(device)
470
+
471
+ A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0)
472
+ A_text_embeddings = A_text_embeddings.unsqueeze(0)
473
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
474
+ B_text_embeddings = B_text_embeddings.unsqueeze(0)
475
+
476
+ A_motion = A_motion.unsqueeze(0)
477
+ A_motion_embeddings = self.transformer.wte(A_motion)
478
+ B_motion = torch.tensor([]).to(device)
479
+
480
+ attention_weights = []
481
+
482
+ for k in range(B_token_length):
483
+ if k == 0:
484
+ x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1)
485
+
486
+ else:
487
+ x = xs
488
+
489
+
490
+
491
+ conditions = self.forward_babel_eval(x, return_attention=False)
492
+ conditions = conditions[:, -1, :]
493
+
494
+ empty_clip_text = ''
495
+ if tokenizer == 'clip':
496
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
497
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
498
+ elif tokenizer == 't5-xxl':
499
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
500
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
501
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
502
+
503
+ empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
504
+
505
+ if k == 0:
506
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1)
507
+ empty_conditions = self.forward_babel_eval(empty_input)
508
+ else:
509
+ B_motion_embeddings = self.transformer.wte(B_motion)
510
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1)
511
+ empty_conditions = self.forward_babel_eval(empty_input)
512
+
513
+ empty_conditions = empty_conditions[:, -1, :]
514
+ temperature = 1.0
515
+
516
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
517
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
518
+
519
+ # chunk
520
+ if cfg != 1:
521
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
522
+ else:
523
+ scaled_logits = sampled_token_latent
524
+
525
+ scaled_logits = scaled_logits.unsqueeze(0)
526
+
527
+ if reference_end_token is not None:
528
+ distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
529
+ print(distance_l2)
530
+ if distance_l2 < threshold:
531
+ break
532
+
533
+ B_motion = torch.cat((B_motion, scaled_logits), dim=1)
534
+
535
+ scaled_logits_embedding = self.transformer.wte(scaled_logits)
536
+ xs = torch.cat((x, scaled_logits_embedding), dim=1)
537
+
538
+
539
+
540
+ return xs, B_motion
541
+
542
+
543
+ def sample_for_eval_CFG_babel_inference_new(self, B_text, A_motion, if_categorial=False, length=78, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3):
544
+
545
+ import clip
546
+ B_token_length = length // unit_length
547
+
548
+ if tokenizer == 'clip':
549
+ A_text = clip.tokenize(A_text, truncate=True).to(device)
550
+ A_feat_clip_text = clip_model.encode_text(A_text).float()
551
+ B_text = clip.tokenize(B_text, truncate=True).to(device)
552
+ B_feat_clip_text = clip_model.encode_text(B_text).float()
553
+ elif tokenizer == 't5-xxl':
554
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
555
+ B_feat_clip_text = B_feat_clip_text.to(device)
556
+
557
+ empty_clip_text = ''
558
+ if tokenizer == 'clip':
559
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
560
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
561
+ elif tokenizer == 't5-xxl':
562
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
563
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
564
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
565
+
566
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
567
+
568
+ A_motion = A_motion.unsqueeze(0)
569
+ A_motion_embeddings = self.transformer.wte(A_motion)
570
+ B_motion = torch.tensor([]).to(device)
571
+
572
+
573
+ attention_weights = []
574
+
575
+ for k in range(B_token_length):
576
+ if k == 0:
577
+ x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1)
578
+ else:
579
+ x = xs
580
+
581
+ conditions = self.forward_babel_eval(x, return_attention=False)
582
+ conditions = conditions[:, -1, :]
583
+
584
+
585
+ empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
586
+
587
+ if k == 0:
588
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1)
589
+
590
+ empty_conditions = self.forward_babel_eval(empty_input)
591
+ else:
592
+ B_motion_embeddings = self.transformer.wte(B_motion)
593
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1)
594
+ empty_conditions = self.forward_babel_eval(empty_input)
595
+
596
+ empty_conditions = empty_conditions[:, -1, :]
597
+ temperature = 1.0
598
+
599
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
600
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
601
+
602
+ # chunk
603
+ if cfg != 1:
604
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
605
+ else:
606
+ scaled_logits = sampled_token_latent
607
+
608
+ scaled_logits = scaled_logits.unsqueeze(0)
609
+
610
+ if reference_end_token is not None:
611
+ distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
612
+ print(distance_l2)
613
+ if distance_l2 < threshold:
614
+ break
615
+
616
+ B_motion = torch.cat((B_motion, scaled_logits), dim=1)
617
+
618
+ scaled_logits_embedding = self.transformer.wte(scaled_logits)
619
+ xs = torch.cat((x, scaled_logits_embedding), dim=1)
620
+
621
+
622
+
623
+ return xs, B_motion
624
+
625
+
626
+ def sample_for_eval_CFG_babel_inference_new_demo(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0):
627
+
628
+ import clip
629
+ B_token_length = length // unit_length - A_motion.shape[0]
630
+
631
+ if tokenizer == 'clip':
632
+ A_text = clip.tokenize(A_text, truncate=True).to(device)
633
+ A_feat_clip_text = clip_model.encode_text(A_text).float()
634
+ B_text = clip.tokenize(B_text, truncate=True).to(device)
635
+ B_feat_clip_text = clip_model.encode_text(B_text).float()
636
+ elif tokenizer == 't5-xxl':
637
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
638
+ B_feat_clip_text = B_feat_clip_text.to(device)
639
+
640
+ empty_clip_text = ''
641
+ if tokenizer == 'clip':
642
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
643
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
644
+ elif tokenizer == 't5-xxl':
645
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
646
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
647
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
648
+
649
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
650
+ B_text_embeddings = B_text_embeddings.unsqueeze(0)
651
+
652
+ A_motion = A_motion.unsqueeze(0)
653
+ A_motion_embeddings = self.transformer.wte(A_motion)
654
+ B_motion = torch.tensor([]).to(device)
655
+
656
+ # 存储所有层的注意力权重
657
+ attention_weights = []
658
+
659
+ for k in range(B_token_length):
660
+ if k == 0:
661
+ x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1)
662
+
663
+ else:
664
+ x = xs
665
+
666
+
667
+ conditions = self.forward_babel_eval(x, return_attention=False)
668
+ conditions = conditions[:, -1, :]
669
+
670
+
671
+ empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
672
+
673
+ if k == 0:
674
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1)
675
+ empty_conditions = self.forward_babel_eval(empty_input)
676
+ else:
677
+ B_motion_embeddings = self.transformer.wte(B_motion)
678
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1)
679
+ empty_conditions = self.forward_babel_eval(empty_input)
680
+
681
+ empty_conditions = empty_conditions[:, -1, :]
682
+
683
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
684
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
685
+
686
+ # chunk
687
+ if cfg != 1:
688
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
689
+ else:
690
+ scaled_logits = sampled_token_latent
691
+
692
+ scaled_logits = scaled_logits.unsqueeze(0)
693
+
694
+ if reference_end_token is not None:
695
+ distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
696
+ print(distance_l2)
697
+ if distance_l2 < threshold and k > 10:
698
+ break
699
+
700
+ B_motion = torch.cat((B_motion, scaled_logits), dim=1)
701
+
702
+ scaled_logits_embedding = self.transformer.wte(scaled_logits)
703
+ xs = torch.cat((x, scaled_logits_embedding), dim=1)
704
+
705
+
706
+
707
+ return xs, B_motion
708
+
709
+ def sample_for_eval_CFG_babel_inference_two_forward(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0):
710
+ """
711
+ Inference loop that mimics the "Two-Forward" training strategy.
712
+ This version correctly performs two full passes over the entire sequence.
713
+ """
714
+ import clip
715
+ B_token_length = length // unit_length - A_motion.shape[0]
716
+
717
+ if tokenizer == 't5-xxl':
718
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float().to(device)
719
+ else:
720
+ raise NotImplementedError("Only t5-xxl is supported for this function.")
721
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode('')).float().unsqueeze(0).to(device)
722
+
723
+ # --- Create 3D embeddings [batch, seq, dim] ---
724
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0).unsqueeze(0)
725
+ empty_text_embeddings = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) # This is [1, 1, 768]
726
+
727
+ A_motion_embeddings = self.transformer.wte(A_motion.unsqueeze(0))
728
+
729
+ # === 1. First Forward Pass (Generate Rough Draft) ===
730
+ rough_motion_tokens = A_motion
731
+ for k in range(B_token_length):
732
+ current_rough_embeddings = self.transformer.wte(rough_motion_tokens.unsqueeze(0))
733
+
734
+ # Conditioned
735
+ x_cond = torch.cat([B_text_embeddings, current_rough_embeddings], dim=1)
736
+ conditions = self.forward_babel_eval(x_cond, return_attention=False)[:, -1, :]
737
+
738
+ # Unconditioned
739
+ x_uncond = torch.cat([empty_text_embeddings, current_rough_embeddings], dim=1)
740
+ empty_conditions = self.forward_babel_eval(x_uncond, return_attention=False)[:, -1, :]
741
+
742
+ # Sample a rough prediction for the next token
743
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
744
+ pred_xstart_rough = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
745
+ if cfg != 1:
746
+ pred_xstart_rough, _ = pred_xstart_rough.chunk(2, dim=0)
747
+
748
+ rough_motion_tokens = torch.cat([rough_motion_tokens, pred_xstart_rough], dim=0)
749
+
750
+ # === 2. Second Forward Pass (Generate Refined Motion) ===
751
+ # Now we have the full rough draft. We use it as the input for the second pass.
752
+ refined_motion_tokens = A_motion
753
+ for k in range(B_token_length):
754
+ # The input to the transformer is the full rough sequence
755
+ rough_embeddings = self.transformer.wte(rough_motion_tokens.unsqueeze(0))
756
+
757
+ # Conditioned
758
+ x_cond_refined = torch.cat([B_text_embeddings, rough_embeddings], dim=1)
759
+ # We take the condition corresponding to the token we want to predict
760
+ conditions_refined = self.forward_babel_eval(x_cond_refined, return_attention=False)[:, A_motion.shape[0] + k, :]
761
+
762
+ # Unconditioned
763
+ x_uncond_refined = torch.cat([empty_text_embeddings, rough_embeddings], dim=1)
764
+ empty_conditions_refined = self.forward_babel_eval(x_uncond_refined, return_attention=False)[:, A_motion.shape[0] + k, :]
765
+
766
+ # Sample the final, refined token
767
+ mix_conditions_refined = torch.cat([conditions_refined, empty_conditions_refined], dim=0)
768
+ final_token, _ = self.diff_loss.sample(mix_conditions_refined, temperature=temperature, cfg=cfg).chunk(2, dim=0)
769
+
770
+ # Append the refined token to our final output history
771
+ refined_motion_tokens = torch.cat([refined_motion_tokens, final_token], dim=0)
772
+
773
+ # IMPORTANT: For the next step, we must update the "rough draft" with our new refined token
774
+ # This mimics the training where the input is a mix of GT and predictions.
775
+ # Here, it's a mix of the initial rough draft and the new refined tokens.
776
+ rough_motion_tokens[A_motion.shape[0] + k] = final_token.squeeze(0)
777
+
778
+ # Return only the newly generated tokens (B_motion)
779
+ B_motion = refined_motion_tokens[A_motion.shape[0]:, :].unsqueeze(0)
780
+ return None, B_motion
781
+
782
+
783
+ #--------------Test classification head--------------------
784
+ def sample_for_eval_classification(self, clip_text, if_categorial=False, length=196, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4):
785
+
786
+ import clip
787
+
788
+
789
+ for k in range(51):
790
+ if k == 0:
791
+ x = []
792
+ else:
793
+ x = xs
794
+
795
+ if tokenizer == 'clip':
796
+ text = clip.tokenize(clip_text, truncate=True).to(device)
797
+
798
+ feat_clip_text = clip_model.encode_text(text).float()
799
+ elif tokenizer == 't5-xxl':
800
+ feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
801
+
802
+ conditions = self.forward(x, feat_clip_text)
803
+ conditions = conditions[:, -1, :]
804
+
805
+ empty_clip_text = ''
806
+ if tokenizer == 'clip':
807
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
808
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
809
+ elif tokenizer == 't5-xxl':
810
+ empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float()
811
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
812
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
813
+
814
+ empty_conditions = self.forward(x, empty_feat_clip_text)
815
+ empty_conditions = empty_conditions[:, -1, :]
816
+
817
+ temperature = 1.0
818
+ cfg = 7.5
819
+
820
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
821
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
822
+
823
+ # chunk
824
+ if cfg != 1:
825
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
826
+ else:
827
+ scaled_logits = sampled_token_latent
828
+
829
+
830
+ prediction_logits = self.classify_head(conditions)
831
+ probs = torch.sigmoid(prediction_logits)
832
+ predicted_classes = torch.argmax(probs, dim=-1)
833
+
834
+
835
+ scaled_logits = scaled_logits.unsqueeze(0)
836
+
837
+ if k == 0:
838
+ xs = scaled_logits
839
+ else:
840
+ xs = torch.cat((xs, scaled_logits), dim=1)
841
+
842
+ if predicted_classes == 1:
843
+ break
844
+
845
+ return xs
846
+
847
+
848
+ #--------------------Test CFG-----------------------
849
+ def sample_for_eval_CFG_test(self, clip_text, if_categorial=False, length=196, clip_model=None, cfg=1, device=torch.device('cuda'), tokenizer='clip', unit_length=4):
850
+
851
+ import clip
852
+ max_token_len = length // unit_length
853
+
854
+
855
+ for k in range(max_token_len):
856
+ if k == 0:
857
+ x = []
858
+ else:
859
+ x = xs
860
+
861
+
862
+ if cfg != 1:
863
+ if tokenizer == 'clip':
864
+ text = clip.tokenize(clip_text, truncate=True).to(device)
865
+
866
+ feat_clip_text = clip_model.encode_text(text).float()
867
+ elif tokenizer == 't5-xxl':
868
+ feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
869
+
870
+ conditions = self.forward(x, feat_clip_text)
871
+
872
+ conditions = conditions[:, -1, :]
873
+ empty_clip_text = ''
874
+ if tokenizer == 'clip':
875
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
876
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
877
+ elif tokenizer == 't5-xxl':
878
+ empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float()
879
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
880
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
881
+
882
+ empty_conditions = self.forward(x, empty_feat_clip_text)
883
+ empty_conditions = empty_conditions[:, -1, :]
884
+ temperature = 1.0
885
+
886
+
887
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
888
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
889
+
890
+ # chunk
891
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
892
+
893
+ else:
894
+ if tokenizer == 'clip':
895
+ text = clip.tokenize(clip_text, truncate=True).to(device)
896
+ feat_clip_text = clip_model.encode_text(text).float()
897
+ elif tokenizer == 't5-xxl':
898
+ feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
899
+ feat_clip_text = feat_clip_text.to(device)
900
+
901
+
902
+ conditions = self.forward(x, feat_clip_text)
903
+
904
+ conditions = conditions[:, -1, :]
905
+ temperature = 1.0
906
+ sampled_token_latent = self.diff_loss.sample(conditions, temperature=temperature, cfg=cfg)
907
+ scaled_logits = sampled_token_latent
908
+
909
+ scaled_logits = scaled_logits.unsqueeze(0)
910
+
911
+ if k == 0:
912
+ xs = scaled_logits
913
+ else:
914
+ xs = torch.cat((xs, scaled_logits), dim=1)
915
+
916
+ return xs
917
+ #--------------------------------------------------
918
+
919
+ def forward_discrete(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None) -> torch.Tensor:
920
+ """
921
+ Vector-token path: idx must be shape [B, T, input_token_dim].
922
+ If you want discrete IDs instead, you must switch wte to nn.Embedding.
923
+ """
924
+ context = None
925
+ if idx.numel() == 0:
926
+ context = self._prepare_context(clip_feature)
927
+ token_embeddings = context
928
+ if token_embeddings is None:
929
+ raise ValueError("Conditioning features are required when no motion tokens are provided.")
930
+ else:
931
+ b, t, _ = idx.size()
932
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
933
+ token_embeddings = self.transformer.wte(idx) # Linear -> [B, T, n_embd]
934
+ context = self._prepare_context(clip_feature, batch_size=b)
935
+ if context is not None:
936
+ token_embeddings = torch.cat([context, token_embeddings], dim=1)
937
+
938
+ x = token_embeddings
939
+
940
+ if use_cache and past_key_values is None:
941
+ past_key_values = [None] * len(self.transformer.h)
942
+
943
+ for i, block in enumerate(self.transformer.h):
944
+ if use_cache:
945
+ last_past = past_key_values[i]
946
+ x, presents = block(x, context=context, last_past=last_past, use_cache=use_cache)
947
+ past_key_values[i] = list(presents)
948
+ else:
949
+ x = block(x, context=context)
950
+
951
+ x = self.transformer.ln_f(x)
952
+ logits = self.out_proj(x)
953
+ return logits
954
+
955
+
956
+ def forward(self, idx: torch.Tensor, feature: Optional[torch.Tensor]) -> torch.Tensor:
957
+ """
958
+ If self._prompt_cached is True, we DO NOT concat context each call.
959
+ Instead, blocks read the cached prompt KV.
960
+ Otherwise we embed and concat context as before.
961
+ """
962
+ context = None
963
+ if len(idx) == 0:
964
+ if self._prompt_cached:
965
+ if self._prompt_bsz is None:
966
+ raise ValueError("Prompt cache set but batch size unknown.")
967
+ b = self._prompt_bsz
968
+ token_embeddings = torch.empty(b, 0, self.config.n_embd, device=self.bos.device, dtype=self.bos.dtype)
969
+ else:
970
+ context = self._prepare_context(feature)
971
+ token_embeddings = context
972
+ if token_embeddings is None:
973
+ raise ValueError("Conditioning features are required when no motion tokens are provided.")
974
+ else:
975
+ b, t, c = idx.size()
976
+ idx = idx.float()
977
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
978
+ token_embeddings = self.transformer.wte(idx)
979
+ if not self._prompt_cached:
980
+ context = self._prepare_context(feature, batch_size=b)
981
+ if context is not None:
982
+ token_embeddings = torch.cat([context, token_embeddings], dim=1)
983
+
984
+ # Always prepend BOS scene token
985
+ bos = self.bos.expand(token_embeddings.size(0), 1, -1)
986
+ x = torch.cat([bos, token_embeddings], dim=1)
987
+
988
+ # blocks: if context is None -> use cached prompt kv (if set)
989
+ for block in self.transformer.h:
990
+ x = block(x, context=context)
991
+ x = self.transformer.ln_f(x)
992
+ logits = self.out_proj(x)
993
+ return logits
994
+
995
+
996
+ def forward_inference(self, idx: torch.Tensor, feature: Optional[torch.Tensor]) -> torch.Tensor:
997
+ context = None
998
+ if len(idx) == 0:
999
+ if self._prompt_cached:
1000
+ if self._prompt_bsz is None:
1001
+ raise ValueError("Prompt cache set but batch size unknown.")
1002
+ b = self._prompt_bsz
1003
+ token_embeddings = torch.empty(b, 0, self.config.n_embd, device=self.bos.device, dtype=self.bos.dtype)
1004
+ else:
1005
+ context = self._prepare_context(feature)
1006
+ token_embeddings = context
1007
+ if token_embeddings is None:
1008
+ raise ValueError("Conditioning features are required when no motion tokens are provided.")
1009
+ else:
1010
+ b, t, c = idx.size()
1011
+ idx = idx.float()
1012
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
1013
+ token_embeddings = self.transformer.wte(idx)
1014
+ if not self._prompt_cached:
1015
+ context = self._prepare_context(feature, batch_size=b)
1016
+ if context is not None:
1017
+ token_embeddings = torch.cat([context, token_embeddings], dim=1)
1018
+
1019
+ x = token_embeddings
1020
+ if len(x.shape) == 2:
1021
+ x = x.unsqueeze(0)
1022
+
1023
+ # prepend BOS
1024
+ bos = self.bos.expand(x.size(0), 1, -1)
1025
+ x = torch.cat([bos, x], dim=1)
1026
+
1027
+ if context is not None and context.size(0) != x.size(0):
1028
+ if context.size(0) == 1:
1029
+ context = context.expand(x.size(0), -1, -1)
1030
+ else:
1031
+ raise ValueError("Conditioning batch size does not match token batch size.")
1032
+
1033
+ for block in self.transformer.h:
1034
+ x = block(x, context=context)
1035
+ x = self.transformer.ln_f(x)
1036
+ logits = self.out_proj(x)
1037
+ return logits
1038
+
1039
+
1040
+ def babel_long(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None, num_subseq=None, length=None) -> torch.Tensor:
1041
+
1042
+ b, t, c = idx.size()
1043
+ idx = idx.float()
1044
+ idx = self.transformer.wte(idx)
1045
+ assert (
1046
+ t <= self.config.block_size
1047
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
1048
+ for i in range(b):
1049
+ length_i = length[i][:num_subseq[i]]
1050
+ clip_feature_i = clip_feature[i][:num_subseq[i]]
1051
+
1052
+ pointer = 0
1053
+ for j in range(num_subseq[i]):
1054
+ if j > 0:
1055
+ pointer += length_i[j].item()
1056
+ pointer += 1
1057
+ pointer = int(pointer)
1058
+
1059
+ clip_feature_i_j = self.transformer.cond_embed(clip_feature_i[j].unsqueeze(0)).unsqueeze(1)
1060
+ idx[i] = torch.cat([idx[i][:pointer].unsqueeze(0), clip_feature_i_j, idx[i][pointer:-1].unsqueeze(0)], dim=1)[0]
1061
+
1062
+ x = idx
1063
+
1064
+ context = None
1065
+
1066
+
1067
+ if use_cache:
1068
+ if past_key_values is None:
1069
+ past_key_values = [None] * len(self.transformer.h)
1070
+
1071
+
1072
+ for i,block in enumerate(self.transformer.h):
1073
+ if use_cache:
1074
+ last_past = past_key_values[i]
1075
+ x, presents = block(x, context=context, last_past=last_past, use_cache=use_cache)
1076
+ past_key_values[i] = list(presents)
1077
+ else:
1078
+ x = block(x, context=context)
1079
+ x = self.transformer.ln_f(x)
1080
+
1081
+ logits = self.out_proj(x)
1082
+ return logits
1083
+
1084
+
1085
+ def forward_babel_eval(self, x, return_attention=False) -> torch.Tensor:
1086
+ layer_attentions = []
1087
+ context = None
1088
+ for block in self.transformer.h:
1089
+ if return_attention:
1090
+ x, att = block(x, context=context, return_attention=True)
1091
+ layer_attentions.append(att)
1092
+ else:
1093
+ x = block(x, context=context)
1094
+
1095
+ x = self.transformer.ln_f(x)
1096
+ if self.use_out_proj:
1097
+ logits = self.out_proj(x)
1098
+ else:
1099
+ logits = x
1100
+
1101
+ if return_attention:
1102
+ return logits, layer_attentions
1103
+ return logits
1104
+
1105
+ def forward_babel(self, idx: torch.Tensor, clip_feature: torch.Tensor, A_token_length) -> torch.Tensor:
1106
+ context = None
1107
+ if len(idx) == 0: # inference
1108
+ context = self._prepare_context(clip_feature)
1109
+ token_embeddings = context
1110
+
1111
+ else:
1112
+ b, t, c = idx.size()
1113
+ idx = idx.float()
1114
+ assert (
1115
+ t <= self.config.block_size
1116
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
1117
+
1118
+
1119
+
1120
+ A_feature = clip_feature[:, 0, :]
1121
+ B_feature = clip_feature[:, 1, :]
1122
+
1123
+
1124
+ A_text_embeddings = self.transformer.cond_embed(A_feature).unsqueeze(1)
1125
+ B_text_embeddings = self.transformer.cond_embed(B_feature).unsqueeze(1)
1126
+ context = torch.cat([A_text_embeddings, B_text_embeddings], dim=1)
1127
+
1128
+ token_embeddings = torch.zeros(b, self.config.block_size, self.config.n_embd).to(idx.device)
1129
+ for i in range(b):
1130
+ A_idx = idx[i, :A_token_length[i].item(), :]
1131
+ B_idx = idx[i, A_token_length[i].item():-2, :]
1132
+ token_embeddings[i, :, :] = torch.cat([A_text_embeddings[i], self.BOM_tag, self.transformer.wte(A_idx), B_text_embeddings[i], self.BOM_tag, self.transformer.wte(B_idx)], dim=0) #token_embeddings.shape = (b,t+1,1024)
1133
+
1134
+ x = token_embeddings
1135
+ if context is not None and context.size(0) != x.size(0):
1136
+ if context.size(0) == 1:
1137
+ context = context.expand(x.size(0), -1, -1)
1138
+ else:
1139
+ raise ValueError("Conditioning batch size does not match token batch size.")
1140
+ for block in self.transformer.h:
1141
+ x = block(x, context=context)
1142
+ x = self.transformer.ln_f(x)
1143
+
1144
+ if self.use_out_proj:
1145
+ logits = self.out_proj(x)
1146
+ else:
1147
+ logits = x
1148
+
1149
+
1150
+ return logits
1151
+
1152
+ def forward_babel2(self, idx: torch.Tensor, clip_feature: torch.Tensor) -> torch.Tensor:
1153
+ context = None
1154
+ if idx.numel() == 0: # inference with only context
1155
+ context = self._prepare_context(clip_feature)
1156
+ token_embeddings = context
1157
+ else:
1158
+ b, t, c = idx.size()
1159
+ idx = idx.float()
1160
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
1161
+
1162
+ B_feature = clip_feature # [B, D] or [B, 1, D]
1163
+ B_text_embeddings = self.transformer.cond_embed(B_feature) # [B, D] -> [B, D]
1164
+ if B_text_embeddings.dim() == 2:
1165
+ B_text_embeddings = B_text_embeddings.unsqueeze(1) # [B, 1, D]
1166
+ context = B_text_embeddings # [B, 1, D]
1167
+
1168
+ idx_embeddings = self.transformer.wte(idx) # [B, T, D]
1169
+ token_embeddings = torch.cat([B_text_embeddings, idx_embeddings], dim=1) # [B, 1+T, D]
1170
+
1171
+ x = token_embeddings
1172
+ if context is not None:
1173
+ if context.dim() == 2:
1174
+ context = context.unsqueeze(1)
1175
+ if context.size(0) != x.size(0):
1176
+ if context.size(0) == 1:
1177
+ context = context.expand(x.size(0), -1, -1)
1178
+ else:
1179
+ raise ValueError("Conditioning batch size does not match token batch size.")
1180
+
1181
+ for block in self.transformer.h:
1182
+ x = block(x, context=context)
1183
+ x = self.transformer.ln_f(x)
1184
+
1185
+ logits = self.out_proj(x) if self.use_out_proj else x
1186
+ return logits
1187
+
1188
+
1189
+ def resize_token_embeddings(
1190
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, using_old_initilization: bool = False
1191
+ ) -> nn.Embedding:
1192
+ """
1193
+ Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
1194
+
1195
+ Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
1196
+
1197
+ Arguments:
1198
+ new_num_tokens (`int`, *optional*):
1199
+ The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
1200
+ vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
1201
+ returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
1202
+ pad_to_multiple_of (`int`, *optional*):
1203
+ If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
1204
+ `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
1205
+
1206
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
1207
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
1208
+ details about this, or help on choosing the correct value for resizing, refer to this guide:
1209
+ https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
1210
+
1211
+ Return:
1212
+ `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
1213
+ """
1214
+ model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
1215
+ if new_num_tokens is None and pad_to_multiple_of is None:
1216
+ return model_embeds
1217
+
1218
+ # Update base model and current model config
1219
+ self.config.vocab_size = model_embeds.weight.shape[0]
1220
+ self.vocab_size = model_embeds.weight.shape[0]
1221
+
1222
+ # Tie weights again if needed
1223
+ # self.tie_weights()
1224
+
1225
+ return model_embeds
1226
+
1227
+ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
1228
+ old_embeddings = self.get_input_embeddings()
1229
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
1230
+ old_embeddings_requires_grad = old_embeddings.weight.requires_grad
1231
+ new_embeddings.requires_grad_(old_embeddings_requires_grad)
1232
+ self.set_input_embeddings(new_embeddings)
1233
+
1234
+ # Update new_num_tokens with the actual size of new_embeddings
1235
+ if pad_to_multiple_of is not None:
1236
+ # if is_deepspeed_zero3_enabled():
1237
+ # import deepspeed
1238
+
1239
+ # with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
1240
+ # new_num_tokens = new_embeddings.weight.shape[0]
1241
+ # else:
1242
+ new_num_tokens = new_embeddings.weight.shape[0]
1243
+
1244
+ # if word embeddings are not tied, make sure that lm head is resized as well
1245
+ # if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
1246
+ if self.get_output_embeddings() is not None and not False:
1247
+ old_lm_head = self.get_output_embeddings()
1248
+ new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
1249
+ # if hasattr(old_lm_head, "_hf_hook"):
1250
+ # hook = old_lm_head._hf_hook
1251
+ # add_hook_to_module(new_lm_head, hook)
1252
+ old_lm_head_requires_grad = old_lm_head.weight.requires_grad
1253
+ new_lm_head.requires_grad_(old_lm_head_requires_grad)
1254
+ self.set_output_embeddings(new_lm_head)
1255
+
1256
+ return self.get_input_embeddings()
1257
+
1258
+ def _get_resized_embeddings(
1259
+ self,
1260
+ old_embeddings: nn.Embedding,
1261
+ new_num_tokens: Optional[int] = None,
1262
+ pad_to_multiple_of: Optional[int] = None,
1263
+ ) -> nn.Embedding:
1264
+ """
1265
+ Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
1266
+ initialized vectors at the end. Reducing the size will remove vectors from the end
1267
+
1268
+ Args:
1269
+ old_embeddings (`torch.nn.Embedding`):
1270
+ Old embeddings to be resized.
1271
+ new_num_tokens (`int`, *optional*):
1272
+ New number of tokens in the embedding matrix.
1273
+
1274
+ Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
1275
+ vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
1276
+ `torch.nn.Embedding` module of the model without doing anything.
1277
+ pad_to_multiple_of (`int`, *optional*):
1278
+ If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
1279
+ `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
1280
+
1281
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
1282
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
1283
+ details about this, or help on choosing the correct value for resizing, refer to this guide:
1284
+ https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
1285
+
1286
+
1287
+ Return:
1288
+ `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
1289
+ `new_num_tokens` is `None`
1290
+ """
1291
+
1292
+ if pad_to_multiple_of is not None:
1293
+ if not isinstance(pad_to_multiple_of, int):
1294
+ raise ValueError(
1295
+ f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer"
1296
+ )
1297
+ if new_num_tokens is None:
1298
+ new_num_tokens = old_embeddings.weight.shape[0]
1299
+ new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
1300
+ else:
1301
+ print(
1302
+ "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding"
1303
+ f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available."
1304
+ " For more details about this, or help on choosing the correct value for resizing, refer to this guide:"
1305
+ " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc"
1306
+ )
1307
+
1308
+ if new_num_tokens is None:
1309
+ return old_embeddings
1310
+
1311
+ # if is_deepspeed_zero3_enabled():
1312
+ if False:
1313
+ import deepspeed
1314
+
1315
+ with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
1316
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
1317
+ else:
1318
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
1319
+
1320
+ # if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
1321
+ if old_num_tokens == new_num_tokens and not False:
1322
+ return old_embeddings
1323
+
1324
+ if not isinstance(old_embeddings, nn.Embedding):
1325
+ raise TypeError(
1326
+ f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
1327
+ " should either use a different resize function or make sure that `old_embeddings` are an instance of"
1328
+ f" {nn.Embedding}."
1329
+ )
1330
+
1331
+ # Build new embeddings
1332
+
1333
+ # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
1334
+ # because the shape of the new embedding layer is used across various modeling files
1335
+ # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
1336
+ # to errors when training.
1337
+ new_embeddings = nn.Embedding(
1338
+ new_num_tokens,
1339
+ old_embedding_dim,
1340
+ device=old_embeddings.weight.device,
1341
+ dtype=old_embeddings.weight.dtype,
1342
+ )
1343
+
1344
+ # initialize all new embeddings (in particular added tokens)
1345
+ self._init_weights(new_embeddings)
1346
+
1347
+ # Copy token embeddings from the previous weights
1348
+
1349
+ # numbers of tokens to copy
1350
+ n = min(old_num_tokens, new_num_tokens)
1351
+
1352
+ # if is_deepspeed_zero3_enabled():
1353
+ if False:
1354
+ import deepspeed
1355
+
1356
+ params = [old_embeddings.weight, new_embeddings.weight]
1357
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
1358
+ new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
1359
+ else:
1360
+ new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
1361
+
1362
+ return new_embeddings
1363
+
1364
+
1365
+ def _get_resized_lm_head(
1366
+ self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
1367
+ ) -> nn.Linear:
1368
+ """
1369
+ Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
1370
+ vectors at the end. Reducing the size will remove vectors from the end
1371
+
1372
+ Args:
1373
+ old_lm_head (`torch.nn.Linear`):
1374
+ Old lm head liner layer to be resized.
1375
+ new_num_tokens (`int`, *optional*):
1376
+ New number of tokens in the linear matrix.
1377
+
1378
+ Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
1379
+ vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
1380
+ `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults
1381
+ to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim,
1382
+ vocab_size` else `vocab_size, lm_head_dim`.
1383
+
1384
+ Return:
1385
+ `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
1386
+ `None`
1387
+ """
1388
+ if new_num_tokens is None:
1389
+ return old_lm_head
1390
+
1391
+ # if is_deepspeed_zero3_enabled():
1392
+ if False:
1393
+ import deepspeed
1394
+
1395
+ with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
1396
+ old_num_tokens, old_lm_head_dim = (
1397
+ old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
1398
+ )
1399
+ else:
1400
+ old_num_tokens, old_lm_head_dim = (
1401
+ old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
1402
+ )
1403
+
1404
+ # if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
1405
+ if old_num_tokens == new_num_tokens and not False:
1406
+ return old_lm_head
1407
+
1408
+ if not isinstance(old_lm_head, nn.Linear):
1409
+ raise TypeError(
1410
+ f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You"
1411
+ " should either use a different resize function or make sure that `old_lm_head` are an instance of"
1412
+ f" {nn.Linear}."
1413
+ )
1414
+
1415
+ # Build new lm head
1416
+ new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
1417
+ has_new_lm_head_bias = old_lm_head.bias is not None
1418
+
1419
+ # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
1420
+ # because the shape of the new embedding layer is used across various modeling files
1421
+ # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
1422
+ # to errors when training.
1423
+ new_lm_head = nn.Linear(
1424
+ *new_lm_head_shape,
1425
+ bias=has_new_lm_head_bias,
1426
+ device=old_lm_head.weight.device,
1427
+ dtype=old_lm_head.weight.dtype,
1428
+ )
1429
+
1430
+ # initialize new lm head (in particular added tokens)
1431
+ self._init_weights(new_lm_head)
1432
+
1433
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
1434
+
1435
+ # if is_deepspeed_zero3_enabled():
1436
+ if False:
1437
+ import deepspeed
1438
+
1439
+ params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
1440
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
1441
+ self._copy_lm_head_original_to_resized(
1442
+ new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
1443
+ )
1444
+ else:
1445
+ self._copy_lm_head_original_to_resized(
1446
+ new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
1447
+ )
1448
+
1449
+ return new_lm_head
1450
+
1451
+ def _copy_lm_head_original_to_resized(
1452
+ self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
1453
+ ):
1454
+ # Copy old lm head weights to new lm head
1455
+ if not transposed:
1456
+ new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
1457
+ else:
1458
+ new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
1459
+
1460
+ # Copy bias weights to new lm head
1461
+ if has_new_lm_head_bias:
1462
+ new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
1463
+
1464
+ @classmethod
1465
+ def from_name(cls, name: str) -> Self:
1466
+ return cls(LLaMAHFConfig.from_name(name))
1467
+
1468
+
1469
+ class Block(nn.Module):
1470
+ def __init__(self, config: LLaMAHFConfig) -> None:
1471
+ super().__init__()
1472
+ self.rms_1 = RMSNorm(config.n_embd)
1473
+ self.attn = CausalSelfAttention(config)
1474
+ self.rms_cross = RMSNorm(config.n_embd)
1475
+ self.cross_attn = CrossAttention(config)
1476
+ self.rms_2 = RMSNorm(config.n_embd)
1477
+ self.mlp = MLP(config)
1478
+ # cached prompt kv (precomputed by set_prompt)
1479
+ self._ctx_k_repeat = None
1480
+ self._ctx_v_repeat = None
1481
+ self._ctx_bsz = None
1482
+
1483
+ @torch.no_grad()
1484
+ def set_context_cache(self, context: torch.Tensor):
1485
+ # Precompute KV for cross attention and repeat across kv groups
1486
+ B, S, D = context.shape
1487
+ ca = self.cross_attn
1488
+ k = ca.k_proj(context).view(B, S, ca.n_kv_head, ca.head_dim).transpose(1, 2)
1489
+ v = ca.v_proj(context).view(B, S, ca.n_kv_head, ca.head_dim).transpose(1, 2)
1490
+ k = ca.k_norm(k)
1491
+ # repeat K/V to match heads
1492
+ self._ctx_k_repeat = repeat_kv(k, ca.num_kv_groups) # [B, n_head, S, d]
1493
+ self._ctx_v_repeat = repeat_kv(v, ca.num_kv_groups) # [B, n_head, S, d]
1494
+ self._ctx_bsz = B
1495
+
1496
+ @torch.no_grad()
1497
+ def clear_context_cache(self):
1498
+ self._ctx_k_repeat = None
1499
+ self._ctx_v_repeat = None
1500
+ self._ctx_bsz = None
1501
+
1502
+ def _cross_attend_cached(self, x: torch.Tensor):
1503
+ # x: [B, T, D]
1504
+ if self._ctx_k_repeat is None or self._ctx_v_repeat is None:
1505
+ return x # no-op if no cached prompt
1506
+ B, T, _ = x.size()
1507
+ if self._ctx_bsz is not None and self._ctx_bsz != B:
1508
+ # different batch: ignore cache (or you could raise)
1509
+ return x
1510
+ ca = self.cross_attn
1511
+ q = ca.q_proj(x).view(B, T, ca.n_head, ca.head_dim).transpose(1, 2)
1512
+ q = ca.q_norm(q)
1513
+ y = F.scaled_dot_product_attention(
1514
+ q, self._ctx_k_repeat, self._ctx_v_repeat,
1515
+ attn_mask=None, dropout_p=0.0, is_causal=False, scale=ca.softmax_scale,
1516
+ )
1517
+ y = y.transpose(1, 2).contiguous().view(B, T, ca.n_head * ca.head_dim)
1518
+ return x + ca.o_proj(y)
1519
+
1520
+ def forward(
1521
+ self,
1522
+ x: torch.Tensor,
1523
+ context: Optional[torch.Tensor] = None,
1524
+ last_past=None,
1525
+ use_cache: bool = False,
1526
+ return_attention: bool = False,
1527
+ ) -> torch.Tensor:
1528
+ present = None
1529
+ # self-attn
1530
+ if use_cache:
1531
+ if return_attention:
1532
+ attn_output, attn = self.attn.forward_attn(self.rms_1(x), last_past, use_cache)
1533
+ else:
1534
+ attn_output, present = self.attn(self.rms_1(x), last_past, use_cache)
1535
+ x = x + attn_output
1536
+ else:
1537
+ if return_attention:
1538
+ attn_output, attn = self.attn.forward_attn(self.rms_1(x))
1539
+ else:
1540
+ attn_output = self.attn(self.rms_1(x))
1541
+ x = x + attn_output
1542
+
1543
+ # cross-attn: prefer live context if provided; else use cached prompt kv
1544
+ if context is not None:
1545
+ x = x + self.cross_attn(self.rms_cross(x), context)
1546
+ else:
1547
+ x = self._cross_attend_cached(self.rms_cross(x))
1548
+
1549
+ # mlp
1550
+ x = x + self.mlp(self.rms_2(x))
1551
+
1552
+ if use_cache:
1553
+ if return_attention:
1554
+ return x, present, attn
1555
+ else:
1556
+ return x, present
1557
+ else:
1558
+ if return_attention:
1559
+ return x, attn
1560
+ else:
1561
+ return x
1562
+
1563
+
1564
+
1565
+ class CausalSelfAttention(nn.Module):
1566
+ def __init__(self, config: LLaMAHFConfig) -> None:
1567
+ super().__init__()
1568
+ assert config.n_embd % config.n_head == 0
1569
+
1570
+ self.n_head = config.n_head
1571
+ self.n_kv_head = config.n_kv_head or max(1, config.n_head // 4)
1572
+ assert self.n_head % self.n_kv_head == 0, "n_head must be divisible by n_kv_head"
1573
+ self.head_dim = config.n_embd // config.n_head
1574
+ self.block_size = config.block_size
1575
+ self.rope_base = config.rope_base
1576
+ self.rope_cache = None
1577
+ self.num_kv_groups = self.n_head // self.n_kv_head
1578
+
1579
+ self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False)
1580
+ self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
1581
+ self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
1582
+ self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
1583
+
1584
+ self.q_norm = RMSNorm(self.head_dim)
1585
+ self.k_norm = RMSNorm(self.head_dim)
1586
+
1587
+ self.softmax_scale = self.head_dim ** -0.5
1588
+
1589
+ def forward(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor:
1590
+ B, T, _ = x.size()
1591
+
1592
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
1593
+ k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
1594
+ v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
1595
+
1596
+ q = self.q_norm(q)
1597
+ k = self.k_norm(k)
1598
+
1599
+ if (
1600
+ self.rope_cache is None
1601
+ or self.rope_cache.dtype != x.dtype
1602
+ or self.rope_cache.device != x.device
1603
+ ):
1604
+ self.rope_cache = build_rope_cache(
1605
+ seq_len=self.block_size,
1606
+ n_elem=self.head_dim,
1607
+ dtype=x.dtype,
1608
+ device=x.device,
1609
+ base=self.rope_base,
1610
+ )
1611
+
1612
+ q = apply_rope(q, self.rope_cache)
1613
+ k = apply_rope(k, self.rope_cache)
1614
+
1615
+ if use_cache:
1616
+ if last_past is not None:
1617
+ past_key, past_value = last_past
1618
+ k = torch.cat([past_key, k], dim=-2)
1619
+ v = torch.cat([past_value, v], dim=-2)
1620
+ present = (k, v)
1621
+ else:
1622
+ present = None
1623
+
1624
+ k_repeat = repeat_kv(k, self.num_kv_groups)
1625
+ v_repeat = repeat_kv(v, self.num_kv_groups)
1626
+
1627
+ y = F.scaled_dot_product_attention(
1628
+ q,
1629
+ k_repeat,
1630
+ v_repeat,
1631
+ attn_mask=None,
1632
+ dropout_p=0.0,
1633
+ is_causal=True,
1634
+ scale=self.softmax_scale,
1635
+ )
1636
+
1637
+ y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim)
1638
+ y = self.o_proj(y)
1639
+
1640
+ if use_cache:
1641
+ return y, present
1642
+ return y
1643
+
1644
+ def forward_attn(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor:
1645
+ B, T, _ = x.size()
1646
+
1647
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
1648
+ k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
1649
+ v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
1650
+
1651
+ q = self.q_norm(q)
1652
+ k = self.k_norm(k)
1653
+
1654
+ if (
1655
+ self.rope_cache is None
1656
+ or self.rope_cache.dtype != x.dtype
1657
+ or self.rope_cache.device != x.device
1658
+ ):
1659
+ self.rope_cache = build_rope_cache(
1660
+ seq_len=self.block_size,
1661
+ n_elem=self.head_dim,
1662
+ dtype=x.dtype,
1663
+ device=x.device,
1664
+ base=self.rope_base,
1665
+ )
1666
+
1667
+ q = apply_rope(q, self.rope_cache)
1668
+ k = apply_rope(k, self.rope_cache)
1669
+
1670
+ if use_cache:
1671
+ if last_past is not None:
1672
+ past_key, past_value = last_past
1673
+ k = torch.cat([past_key, k], dim=-2)
1674
+ v = torch.cat([past_value, v], dim=-2)
1675
+
1676
+ k_repeat = repeat_kv(k, self.num_kv_groups)
1677
+ v_repeat = repeat_kv(v, self.num_kv_groups)
1678
+
1679
+ att = torch.matmul(q, k_repeat.transpose(-2, -1)) * self.softmax_scale
1680
+ att = F.softmax(att, dim=-1)
1681
+
1682
+ y = torch.matmul(att, v_repeat)
1683
+ y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim)
1684
+ y = self.o_proj(y)
1685
+
1686
+ return y, att
1687
+
1688
+
1689
+ class CrossAttention(nn.Module):
1690
+ def __init__(self, config: LLaMAHFConfig) -> None:
1691
+ super().__init__()
1692
+ assert config.n_embd % config.n_head == 0
1693
+
1694
+ self.n_head = config.n_head
1695
+ self.n_kv_head = config.n_kv_head or max(1, config.n_head // 4)
1696
+ assert self.n_head % self.n_kv_head == 0, "n_head must be divisible by n_kv_head"
1697
+ self.head_dim = config.n_embd // config.n_head
1698
+ self.num_kv_groups = self.n_head // self.n_kv_head
1699
+
1700
+ self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False)
1701
+ self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
1702
+ self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
1703
+ self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
1704
+
1705
+ self.q_norm = RMSNorm(self.head_dim)
1706
+ self.k_norm = RMSNorm(self.head_dim)
1707
+
1708
+ self.softmax_scale = self.head_dim ** -0.5
1709
+
1710
+ def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
1711
+ B, T, _ = x.size()
1712
+ _, S, _ = context.size()
1713
+
1714
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
1715
+ k = self.k_proj(context).view(B, S, self.n_kv_head, self.head_dim).transpose(1, 2)
1716
+ v = self.v_proj(context).view(B, S, self.n_kv_head, self.head_dim).transpose(1, 2)
1717
+
1718
+ q = self.q_norm(q)
1719
+ k = self.k_norm(k)
1720
+
1721
+ k_repeat = repeat_kv(k, self.num_kv_groups)
1722
+ v_repeat = repeat_kv(v, self.num_kv_groups)
1723
+
1724
+ y = F.scaled_dot_product_attention(
1725
+ q,
1726
+ k_repeat,
1727
+ v_repeat,
1728
+ attn_mask=None,
1729
+ dropout_p=0.0,
1730
+ is_causal=False,
1731
+ scale=self.softmax_scale,
1732
+ )
1733
+
1734
+ y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim)
1735
+ return self.o_proj(y)
1736
+
1737
+
1738
+ def repeat_kv(hidden_states: torch.Tensor, num_groups: int) -> torch.Tensor:
1739
+ if num_groups == 1:
1740
+ return hidden_states
1741
+ bsz, n_kv, seq_len, head_dim = hidden_states.shape
1742
+ hidden_states = hidden_states.unsqueeze(2).expand(bsz, n_kv, num_groups, seq_len, head_dim)
1743
+ return hidden_states.reshape(bsz, n_kv * num_groups, seq_len, head_dim)
1744
+
1745
+
1746
+ class LengthCausalSelfAttention(nn.Module):
1747
+ def __init__(self, config: LLaMAHFConfig) -> None:
1748
+ super().__init__()
1749
+ assert config.n_embd % config.n_head == 0
1750
+
1751
+ self.n_head = config.n_head
1752
+ self.n_kv_head = config.n_kv_head or max(1, config.n_head // 4)
1753
+ assert self.n_head % self.n_kv_head == 0, "n_head must be divisible by n_kv_head"
1754
+ self.head_dim = config.n_embd // config.n_head
1755
+ self.block_size = config.block_size
1756
+ self.rope_base = config.rope_base
1757
+ self.rope_cache = None
1758
+ self.num_kv_groups = self.n_head // self.n_kv_head
1759
+
1760
+ self.q_proj = nn.Linear(config.n_embd, self.n_head * self.head_dim, bias=False)
1761
+ self.k_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
1762
+ self.v_proj = nn.Linear(config.n_embd, self.n_kv_head * self.head_dim, bias=False)
1763
+ self.o_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
1764
+
1765
+ self.q_norm = RMSNorm(self.head_dim)
1766
+ self.k_norm = RMSNorm(self.head_dim)
1767
+
1768
+ self.softmax_scale = self.head_dim ** -0.5
1769
+
1770
+ def forward(self, x: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor:
1771
+ B, T, _ = x.size()
1772
+
1773
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
1774
+ k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
1775
+ v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2)
1776
+
1777
+ q = self.q_norm(q)
1778
+ k = self.k_norm(k)
1779
+
1780
+ if (
1781
+ self.rope_cache is None
1782
+ or self.rope_cache.dtype != x.dtype
1783
+ or self.rope_cache.device != x.device
1784
+ ):
1785
+ self.rope_cache = build_rope_cache(
1786
+ seq_len=self.block_size,
1787
+ n_elem=self.head_dim,
1788
+ dtype=x.dtype,
1789
+ device=x.device,
1790
+ base=self.rope_base,
1791
+ )
1792
+
1793
+ q = apply_rope(q, self.rope_cache)
1794
+ k = apply_rope(k, self.rope_cache)
1795
+
1796
+ attn_mask = torch.ones(T, T, dtype=torch.bool, device=x.device)
1797
+ attn_mask = torch.tril(attn_mask)
1798
+ attn_mask = attn_mask.unsqueeze(0).expand(B, -1, -1)
1799
+
1800
+ text_mask = y_mask.unsqueeze(2) * y_mask.unsqueeze(1)
1801
+ text_mask = F.pad(text_mask, (0, T - y_mask.shape[1], 0, T - y_mask.shape[1]), mode='constant', value=0)
1802
+ attn_mask = torch.logical_or(attn_mask, text_mask)
1803
+
1804
+ k_repeat = repeat_kv(k, self.num_kv_groups)
1805
+ v_repeat = repeat_kv(v, self.num_kv_groups)
1806
+
1807
+ y = F.scaled_dot_product_attention(
1808
+ q,
1809
+ k_repeat,
1810
+ v_repeat,
1811
+ attn_mask=attn_mask.unsqueeze(1),
1812
+ dropout_p=0.0,
1813
+ is_causal=False,
1814
+ scale=self.softmax_scale,
1815
+ )
1816
+
1817
+ y = y.transpose(1, 2).contiguous().view(B, T, self.n_head * self.head_dim)
1818
+ y = self.o_proj(y)
1819
+
1820
+ return y
1821
+
1822
+
1823
+ class MLP(nn.Module):
1824
+ def __init__(self, config: LLaMAHFConfig) -> None:
1825
+ super().__init__()
1826
+ hidden_dim = 4 * config.n_embd
1827
+ n_hidden = int(2 * hidden_dim / 3)
1828
+ N = 256
1829
+ # ensure n_hidden is multiple of N
1830
+ n_hidden = ((n_hidden - 1) // N) * N + N
1831
+
1832
+ self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
1833
+ self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
1834
+ self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)
1835
+
1836
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1837
+
1838
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
1839
+ x = self.c_proj(x)
1840
+ return x
1841
+
1842
+
1843
+ class RMSNorm(nn.Module):
1844
+ """Root Mean Square Layer Normalization.
1845
+
1846
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
1847
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
1848
+ """
1849
+
1850
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
1851
+ super().__init__()
1852
+ self.scale = nn.Parameter(torch.ones(size))
1853
+ self.eps = eps
1854
+ self.dim = dim
1855
+
1856
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1857
+ # NOTE: the original RMSNorm paper implementation is not equivalent
1858
+ # norm_x = x.norm(2, dim=self.dim, keepdim=True)
1859
+ # rms_x = norm_x * d_x ** (-1. / 2)
1860
+ # x_normed = x / (rms_x + self.eps)
1861
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
1862
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
1863
+ return self.scale * x_normed
1864
+
1865
+
1866
+ def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> torch.Tensor:
1867
+ """
1868
+ Rotary-position cache with safe dtype handling.
1869
+ """
1870
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
1871
+ seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
1872
+ idx_theta = torch.outer(seq_idx, theta)
1873
+
1874
+ # cast to float32 for torch.polar when needed
1875
+ dtypes_requiring_casting = [torch.float16, torch.bfloat16, torch.int8]
1876
+ working_dtype = torch.float32 if dtype in dtypes_requiring_casting else dtype
1877
+ complex_dtype = torch.complex64 # torch.complex32 does not exist
1878
+
1879
+ cache = torch.polar(torch.ones_like(idx_theta, dtype=working_dtype, device=device),
1880
+ idx_theta.to(working_dtype)).to(complex_dtype)
1881
+ return cache
1882
+
1883
+
1884
+ def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
1885
+ x = x.transpose(1, 2)
1886
+
1887
+ # truncate to support variable sizes
1888
+ T = x.size(1)
1889
+ rope_cache = rope_cache[:T]
1890
+ # cast because `view_as_complex` does not support 16 bit tensors
1891
+ xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
1892
+ rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3))
1893
+ x_out = torch.view_as_real(xc * rope_cache).flatten(3)
1894
+ return x_out.transpose(1, 2).type_as(x)