Ali Mohsin commited on
Commit
224f0c5
·
1 Parent(s): 6b38677

refactor: simplify garment model initialization and reloading logic

Browse files
Files changed (1) hide show
  1. app.py +24 -29
app.py CHANGED
@@ -170,30 +170,18 @@ def train_garment(video_path, garment_name):
170
  # --- Inference Logic ---
171
 
172
  def init_processor(garment_name):
173
- global frame_processor, current_garment_id
174
  if garment_name is None:
175
- return
176
 
177
- print(f"Loading garment: {garment_name}")
178
- # Initialize if needed
179
- if frame_processor is None:
180
- # Create processor with list. We just need to give it the name we want to load.
181
- # It handles loading from ./checkpoints or ./rtv_ckpts
182
- # We need to construct a list where the index matches.
183
- # For simplicity, we just re-instantiate or hold a list of 1.
184
- # But FrameProcessor holds a list. Let's make it hold just the current one for simplicity.
185
- frame_processor = FrameProcessor([garment_name], ckpt_dir='./checkpoints')
186
- # Note: FrameProcessor expects 'rtv_ckpts' for pretrained?
187
- # Check source: it uses opt.checkpoints_dir which defaults to rtv_ckpts in make_pix2pix_model
188
- # But we passed ckpt_dir='./checkpoints'.
189
- # We need to make sure pretrained and new ones are found.
190
- # Simpler approach: symlink rtv_ckpts content to checkpoints
191
- else:
192
- # Re-init for simplicity as ID mapping is complex dynamically
193
- frame_processor = FrameProcessor([garment_name], ckpt_dir='./checkpoints') # or './rtv_ckpts'
194
-
195
  # Trigger load
196
- frame_processor.switch_to_target_garment(0)
 
197
 
198
  @spaces.GPU
199
  def process_frame(image, garment_name, enable_tryon):
@@ -204,10 +192,17 @@ def process_frame(image, garment_name, enable_tryon):
204
  return image
205
 
206
  global frame_processor
207
- # Check if we need to load/reload
208
- # This is a basic check. Real app needs better state management.
209
- if frame_processor is None or frame_processor.garment_name_list[0] != garment_name:
210
- try:
 
 
 
 
 
 
 
211
  # Link Pretrained to checkpoints if needed
212
  if os.path.exists(f"rtv_ckpts/{garment_name}") and not os.path.exists(f"checkpoints/{garment_name}"):
213
  if not os.path.exists("checkpoints"): os.makedirs("checkpoints")
@@ -216,10 +211,10 @@ def process_frame(image, garment_name, enable_tryon):
216
  import shutil
217
  shutil.copytree(f"rtv_ckpts/{garment_name}", f"checkpoints/{garment_name}")
218
 
219
- init_processor(garment_name)
220
- except Exception as e:
221
- print(f"Error loading model: {e}")
222
- return image
223
 
224
  # Convert to RGB (Gradio is RGB, OpenCV is BGR)
225
  # RTV expects BGR usually? checking rtl_demo...
 
170
  # --- Inference Logic ---
171
 
172
  def init_processor(garment_name):
173
+ # global frame_processor, current_garment_id # Avoid global reliance in helper
174
  if garment_name is None:
175
+ return None
176
 
177
+ print(f"Loading garment: {garment_name}", flush=True)
178
+ # Initialize
179
+ # Always create new for now to ensure we have the right one
180
+ processor = FrameProcessor([garment_name], ckpt_dir='./checkpoints')
181
+
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  # Trigger load
183
+ processor.switch_to_target_garment(0)
184
+ return processor
185
 
186
  @spaces.GPU
187
  def process_frame(image, garment_name, enable_tryon):
 
192
  return image
193
 
194
  global frame_processor
195
+
196
+ try:
197
+ # Check if we need to load/reload
198
+ # We need to treat frame_processor as potentially stale or None
199
+ should_reload = False
200
+ if frame_processor is None:
201
+ should_reload = True
202
+ elif frame_processor.garment_name_list[0] != garment_name:
203
+ should_reload = True
204
+
205
+ if should_reload:
206
  # Link Pretrained to checkpoints if needed
207
  if os.path.exists(f"rtv_ckpts/{garment_name}") and not os.path.exists(f"checkpoints/{garment_name}"):
208
  if not os.path.exists("checkpoints"): os.makedirs("checkpoints")
 
211
  import shutil
212
  shutil.copytree(f"rtv_ckpts/{garment_name}", f"checkpoints/{garment_name}")
213
 
214
+ frame_processor = init_processor(garment_name)
215
+ except Exception as e:
216
+ print(f"Error loading model: {e}", flush=True)
217
+ return image
218
 
219
  # Convert to RGB (Gradio is RGB, OpenCV is BGR)
220
  # RTV expects BGR usually? checking rtl_demo...