Update model.py
Browse files
model.py
CHANGED
|
@@ -78,23 +78,26 @@ class RotaryEmbedding(nn.Module):
|
|
| 78 |
self.axes_dims = axes_dims # (temporal, height, width)
|
| 79 |
self.theta = theta
|
| 80 |
|
| 81 |
-
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
| 82 |
"""
|
| 83 |
ids: (B, N, 3) - temporal, height, width indices
|
|
|
|
| 84 |
Returns: (B, N, dim) rotary embeddings
|
| 85 |
"""
|
| 86 |
B, N, _ = ids.shape
|
| 87 |
device = ids.device
|
| 88 |
-
|
|
|
|
|
|
|
| 89 |
|
| 90 |
embeddings = []
|
| 91 |
dim_offset = 0
|
| 92 |
|
| 93 |
for axis_idx, axis_dim in enumerate(self.axes_dims):
|
| 94 |
# Compute frequencies for this axis
|
| 95 |
-
freqs = 1.0 / (self.theta ** (torch.arange(0, axis_dim, 2, device=device, dtype=
|
| 96 |
# Get positions for this axis
|
| 97 |
-
positions = ids[:, :, axis_idx].
|
| 98 |
# Outer product: (B, N) x (axis_dim/2) -> (B, N, axis_dim/2)
|
| 99 |
angles = positions.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0)
|
| 100 |
# Interleave sin/cos
|
|
@@ -103,7 +106,8 @@ class RotaryEmbedding(nn.Module):
|
|
| 103 |
embeddings.append(emb)
|
| 104 |
dim_offset += axis_dim
|
| 105 |
|
| 106 |
-
|
|
|
|
| 107 |
|
| 108 |
|
| 109 |
def apply_rope(x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
|
|
@@ -111,7 +115,9 @@ def apply_rope(x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
|
|
| 111 |
# x: (B, heads, N, head_dim)
|
| 112 |
# rope: (B, N, head_dim)
|
| 113 |
B, H, N, D = x.shape
|
| 114 |
-
|
|
|
|
|
|
|
| 115 |
|
| 116 |
# Split into pairs
|
| 117 |
x_pairs = x.reshape(B, H, N, D // 2, 2)
|
|
@@ -221,6 +227,11 @@ class Attention(nn.Module):
|
|
| 221 |
mask: Optional[torch.Tensor] = None
|
| 222 |
) -> torch.Tensor:
|
| 223 |
B, N, _ = x.shape
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
| 226 |
q, k, v = qkv.permute(2, 0, 3, 1, 4) # 3 x (B, heads, N, head_dim)
|
|
@@ -263,6 +274,12 @@ class JointAttention(nn.Module):
|
|
| 263 |
B, L, _ = txt.shape
|
| 264 |
_, N, _ = img.shape
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
# Compute Q, K, V for both streams
|
| 267 |
txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim)
|
| 268 |
img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim)
|
|
@@ -506,8 +523,8 @@ class TinyFlux(nn.Module):
|
|
| 506 |
if self.config.guidance_embeds and guidance is not None:
|
| 507 |
vec = vec + self.guidance_in(guidance)
|
| 508 |
|
| 509 |
-
# RoPE for image positions
|
| 510 |
-
img_rope = self.rope(img_ids)
|
| 511 |
|
| 512 |
# Double-stream blocks
|
| 513 |
for block in self.double_blocks:
|
|
|
|
| 78 |
self.axes_dims = axes_dims # (temporal, height, width)
|
| 79 |
self.theta = theta
|
| 80 |
|
| 81 |
+
def forward(self, ids: torch.Tensor, dtype: torch.dtype = None) -> torch.Tensor:
|
| 82 |
"""
|
| 83 |
ids: (B, N, 3) - temporal, height, width indices
|
| 84 |
+
dtype: output dtype (defaults to ids.dtype, but use model dtype for bf16)
|
| 85 |
Returns: (B, N, dim) rotary embeddings
|
| 86 |
"""
|
| 87 |
B, N, _ = ids.shape
|
| 88 |
device = ids.device
|
| 89 |
+
# Compute in float32 for precision, cast at the end
|
| 90 |
+
compute_dtype = torch.float32
|
| 91 |
+
output_dtype = dtype if dtype is not None else ids.dtype
|
| 92 |
|
| 93 |
embeddings = []
|
| 94 |
dim_offset = 0
|
| 95 |
|
| 96 |
for axis_idx, axis_dim in enumerate(self.axes_dims):
|
| 97 |
# Compute frequencies for this axis
|
| 98 |
+
freqs = 1.0 / (self.theta ** (torch.arange(0, axis_dim, 2, device=device, dtype=compute_dtype) / axis_dim))
|
| 99 |
# Get positions for this axis
|
| 100 |
+
positions = ids[:, :, axis_idx].to(compute_dtype) # (B, N)
|
| 101 |
# Outer product: (B, N) x (axis_dim/2) -> (B, N, axis_dim/2)
|
| 102 |
angles = positions.unsqueeze(-1) * freqs.unsqueeze(0).unsqueeze(0)
|
| 103 |
# Interleave sin/cos
|
|
|
|
| 106 |
embeddings.append(emb)
|
| 107 |
dim_offset += axis_dim
|
| 108 |
|
| 109 |
+
result = torch.cat(embeddings, dim=-1) # (B, N, dim)
|
| 110 |
+
return result.to(output_dtype)
|
| 111 |
|
| 112 |
|
| 113 |
def apply_rope(x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
|
|
|
|
| 115 |
# x: (B, heads, N, head_dim)
|
| 116 |
# rope: (B, N, head_dim)
|
| 117 |
B, H, N, D = x.shape
|
| 118 |
+
|
| 119 |
+
# Ensure rope matches x dtype
|
| 120 |
+
rope = rope.to(x.dtype).unsqueeze(1) # (B, 1, N, D)
|
| 121 |
|
| 122 |
# Split into pairs
|
| 123 |
x_pairs = x.reshape(B, H, N, D // 2, 2)
|
|
|
|
| 227 |
mask: Optional[torch.Tensor] = None
|
| 228 |
) -> torch.Tensor:
|
| 229 |
B, N, _ = x.shape
|
| 230 |
+
dtype = x.dtype
|
| 231 |
+
|
| 232 |
+
# Ensure RoPE matches input dtype
|
| 233 |
+
if rope is not None:
|
| 234 |
+
rope = rope.to(dtype)
|
| 235 |
|
| 236 |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
|
| 237 |
q, k, v = qkv.permute(2, 0, 3, 1, 4) # 3 x (B, heads, N, head_dim)
|
|
|
|
| 274 |
B, L, _ = txt.shape
|
| 275 |
_, N, _ = img.shape
|
| 276 |
|
| 277 |
+
# Ensure consistent dtype (use img dtype as reference)
|
| 278 |
+
dtype = img.dtype
|
| 279 |
+
txt = txt.to(dtype)
|
| 280 |
+
if rope is not None:
|
| 281 |
+
rope = rope.to(dtype)
|
| 282 |
+
|
| 283 |
# Compute Q, K, V for both streams
|
| 284 |
txt_qkv = self.txt_qkv(txt).reshape(B, L, 3, self.num_heads, self.head_dim)
|
| 285 |
img_qkv = self.img_qkv(img).reshape(B, N, 3, self.num_heads, self.head_dim)
|
|
|
|
| 523 |
if self.config.guidance_embeds and guidance is not None:
|
| 524 |
vec = vec + self.guidance_in(guidance)
|
| 525 |
|
| 526 |
+
# RoPE for image positions (match model dtype)
|
| 527 |
+
img_rope = self.rope(img_ids, dtype=img.dtype)
|
| 528 |
|
| 529 |
# Double-stream blocks
|
| 530 |
for block in self.double_blocks:
|