Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,735 Bytes
f0e942d b635626 f0e942d 41696aa 7f77208 4ca21ee 41696aa 7f77208 41696aa f0e942d a194a28 4ca21ee a194a28 f0e942d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
from dataclasses import dataclass
import torch
from torch import Tensor, nn
from src.flux.modules.layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = nn.ModuleList(
[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
cur_block=i,
)
for i in range(params.depth)
]
)
self.single_blocks = nn.ModuleList(
[
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio)
for _ in range(params.depth_single_blocks)
]
)
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def forward(
self,
img: Tensor,
img_ids: Tensor,
txt: Tensor, # t5 text
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor, # clip text
cur_step: int,
guidance: Tensor | None = None,
info = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# --- CRITICAL DEBUG: Check the device of self.img_in's parameters ---
weight_device = self.img_in.weight.device
bias_device = self.img_in.bias.device if self.img_in.bias is not None else "N/A (None)"
#print(f"self.img_in.weight device: {weight_device}")
#print(f"self.img_in.bias device: {bias_device}")
#print("Model img_in weight sample:", self.img_in.weight[0, :10]) # Print first 10 elements of first row
#print("Model img_in bias sample:", self.img_in.bias[:10]) # Print first 10 elements o
# --- FIX: Explicitly move img to the device of img_in's weight if they differ ---
# This is the core fix if the mismatch is here
if img.device != weight_device:
print(f"!!! Mismatch detected: img on {img.device}, img_in.weight on {weight_device}. Moving img to {weight_device} !!!")
img = img.to(weight_device)
# It's also good practice to ensure other inputs are on the same device if they aren't already
# However, based on your previous check, they should be. Let's double-check one key one:
if txt.device != weight_device:
print(f"!!! Also moving 'txt' from {txt.device} to {weight_device} !!!")
txt = txt.to(weight_device)
# Add similar checks/moves for txt_ids, y, timesteps, guidance if needed,
# but based on your previous debug, they were on cuda:0.
#print("--- End of Critical Debug ---")
# running on sequences img
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
#print(f"self.img_in(img)_{cur_step}:{img}")
#print(f"self.time_in(timestep_embedding(timesteps, 256))_{cur_step}:{vec}")
if self.params.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, cur_step=cur_step, info=info)
cnt = 0
img = torch.cat((txt, img), 1)
info['type'] = 'single'
for block in self.single_blocks:
info['id'] = cnt
img, info = block(img, vec=vec, pe=pe, info=info)
cnt += 1
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img, info |