Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -144,6 +144,51 @@ pipe.to('cuda')
|
|
| 144 |
print("✅ Models loaded successfully!")
|
| 145 |
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
# ============================================================
|
| 148 |
# GPU 推理函数(只包含实际的推理逻辑)
|
| 149 |
# ============================================================
|
|
@@ -164,10 +209,23 @@ def generate_image(
|
|
| 164 |
print(f" Prompt: {prompt[:50]}...")
|
| 165 |
print(f" Steps: {num_steps}, CFG: {true_cfg_scale}, Seed: {seed}")
|
| 166 |
|
| 167 |
-
|
| 168 |
-
# ZeroGPU 可能没有正确移动 register_buffer 注册的张量
|
| 169 |
pipe.to('cuda')
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
# 调试信息:检查模型设备
|
| 172 |
print(f" [DEBUG] text_encoder device: {next(pipe.text_encoder.parameters()).device}")
|
| 173 |
print(f" [DEBUG] transformer device: {next(pipe.transformer.parameters()).device}")
|
|
|
|
| 144 |
print("✅ Models loaded successfully!")
|
| 145 |
|
| 146 |
|
| 147 |
+
def fix_rope_buffers(model):
|
| 148 |
+
"""
|
| 149 |
+
修复 RoPE (Rotary Position Embedding) 中的 buffer 张量
|
| 150 |
+
ZeroGPU 环境下,register_buffer 注册的张量可能不会被正确移动到 GPU
|
| 151 |
+
|
| 152 |
+
这个函数会遍历模型的所有子模块,检查并修复以下 buffer:
|
| 153 |
+
- inv_freq: RoPE 的核心频率 buffer
|
| 154 |
+
- cos_cached / sin_cached: 某些实现会缓存的 cos/sin 值
|
| 155 |
+
- 其他所有未在 CUDA 上的 buffer
|
| 156 |
+
"""
|
| 157 |
+
device = 'cuda'
|
| 158 |
+
fixed_count = 0
|
| 159 |
+
|
| 160 |
+
for name, module in model.named_modules():
|
| 161 |
+
# 修复 inv_freq buffer (RoPE 的核心 buffer)
|
| 162 |
+
if hasattr(module, 'inv_freq') and module.inv_freq is not None:
|
| 163 |
+
if module.inv_freq.device.type != 'cuda':
|
| 164 |
+
module.inv_freq = module.inv_freq.to(device)
|
| 165 |
+
fixed_count += 1
|
| 166 |
+
print(f" [FIX] Moved {name}.inv_freq to {device}")
|
| 167 |
+
|
| 168 |
+
# 修复 cos_cached 和 sin_cached (某些 RoPE 实现会缓存这些)
|
| 169 |
+
if hasattr(module, 'cos_cached') and module.cos_cached is not None:
|
| 170 |
+
if module.cos_cached.device.type != 'cuda':
|
| 171 |
+
module.cos_cached = module.cos_cached.to(device)
|
| 172 |
+
fixed_count += 1
|
| 173 |
+
print(f" [FIX] Moved {name}.cos_cached to {device}")
|
| 174 |
+
|
| 175 |
+
if hasattr(module, 'sin_cached') and module.sin_cached is not None:
|
| 176 |
+
if module.sin_cached.device.type != 'cuda':
|
| 177 |
+
module.sin_cached = module.sin_cached.to(device)
|
| 178 |
+
fixed_count += 1
|
| 179 |
+
print(f" [FIX] Moved {name}.sin_cached to {device}")
|
| 180 |
+
|
| 181 |
+
# 通用:修复所有 buffer(更全面的修复)
|
| 182 |
+
for buf_name, buf in module.named_buffers(recurse=False):
|
| 183 |
+
if buf is not None and buf.device.type != 'cuda':
|
| 184 |
+
setattr(module, buf_name, buf.to(device))
|
| 185 |
+
fixed_count += 1
|
| 186 |
+
print(f" [FIX] Moved {name}.{buf_name} to {device}")
|
| 187 |
+
|
| 188 |
+
return fixed_count
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
|
| 192 |
# ============================================================
|
| 193 |
# GPU 推理函数(只包含实际的推理逻辑)
|
| 194 |
# ============================================================
|
|
|
|
| 209 |
print(f" Prompt: {prompt[:50]}...")
|
| 210 |
print(f" Steps: {num_steps}, CFG: {true_cfg_scale}, Seed: {seed}")
|
| 211 |
|
| 212 |
+
# Step 1: 移动 pipeline 到 CUDA
|
|
|
|
| 213 |
pipe.to('cuda')
|
| 214 |
|
| 215 |
+
# Step 2: 关键修复 - 手动修复 RoPE buffer
|
| 216 |
+
# ZeroGPU 可能没有正确移动 register_buffer 注册的张量
|
| 217 |
+
print(" [DEBUG] Fixing RoPE buffers...")
|
| 218 |
+
fixed = 0
|
| 219 |
+
fixed += fix_rope_buffers(pipe.text_encoder)
|
| 220 |
+
fixed += fix_rope_buffers(pipe.transformer)
|
| 221 |
+
fixed += fix_rope_buffers(pipe.vae)
|
| 222 |
+
print(f" [DEBUG] Fixed {fixed} buffer(s)")
|
| 223 |
+
|
| 224 |
+
# 调试信息:检查模型设备
|
| 225 |
+
print(f" [DEBUG] text_encoder device: {next(pipe.text_encoder.parameters()).device}")
|
| 226 |
+
print(f" [DEBUG] transformer device: {next(pipe.transformer.parameters()).device}")
|
| 227 |
+
print(f" [DEBUG] vae device: {next(pipe.vae.parameters()).device}")
|
| 228 |
+
|
| 229 |
# 调试信息:检查模型设备
|
| 230 |
print(f" [DEBUG] text_encoder device: {next(pipe.text_encoder.parameters()).device}")
|
| 231 |
print(f" [DEBUG] transformer device: {next(pipe.transformer.parameters()).device}")
|