Benrise commited on
Commit
06e301f
·
1 Parent(s): 9366128

Fix unet statedict load

Browse files
Files changed (1) hide show
  1. app.py +6 -15
app.py CHANGED
@@ -28,9 +28,6 @@ os.environ["CUDA_MODULE_LOADING"] = "LAZY"
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  weight_dtype = torch.float16 if device == "cuda" else torch.float32
30
 
31
- CHECKPOINT_DIR = "./checkpoints/VITONHD/model"
32
- os.makedirs(CHECKPOINT_DIR, exist_ok=True)
33
-
34
  def load_models():
35
  """Загружает все необходимые модели"""
36
  print("⚙️ Загрузка моделей...")
@@ -67,17 +64,11 @@ def load_models():
67
  subfolder="unet"
68
  )
69
 
70
- unet_checkpoint_path = os.path.join(CHECKPOINT_DIR, "pytorch_model.bin")
71
- if not os.path.exists(unet_checkpoint_path):
72
- print("⏳ Загрузка чекпоинта модели...")
73
- temp_path = hf_hub_download(
74
- repo_id="Benrise/VITON-HD",
75
- filename="VITONHD/model/pytorch_model.bin",
76
- token=TOKEN
77
- )
78
- import shutil
79
- shutil.copy(temp_path, unet_checkpoint_path)
80
- print(f"✅ Файл успешно сохранён по пути: {unet_checkpoint_path}")
81
  unet.load_state_dict(torch.load(unet_checkpoint_path, map_location=device))
82
 
83
  cloth_encoder = ClothEncoder.from_pretrained(
@@ -187,7 +178,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 900px}
187
  )
188
 
189
  if __name__ == "__main__":
190
- demo.queue(concurrency_count=1, max_size=2).launch(
191
  server_name="0.0.0.0" if os.getenv("SPACE_ID") else None,
192
  share=os.getenv("GRADIO_SHARE") == "True"
193
  )
 
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  weight_dtype = torch.float16 if device == "cuda" else torch.float32
30
 
 
 
 
31
  def load_models():
32
  """Загружает все необходимые модели"""
33
  print("⚙️ Загрузка моделей...")
 
64
  subfolder="unet"
65
  )
66
 
67
+ unet_checkpoint_path = hf_hub_download(
68
+ repo_id="Benrise/VITON-HD",
69
+ filename="VITONHD/model/pytorch_model.bin",
70
+ token=TOKEN
71
+ )
 
 
 
 
 
 
72
  unet.load_state_dict(torch.load(unet_checkpoint_path, map_location=device))
73
 
74
  cloth_encoder = ClothEncoder.from_pretrained(
 
178
  )
179
 
180
  if __name__ == "__main__":
181
+ demo.queue(max_size=1).launch(
182
  server_name="0.0.0.0" if os.getenv("SPACE_ID") else None,
183
  share=os.getenv("GRADIO_SHARE") == "True"
184
  )