Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,12 +14,7 @@ from transformers import (
|
|
| 14 |
T5EncoderModel,
|
| 15 |
T5Tokenizer
|
| 16 |
)
|
| 17 |
-
from accelerate import
|
| 18 |
-
init_empty_weights,
|
| 19 |
-
set_module_tensor_to_device,
|
| 20 |
-
infer_auto_device_map,
|
| 21 |
-
load_checkpoint_and_dispatch
|
| 22 |
-
)
|
| 23 |
from safetensors import safe_open
|
| 24 |
|
| 25 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -45,7 +40,7 @@ tokenizer_3 = T5Tokenizer.from_pretrained(model_repo_id, subfolder="tokenizer_3"
|
|
| 45 |
config_file = hf_hub_download(repo_id=model_repo_id, filename="transformer/config.json")
|
| 46 |
with open(config_file, "r") as fp:
|
| 47 |
config = json.loads(fp)
|
| 48 |
-
with
|
| 49 |
transformer = SD3Transformer2DModel.from_config(config)
|
| 50 |
|
| 51 |
# Get transformer state dict and load
|
|
|
|
| 14 |
T5EncoderModel,
|
| 15 |
T5Tokenizer
|
| 16 |
)
|
| 17 |
+
from accelerate import init_empty_weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
from safetensors import safe_open
|
| 19 |
|
| 20 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 40 |
config_file = hf_hub_download(repo_id=model_repo_id, filename="transformer/config.json")
|
| 41 |
with open(config_file, "r") as fp:
|
| 42 |
config = json.loads(fp)
|
| 43 |
+
with init_empty_weights():
|
| 44 |
transformer = SD3Transformer2DModel.from_config(config)
|
| 45 |
|
| 46 |
# Get transformer state dict and load
|