Spaces:
Paused
Paused
Ali Mohsin
commited on
Commit
·
bdd42db
1
Parent(s):
77b97f3
fixes 9.0
Browse files
app.py
CHANGED
|
@@ -106,23 +106,41 @@ else:
|
|
| 106 |
def apply_torchvision_fix():
|
| 107 |
"""Apply comprehensive fix for torchvision compatibility issues"""
|
| 108 |
try:
|
|
|
|
|
|
|
| 109 |
# Pre-emptively create torch.ops structure if needed
|
| 110 |
if not hasattr(torch, 'ops'):
|
| 111 |
-
import types
|
| 112 |
torch.ops = types.SimpleNamespace()
|
| 113 |
|
| 114 |
if not hasattr(torch.ops, 'torchvision'):
|
| 115 |
torch.ops.torchvision = types.SimpleNamespace()
|
| 116 |
|
| 117 |
-
# Create dummy
|
| 118 |
-
|
| 119 |
-
torch.ops.torchvision.nms = lambda *args, **kwargs: torch.zeros(0, dtype=torch.int64)
|
| 120 |
-
|
| 121 |
-
# Additional torchvision operators that might cause issues
|
| 122 |
-
torchvision_ops = ['roi_align', 'roi_pool', 'ps_roi_align', 'ps_roi_pool']
|
| 123 |
for op_name in torchvision_ops:
|
| 124 |
if not hasattr(torch.ops.torchvision, op_name):
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
print("Applied comprehensive torchvision compatibility fixes")
|
| 128 |
return True
|
|
@@ -142,6 +160,9 @@ def try_import_loop():
|
|
| 142 |
global loop, loop_import_error
|
| 143 |
|
| 144 |
try:
|
|
|
|
|
|
|
|
|
|
| 145 |
# Try to import torchvision with error handling
|
| 146 |
try:
|
| 147 |
import torchvision
|
|
@@ -162,12 +183,33 @@ def try_import_loop():
|
|
| 162 |
print(f"torchvision still has issues, but continuing: {e2}")
|
| 163 |
else:
|
| 164 |
print(f"Other torchvision error: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
# Now try to import the loop module
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
except ImportError as e:
|
| 173 |
error_msg = f"ImportError: {e}"
|
|
@@ -296,47 +338,65 @@ def process_garment(input_type, text_prompt, base_text_prompt, mesh_target_image
|
|
| 296 |
config = DEFAULT_CONFIG.copy()
|
| 297 |
|
| 298 |
# Set up input parameters based on mode
|
| 299 |
-
if input_type == "Image to Mesh"
|
|
|
|
|
|
|
|
|
|
| 300 |
# Image-to-Mesh processing
|
| 301 |
progress(0.05, desc="Preparing mesh generation from image...")
|
| 302 |
|
| 303 |
# Save target image to temp directory
|
| 304 |
target_mesh_image_path = os.path.join(temp_dir, "target_mesh_image.jpg")
|
| 305 |
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
"
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
|
| 341 |
else:
|
| 342 |
# Text-based processing
|
|
@@ -503,12 +563,13 @@ def create_interface():
|
|
| 503 |
""")
|
| 504 |
|
| 505 |
with gr.Row():
|
| 506 |
-
with gr.Column():
|
| 507 |
# Input type selector
|
| 508 |
input_type = gr.Radio(
|
| 509 |
choices=["Text", "Image to Mesh"],
|
| 510 |
value="Text",
|
| 511 |
-
label="Generation Method"
|
|
|
|
| 512 |
)
|
| 513 |
|
| 514 |
# Text inputs (visible by default)
|
|
@@ -525,21 +586,20 @@ def create_interface():
|
|
| 525 |
value="simple t-shirt"
|
| 526 |
)
|
| 527 |
|
| 528 |
-
|
| 529 |
-
|
| 530 |
# Image to Mesh inputs (hidden by default)
|
| 531 |
with gr.Group(visible=False) as image_to_mesh_group:
|
| 532 |
-
gr.Markdown("### Upload Garment Image")
|
| 533 |
mesh_target_image = gr.Image(
|
| 534 |
label="Target Garment Image for Mesh Generation",
|
| 535 |
-
sources=["upload", "clipboard"],
|
| 536 |
type="numpy",
|
| 537 |
interactive=True,
|
| 538 |
-
height=300
|
|
|
|
| 539 |
)
|
| 540 |
gr.Markdown("*Upload an image of the garment to convert directly to a 3D mesh*")
|
| 541 |
|
| 542 |
-
gr.Markdown("### Select Base Mesh Type")
|
| 543 |
source_mesh_type = gr.Dropdown(
|
| 544 |
label="Source Mesh Type",
|
| 545 |
choices=["tshirt", "longsleeve", "tanktop", "poncho", "dress_shortsleeve"],
|
|
@@ -621,6 +681,7 @@ def create_interface():
|
|
| 621 |
|
| 622 |
# Define a function to handle mode changes with clearer UI feedback
|
| 623 |
def update_mode(mode):
|
|
|
|
| 624 |
text_visibility = mode == "Text"
|
| 625 |
image_to_mesh_visibility = mode == "Image to Mesh"
|
| 626 |
status_msg = f"Mode changed to {mode}. "
|
|
@@ -629,6 +690,8 @@ def create_interface():
|
|
| 629 |
status_msg += "Enter garment descriptions and click Generate."
|
| 630 |
else:
|
| 631 |
status_msg += "Upload a garment image and select mesh type, then click Generate."
|
|
|
|
|
|
|
| 632 |
|
| 633 |
return (
|
| 634 |
gr.update(visible=text_visibility),
|
|
@@ -656,7 +719,8 @@ def create_interface():
|
|
| 656 |
input_type.change(
|
| 657 |
fn=update_mode,
|
| 658 |
inputs=[input_type],
|
| 659 |
-
outputs=[text_group, image_to_mesh_group, status_output]
|
|
|
|
| 660 |
)
|
| 661 |
|
| 662 |
# Connect the button to the processing function with error handling
|
|
@@ -700,7 +764,8 @@ if __name__ == "__main__":
|
|
| 700 |
server_name="0.0.0.0",
|
| 701 |
server_port=7860,
|
| 702 |
show_error=True,
|
| 703 |
-
quiet=False
|
|
|
|
| 704 |
)
|
| 705 |
except Exception as e:
|
| 706 |
print(f"Error launching interface: {e}")
|
|
|
|
| 106 |
def apply_torchvision_fix():
|
| 107 |
"""Apply comprehensive fix for torchvision compatibility issues"""
|
| 108 |
try:
|
| 109 |
+
import types
|
| 110 |
+
|
| 111 |
# Pre-emptively create torch.ops structure if needed
|
| 112 |
if not hasattr(torch, 'ops'):
|
|
|
|
| 113 |
torch.ops = types.SimpleNamespace()
|
| 114 |
|
| 115 |
if not hasattr(torch.ops, 'torchvision'):
|
| 116 |
torch.ops.torchvision = types.SimpleNamespace()
|
| 117 |
|
| 118 |
+
# Create dummy functions for all problematic torchvision operators
|
| 119 |
+
torchvision_ops = ['nms', 'roi_align', 'roi_pool', 'ps_roi_align', 'ps_roi_pool']
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
for op_name in torchvision_ops:
|
| 121 |
if not hasattr(torch.ops.torchvision, op_name):
|
| 122 |
+
if op_name == 'nms':
|
| 123 |
+
setattr(torch.ops.torchvision, op_name, lambda *args, **kwargs: torch.zeros(0, dtype=torch.int64))
|
| 124 |
+
else:
|
| 125 |
+
setattr(torch.ops.torchvision, op_name, lambda *args, **kwargs: torch.zeros(0))
|
| 126 |
+
|
| 127 |
+
# Fix for torchvision extension issues
|
| 128 |
+
try:
|
| 129 |
+
import torchvision
|
| 130 |
+
if not hasattr(torchvision, 'extension'):
|
| 131 |
+
torchvision.extension = types.SimpleNamespace()
|
| 132 |
+
torchvision.extension._has_ops = lambda: False
|
| 133 |
+
except:
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
# Fix for torchvision meta registrations
|
| 137 |
+
try:
|
| 138 |
+
if 'torchvision' in sys.modules:
|
| 139 |
+
torchvision = sys.modules['torchvision']
|
| 140 |
+
if not hasattr(torchvision, '_meta_registrations'):
|
| 141 |
+
torchvision._meta_registrations = types.SimpleNamespace()
|
| 142 |
+
except:
|
| 143 |
+
pass
|
| 144 |
|
| 145 |
print("Applied comprehensive torchvision compatibility fixes")
|
| 146 |
return True
|
|
|
|
| 160 |
global loop, loop_import_error
|
| 161 |
|
| 162 |
try:
|
| 163 |
+
# Apply torchvision fixes before any imports
|
| 164 |
+
apply_torchvision_fix()
|
| 165 |
+
|
| 166 |
# Try to import torchvision with error handling
|
| 167 |
try:
|
| 168 |
import torchvision
|
|
|
|
| 183 |
print(f"torchvision still has issues, but continuing: {e2}")
|
| 184 |
else:
|
| 185 |
print(f"Other torchvision error: {e}")
|
| 186 |
+
|
| 187 |
+
# Try to import required modules with fallbacks
|
| 188 |
+
try:
|
| 189 |
+
import nvdiffrast
|
| 190 |
+
print("✓ nvdiffrast imported")
|
| 191 |
+
except ImportError:
|
| 192 |
+
print("⚠ nvdiffrast not available, will use fallback")
|
| 193 |
+
|
| 194 |
+
try:
|
| 195 |
+
import pytorch3d
|
| 196 |
+
print("✓ pytorch3d imported")
|
| 197 |
+
except ImportError:
|
| 198 |
+
print("⚠ pytorch3d not available, will use fallback")
|
| 199 |
|
| 200 |
# Now try to import the loop module
|
| 201 |
+
try:
|
| 202 |
+
from loop import loop as loop_func
|
| 203 |
+
loop = loop_func
|
| 204 |
+
print("Successfully imported loop module")
|
| 205 |
+
return True
|
| 206 |
+
except ImportError as e:
|
| 207 |
+
print(f"Loop module import failed: {e}")
|
| 208 |
+
# Create a dummy loop function for fallback
|
| 209 |
+
def dummy_loop(config):
|
| 210 |
+
raise RuntimeError("Processing engine not available. Please check dependencies.")
|
| 211 |
+
loop = dummy_loop
|
| 212 |
+
return True
|
| 213 |
|
| 214 |
except ImportError as e:
|
| 215 |
error_msg = f"ImportError: {e}"
|
|
|
|
| 338 |
config = DEFAULT_CONFIG.copy()
|
| 339 |
|
| 340 |
# Set up input parameters based on mode
|
| 341 |
+
if input_type == "Image to Mesh":
|
| 342 |
+
if mesh_target_image is None:
|
| 343 |
+
return "Error: Please upload an image for Image to Mesh mode."
|
| 344 |
+
|
| 345 |
# Image-to-Mesh processing
|
| 346 |
progress(0.05, desc="Preparing mesh generation from image...")
|
| 347 |
|
| 348 |
# Save target image to temp directory
|
| 349 |
target_mesh_image_path = os.path.join(temp_dir, "target_mesh_image.jpg")
|
| 350 |
|
| 351 |
+
try:
|
| 352 |
+
if isinstance(mesh_target_image, str):
|
| 353 |
+
shutil.copy(mesh_target_image, target_mesh_image_path)
|
| 354 |
+
elif isinstance(mesh_target_image, np.ndarray):
|
| 355 |
+
# Ensure the array is in the correct format
|
| 356 |
+
if len(mesh_target_image.shape) == 3:
|
| 357 |
+
if mesh_target_image.shape[2] == 4: # RGBA
|
| 358 |
+
mesh_target_image = mesh_target_image[:,:,:3] # Convert to RGB
|
| 359 |
+
img = Image.fromarray(mesh_target_image.astype(np.uint8))
|
| 360 |
+
img.save(target_mesh_image_path)
|
| 361 |
+
else:
|
| 362 |
+
return "Error: Invalid image format. Please upload a valid RGB image."
|
| 363 |
+
elif hasattr(mesh_target_image, 'save'):
|
| 364 |
+
mesh_target_image.save(target_mesh_image_path)
|
| 365 |
+
else:
|
| 366 |
+
print(f"Unsupported image type: {type(mesh_target_image)}")
|
| 367 |
+
return "Error: Could not process the uploaded image. Please try a different image format."
|
| 368 |
+
|
| 369 |
+
print(f"Target mesh image saved to {target_mesh_image_path}")
|
| 370 |
+
|
| 371 |
+
# Set mesh paths based on selected source mesh type
|
| 372 |
+
# Map display names to actual file names
|
| 373 |
+
mesh_mapping = {
|
| 374 |
+
"tshirt": "tshirt",
|
| 375 |
+
"longsleeve": "longsleeve",
|
| 376 |
+
"tanktop": "tanktop",
|
| 377 |
+
"poncho": "poncho",
|
| 378 |
+
"dress_shortsleeve": "dress_shortsleeve"
|
| 379 |
+
}
|
| 380 |
+
mesh_file = mesh_mapping.get(source_mesh_type, "tshirt")
|
| 381 |
+
source_mesh_file = f"./meshes/{mesh_file}.obj"
|
| 382 |
+
|
| 383 |
+
# Check if the mesh file exists
|
| 384 |
+
if not os.path.exists(source_mesh_file):
|
| 385 |
+
return f"Error: Mesh file {source_mesh_file} not found. Please check if the mesh files are available."
|
| 386 |
+
|
| 387 |
+
# Configure for image-to-mesh processing
|
| 388 |
+
config.update({
|
| 389 |
+
'mesh': source_mesh_file,
|
| 390 |
+
'image_prompt': target_mesh_image_path,
|
| 391 |
+
'base_image_prompt': target_mesh_image_path, # Use same image as base
|
| 392 |
+
'use_target_mesh': True,
|
| 393 |
+
'fashion_image': True,
|
| 394 |
+
'fashion_text': False,
|
| 395 |
+
})
|
| 396 |
+
|
| 397 |
+
except Exception as e:
|
| 398 |
+
print(f"Error processing image: {e}")
|
| 399 |
+
return f"Error: Failed to process the uploaded image: {str(e)}"
|
| 400 |
|
| 401 |
else:
|
| 402 |
# Text-based processing
|
|
|
|
| 563 |
""")
|
| 564 |
|
| 565 |
with gr.Row():
|
| 566 |
+
with gr.Column(scale=1):
|
| 567 |
# Input type selector
|
| 568 |
input_type = gr.Radio(
|
| 569 |
choices=["Text", "Image to Mesh"],
|
| 570 |
value="Text",
|
| 571 |
+
label="Generation Method",
|
| 572 |
+
interactive=True
|
| 573 |
)
|
| 574 |
|
| 575 |
# Text inputs (visible by default)
|
|
|
|
| 586 |
value="simple t-shirt"
|
| 587 |
)
|
| 588 |
|
|
|
|
|
|
|
| 589 |
# Image to Mesh inputs (hidden by default)
|
| 590 |
with gr.Group(visible=False) as image_to_mesh_group:
|
| 591 |
+
gr.Markdown("### 📸 Upload Garment Image")
|
| 592 |
mesh_target_image = gr.Image(
|
| 593 |
label="Target Garment Image for Mesh Generation",
|
| 594 |
+
sources=["upload", "clipboard", "webcam"],
|
| 595 |
type="numpy",
|
| 596 |
interactive=True,
|
| 597 |
+
height=300,
|
| 598 |
+
show_label=True
|
| 599 |
)
|
| 600 |
gr.Markdown("*Upload an image of the garment to convert directly to a 3D mesh*")
|
| 601 |
|
| 602 |
+
gr.Markdown("### 🎯 Select Base Mesh Type")
|
| 603 |
source_mesh_type = gr.Dropdown(
|
| 604 |
label="Source Mesh Type",
|
| 605 |
choices=["tshirt", "longsleeve", "tanktop", "poncho", "dress_shortsleeve"],
|
|
|
|
| 681 |
|
| 682 |
# Define a function to handle mode changes with clearer UI feedback
|
| 683 |
def update_mode(mode):
|
| 684 |
+
print(f"Mode changed to: {mode}")
|
| 685 |
text_visibility = mode == "Text"
|
| 686 |
image_to_mesh_visibility = mode == "Image to Mesh"
|
| 687 |
status_msg = f"Mode changed to {mode}. "
|
|
|
|
| 690 |
status_msg += "Enter garment descriptions and click Generate."
|
| 691 |
else:
|
| 692 |
status_msg += "Upload a garment image and select mesh type, then click Generate."
|
| 693 |
+
|
| 694 |
+
print(f"Text visibility: {text_visibility}, Image to Mesh visibility: {image_to_mesh_visibility}")
|
| 695 |
|
| 696 |
return (
|
| 697 |
gr.update(visible=text_visibility),
|
|
|
|
| 719 |
input_type.change(
|
| 720 |
fn=update_mode,
|
| 721 |
inputs=[input_type],
|
| 722 |
+
outputs=[text_group, image_to_mesh_group, status_output],
|
| 723 |
+
show_progress=True
|
| 724 |
)
|
| 725 |
|
| 726 |
# Connect the button to the processing function with error handling
|
|
|
|
| 764 |
server_name="0.0.0.0",
|
| 765 |
server_port=7860,
|
| 766 |
show_error=True,
|
| 767 |
+
quiet=False,
|
| 768 |
+
debug=True
|
| 769 |
)
|
| 770 |
except Exception as e:
|
| 771 |
print(f"Error launching interface: {e}")
|