Ali Mohsin commited on
Commit
bdd42db
·
1 Parent(s): 77b97f3
Files changed (1) hide show
  1. app.py +122 -57
app.py CHANGED
@@ -106,23 +106,41 @@ else:
106
  def apply_torchvision_fix():
107
  """Apply comprehensive fix for torchvision compatibility issues"""
108
  try:
 
 
109
  # Pre-emptively create torch.ops structure if needed
110
  if not hasattr(torch, 'ops'):
111
- import types
112
  torch.ops = types.SimpleNamespace()
113
 
114
  if not hasattr(torch.ops, 'torchvision'):
115
  torch.ops.torchvision = types.SimpleNamespace()
116
 
117
- # Create dummy nms function to prevent operator errors
118
- if not hasattr(torch.ops.torchvision, 'nms'):
119
- torch.ops.torchvision.nms = lambda *args, **kwargs: torch.zeros(0, dtype=torch.int64)
120
-
121
- # Additional torchvision operators that might cause issues
122
- torchvision_ops = ['roi_align', 'roi_pool', 'ps_roi_align', 'ps_roi_pool']
123
  for op_name in torchvision_ops:
124
  if not hasattr(torch.ops.torchvision, op_name):
125
- setattr(torch.ops.torchvision, op_name, lambda *args, **kwargs: torch.zeros(0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  print("Applied comprehensive torchvision compatibility fixes")
128
  return True
@@ -142,6 +160,9 @@ def try_import_loop():
142
  global loop, loop_import_error
143
 
144
  try:
 
 
 
145
  # Try to import torchvision with error handling
146
  try:
147
  import torchvision
@@ -162,12 +183,33 @@ def try_import_loop():
162
  print(f"torchvision still has issues, but continuing: {e2}")
163
  else:
164
  print(f"Other torchvision error: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  # Now try to import the loop module
167
- from loop import loop as loop_func
168
- loop = loop_func
169
- print("Successfully imported loop module")
170
- return True
 
 
 
 
 
 
 
 
171
 
172
  except ImportError as e:
173
  error_msg = f"ImportError: {e}"
@@ -296,47 +338,65 @@ def process_garment(input_type, text_prompt, base_text_prompt, mesh_target_image
296
  config = DEFAULT_CONFIG.copy()
297
 
298
  # Set up input parameters based on mode
299
- if input_type == "Image to Mesh" and mesh_target_image is not None:
 
 
 
300
  # Image-to-Mesh processing
301
  progress(0.05, desc="Preparing mesh generation from image...")
302
 
303
  # Save target image to temp directory
304
  target_mesh_image_path = os.path.join(temp_dir, "target_mesh_image.jpg")
305
 
306
- if isinstance(mesh_target_image, str):
307
- shutil.copy(mesh_target_image, target_mesh_image_path)
308
- elif isinstance(mesh_target_image, np.ndarray):
309
- img = Image.fromarray(mesh_target_image.astype(np.uint8))
310
- img.save(target_mesh_image_path)
311
- elif hasattr(mesh_target_image, 'save'):
312
- mesh_target_image.save(target_mesh_image_path)
313
- else:
314
- print(f"Unsupported image type: {type(mesh_target_image)}")
315
- return "Error: Could not process the uploaded image. Please try a different image format."
316
-
317
- print(f"Target mesh image saved to {target_mesh_image_path}")
318
-
319
- # Set mesh paths based on selected source mesh type
320
- # Map display names to actual file names
321
- mesh_mapping = {
322
- "tshirt": "tshirt",
323
- "longsleeve": "longsleeve",
324
- "tanktop": "tanktop",
325
- "poncho": "poncho",
326
- "dress_shortsleeve": "dress_shortsleeve"
327
- }
328
- mesh_file = mesh_mapping.get(source_mesh_type, "tshirt")
329
- source_mesh_file = f"./meshes/{mesh_file}.obj"
330
-
331
- # Configure for image-to-mesh processing
332
- config.update({
333
- 'mesh': source_mesh_file,
334
- 'image_prompt': target_mesh_image_path,
335
- 'base_image_prompt': target_mesh_image_path, # Use same image as base
336
- 'use_target_mesh': True,
337
- 'fashion_image': True,
338
- 'fashion_text': False,
339
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
340
 
341
  else:
342
  # Text-based processing
@@ -503,12 +563,13 @@ def create_interface():
503
  """)
504
 
505
  with gr.Row():
506
- with gr.Column():
507
  # Input type selector
508
  input_type = gr.Radio(
509
  choices=["Text", "Image to Mesh"],
510
  value="Text",
511
- label="Generation Method"
 
512
  )
513
 
514
  # Text inputs (visible by default)
@@ -525,21 +586,20 @@ def create_interface():
525
  value="simple t-shirt"
526
  )
527
 
528
-
529
-
530
  # Image to Mesh inputs (hidden by default)
531
  with gr.Group(visible=False) as image_to_mesh_group:
532
- gr.Markdown("### Upload Garment Image")
533
  mesh_target_image = gr.Image(
534
  label="Target Garment Image for Mesh Generation",
535
- sources=["upload", "clipboard"],
536
  type="numpy",
537
  interactive=True,
538
- height=300
 
539
  )
540
  gr.Markdown("*Upload an image of the garment to convert directly to a 3D mesh*")
541
 
542
- gr.Markdown("### Select Base Mesh Type")
543
  source_mesh_type = gr.Dropdown(
544
  label="Source Mesh Type",
545
  choices=["tshirt", "longsleeve", "tanktop", "poncho", "dress_shortsleeve"],
@@ -621,6 +681,7 @@ def create_interface():
621
 
622
  # Define a function to handle mode changes with clearer UI feedback
623
  def update_mode(mode):
 
624
  text_visibility = mode == "Text"
625
  image_to_mesh_visibility = mode == "Image to Mesh"
626
  status_msg = f"Mode changed to {mode}. "
@@ -629,6 +690,8 @@ def create_interface():
629
  status_msg += "Enter garment descriptions and click Generate."
630
  else:
631
  status_msg += "Upload a garment image and select mesh type, then click Generate."
 
 
632
 
633
  return (
634
  gr.update(visible=text_visibility),
@@ -656,7 +719,8 @@ def create_interface():
656
  input_type.change(
657
  fn=update_mode,
658
  inputs=[input_type],
659
- outputs=[text_group, image_to_mesh_group, status_output]
 
660
  )
661
 
662
  # Connect the button to the processing function with error handling
@@ -700,7 +764,8 @@ if __name__ == "__main__":
700
  server_name="0.0.0.0",
701
  server_port=7860,
702
  show_error=True,
703
- quiet=False
 
704
  )
705
  except Exception as e:
706
  print(f"Error launching interface: {e}")
 
106
  def apply_torchvision_fix():
107
  """Apply comprehensive fix for torchvision compatibility issues"""
108
  try:
109
+ import types
110
+
111
  # Pre-emptively create torch.ops structure if needed
112
  if not hasattr(torch, 'ops'):
 
113
  torch.ops = types.SimpleNamespace()
114
 
115
  if not hasattr(torch.ops, 'torchvision'):
116
  torch.ops.torchvision = types.SimpleNamespace()
117
 
118
+ # Create dummy functions for all problematic torchvision operators
119
+ torchvision_ops = ['nms', 'roi_align', 'roi_pool', 'ps_roi_align', 'ps_roi_pool']
 
 
 
 
120
  for op_name in torchvision_ops:
121
  if not hasattr(torch.ops.torchvision, op_name):
122
+ if op_name == 'nms':
123
+ setattr(torch.ops.torchvision, op_name, lambda *args, **kwargs: torch.zeros(0, dtype=torch.int64))
124
+ else:
125
+ setattr(torch.ops.torchvision, op_name, lambda *args, **kwargs: torch.zeros(0))
126
+
127
+ # Fix for torchvision extension issues
128
+ try:
129
+ import torchvision
130
+ if not hasattr(torchvision, 'extension'):
131
+ torchvision.extension = types.SimpleNamespace()
132
+ torchvision.extension._has_ops = lambda: False
133
+ except:
134
+ pass
135
+
136
+ # Fix for torchvision meta registrations
137
+ try:
138
+ if 'torchvision' in sys.modules:
139
+ torchvision = sys.modules['torchvision']
140
+ if not hasattr(torchvision, '_meta_registrations'):
141
+ torchvision._meta_registrations = types.SimpleNamespace()
142
+ except:
143
+ pass
144
 
145
  print("Applied comprehensive torchvision compatibility fixes")
146
  return True
 
160
  global loop, loop_import_error
161
 
162
  try:
163
+ # Apply torchvision fixes before any imports
164
+ apply_torchvision_fix()
165
+
166
  # Try to import torchvision with error handling
167
  try:
168
  import torchvision
 
183
  print(f"torchvision still has issues, but continuing: {e2}")
184
  else:
185
  print(f"Other torchvision error: {e}")
186
+
187
+ # Try to import required modules with fallbacks
188
+ try:
189
+ import nvdiffrast
190
+ print("✓ nvdiffrast imported")
191
+ except ImportError:
192
+ print("⚠ nvdiffrast not available, will use fallback")
193
+
194
+ try:
195
+ import pytorch3d
196
+ print("✓ pytorch3d imported")
197
+ except ImportError:
198
+ print("⚠ pytorch3d not available, will use fallback")
199
 
200
  # Now try to import the loop module
201
+ try:
202
+ from loop import loop as loop_func
203
+ loop = loop_func
204
+ print("Successfully imported loop module")
205
+ return True
206
+ except ImportError as e:
207
+ print(f"Loop module import failed: {e}")
208
+ # Create a dummy loop function for fallback
209
+ def dummy_loop(config):
210
+ raise RuntimeError("Processing engine not available. Please check dependencies.")
211
+ loop = dummy_loop
212
+ return True
213
 
214
  except ImportError as e:
215
  error_msg = f"ImportError: {e}"
 
338
  config = DEFAULT_CONFIG.copy()
339
 
340
  # Set up input parameters based on mode
341
+ if input_type == "Image to Mesh":
342
+ if mesh_target_image is None:
343
+ return "Error: Please upload an image for Image to Mesh mode."
344
+
345
  # Image-to-Mesh processing
346
  progress(0.05, desc="Preparing mesh generation from image...")
347
 
348
  # Save target image to temp directory
349
  target_mesh_image_path = os.path.join(temp_dir, "target_mesh_image.jpg")
350
 
351
+ try:
352
+ if isinstance(mesh_target_image, str):
353
+ shutil.copy(mesh_target_image, target_mesh_image_path)
354
+ elif isinstance(mesh_target_image, np.ndarray):
355
+ # Ensure the array is in the correct format
356
+ if len(mesh_target_image.shape) == 3:
357
+ if mesh_target_image.shape[2] == 4: # RGBA
358
+ mesh_target_image = mesh_target_image[:,:,:3] # Convert to RGB
359
+ img = Image.fromarray(mesh_target_image.astype(np.uint8))
360
+ img.save(target_mesh_image_path)
361
+ else:
362
+ return "Error: Invalid image format. Please upload a valid RGB image."
363
+ elif hasattr(mesh_target_image, 'save'):
364
+ mesh_target_image.save(target_mesh_image_path)
365
+ else:
366
+ print(f"Unsupported image type: {type(mesh_target_image)}")
367
+ return "Error: Could not process the uploaded image. Please try a different image format."
368
+
369
+ print(f"Target mesh image saved to {target_mesh_image_path}")
370
+
371
+ # Set mesh paths based on selected source mesh type
372
+ # Map display names to actual file names
373
+ mesh_mapping = {
374
+ "tshirt": "tshirt",
375
+ "longsleeve": "longsleeve",
376
+ "tanktop": "tanktop",
377
+ "poncho": "poncho",
378
+ "dress_shortsleeve": "dress_shortsleeve"
379
+ }
380
+ mesh_file = mesh_mapping.get(source_mesh_type, "tshirt")
381
+ source_mesh_file = f"./meshes/{mesh_file}.obj"
382
+
383
+ # Check if the mesh file exists
384
+ if not os.path.exists(source_mesh_file):
385
+ return f"Error: Mesh file {source_mesh_file} not found. Please check if the mesh files are available."
386
+
387
+ # Configure for image-to-mesh processing
388
+ config.update({
389
+ 'mesh': source_mesh_file,
390
+ 'image_prompt': target_mesh_image_path,
391
+ 'base_image_prompt': target_mesh_image_path, # Use same image as base
392
+ 'use_target_mesh': True,
393
+ 'fashion_image': True,
394
+ 'fashion_text': False,
395
+ })
396
+
397
+ except Exception as e:
398
+ print(f"Error processing image: {e}")
399
+ return f"Error: Failed to process the uploaded image: {str(e)}"
400
 
401
  else:
402
  # Text-based processing
 
563
  """)
564
 
565
  with gr.Row():
566
+ with gr.Column(scale=1):
567
  # Input type selector
568
  input_type = gr.Radio(
569
  choices=["Text", "Image to Mesh"],
570
  value="Text",
571
+ label="Generation Method",
572
+ interactive=True
573
  )
574
 
575
  # Text inputs (visible by default)
 
586
  value="simple t-shirt"
587
  )
588
 
 
 
589
  # Image to Mesh inputs (hidden by default)
590
  with gr.Group(visible=False) as image_to_mesh_group:
591
+ gr.Markdown("### 📸 Upload Garment Image")
592
  mesh_target_image = gr.Image(
593
  label="Target Garment Image for Mesh Generation",
594
+ sources=["upload", "clipboard", "webcam"],
595
  type="numpy",
596
  interactive=True,
597
+ height=300,
598
+ show_label=True
599
  )
600
  gr.Markdown("*Upload an image of the garment to convert directly to a 3D mesh*")
601
 
602
+ gr.Markdown("### 🎯 Select Base Mesh Type")
603
  source_mesh_type = gr.Dropdown(
604
  label="Source Mesh Type",
605
  choices=["tshirt", "longsleeve", "tanktop", "poncho", "dress_shortsleeve"],
 
681
 
682
  # Define a function to handle mode changes with clearer UI feedback
683
  def update_mode(mode):
684
+ print(f"Mode changed to: {mode}")
685
  text_visibility = mode == "Text"
686
  image_to_mesh_visibility = mode == "Image to Mesh"
687
  status_msg = f"Mode changed to {mode}. "
 
690
  status_msg += "Enter garment descriptions and click Generate."
691
  else:
692
  status_msg += "Upload a garment image and select mesh type, then click Generate."
693
+
694
+ print(f"Text visibility: {text_visibility}, Image to Mesh visibility: {image_to_mesh_visibility}")
695
 
696
  return (
697
  gr.update(visible=text_visibility),
 
719
  input_type.change(
720
  fn=update_mode,
721
  inputs=[input_type],
722
+ outputs=[text_group, image_to_mesh_group, status_output],
723
+ show_progress=True
724
  )
725
 
726
  # Connect the button to the processing function with error handling
 
764
  server_name="0.0.0.0",
765
  server_port=7860,
766
  show_error=True,
767
+ quiet=False,
768
+ debug=True
769
  )
770
  except Exception as e:
771
  print(f"Error launching interface: {e}")