Spaces:
Running on Zero
Running on Zero
Use fp32 for inference to fix CUBLAS errors on ZeroGPU
Browse files- app.py +1 -1
- inference.py +6 -6
app.py
CHANGED
|
@@ -104,7 +104,7 @@ def init_generator():
|
|
| 104 |
author_descriptions_path='dataset/calligraphy_styles_en.json',
|
| 105 |
use_deepspeed=False,
|
| 106 |
use_4bit_quantization=False, # Full precision model
|
| 107 |
-
dtype="
|
| 108 |
)
|
| 109 |
return generator
|
| 110 |
|
|
|
|
| 104 |
author_descriptions_path='dataset/calligraphy_styles_en.json',
|
| 105 |
use_deepspeed=False,
|
| 106 |
use_4bit_quantization=False, # Full precision model
|
| 107 |
+
dtype="fp32", # Use fp32 to avoid CUBLAS errors on ZeroGPU
|
| 108 |
)
|
| 109 |
return generator
|
| 110 |
|
inference.py
CHANGED
|
@@ -365,13 +365,13 @@ class CalligraphyGenerator:
|
|
| 365 |
checkpoint = {k: v.to(target_dtype) for k, v in checkpoint.items()}
|
| 366 |
|
| 367 |
if not forced_dtype:
|
| 368 |
-
# Always use
|
| 369 |
-
target_dtype = torch.
|
| 370 |
-
if checkpoint_dtype != torch.
|
| 371 |
-
print(f"Converting checkpoint from {checkpoint_dtype} to
|
| 372 |
-
checkpoint = {k: v.to(torch.
|
| 373 |
else:
|
| 374 |
-
print(f"Using
|
| 375 |
|
| 376 |
# Load weights into model
|
| 377 |
model.load_state_dict(checkpoint, strict=False, assign=True)
|
|
|
|
| 365 |
checkpoint = {k: v.to(target_dtype) for k, v in checkpoint.items()}
|
| 366 |
|
| 367 |
if not forced_dtype:
|
| 368 |
+
# Always use fp32 for inference - fp16/bf16 have CUDA/CUBLAS compatibility issues on ZeroGPU
|
| 369 |
+
target_dtype = torch.float32
|
| 370 |
+
if checkpoint_dtype != torch.float32:
|
| 371 |
+
print(f"Converting checkpoint from {checkpoint_dtype} to float32...")
|
| 372 |
+
checkpoint = {k: v.to(torch.float32) for k, v in checkpoint.items()}
|
| 373 |
else:
|
| 374 |
+
print(f"Using float32 for inference")
|
| 375 |
|
| 376 |
# Load weights into model
|
| 377 |
model.load_state_dict(checkpoint, strict=False, assign=True)
|