OrlandoHugBot commited on
Commit
96e3744
·
verified ·
1 Parent(s): 3fe273f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -2
app.py CHANGED
@@ -144,6 +144,51 @@ pipe.to('cuda')
144
  print("✅ Models loaded successfully!")
145
 
146
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  # ============================================================
148
  # GPU 推理函数(只包含实际的推理逻辑)
149
  # ============================================================
@@ -164,10 +209,23 @@ def generate_image(
164
  print(f" Prompt: {prompt[:50]}...")
165
  print(f" Steps: {num_steps}, CFG: {true_cfg_scale}, Seed: {seed}")
166
 
167
- # 关键修复:在真实 GPU 环境中确保所有张量(包括 buffer)都在 GPU
168
- # ZeroGPU 可能没有正确移动 register_buffer 注册的张量
169
  pipe.to('cuda')
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # 调试信息:检查模型设备
172
  print(f" [DEBUG] text_encoder device: {next(pipe.text_encoder.parameters()).device}")
173
  print(f" [DEBUG] transformer device: {next(pipe.transformer.parameters()).device}")
 
144
  print("✅ Models loaded successfully!")
145
 
146
 
147
+ def fix_rope_buffers(model):
148
+ """
149
+ 修复 RoPE (Rotary Position Embedding) 中的 buffer 张量
150
+ ZeroGPU 环境下,register_buffer 注册的张量可能不会被正确移动到 GPU
151
+
152
+ 这个函数会遍历模型的所有子模块,检查并修复以下 buffer:
153
+ - inv_freq: RoPE 的核心频率 buffer
154
+ - cos_cached / sin_cached: 某些实现会缓存的 cos/sin 值
155
+ - 其他所有未在 CUDA 上的 buffer
156
+ """
157
+ device = 'cuda'
158
+ fixed_count = 0
159
+
160
+ for name, module in model.named_modules():
161
+ # 修复 inv_freq buffer (RoPE 的核心 buffer)
162
+ if hasattr(module, 'inv_freq') and module.inv_freq is not None:
163
+ if module.inv_freq.device.type != 'cuda':
164
+ module.inv_freq = module.inv_freq.to(device)
165
+ fixed_count += 1
166
+ print(f" [FIX] Moved {name}.inv_freq to {device}")
167
+
168
+ # 修复 cos_cached 和 sin_cached (某些 RoPE 实现会缓存这些)
169
+ if hasattr(module, 'cos_cached') and module.cos_cached is not None:
170
+ if module.cos_cached.device.type != 'cuda':
171
+ module.cos_cached = module.cos_cached.to(device)
172
+ fixed_count += 1
173
+ print(f" [FIX] Moved {name}.cos_cached to {device}")
174
+
175
+ if hasattr(module, 'sin_cached') and module.sin_cached is not None:
176
+ if module.sin_cached.device.type != 'cuda':
177
+ module.sin_cached = module.sin_cached.to(device)
178
+ fixed_count += 1
179
+ print(f" [FIX] Moved {name}.sin_cached to {device}")
180
+
181
+ # 通用:修复所有 buffer(更全面的修复)
182
+ for buf_name, buf in module.named_buffers(recurse=False):
183
+ if buf is not None and buf.device.type != 'cuda':
184
+ setattr(module, buf_name, buf.to(device))
185
+ fixed_count += 1
186
+ print(f" [FIX] Moved {name}.{buf_name} to {device}")
187
+
188
+ return fixed_count
189
+
190
+
191
+
192
  # ============================================================
193
  # GPU 推理函数(只包含实际的推理逻辑)
194
  # ============================================================
 
209
  print(f" Prompt: {prompt[:50]}...")
210
  print(f" Steps: {num_steps}, CFG: {true_cfg_scale}, Seed: {seed}")
211
 
212
+ # Step 1: 移动 pipeline CUDA
 
213
  pipe.to('cuda')
214
 
215
+ # Step 2: 关键修复 - 手动修复 RoPE buffer
216
+ # ZeroGPU 可能没有正确移动 register_buffer 注册的张量
217
+ print(" [DEBUG] Fixing RoPE buffers...")
218
+ fixed = 0
219
+ fixed += fix_rope_buffers(pipe.text_encoder)
220
+ fixed += fix_rope_buffers(pipe.transformer)
221
+ fixed += fix_rope_buffers(pipe.vae)
222
+ print(f" [DEBUG] Fixed {fixed} buffer(s)")
223
+
224
+ # 调试信息:检查模型设备
225
+ print(f" [DEBUG] text_encoder device: {next(pipe.text_encoder.parameters()).device}")
226
+ print(f" [DEBUG] transformer device: {next(pipe.transformer.parameters()).device}")
227
+ print(f" [DEBUG] vae device: {next(pipe.vae.parameters()).device}")
228
+
229
  # 调试信息:检查模型设备
230
  print(f" [DEBUG] text_encoder device: {next(pipe.text_encoder.parameters()).device}")
231
  print(f" [DEBUG] transformer device: {next(pipe.transformer.parameters()).device}")