Spaces:
Sleeping
Sleeping
Add dtype parameter to fix CUDA bf16 compatibility issues
Browse filesUse fp16 instead of bf16 for inference to avoid CUBLAS_STATUS_INVALID_VALUE
errors on GPUs/CUDA versions with limited bf16 support.
- app.py +1 -0
- inference.py +30 -12
app.py
CHANGED
|
@@ -104,6 +104,7 @@ def init_generator():
|
|
| 104 |
author_descriptions_path='dataset/calligraphy_styles_en.json',
|
| 105 |
use_deepspeed=False,
|
| 106 |
use_4bit_quantization=False, # Full precision model
|
|
|
|
| 107 |
)
|
| 108 |
return generator
|
| 109 |
|
|
|
|
| 104 |
author_descriptions_path='dataset/calligraphy_styles_en.json',
|
| 105 |
use_deepspeed=False,
|
| 106 |
use_4bit_quantization=False, # Full precision model
|
| 107 |
+
dtype="fp16", # Use fp16 instead of bf16 for better CUDA compatibility
|
| 108 |
)
|
| 109 |
return generator
|
| 110 |
|
inference.py
CHANGED
|
@@ -150,7 +150,8 @@ class CalligraphyGenerator:
|
|
| 150 |
author_descriptions_path: str = "calligraphy_styles_en.json",
|
| 151 |
use_deepspeed: bool = False,
|
| 152 |
use_4bit_quantization: bool = False,
|
| 153 |
-
deepspeed_config: Optional[str] = None
|
|
|
|
| 154 |
):
|
| 155 |
"""
|
| 156 |
Initialize the calligraphy generator
|
|
@@ -166,6 +167,7 @@ class CalligraphyGenerator:
|
|
| 166 |
author_descriptions_path: path to author style descriptions JSON
|
| 167 |
use_deepspeed: whether to use DeepSpeed ZeRO for memory optimization
|
| 168 |
deepspeed_config: path to DeepSpeed config JSON file
|
|
|
|
| 169 |
"""
|
| 170 |
self.device = torch.device(device)
|
| 171 |
self.model_name = model_name
|
|
@@ -174,6 +176,7 @@ class CalligraphyGenerator:
|
|
| 174 |
self.use_deepspeed = use_deepspeed
|
| 175 |
self.deepspeed_config = deepspeed_config
|
| 176 |
self.use_4bit_quantization = use_4bit_quantization
|
|
|
|
| 177 |
|
| 178 |
# Load font and author style descriptions
|
| 179 |
if os.path.exists(font_descriptions_path):
|
|
@@ -343,18 +346,33 @@ class CalligraphyGenerator:
|
|
| 343 |
checkpoint_dtype = first_tensor.dtype
|
| 344 |
print(f"Checkpoint dtype: {checkpoint_dtype}")
|
| 345 |
|
| 346 |
-
#
|
| 347 |
-
|
| 348 |
-
if
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
print(f"
|
| 356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
else:
|
|
|
|
|
|
|
| 358 |
target_dtype = torch.float16
|
| 359 |
print(f"Converting checkpoint from {checkpoint_dtype} to float16 for efficiency...")
|
| 360 |
checkpoint = {k: v.half() for k, v in checkpoint.items()}
|
|
|
|
| 150 |
author_descriptions_path: str = "calligraphy_styles_en.json",
|
| 151 |
use_deepspeed: bool = False,
|
| 152 |
use_4bit_quantization: bool = False,
|
| 153 |
+
deepspeed_config: Optional[str] = None,
|
| 154 |
+
dtype: Optional[str] = None
|
| 155 |
):
|
| 156 |
"""
|
| 157 |
Initialize the calligraphy generator
|
|
|
|
| 167 |
author_descriptions_path: path to author style descriptions JSON
|
| 168 |
use_deepspeed: whether to use DeepSpeed ZeRO for memory optimization
|
| 169 |
deepspeed_config: path to DeepSpeed config JSON file
|
| 170 |
+
dtype: force specific dtype for inference: "fp16", "bf16", "fp32", or None for auto
|
| 171 |
"""
|
| 172 |
self.device = torch.device(device)
|
| 173 |
self.model_name = model_name
|
|
|
|
| 176 |
self.use_deepspeed = use_deepspeed
|
| 177 |
self.deepspeed_config = deepspeed_config
|
| 178 |
self.use_4bit_quantization = use_4bit_quantization
|
| 179 |
+
self.forced_dtype = dtype # "fp16", "bf16", "fp32", or None for auto
|
| 180 |
|
| 181 |
# Load font and author style descriptions
|
| 182 |
if os.path.exists(font_descriptions_path):
|
|
|
|
| 346 |
checkpoint_dtype = first_tensor.dtype
|
| 347 |
print(f"Checkpoint dtype: {checkpoint_dtype}")
|
| 348 |
|
| 349 |
+
# Check if user forced a specific dtype
|
| 350 |
+
forced_dtype = getattr(self, 'forced_dtype', None)
|
| 351 |
+
if forced_dtype:
|
| 352 |
+
dtype_map = {
|
| 353 |
+
"fp16": torch.float16,
|
| 354 |
+
"bf16": torch.bfloat16,
|
| 355 |
+
"fp32": torch.float32,
|
| 356 |
+
}
|
| 357 |
+
if forced_dtype not in dtype_map:
|
| 358 |
+
print(f"Warning: Unknown dtype '{forced_dtype}', using auto selection")
|
| 359 |
+
forced_dtype = None
|
| 360 |
+
else:
|
| 361 |
+
target_dtype = dtype_map[forced_dtype]
|
| 362 |
+
print(f"Using forced dtype: {target_dtype}")
|
| 363 |
+
if checkpoint_dtype != target_dtype:
|
| 364 |
+
print(f"Converting checkpoint from {checkpoint_dtype} to {target_dtype}...")
|
| 365 |
+
checkpoint = {k: v.to(target_dtype) for k, v in checkpoint.items()}
|
| 366 |
+
|
| 367 |
+
if not forced_dtype:
|
| 368 |
+
# Use bfloat16 for inference if checkpoint is in bf16/fp16, otherwise keep as is
|
| 369 |
+
# bfloat16 is preferred for stability, fp16 for speed
|
| 370 |
+
if checkpoint_dtype in [torch.bfloat16, torch.float16]:
|
| 371 |
+
target_dtype = checkpoint_dtype
|
| 372 |
+
print(f"Using {target_dtype} for inference (memory efficient)")
|
| 373 |
else:
|
| 374 |
+
# Convert to float16 for memory efficiency (more compatible than bf16)
|
| 375 |
+
# bf16 can have CUBLAS issues on some GPUs/CUDA versions
|
| 376 |
target_dtype = torch.float16
|
| 377 |
print(f"Converting checkpoint from {checkpoint_dtype} to float16 for efficiency...")
|
| 378 |
checkpoint = {k: v.half() for k, v in checkpoint.items()}
|