Hunsain Mazhar commited on
Commit
39d25c5
Β·
1 Parent(s): 786d386

Improve model downloading and error handling; added robust download logic and enhanced memory management

Browse files
Files changed (1) hide show
  1. app.py +68 -21
app.py CHANGED
@@ -1,16 +1,17 @@
1
  import sys
2
  import os
3
- import gc # <--- ADDED: Garbage Collection
 
4
 
5
  # --- 1. System Setup & Error Handling ---
 
6
  try:
7
  import detectron2
8
  except ImportError:
 
9
  os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
10
 
11
  import requests
12
- from requests.adapters import HTTPAdapter
13
- from urllib3.util.retry import Retry
14
  import gradio as gr
15
  import spaces
16
  from PIL import Image
@@ -18,6 +19,7 @@ import numpy as np
18
  import torch
19
  from torchvision import transforms
20
  from torchvision.transforms.functional import to_pil_image
 
21
 
22
  sys.path.append('./')
23
 
@@ -32,7 +34,7 @@ try:
32
  from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
33
  import apply_net
34
  except ImportError as e:
35
- raise ImportError(f"CRITICAL ERROR: Missing core modules. {e}")
36
 
37
  from transformers import (
38
  CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPTextModel,
@@ -41,20 +43,67 @@ from transformers import (
41
  from diffusers import DDPMScheduler, AutoencoderKL
42
 
43
  # ---------------------------------------------------------
44
- # 2. DOWNLOADER
45
  # ---------------------------------------------------------
46
- def download_file(url, path):
47
- if os.path.exists(path): return
48
- # ... (Keep existing downloader logic if you wish, or use the robust one from before)
49
- # For brevity, assuming files exist or you use the previous robust downloader code here.
50
- # If not, paste the 'download_file' function from the previous response here.
51
- print(f"Checking {path}...")
52
- if not os.path.exists(path):
53
- os.system(f"wget -O {path} {url}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  def check_and_download_models():
56
- # ... (Same file list as before)
57
- pass # Call your download logic here
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  # ---------------------------------------------------------
60
  # 3. LOAD MODELS
@@ -71,9 +120,11 @@ def load_models():
71
  vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16)
72
  UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16)
73
 
 
74
  parsing_model = Parsing(0)
75
  openpose_model = OpenPose(0)
76
 
 
77
  UNet_Encoder.requires_grad_(False)
78
  image_encoder.requires_grad_(False)
79
  vae.requires_grad_(False)
@@ -94,15 +145,13 @@ pipe, openpose_model, parsing_model = load_models()
94
  tensor_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
95
 
96
  # ---------------------------------------------------------
97
- # 4. INFERENCE (Fixed Memory Leak)
98
  # ---------------------------------------------------------
99
- # Increase duration to 120s to prevent timeouts
100
  @spaces.GPU(duration=120)
101
  def start_tryon(human_img, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed):
102
  device = "cuda"
103
 
104
  try:
105
- # Move models to GPU
106
  openpose_model.preprocessor.body_estimation.model.to(device)
107
  pipe.to(device)
108
  pipe.unet_encoder.to(device)
@@ -185,9 +234,7 @@ def start_tryon(human_img, garm_img, garment_des, is_checked, is_checked_crop, d
185
  raise gr.Error(f"Error: {e}")
186
 
187
  finally:
188
- # --- CRITICAL MEMORY CLEANUP ---
189
- # This code runs no matter what, preventing the "3-4 run crash"
190
- print("Cleaning GPU memory...")
191
  try:
192
  del keypoints, model_parse, mask, pose_img, prompt_embeds, garm_tensor
193
  except:
 
1
  import sys
2
  import os
3
+ import gc
4
+ import shutil
5
 
6
  # --- 1. System Setup & Error Handling ---
7
+ # Force install detectron2 if missing
8
  try:
9
  import detectron2
10
  except ImportError:
11
+ print("⚠️ Detectron2 missing. Installing...")
12
  os.system('pip install git+https://github.com/facebookresearch/detectron2.git')
13
 
14
  import requests
 
 
15
  import gradio as gr
16
  import spaces
17
  from PIL import Image
 
19
  import torch
20
  from torchvision import transforms
21
  from torchvision.transforms.functional import to_pil_image
22
+ from huggingface_hub import hf_hub_download
23
 
24
  sys.path.append('./')
25
 
 
34
  from detectron2.data.detection_utils import convert_PIL_to_numpy, _apply_exif_orientation
35
  import apply_net
36
  except ImportError as e:
37
+ raise ImportError(f"CRITICAL ERROR: Missing core modules. Error: {e}")
38
 
39
  from transformers import (
40
  CLIPImageProcessor, CLIPVisionModelWithProjection, CLIPTextModel,
 
43
  from diffusers import DDPMScheduler, AutoencoderKL
44
 
45
  # ---------------------------------------------------------
46
+ # 2. ROBUST MODEL DOWNLOADER (The Fix)
47
  # ---------------------------------------------------------
48
+ def download_model_robust(repo_id, filename, local_path):
49
+ if os.path.exists(local_path):
50
+ # Quick size check to ensure it's not an empty corrupt file
51
+ if os.path.getsize(local_path) > 1000:
52
+ print(f"βœ… Found {local_path}")
53
+ return
54
+ else:
55
+ print(f"⚠️ Corrupt file found at {local_path}, redownloading...")
56
+ os.remove(local_path)
57
+
58
+ print(f"⬇️ Downloading {filename} to {local_path}...")
59
+ try:
60
+ # Create directory
61
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
62
+
63
+ # Download using Hugging Face Hub (Fast & Cached)
64
+ downloaded_file = hf_hub_download(
65
+ repo_id=repo_id,
66
+ filename=filename,
67
+ local_dir=os.path.dirname(local_path),
68
+ local_dir_use_symlinks=False
69
+ )
70
+
71
+ # If the filename in repo is different from target, rename it
72
+ # (hf_hub_download saves to local_dir/filename)
73
+ actual_download_path = os.path.join(os.path.dirname(local_path), filename)
74
+ if actual_download_path != local_path:
75
+ # Move it to the exact expected path if different
76
+ if os.path.exists(actual_download_path):
77
+ shutil.move(actual_download_path, local_path)
78
+
79
+ print(f"βœ… Successfully downloaded {local_path}")
80
+
81
+ except Exception as e:
82
+ print(f"❌ Failed to download {filename}: {e}")
83
+ # Manual Fallback for complex paths
84
+ try:
85
+ url = f"https://huggingface.co/{repo_id}/resolve/main/{filename}"
86
+ print(f"πŸ”„ Trying direct URL fallback: {url}")
87
+ os.system(f"wget -O {local_path} {url}")
88
+ except:
89
+ pass
90
 
91
  def check_and_download_models():
92
+ print("⏳ VALIDATING MODELS...")
93
+
94
+ # 1. Parsing & OpenPose (From Camenduru)
95
+ download_model_robust("camenduru/IDM-VTON", "humanparsing/parsing_atr.onnx", "ckpt/humanparsing/parsing_atr.onnx")
96
+ download_model_robust("camenduru/IDM-VTON", "humanparsing/parsing_lip.onnx", "ckpt/humanparsing/parsing_lip.onnx")
97
+ download_model_robust("camenduru/IDM-VTON", "densepose/model_final_162be9.pkl", "ckpt/densepose/model_final_162be9.pkl")
98
+ download_model_robust("camenduru/IDM-VTON", "openpose/ckpts/body_pose_model.pth", "ckpt/openpose/ckpts/body_pose_model.pth")
99
+
100
+ # 2. IP Adapter (From h94)
101
+ download_model_robust("h94/IP-Adapter", "sdxl_models/ip-adapter-plus_sdxl_vit-h.bin", "ckpt/ip_adapter/ip-adapter-plus_sdxl_vit-h.bin")
102
+ download_model_robust("h94/IP-Adapter", "models/image_encoder/config.json", "ckpt/image_encoder/config.json")
103
+ download_model_robust("h94/IP-Adapter", "models/image_encoder/pytorch_model.bin", "ckpt/image_encoder/pytorch_model.bin")
104
+
105
+ # EXECUTE DOWNLOAD BEFORE LOADING ANYTHING
106
+ check_and_download_models()
107
 
108
  # ---------------------------------------------------------
109
  # 3. LOAD MODELS
 
120
  vae = AutoencoderKL.from_pretrained(base_path, subfolder="vae", torch_dtype=torch.float16)
121
  UNet_Encoder = UNet2DConditionModel_ref.from_pretrained(base_path, subfolder="unet_encoder", torch_dtype=torch.float16)
122
 
123
+ # Initialize Preprocessors
124
  parsing_model = Parsing(0)
125
  openpose_model = OpenPose(0)
126
 
127
+ # Freeze Weights
128
  UNet_Encoder.requires_grad_(False)
129
  image_encoder.requires_grad_(False)
130
  vae.requires_grad_(False)
 
145
  tensor_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
146
 
147
  # ---------------------------------------------------------
148
+ # 4. INFERENCE
149
  # ---------------------------------------------------------
 
150
  @spaces.GPU(duration=120)
151
  def start_tryon(human_img, garm_img, garment_des, is_checked, is_checked_crop, denoise_steps, seed):
152
  device = "cuda"
153
 
154
  try:
 
155
  openpose_model.preprocessor.body_estimation.model.to(device)
156
  pipe.to(device)
157
  pipe.unet_encoder.to(device)
 
234
  raise gr.Error(f"Error: {e}")
235
 
236
  finally:
237
+ # Memory Cleanup
 
 
238
  try:
239
  del keypoints, model_parse, mask, pose_img, prompt_embeds, garm_tensor
240
  except: