TSXu commited on
Commit
9d88d74
·
1 Parent(s): e7ca422

Use fp32 for inference to fix CUBLAS errors on ZeroGPU

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. inference.py +6 -6
app.py CHANGED
@@ -104,7 +104,7 @@ def init_generator():
104
  author_descriptions_path='dataset/calligraphy_styles_en.json',
105
  use_deepspeed=False,
106
  use_4bit_quantization=False, # Full precision model
107
- dtype="fp16", # Use fp16 instead of bf16 for better CUDA compatibility
108
  )
109
  return generator
110
 
 
104
  author_descriptions_path='dataset/calligraphy_styles_en.json',
105
  use_deepspeed=False,
106
  use_4bit_quantization=False, # Full precision model
107
+ dtype="fp32", # Use fp32 to avoid CUBLAS errors on ZeroGPU
108
  )
109
  return generator
110
 
inference.py CHANGED
@@ -365,13 +365,13 @@ class CalligraphyGenerator:
365
  checkpoint = {k: v.to(target_dtype) for k, v in checkpoint.items()}
366
 
367
  if not forced_dtype:
368
- # Always use fp16 for inference - bf16 has CUDA/CUBLAS compatibility issues
369
- target_dtype = torch.float16
370
- if checkpoint_dtype != torch.float16:
371
- print(f"Converting checkpoint from {checkpoint_dtype} to float16...")
372
- checkpoint = {k: v.to(torch.float16) for k, v in checkpoint.items()}
373
  else:
374
- print(f"Using float16 for inference")
375
 
376
  # Load weights into model
377
  model.load_state_dict(checkpoint, strict=False, assign=True)
 
365
  checkpoint = {k: v.to(target_dtype) for k, v in checkpoint.items()}
366
 
367
  if not forced_dtype:
368
+ # Always use fp32 for inference - fp16/bf16 have CUDA/CUBLAS compatibility issues on ZeroGPU
369
+ target_dtype = torch.float32
370
+ if checkpoint_dtype != torch.float32:
371
+ print(f"Converting checkpoint from {checkpoint_dtype} to float32...")
372
+ checkpoint = {k: v.to(torch.float32) for k, v in checkpoint.items()}
373
  else:
374
+ print(f"Using float32 for inference")
375
 
376
  # Load weights into model
377
  model.load_state_dict(checkpoint, strict=False, assign=True)