Ali Mohsin commited on
Commit ·
224f0c5
1
Parent(s): 6b38677
refactor: simplify garment model initialization and reloading logic
Browse files
app.py
CHANGED
|
@@ -170,30 +170,18 @@ def train_garment(video_path, garment_name):
|
|
| 170 |
# --- Inference Logic ---
|
| 171 |
|
| 172 |
def init_processor(garment_name):
|
| 173 |
-
global frame_processor, current_garment_id
|
| 174 |
if garment_name is None:
|
| 175 |
-
return
|
| 176 |
|
| 177 |
-
print(f"Loading garment: {garment_name}")
|
| 178 |
-
# Initialize
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
# We need to construct a list where the index matches.
|
| 183 |
-
# For simplicity, we just re-instantiate or hold a list of 1.
|
| 184 |
-
# But FrameProcessor holds a list. Let's make it hold just the current one for simplicity.
|
| 185 |
-
frame_processor = FrameProcessor([garment_name], ckpt_dir='./checkpoints')
|
| 186 |
-
# Note: FrameProcessor expects 'rtv_ckpts' for pretrained?
|
| 187 |
-
# Check source: it uses opt.checkpoints_dir which defaults to rtv_ckpts in make_pix2pix_model
|
| 188 |
-
# But we passed ckpt_dir='./checkpoints'.
|
| 189 |
-
# We need to make sure pretrained and new ones are found.
|
| 190 |
-
# Simpler approach: symlink rtv_ckpts content to checkpoints
|
| 191 |
-
else:
|
| 192 |
-
# Re-init for simplicity as ID mapping is complex dynamically
|
| 193 |
-
frame_processor = FrameProcessor([garment_name], ckpt_dir='./checkpoints') # or './rtv_ckpts'
|
| 194 |
-
|
| 195 |
# Trigger load
|
| 196 |
-
|
|
|
|
| 197 |
|
| 198 |
@spaces.GPU
|
| 199 |
def process_frame(image, garment_name, enable_tryon):
|
|
@@ -204,10 +192,17 @@ def process_frame(image, garment_name, enable_tryon):
|
|
| 204 |
return image
|
| 205 |
|
| 206 |
global frame_processor
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
# Link Pretrained to checkpoints if needed
|
| 212 |
if os.path.exists(f"rtv_ckpts/{garment_name}") and not os.path.exists(f"checkpoints/{garment_name}"):
|
| 213 |
if not os.path.exists("checkpoints"): os.makedirs("checkpoints")
|
|
@@ -216,10 +211,10 @@ def process_frame(image, garment_name, enable_tryon):
|
|
| 216 |
import shutil
|
| 217 |
shutil.copytree(f"rtv_ckpts/{garment_name}", f"checkpoints/{garment_name}")
|
| 218 |
|
| 219 |
-
init_processor(garment_name)
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
|
| 224 |
# Convert to RGB (Gradio is RGB, OpenCV is BGR)
|
| 225 |
# RTV expects BGR usually? checking rtl_demo...
|
|
|
|
| 170 |
# --- Inference Logic ---
|
| 171 |
|
| 172 |
def init_processor(garment_name):
|
| 173 |
+
# global frame_processor, current_garment_id # Avoid global reliance in helper
|
| 174 |
if garment_name is None:
|
| 175 |
+
return None
|
| 176 |
|
| 177 |
+
print(f"Loading garment: {garment_name}", flush=True)
|
| 178 |
+
# Initialize
|
| 179 |
+
# Always create new for now to ensure we have the right one
|
| 180 |
+
processor = FrameProcessor([garment_name], ckpt_dir='./checkpoints')
|
| 181 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
# Trigger load
|
| 183 |
+
processor.switch_to_target_garment(0)
|
| 184 |
+
return processor
|
| 185 |
|
| 186 |
@spaces.GPU
|
| 187 |
def process_frame(image, garment_name, enable_tryon):
|
|
|
|
| 192 |
return image
|
| 193 |
|
| 194 |
global frame_processor
|
| 195 |
+
|
| 196 |
+
try:
|
| 197 |
+
# Check if we need to load/reload
|
| 198 |
+
# We need to treat frame_processor as potentially stale or None
|
| 199 |
+
should_reload = False
|
| 200 |
+
if frame_processor is None:
|
| 201 |
+
should_reload = True
|
| 202 |
+
elif frame_processor.garment_name_list[0] != garment_name:
|
| 203 |
+
should_reload = True
|
| 204 |
+
|
| 205 |
+
if should_reload:
|
| 206 |
# Link Pretrained to checkpoints if needed
|
| 207 |
if os.path.exists(f"rtv_ckpts/{garment_name}") and not os.path.exists(f"checkpoints/{garment_name}"):
|
| 208 |
if not os.path.exists("checkpoints"): os.makedirs("checkpoints")
|
|
|
|
| 211 |
import shutil
|
| 212 |
shutil.copytree(f"rtv_ckpts/{garment_name}", f"checkpoints/{garment_name}")
|
| 213 |
|
| 214 |
+
frame_processor = init_processor(garment_name)
|
| 215 |
+
except Exception as e:
|
| 216 |
+
print(f"Error loading model: {e}", flush=True)
|
| 217 |
+
return image
|
| 218 |
|
| 219 |
# Convert to RGB (Gradio is RGB, OpenCV is BGR)
|
| 220 |
# RTV expects BGR usually? checking rtl_demo...
|