TSXu commited on
Commit
8bfa41d
·
1 Parent(s): c322e84

Add dtype parameter to fix CUDA bf16 compatibility issues

Browse files

Use fp16 instead of bf16 for inference to avoid CUBLAS_STATUS_INVALID_VALUE
errors on GPUs/CUDA versions with limited bf16 support.

Files changed (2) hide show
  1. app.py +1 -0
  2. inference.py +30 -12
app.py CHANGED
@@ -104,6 +104,7 @@ def init_generator():
104
  author_descriptions_path='dataset/calligraphy_styles_en.json',
105
  use_deepspeed=False,
106
  use_4bit_quantization=False, # Full precision model
 
107
  )
108
  return generator
109
 
 
104
  author_descriptions_path='dataset/calligraphy_styles_en.json',
105
  use_deepspeed=False,
106
  use_4bit_quantization=False, # Full precision model
107
+ dtype="fp16", # Use fp16 instead of bf16 for better CUDA compatibility
108
  )
109
  return generator
110
 
inference.py CHANGED
@@ -150,7 +150,8 @@ class CalligraphyGenerator:
150
  author_descriptions_path: str = "calligraphy_styles_en.json",
151
  use_deepspeed: bool = False,
152
  use_4bit_quantization: bool = False,
153
- deepspeed_config: Optional[str] = None
 
154
  ):
155
  """
156
  Initialize the calligraphy generator
@@ -166,6 +167,7 @@ class CalligraphyGenerator:
166
  author_descriptions_path: path to author style descriptions JSON
167
  use_deepspeed: whether to use DeepSpeed ZeRO for memory optimization
168
  deepspeed_config: path to DeepSpeed config JSON file
 
169
  """
170
  self.device = torch.device(device)
171
  self.model_name = model_name
@@ -174,6 +176,7 @@ class CalligraphyGenerator:
174
  self.use_deepspeed = use_deepspeed
175
  self.deepspeed_config = deepspeed_config
176
  self.use_4bit_quantization = use_4bit_quantization
 
177
 
178
  # Load font and author style descriptions
179
  if os.path.exists(font_descriptions_path):
@@ -343,18 +346,33 @@ class CalligraphyGenerator:
343
  checkpoint_dtype = first_tensor.dtype
344
  print(f"Checkpoint dtype: {checkpoint_dtype}")
345
 
346
- # Use bfloat16 for inference if checkpoint is in bf16/fp16, otherwise keep as is
347
- # bfloat16 is preferred for stability, fp16 for speed
348
- if checkpoint_dtype in [torch.bfloat16, torch.float16]:
349
- target_dtype = checkpoint_dtype
350
- print(f"Using {target_dtype} for inference (memory efficient)")
351
- else:
352
- # Convert to bfloat16 for memory efficiency if available
353
- if torch.cuda.is_bf16_supported():
354
- target_dtype = torch.bfloat16
355
- print(f"Converting checkpoint from {checkpoint_dtype} to bfloat16 for efficiency...")
356
- checkpoint = {k: v.to(torch.bfloat16) for k, v in checkpoint.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  else:
 
 
358
  target_dtype = torch.float16
359
  print(f"Converting checkpoint from {checkpoint_dtype} to float16 for efficiency...")
360
  checkpoint = {k: v.half() for k, v in checkpoint.items()}
 
150
  author_descriptions_path: str = "calligraphy_styles_en.json",
151
  use_deepspeed: bool = False,
152
  use_4bit_quantization: bool = False,
153
+ deepspeed_config: Optional[str] = None,
154
+ dtype: Optional[str] = None
155
  ):
156
  """
157
  Initialize the calligraphy generator
 
167
  author_descriptions_path: path to author style descriptions JSON
168
  use_deepspeed: whether to use DeepSpeed ZeRO for memory optimization
169
  deepspeed_config: path to DeepSpeed config JSON file
170
+ dtype: force specific dtype for inference: "fp16", "bf16", "fp32", or None for auto
171
  """
172
  self.device = torch.device(device)
173
  self.model_name = model_name
 
176
  self.use_deepspeed = use_deepspeed
177
  self.deepspeed_config = deepspeed_config
178
  self.use_4bit_quantization = use_4bit_quantization
179
+ self.forced_dtype = dtype # "fp16", "bf16", "fp32", or None for auto
180
 
181
  # Load font and author style descriptions
182
  if os.path.exists(font_descriptions_path):
 
346
  checkpoint_dtype = first_tensor.dtype
347
  print(f"Checkpoint dtype: {checkpoint_dtype}")
348
 
349
+ # Check if user forced a specific dtype
350
+ forced_dtype = getattr(self, 'forced_dtype', None)
351
+ if forced_dtype:
352
+ dtype_map = {
353
+ "fp16": torch.float16,
354
+ "bf16": torch.bfloat16,
355
+ "fp32": torch.float32,
356
+ }
357
+ if forced_dtype not in dtype_map:
358
+ print(f"Warning: Unknown dtype '{forced_dtype}', using auto selection")
359
+ forced_dtype = None
360
+ else:
361
+ target_dtype = dtype_map[forced_dtype]
362
+ print(f"Using forced dtype: {target_dtype}")
363
+ if checkpoint_dtype != target_dtype:
364
+ print(f"Converting checkpoint from {checkpoint_dtype} to {target_dtype}...")
365
+ checkpoint = {k: v.to(target_dtype) for k, v in checkpoint.items()}
366
+
367
+ if not forced_dtype:
368
+ # Use bfloat16 for inference if checkpoint is in bf16/fp16, otherwise keep as is
369
+ # bfloat16 is preferred for stability, fp16 for speed
370
+ if checkpoint_dtype in [torch.bfloat16, torch.float16]:
371
+ target_dtype = checkpoint_dtype
372
+ print(f"Using {target_dtype} for inference (memory efficient)")
373
  else:
374
+ # Convert to float16 for memory efficiency (more compatible than bf16)
375
+ # bf16 can have CUBLAS issues on some GPUs/CUDA versions
376
  target_dtype = torch.float16
377
  print(f"Converting checkpoint from {checkpoint_dtype} to float16 for efficiency...")
378
  checkpoint = {k: v.half() for k, v in checkpoint.items()}