import gradio as gr import sqlite3 import threading import time import os import shutil from gradio_client import Client, handle_file # import spaces # Database setup conn = sqlite3.connect('/tmp/jobs.db', check_same_thread=False) c = conn.cursor() c.execute('''CREATE TABLE IF NOT EXISTS jobs (id INTEGER PRIMARY KEY, image_path TEXT, job_id TEXT, status TEXT, output_path TEXT)''') conn.commit() # TRELLIS API client trellis_client = Client("jkorstad/TRELLIS") # @spaces.GPU # Processing logic with three-step TRELLIS workflow def process_job(job_id): try: # Get the uploaded image path c.execute("SELECT image_path FROM jobs WHERE id=?", (job_id,)) image_path = c.fetchone()[0] # Step 1: Preprocess the image c.execute("UPDATE jobs SET status='preprocessing' WHERE id=?", (job_id,)) conn.commit() preprocessed_image = trellis_client.predict( image=handle_file(image_path), api_name="/preprocess_image" ) # Step 2: Generate 3D asset c.execute("UPDATE jobs SET status='generating' WHERE id=?", (job_id,)) conn.commit() time.sleep(30) # Wait between steps; adjust based on observed timing result_3d = trellis_client.predict( image=handle_file(preprocessed_image), # Use preprocessed image multiimages=[], seed=0, # Default; could make configurable ss_guidance_strength=7.5, ss_sampling_steps=12, slat_guidance_strength=3, slat_sampling_steps=12, multiimage_algo="stochastic", api_name="/image_to_3d" ) video_path = result_3d['video'] # Extract video filepath from dict # Step 3: Extract GLB c.execute("UPDATE jobs SET status='extracting' WHERE id=?", (job_id,)) conn.commit() time.sleep(65) # Wait for 3D processing; adjust as needed glb_result = trellis_client.predict( mesh_simplify=0.95, texture_size=1024, api_name="/extract_glb" ) glb_path = glb_result[0] # First element is the GLB filepath # Move GLB to persistent storage output_path = f'/tmp/outputs/result_{job_id}.glb' os.makedirs('/tmp/outputs', exist_ok=True) shutil.move(glb_path, output_path) # Update job status c.execute("UPDATE jobs SET status='completed', output_path=? WHERE id=?", (output_path, job_id)) conn.commit() except Exception as e: c.execute("UPDATE jobs SET status='failed' WHERE id=?", (job_id,)) conn.commit() print(f"Error processing job {job_id}: {e}") # Gradio interface def submit_images(files): if not files: return "No files uploaded." for file in files: c.execute("INSERT INTO jobs (status) VALUES ('submitted')") job_id = c.lastrowid conn.commit() image_path = f'/tmp/inputs/input_{job_id}.jpg' os.makedirs('/tmp/inputs', exist_ok=True) shutil.copy(file.name, image_path) c.execute("UPDATE jobs SET image_path=? WHERE id=?", (image_path, job_id)) conn.commit() threading.Thread(target=process_job, args=(job_id,), daemon=True).start() return "Jobs submitted. Check the status tab." def get_status(): c.execute("SELECT id, image_path, status, output_path FROM jobs") return c.fetchall() with gr.Blocks(title="TRELLIS 3D Generator") as demo: with gr.Tab("Upload"): files_input = gr.File(file_count="multiple", label="Upload Images (JPG/PNG)") submit_btn = gr.Button("Submit") output_msg = gr.Textbox(label="Message") submit_btn.click(fn=submit_images, inputs=files_input, outputs=output_msg) with gr.Tab("Status"): status_table = gr.DataFrame( headers=["ID", "Image Path", "Status", "Output Path"], label="Job Status" ) refresh_btn = gr.Button("Refresh") refresh_btn.click(fn=get_status, inputs=None, outputs=status_table) demo.launch()