Spaces:
Sleeping
Sleeping
| from flask import Flask, request, jsonify, send_file | |
| from flask_cors import CORS | |
| import torch | |
| import numpy as np | |
| import trimesh | |
| import os | |
| from io import BytesIO | |
| import base64 | |
| from PIL import Image | |
| import uuid | |
| import time | |
| import sys | |
| import gc # For explicit garbage collection | |
| import threading | |
| import queue | |
| import psutil | |
| # Set environment variables before anything else | |
| os.environ['SHAPEE_NO_INTERACTIVE'] = '1' | |
| # Setup cache directory with appropriate permissions | |
| cache_dir = os.path.join(os.getcwd(), 'shap_e_model_cache') | |
| os.makedirs(cache_dir, exist_ok=True) | |
| os.environ['XDG_CACHE_HOME'] = os.getcwd() | |
| print(f"Using cache directory: {cache_dir}") | |
| # Import Shap-E | |
| print("Importing Shap-E modules...") | |
| try: | |
| # Try the direct import approach first | |
| from shap_e.diffusion.sample import sample_latents | |
| from shap_e.diffusion.gaussian_diffusion import diffusion_from_config | |
| from shap_e.models.download import load_model, load_config | |
| from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh | |
| print("Shap-E modules imported successfully!") | |
| except ImportError as e: | |
| print(f"Error importing Shap-E modules: {e}") | |
| # Alternative approach if direct import fails | |
| try: | |
| print("Attempting alternative import approach...") | |
| # Try monkey patching the ipywidgets module if that's the issue | |
| import sys | |
| import types | |
| if 'ipywidgets' not in sys.modules: | |
| sys.modules['ipywidgets'] = types.ModuleType('ipywidgets') | |
| print("Added mock ipywidgets module") | |
| # Try imports again | |
| from shap_e.diffusion.sample import sample_latents | |
| from shap_e.diffusion.gaussian_diffusion import diffusion_from_config | |
| from shap_e.models.download import load_model, load_config | |
| from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh | |
| print("Shap-E modules imported successfully with workaround!") | |
| except Exception as e2: | |
| print(f"Alternative import also failed: {e2}") | |
| sys.exit(1) | |
| except Exception as e: | |
| print(f"Unexpected error importing Shap-E modules: {e}") | |
| sys.exit(1) | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Create output directory if it doesn't exist | |
| output_dir = os.path.join(os.getcwd(), "outputs") | |
| os.makedirs(output_dir, exist_ok=True) | |
| print(f"Output directory: {output_dir}") | |
| # Check permissions on directories | |
| try: | |
| test_file_path = os.path.join(cache_dir, "test_write_permissions.txt") | |
| with open(test_file_path, 'w') as f: | |
| f.write("Testing write permissions") | |
| os.remove(test_file_path) | |
| print("Cache directory is writable") | |
| except Exception as e: | |
| print(f"WARNING: Cache directory is not writable: {e}") | |
| try: | |
| test_file_path = os.path.join(output_dir, "test_write_permissions.txt") | |
| with open(test_file_path, 'w') as f: | |
| f.write("Testing write permissions") | |
| os.remove(test_file_path) | |
| print("Output directory is writable") | |
| except Exception as e: | |
| print(f"WARNING: Output directory is not writable: {e}") | |
| print("Setting up device...") | |
| device = torch.device('cpu') # Force CPU for Hugging Face Spaces | |
| print(f"Using device: {device}") | |
| # Global variables for models (will be loaded on first request) | |
| xm = None | |
| model = None | |
| diffusion = None | |
| # Job queue and results dictionary | |
| job_queue = queue.Queue() | |
| job_results = {} | |
| generation_thread = None | |
| is_thread_running = False | |
| # New global variables for optimizations | |
| last_usage_time = None | |
| active_jobs = 0 | |
| max_concurrent_jobs = 1 # Limit concurrent jobs for 2vCPU | |
| def get_adaptive_parameters(): | |
| """Adjust parameters based on current system resources""" | |
| mem = psutil.virtual_memory() | |
| # Base parameters - more conservative to prevent memory issues | |
| params = { | |
| 'karras_steps': 6, # Reduced from 8 to 6 as default | |
| 'batch_size': 1, | |
| 'guidance_scale': 15.0 | |
| } | |
| # If memory is tight, reduce steps further | |
| if mem.percent > 70: | |
| params['karras_steps'] = 4 # Even more conservative | |
| # If we have more memory to spare, can be slightly more generous | |
| if mem.percent < 50: | |
| params['karras_steps'] = 8 | |
| print(f"Adaptive parameters chosen: karras_steps={params['karras_steps']}, mem={mem.percent}%") | |
| return params | |
| def check_memory_pressure(): | |
| """Check if memory is getting too high and take action if needed""" | |
| mem = psutil.virtual_memory() | |
| if mem.percent > 80: # Reduced threshold from 85 to 80 | |
| print("WARNING: Memory pressure critical. Forcing garbage collection.") | |
| gc.collect() | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| # If still critical, try more aggressive measures | |
| if psutil.virtual_memory().percent > 75: | |
| print("EMERGENCY: Memory still critical. Clearing model cache.") | |
| # Reset global models to force reload when memory is better | |
| global xm, model, diffusion | |
| xm, model, diffusion = None, None, None | |
| gc.collect() | |
| return True | |
| return False | |
| def load_transmitter_model(): | |
| global xm, last_usage_time | |
| last_usage_time = time.time() | |
| if xm is None: | |
| print("Loading transmitter model...") | |
| xm = load_model('transmitter', device=device) | |
| print("Transmitter model loaded!") | |
| def load_primary_model(): | |
| global model, diffusion, last_usage_time | |
| last_usage_time = time.time() | |
| if model is None or diffusion is None: | |
| print("Loading primary models...") | |
| torch.set_default_dtype(torch.float32) # Use float32 instead of float64 | |
| model = load_model('text300M', device=device) | |
| diffusion = diffusion_from_config(load_config('diffusion')) | |
| print("Primary models loaded!") | |
| def load_models_if_needed(): | |
| """Legacy function for compatibility""" | |
| load_primary_model() | |
| load_transmitter_model() | |
| def model_unloader_thread(): | |
| """Thread that periodically unloads models if they haven't been used""" | |
| global xm, model, diffusion, last_usage_time | |
| while True: | |
| time.sleep(180) # Check more frequently: every 3 minutes instead of 5 | |
| if last_usage_time is not None: | |
| idle_time = time.time() - last_usage_time | |
| # If models have been idle for more than 5 minutes (reduced from 10) and no active jobs | |
| if idle_time > 300 and active_jobs == 0: | |
| # Check memory usage - more aggressive unloading | |
| mem = psutil.virtual_memory() | |
| if mem.percent > 40: # Lowered threshold from 50 to 40 | |
| print(f"Models idle for {idle_time:.1f} seconds and memory at {mem.percent}%. Unloading...") | |
| xm, model, diffusion = None, None, None | |
| gc.collect() | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| def save_trimesh(mesh, filename_base): | |
| """Save mesh in multiple formats using trimesh""" | |
| # Convert to trimesh format if needed | |
| if not isinstance(mesh, trimesh.Trimesh): | |
| try: | |
| # Try to convert to trimesh | |
| vertices = np.array(mesh.vertices) | |
| faces = np.array(mesh.faces) | |
| trimesh_obj = trimesh.Trimesh(vertices=vertices, faces=faces) | |
| except Exception as e: | |
| print(f"Error converting to trimesh: {e}") | |
| raise | |
| else: | |
| trimesh_obj = mesh | |
| # Save as GLB | |
| glb_path = f"{filename_base}.glb" | |
| try: | |
| trimesh_obj.export(glb_path, file_type='glb') | |
| print(f"Saved GLB file: {glb_path}") | |
| except Exception as e: | |
| print(f"Error saving GLB: {e}") | |
| # Try alternative approach | |
| try: | |
| scene = trimesh.Scene() | |
| scene.add_geometry(trimesh_obj) | |
| scene.export(glb_path) | |
| print(f"Saved GLB using scene approach: {glb_path}") | |
| except Exception as e2: | |
| print(f"Alternative GLB export also failed: {e2}") | |
| glb_path = None | |
| # Save as OBJ - always works more reliably | |
| obj_path = f"{filename_base}.obj" | |
| try: | |
| trimesh_obj.export(obj_path, file_type='obj') | |
| print(f"Saved OBJ file: {obj_path}") | |
| except Exception as e: | |
| print(f"Error saving OBJ: {e}") | |
| # Try to write directly | |
| try: | |
| with open(obj_path, 'w') as f: | |
| for v in trimesh_obj.vertices: | |
| f.write(f"v {v[0]} {v[1]} {v[2]}\n") | |
| for face in trimesh_obj.faces: | |
| f.write(f"f {face[0]+1} {face[1]+1} {face[2]+1}\n") | |
| print(f"Saved OBJ using direct write: {obj_path}") | |
| except Exception as e2: | |
| print(f"Alternative OBJ export also failed: {e2}") | |
| obj_path = None | |
| # Also save as PLY as a fallback | |
| ply_path = f"{filename_base}.ply" | |
| try: | |
| trimesh_obj.export(ply_path, file_type='ply') | |
| print(f"Saved PLY file: {ply_path}") | |
| except Exception as e: | |
| print(f"Error saving PLY: {e}") | |
| ply_path = None | |
| return { | |
| "glb": os.path.basename(glb_path) if glb_path else None, | |
| "obj": os.path.basename(obj_path) if obj_path else None, | |
| "ply": os.path.basename(ply_path) if ply_path else None | |
| } | |
| def process_job(job_id, prompt): | |
| try: | |
| # Get adaptive parameters | |
| adaptive_params = get_adaptive_parameters() | |
| karras_steps = adaptive_params['karras_steps'] | |
| batch_size = adaptive_params['batch_size'] | |
| guidance_scale = adaptive_params['guidance_scale'] | |
| # Load primary models for generation | |
| load_primary_model() | |
| # Optimization: Run garbage collection before starting intensive task | |
| gc.collect() | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| print(f"Starting latent generation for job {job_id} with {karras_steps} steps...") | |
| # Generate latents | |
| latents = None | |
| with torch.inference_mode(): | |
| latents = sample_latents( | |
| batch_size=batch_size, | |
| model=model, | |
| diffusion=diffusion, | |
| guidance_scale=guidance_scale, | |
| model_kwargs=dict(texts=[prompt] * batch_size), | |
| progress=True, | |
| clip_denoised=True, | |
| use_fp16=False, # CPU doesn't support fp16 | |
| use_karras=True, | |
| karras_steps=karras_steps, | |
| sigma_min=1e-3, | |
| sigma_max=160, | |
| s_churn=0, | |
| ) | |
| print(f"Latent generation complete for job {job_id}!") | |
| # Optimization: Clear unnecessary memory and check pressure | |
| check_memory_pressure() | |
| # Generate a unique filename | |
| unique_id = str(uuid.uuid4()) | |
| filename = f"{output_dir}/{unique_id}" | |
| # Load transmitter model for decoding | |
| load_transmitter_model() | |
| # Convert latent to mesh | |
| print(f"Decoding mesh for job {job_id}...") | |
| t0 = time.time() | |
| # Monitor memory | |
| mem_before = psutil.Process().memory_info().rss / (1024 * 1024) | |
| print(f"Memory before mesh decoding: {mem_before:.2f} MB") | |
| # Decode the mesh | |
| mesh = decode_latent_mesh(xm, latents[0]).tri_mesh() | |
| print(f"Mesh decoded in {time.time() - t0:.2f} seconds") | |
| mem_after = psutil.Process().memory_info().rss / (1024 * 1024) | |
| print(f"Memory after decoding: {mem_after:.2f} MB (delta: {mem_after - mem_before:.2f} MB)") | |
| # Report mesh complexity if possible | |
| try: | |
| vertices_count = len(mesh.vertices) | |
| faces_count = len(mesh.faces) | |
| print(f"Mesh complexity: {vertices_count} vertices, {faces_count} faces") | |
| except Exception as e: | |
| print(f"Could not determine mesh complexity: {e}") | |
| vertices_count = 0 | |
| faces_count = 0 | |
| # Clear latents from memory | |
| del latents | |
| gc.collect() | |
| # Convert to trimesh format and save files | |
| print(f"Converting and saving mesh for job {job_id}...") | |
| # Save mesh using the helper function | |
| saved_files = save_trimesh(mesh, filename) | |
| # Clear mesh from memory | |
| del mesh | |
| gc.collect() | |
| # Check which files were successfully saved | |
| result = { | |
| "success": True, | |
| "message": "3D model generated successfully", | |
| "timestamp": time.time(), | |
| "stats": { | |
| "vertices": vertices_count, | |
| "faces": faces_count | |
| } | |
| } | |
| # Add URLs for the files that were saved | |
| if saved_files["glb"]: | |
| result["glb_url"] = f"/download/{saved_files['glb']}" | |
| if saved_files["obj"]: | |
| result["obj_url"] = f"/download/{saved_files['obj']}" | |
| if saved_files["ply"]: | |
| result["ply_url"] = f"/download/{saved_files['ply']}" | |
| # If no files were saved, mark as failure | |
| if not (saved_files["glb"] or saved_files["obj"] or saved_files["ply"]): | |
| result["success"] = False | |
| result["message"] = "Failed to save mesh in any format" | |
| print(f"Files saved successfully for job {job_id}!") | |
| # Force garbage collection again | |
| gc.collect() | |
| return result | |
| except Exception as e: | |
| print(f"Error during generation for job {job_id}: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "timestamp": time.time() | |
| } | |
| def worker_thread(): | |
| global is_thread_running, active_jobs | |
| is_thread_running = True | |
| try: | |
| while True: | |
| try: | |
| # Get job from queue with a timeout | |
| job_id, prompt = job_queue.get(timeout=1) | |
| print(f"Processing job {job_id} with prompt: {prompt}") | |
| # Process the job | |
| result = process_job(job_id, prompt) | |
| # Store the result and update counter | |
| job_results[job_id] = result | |
| active_jobs -= 1 | |
| # Explicit cleanup after job | |
| gc.collect() | |
| except queue.Empty: | |
| # No jobs in queue, continue waiting | |
| pass | |
| except Exception as e: | |
| print(f"Error in worker thread: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| # If there was a job being processed, mark it as failed | |
| if 'job_id' in locals(): | |
| job_results[job_id] = { | |
| "success": False, | |
| "error": str(e), | |
| "timestamp": time.time() | |
| } | |
| active_jobs -= 1 | |
| # Force garbage collection to clean up | |
| gc.collect() | |
| finally: | |
| is_thread_running = False | |
| def purge_old_results_thread(): | |
| """Thread that periodically cleans up old job results to manage memory""" | |
| while True: | |
| try: | |
| time.sleep(1800) # Run every 30 minutes | |
| # Default threshold: 2 hours | |
| threshold_time = time.time() - (2 * 3600) | |
| # Track jobs to be removed | |
| jobs_to_remove = [] | |
| for job_id, result in job_results.items(): | |
| # If the job has a timestamp and it's older than threshold | |
| if result.get('timestamp', time.time()) < threshold_time: | |
| jobs_to_remove.append(job_id) | |
| # Remove the old jobs | |
| for job_id in jobs_to_remove: | |
| job_results.pop(job_id, None) | |
| if jobs_to_remove: | |
| print(f"Auto-purged {len(jobs_to_remove)} old job results") | |
| # Force garbage collection | |
| gc.collect() | |
| except Exception as e: | |
| print(f"Error in purge thread: {e}") | |
| def ensure_worker_thread_running(): | |
| global generation_thread, is_thread_running | |
| if generation_thread is None or not generation_thread.is_alive(): | |
| print("Starting worker thread...") | |
| generation_thread = threading.Thread(target=worker_thread, daemon=True) | |
| generation_thread.start() | |
| def start_monitoring_threads(): | |
| """Start all monitoring and maintenance threads""" | |
| # Start model unloader thread | |
| threading.Thread(target=model_unloader_thread, daemon=True).start() | |
| # Start results purge thread | |
| threading.Thread(target=purge_old_results_thread, daemon=True).start() | |
| def generate_3d(): | |
| global active_jobs | |
| # Check if we're already at max capacity | |
| if active_jobs >= max_concurrent_jobs: | |
| return jsonify({ | |
| "success": False, | |
| "error": "Server is at maximum capacity. Please try again later.", | |
| "retry_after": 300 | |
| }), 503 | |
| # Get the prompt from the request | |
| data = request.json | |
| if not data or 'prompt' not in data: | |
| return jsonify({"error": "No prompt provided"}), 400 | |
| prompt = data['prompt'] | |
| print(f"Received prompt: {prompt}") | |
| # Generate a job ID | |
| job_id = str(uuid.uuid4()) | |
| # Add job to queue | |
| ensure_worker_thread_running() | |
| job_queue.put((job_id, prompt)) | |
| active_jobs += 1 | |
| # Return job ID immediately | |
| return jsonify({ | |
| "success": True, | |
| "message": "Job submitted successfully", | |
| "job_id": job_id, | |
| "status_url": f"/status/{job_id}" | |
| }) | |
| def job_status(job_id): | |
| if job_id in job_results: | |
| result = job_results[job_id] | |
| # Return the result | |
| return jsonify(result) | |
| else: | |
| # Job is still in progress | |
| return jsonify({ | |
| "success": None, | |
| "message": "Job is still processing", | |
| "job_id": job_id | |
| }) | |
| def download_file(filename): | |
| try: | |
| file_path = os.path.join(output_dir, filename) | |
| if not os.path.exists(file_path): | |
| return jsonify({"error": "File not found"}), 404 | |
| return send_file(file_path, as_attachment=True) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def health_check(): | |
| """Enhanced health check endpoint to monitor resource usage""" | |
| try: | |
| # Memory info | |
| memory_info = psutil.virtual_memory() | |
| memory_usage = f"{memory_info.percent}% (Available: {memory_info.available / (1024**3):.2f} GB)" | |
| # CPU info | |
| cpu_usage = f"{psutil.cpu_percent(interval=0.1)}%" | |
| # Process specific info | |
| process = psutil.Process() | |
| process_memory = f"{process.memory_info().rss / (1024**3):.2f} GB" | |
| # Models status | |
| models_loaded = [] | |
| if model is not None: | |
| models_loaded.append("text300M") | |
| if diffusion is not None: | |
| models_loaded.append("diffusion") | |
| if xm is not None: | |
| models_loaded.append("transmitter") | |
| # Queue status | |
| queue_size = job_queue.qsize() | |
| # Check for model inactivity | |
| model_inactive = "N/A" | |
| if last_usage_time is not None: | |
| model_inactive = f"{(time.time() - last_usage_time) / 60:.1f} minutes" | |
| # Number of saved jobs | |
| saved_jobs = len(job_results) | |
| return jsonify({ | |
| "status": "ok", | |
| "message": "Service is running", | |
| "memory_usage": memory_usage, | |
| "process_memory": process_memory, | |
| "cpu_usage": cpu_usage, | |
| "queue_size": queue_size, | |
| "active_jobs": active_jobs, | |
| "saved_jobs": saved_jobs, | |
| "worker_running": is_thread_running, | |
| "models_loaded": models_loaded, | |
| "model_inactive_time": model_inactive | |
| }) | |
| except Exception as e: | |
| return jsonify({ | |
| "status": "warning", | |
| "error": str(e) | |
| }) | |
| def home(): | |
| """Landing page with usage instructions""" | |
| return """ | |
| <html> | |
| <head> | |
| <title>Text to 3D API</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; line-height: 1.6; margin: 0; padding: 20px; max-width: 800px; margin: 0 auto; } | |
| pre { background: #f4f4f4; padding: 15px; border-radius: 5px; overflow-x: auto; } | |
| code { background: #f4f4f4; padding: 2px 5px; border-radius: 3px; } | |
| h1, h2 { color: #333; } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Text to 3D API</h1> | |
| <p>This is an optimized API that converts text prompts to 3D models.</p> | |
| <h2>How to use:</h2> | |
| <h3>Step 1: Submit a generation job</h3> | |
| <pre> | |
| POST /generate | |
| Content-Type: application/json | |
| { | |
| "prompt": "A futuristic building" | |
| } | |
| </pre> | |
| <p>Response:</p> | |
| <pre> | |
| { | |
| "success": true, | |
| "message": "Job submitted successfully", | |
| "job_id": "123e4567-e89b-12d3-a456-426614174000", | |
| "status_url": "/status/123e4567-e89b-12d3-a456-426614174000" | |
| } | |
| </pre> | |
| <h3>Step 2: Check job status</h3> | |
| <pre> | |
| GET /status/123e4567-e89b-12d3-a456-426614174000 | |
| </pre> | |
| <p>Response (while processing):</p> | |
| <pre> | |
| { | |
| "success": null, | |
| "message": "Job is still processing", | |
| "job_id": "123e4567-e89b-12d3-a456-426614174000" | |
| } | |
| </pre> | |
| <p>Response (when complete):</p> | |
| <pre> | |
| { | |
| "success": true, | |
| "message": "3D model generated successfully", | |
| "glb_url": "/download/abc123.glb", | |
| "obj_url": "/download/abc123.obj", | |
| "ply_url": "/download/abc123.ply" | |
| } | |
| </pre> | |
| <h3>Step 3: Download the files</h3> | |
| <p>Use the provided URLs to download the GLB, OBJ, and PLY files.</p> | |
| <h2>Health Check:</h2> | |
| <pre>GET /health</pre> | |
| <p>Provides information about the service status and resource usage.</p> | |
| </body> | |
| </html> | |
| """ | |
| def purge_old_results(): | |
| """Endpoint to manually purge old job results to free memory""" | |
| try: | |
| # Get the time threshold from request (default to 1 hour) | |
| threshold_hours = request.json.get('threshold_hours', 1) if request.json else 1 | |
| threshold_time = time.time() - (threshold_hours * 3600) | |
| # Track jobs to be removed | |
| jobs_to_remove = [] | |
| for job_id, result in job_results.items(): | |
| # If the job has a timestamp and it's older than threshold | |
| if result.get('timestamp', time.time()) < threshold_time: | |
| jobs_to_remove.append(job_id) | |
| # Remove the old jobs | |
| for job_id in jobs_to_remove: | |
| job_results.pop(job_id, None) | |
| # Force garbage collection | |
| gc.collect() | |
| return jsonify({ | |
| "success": True, | |
| "message": f"Purged {len(jobs_to_remove)} old job results", | |
| "remaining_jobs": len(job_results) | |
| }) | |
| except Exception as e: | |
| return jsonify({ | |
| "success": False, | |
| "error": str(e) | |
| }), 500 | |
| def force_garbage_collection(): | |
| """Endpoint to manually trigger garbage collection""" | |
| try: | |
| # Get current memory usage | |
| before_mem = psutil.Process().memory_info().rss / (1024**3) | |
| # Force garbage collection | |
| gc.collect() | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| # Get memory usage after GC | |
| after_mem = psutil.Process().memory_info().rss / (1024**3) | |
| freed = before_mem - after_mem | |
| return jsonify({ | |
| "success": True, | |
| "message": f"Garbage collection completed", | |
| "before_memory_gb": round(before_mem, 2), | |
| "after_memory_gb": round(after_mem, 2), | |
| "freed_memory_gb": round(freed, 2) if freed > 0 else 0 | |
| }) | |
| except Exception as e: | |
| return jsonify({ | |
| "success": False, | |
| "error": str(e) | |
| }), 500 | |
| if __name__ == '__main__': | |
| # Start all monitoring threads | |
| start_monitoring_threads() | |
| # Start the worker thread | |
| ensure_worker_thread_running() | |
| # Recommended to run with gunicorn for production with increased timeout: | |
| # $ gunicorn app:app --bind 0.0.0.0:7860 --timeout 300 --workers 1 | |
| app.run(host='0.0.0.0', port=7860, debug=False) # Set debug=False in production |