AbstractPhil commited on
Commit
27fa1b7
·
verified ·
1 Parent(s): 4adc3d8

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +25 -8
model.py CHANGED
@@ -78,23 +78,26 @@ class RotaryEmbedding(nn.Module):
78
  self.axes_dims = axes_dims # (temporal, height, width)
79
  self.theta = theta
80
 
81
- def forward(self, ids: torch.Tensor) -> torch.Tensor:
82
  """
83
  ids: (B, N, 3) - temporal, height, width indices
 
84
  Returns: (B, N, dim) rotary embeddings
85
  """
86
  B, N, _ = ids.shape
87
  device = ids.device
88
- dtype = torch.float32
 
 
89
 
90
  embeddings = []
91
  dim_offset = 0
92
 
93
  for axis_idx, axis_dim in enumerate(self.axes_dims):
94
  # Compute frequencies for this axis
95
- freqs = 1.0 / (self.theta ** (torch.arange(0, axis_dim, 2, device=device, dtype=dtype) / axis_dim))
96
  # Get positions for this axis
97
- positions = ids[:, :, axis_idx].float() # (B, N)
98
  # Outer product: (B, N) x (axis_dim/2) -> (B, N, axis_dim/2)
99
  angles = positions.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0)
100
  # Interleave sin/cos
@@ -103,7 +106,8 @@ class RotaryEmbedding(nn.Module):
103
  embeddings.append(emb)
104
  dim_offset += axis_dim
105
 
106
- return torch.cat(embeddings, dim=-1) # (B, N, dim)
 
107
 
108
 
109
  def apply_rope(x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
@@ -111,7 +115,9 @@ def apply_rope(x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
111
  # x: (B, heads, N, head_dim)
112
  # rope: (B, N, head_dim)
113
  B, H, N, D = x.shape
114
- rope = rope.unsqueeze(1) # (B, 1, N, D)
 
 
115
 
116
  # Split into pairs
117
  x_pairs = x.reshape(B, H, N, D // 2, 2)
@@ -221,6 +227,11 @@ class Attention(nn.Module):
221
  mask: Optional[torch.Tensor] = None
222
  ) -> torch.Tensor:
223
  B, N, _ = x.shape
 
 
 
 
 
224
 
225
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
226
  q, k, v = qkv.permute(2, 0, 3, 1, 4) # 3 x (B, heads, N, head_dim)
@@ -263,6 +274,12 @@ class JointAttention(nn.Module):
263
  B, L, _ = txt.shape
264
  _, N, _ = img.shape
265
 
 
 
 
 
 
 
266
  # Compute Q, K, V for both streams
267
  txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim)
268
  img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim)
@@ -506,8 +523,8 @@ class TinyFlux(nn.Module):
506
  if self.config.guidance_embeds and guidance is not None:
507
  vec = vec + self.guidance_in(guidance)
508
 
509
- # RoPE for image positions
510
- img_rope = self.rope(img_ids)
511
 
512
  # Double-stream blocks
513
  for block in self.double_blocks:
 
78
  self.axes_dims = axes_dims # (temporal, height, width)
79
  self.theta = theta
80
 
81
+ def forward(self, ids: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor:
82
  """
83
  ids: (B, N, 3) - temporal, height, width indices
84
+ dtype: output dtype (defaults to ids.dtype, but use model dtype for bf16)
85
  Returns: (B, N, dim) rotary embeddings
86
  """
87
  B, N, _ = ids.shape
88
  device = ids.device
89
+ # Compute in float32 for precision, cast at the end
90
+ compute_dtype = torch.float32
91
+ output_dtype = dtype if dtype is not None else ids.dtype
92
 
93
  embeddings = []
94
  dim_offset = 0
95
 
96
  for axis_idx, axis_dim in enumerate(self.axes_dims):
97
  # Compute frequencies for this axis
98
+ freqs = 1.0 / (self.theta ** (torch.arange(0, axis_dim, 2, device=device, dtype=compute_dtype) / axis_dim))
99
  # Get positions for this axis
100
+ positions = ids[:, :, axis_idx].to(compute_dtype) # (B, N)
101
  # Outer product: (B, N) x (axis_dim/2) -> (B, N, axis_dim/2)
102
  angles = positions.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0)
103
  # Interleave sin/cos
 
106
  embeddings.append(emb)
107
  dim_offset += axis_dim
108
 
109
+ result = torch.cat(embeddings, dim=-1) # (B, N, dim)
110
+ return result.to(output_dtype)
111
 
112
 
113
  def apply_rope(x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
 
115
  # x: (B, heads, N, head_dim)
116
  # rope: (B, N, head_dim)
117
  B, H, N, D = x.shape
118
+
119
+ # Ensure rope matches x dtype
120
+ rope = rope.to(x.dtype).unsqueeze(1) # (B, 1, N, D)
121
 
122
  # Split into pairs
123
  x_pairs = x.reshape(B, H, N, D // 2, 2)
 
227
  mask: Optional[torch.Tensor] = None
228
  ) -> torch.Tensor:
229
  B, N, _ = x.shape
230
+ dtype = x.dtype
231
+
232
+ # Ensure RoPE matches input dtype
233
+ if rope is not None:
234
+ rope = rope.to(dtype)
235
 
236
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
237
  q, k, v = qkv.permute(2, 0, 3, 1, 4) # 3 x (B, heads, N, head_dim)
 
274
  B, L, _ = txt.shape
275
  _, N, _ = img.shape
276
 
277
+ # Ensure consistent dtype (use img dtype as reference)
278
+ dtype = img.dtype
279
+ txt = txt.to(dtype)
280
+ if rope is not None:
281
+ rope = rope.to(dtype)
282
+
283
  # Compute Q, K, V for both streams
284
  txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim)
285
  img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim)
 
523
  if self.config.guidance_embeds and guidance is not None:
524
  vec = vec + self.guidance_in(guidance)
525
 
526
+ # RoPE for image positions (match model dtype)
527
+ img_rope = self.rope(img_ids, dtype=img.dtype)
528
 
529
  # Double-stream blocks
530
  for block in self.double_blocks: