Fix unet statedict load
Browse files
app.py
CHANGED
|
@@ -28,9 +28,6 @@ os.environ["CUDA_MODULE_LOADING"] = "LAZY"
|
|
| 28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
weight_dtype = torch.float16 if device == "cuda" else torch.float32
|
| 30 |
|
| 31 |
-
CHECKPOINT_DIR = "./checkpoints/VITONHD/model"
|
| 32 |
-
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
|
| 33 |
-
|
| 34 |
def load_models():
|
| 35 |
"""Загружает все необходимые модели"""
|
| 36 |
print("⚙️ Загрузка моделей...")
|
|
@@ -67,17 +64,11 @@ def load_models():
|
|
| 67 |
subfolder="unet"
|
| 68 |
)
|
| 69 |
|
| 70 |
-
unet_checkpoint_path =
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
filename="VITONHD/model/pytorch_model.bin",
|
| 76 |
-
token=TOKEN
|
| 77 |
-
)
|
| 78 |
-
import shutil
|
| 79 |
-
shutil.copy(temp_path, unet_checkpoint_path)
|
| 80 |
-
print(f"✅ Файл успешно сохранён по пути: {unet_checkpoint_path}")
|
| 81 |
unet.load_state_dict(torch.load(unet_checkpoint_path, map_location=device))
|
| 82 |
|
| 83 |
cloth_encoder = ClothEncoder.from_pretrained(
|
|
@@ -187,7 +178,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 900px}
|
|
| 187 |
)
|
| 188 |
|
| 189 |
if __name__ == "__main__":
|
| 190 |
-
demo.queue(
|
| 191 |
server_name="0.0.0.0" if os.getenv("SPACE_ID") else None,
|
| 192 |
share=os.getenv("GRADIO_SHARE") == "True"
|
| 193 |
)
|
|
|
|
| 28 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 29 |
weight_dtype = torch.float16 if device == "cuda" else torch.float32
|
| 30 |
|
|
|
|
|
|
|
|
|
|
| 31 |
def load_models():
|
| 32 |
"""Загружает все необходимые модели"""
|
| 33 |
print("⚙️ Загрузка моделей...")
|
|
|
|
| 64 |
subfolder="unet"
|
| 65 |
)
|
| 66 |
|
| 67 |
+
unet_checkpoint_path = hf_hub_download(
|
| 68 |
+
repo_id="Benrise/VITON-HD",
|
| 69 |
+
filename="VITONHD/model/pytorch_model.bin",
|
| 70 |
+
token=TOKEN
|
| 71 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
unet.load_state_dict(torch.load(unet_checkpoint_path, map_location=device))
|
| 73 |
|
| 74 |
cloth_encoder = ClothEncoder.from_pretrained(
|
|
|
|
| 178 |
)
|
| 179 |
|
| 180 |
if __name__ == "__main__":
|
| 181 |
+
demo.queue(max_size=1).launch(
|
| 182 |
server_name="0.0.0.0" if os.getenv("SPACE_ID") else None,
|
| 183 |
share=os.getenv("GRADIO_SHARE") == "True"
|
| 184 |
)
|