ZhouZJ36DL commited on
Commit
41696aa
·
1 Parent(s): 72ddbe6

modified: src/flux/model.py

Browse files
src/flux/__pycache__/__init__.cpython-310.pyc CHANGED
Binary files a/src/flux/__pycache__/__init__.cpython-310.pyc and b/src/flux/__pycache__/__init__.cpython-310.pyc differ
 
src/flux/__pycache__/_version.cpython-310.pyc CHANGED
Binary files a/src/flux/__pycache__/_version.cpython-310.pyc and b/src/flux/__pycache__/_version.cpython-310.pyc differ
 
src/flux/__pycache__/math.cpython-310.pyc CHANGED
Binary files a/src/flux/__pycache__/math.cpython-310.pyc and b/src/flux/__pycache__/math.cpython-310.pyc differ
 
src/flux/__pycache__/model.cpython-310.pyc CHANGED
Binary files a/src/flux/__pycache__/model.cpython-310.pyc and b/src/flux/__pycache__/model.cpython-310.pyc differ
 
src/flux/__pycache__/sampling.cpython-310.pyc CHANGED
Binary files a/src/flux/__pycache__/sampling.cpython-310.pyc and b/src/flux/__pycache__/sampling.cpython-310.pyc differ
 
src/flux/__pycache__/util.cpython-310.pyc CHANGED
Binary files a/src/flux/__pycache__/util.cpython-310.pyc and b/src/flux/__pycache__/util.cpython-310.pyc differ
 
src/flux/model.py CHANGED
@@ -90,6 +90,27 @@ class Flux(nn.Module):
90
  if img.ndim != 3 or txt.ndim != 3:
91
  raise ValueError("Input img and txt tensors must have 3 dimensions.")
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  # running on sequences img
94
  img = self.img_in(img)
95
  vec = self.time_in(timestep_embedding(timesteps, 256))
 
90
  if img.ndim != 3 or txt.ndim != 3:
91
  raise ValueError("Input img and txt tensors must have 3 dimensions.")
92
 
93
+ # --- CRITICAL DEBUG: Check the device of self.img_in's parameters ---
94
+ weight_device = self.img_in.weight.device
95
+ bias_device = self.img_in.bias.device if self.img_in.bias is not None else "N/A (None)"
96
+ print(f"self.img_in.weight device: {weight_device}")
97
+ print(f"self.img_in.bias device: {bias_device}")
98
+
99
+ # --- FIX: Explicitly move img to the device of img_in's weight if they differ ---
100
+ # This is the core fix if the mismatch is here
101
+ if img.device != weight_device:
102
+ print(f"!!! Mismatch detected: img on {img.device}, img_in.weight on {weight_device}. Moving img to {weight_device} !!!")
103
+ img = img.to(weight_device)
104
+ # It's also good practice to ensure other inputs are on the same device if they aren't already
105
+ # However, based on your previous check, they should be. Let's double-check one key one:
106
+ if txt.device != weight_device:
107
+ print(f"!!! Also moving 'txt' from {txt.device} to {weight_device} !!!")
108
+ txt = txt.to(weight_device)
109
+ # Add similar checks/moves for txt_ids, y, timesteps, guidance if needed,
110
+ # but based on your previous debug, they were on cuda:0.
111
+
112
+ print("--- End of Critical Debug ---")
113
+
114
  # running on sequences img
115
  img = self.img_in(img)
116
  vec = self.time_in(timestep_embedding(timesteps, 256))
src/flux/modules/__pycache__/autoencoder.cpython-310.pyc CHANGED
Binary files a/src/flux/modules/__pycache__/autoencoder.cpython-310.pyc and b/src/flux/modules/__pycache__/autoencoder.cpython-310.pyc differ
 
src/flux/modules/__pycache__/conditioner.cpython-310.pyc CHANGED
Binary files a/src/flux/modules/__pycache__/conditioner.cpython-310.pyc and b/src/flux/modules/__pycache__/conditioner.cpython-310.pyc differ
 
src/flux/modules/__pycache__/layers.cpython-310.pyc CHANGED
Binary files a/src/flux/modules/__pycache__/layers.cpython-310.pyc and b/src/flux/modules/__pycache__/layers.cpython-310.pyc differ
 
src/flux/sampling.py CHANGED
@@ -522,23 +522,6 @@ def denoise_multi_turn_consistent(
522
  info['second_order'] = False
523
  info['inject'] = inject_list[i]
524
 
525
- # Check and print devices of all input tensors
526
- tensor_info = {
527
- 'img': img,
528
- 'img_ids': img_ids,
529
- 'txt': txt,
530
- 'txt_ids': txt_ids,
531
- 'y (vec)': vec, # Renamed for clarity in print
532
- 'timesteps (t_vec)': t_vec, # Renamed for clarity in print
533
- 'guidance (guidance_vec)': guidance_vec # Renamed for clarity in print
534
- }
535
- for name, tensor in tensor_info.items():
536
- if torch.is_tensor(tensor):
537
- print(f"{name} device: {tensor.device}, shape: {tensor.shape}")
538
- else:
539
- print(f"{name} is not a tensor: {type(tensor)} (value: {tensor})")
540
-
541
-
542
  if next_step_velocity is None:
543
  pred, info = model(
544
  img=img,
 
522
  info['second_order'] = False
523
  info['inject'] = inject_list[i]
524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  if next_step_velocity is None:
526
  pred, info = model(
527
  img=img,