Spaces:
Running
Running
Upload wan/modules/t5.py with huggingface_hub
Browse files- wan/modules/t5.py +3 -2
wan/modules/t5.py
CHANGED
|
@@ -480,7 +480,7 @@ class T5EncoderModel:
|
|
| 480 |
self,
|
| 481 |
text_len,
|
| 482 |
dtype=torch.bfloat16,
|
| 483 |
-
device=
|
| 484 |
checkpoint_path=None,
|
| 485 |
tokenizer_path=None,
|
| 486 |
shard_fn=None,
|
|
@@ -490,7 +490,8 @@ class T5EncoderModel:
|
|
| 490 |
assert quant is None or quant in ("int8", "fp8")
|
| 491 |
self.text_len = text_len
|
| 492 |
self.dtype = dtype
|
| 493 |
-
|
|
|
|
| 494 |
self.checkpoint_path = checkpoint_path
|
| 495 |
self.tokenizer_path = tokenizer_path
|
| 496 |
|
|
|
|
| 480 |
self,
|
| 481 |
text_len,
|
| 482 |
dtype=torch.bfloat16,
|
| 483 |
+
device=None,
|
| 484 |
checkpoint_path=None,
|
| 485 |
tokenizer_path=None,
|
| 486 |
shard_fn=None,
|
|
|
|
| 490 |
assert quant is None or quant in ("int8", "fp8")
|
| 491 |
self.text_len = text_len
|
| 492 |
self.dtype = dtype
|
| 493 |
+
# Defer CUDA device lookup to runtime (for ZeroGPU compatibility)
|
| 494 |
+
self.device = device if device is not None else torch.cuda.current_device()
|
| 495 |
self.checkpoint_path = checkpoint_path
|
| 496 |
self.tokenizer_path = tokenizer_path
|
| 497 |
|