File size: 4,134 Bytes
24a4bd7
 
 
 
 
 
03e7772
8fd199a
9549eae
03e7772
8fd199a
24a4bd7
 
 
 
 
03e7772
 
24a4bd7
8fd199a
 
03e7772
24a4bd7
 
03e7772
24a4bd7
 
03e7772
 
 
 
 
 
 
 
 
 
 
 
3305826
03e7772
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44cda22
03e7772
 
 
 
 
 
 
 
8fd199a
 
03e7772
 
 
 
24a4bd7
03e7772
24a4bd7
 
 
03e7772
24a4bd7
03e7772
24a4bd7
 
 
 
 
 
 
8fd199a
 
24a4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import gradio as gr
import sqlite3
import threading
import time
import os
import shutil
from gradio_client import Client, handle_file
# import spaces

# Database setup
conn = sqlite3.connect('/tmp/jobs.db', check_same_thread=False)
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS jobs
             (id INTEGER PRIMARY KEY, image_path TEXT, job_id TEXT, status TEXT, output_path TEXT)''')
conn.commit()

# TRELLIS API client
trellis_client = Client("jkorstad/TRELLIS")


# @spaces.GPU
# Processing logic with three-step TRELLIS workflow
def process_job(job_id):
    try:
        # Get the uploaded image path
        c.execute("SELECT image_path FROM jobs WHERE id=?", (job_id,))
        image_path = c.fetchone()[0]
        
        # Step 1: Preprocess the image
        c.execute("UPDATE jobs SET status='preprocessing' WHERE id=?", (job_id,))
        conn.commit()
        preprocessed_image = trellis_client.predict(
            image=handle_file(image_path),
            api_name="/preprocess_image"
        )
        
        # Step 2: Generate 3D asset
        c.execute("UPDATE jobs SET status='generating' WHERE id=?", (job_id,))
        conn.commit()
        time.sleep(30)  # Wait between steps; adjust based on observed timing
        result_3d = trellis_client.predict(
            image=handle_file(preprocessed_image),  # Use preprocessed image
            multiimages=[],
            seed=0,  # Default; could make configurable
            ss_guidance_strength=7.5,
            ss_sampling_steps=12,
            slat_guidance_strength=3,
            slat_sampling_steps=12,
            multiimage_algo="stochastic",
            api_name="/image_to_3d"
        )
        video_path = result_3d['video']  # Extract video filepath from dict
        
        # Step 3: Extract GLB
        c.execute("UPDATE jobs SET status='extracting' WHERE id=?", (job_id,))
        conn.commit()
        time.sleep(65)  # Wait for 3D processing; adjust as needed
        glb_result = trellis_client.predict(
            mesh_simplify=0.95,
            texture_size=1024,
            api_name="/extract_glb"
        )
        glb_path = glb_result[0]  # First element is the GLB filepath
        
        # Move GLB to persistent storage
        output_path = f'/tmp/outputs/result_{job_id}.glb'
        os.makedirs('/tmp/outputs', exist_ok=True)
        shutil.move(glb_path, output_path)
        
        # Update job status
        c.execute("UPDATE jobs SET status='completed', output_path=? WHERE id=?", (output_path, job_id))
        conn.commit()
        
    except Exception as e:
        c.execute("UPDATE jobs SET status='failed' WHERE id=?", (job_id,))
        conn.commit()
        print(f"Error processing job {job_id}: {e}")

# Gradio interface
def submit_images(files):
    if not files:
        return "No files uploaded."
    for file in files:
        c.execute("INSERT INTO jobs (status) VALUES ('submitted')")
        job_id = c.lastrowid
        conn.commit()
        image_path = f'/tmp/inputs/input_{job_id}.jpg'
        os.makedirs('/tmp/inputs', exist_ok=True)
        shutil.copy(file.name, image_path)
        c.execute("UPDATE jobs SET image_path=? WHERE id=?", (image_path, job_id))
        conn.commit()
        threading.Thread(target=process_job, args=(job_id,), daemon=True).start()
    return "Jobs submitted. Check the status tab."

def get_status():
    c.execute("SELECT id, image_path, status, output_path FROM jobs")
    return c.fetchall()

with gr.Blocks(title="TRELLIS 3D Generator") as demo:
    with gr.Tab("Upload"):
        files_input = gr.File(file_count="multiple", label="Upload Images (JPG/PNG)")
        submit_btn = gr.Button("Submit")
        output_msg = gr.Textbox(label="Message")
        submit_btn.click(fn=submit_images, inputs=files_input, outputs=output_msg)
    with gr.Tab("Status"):
        status_table = gr.DataFrame(
            headers=["ID", "Image Path", "Status", "Output Path"],
            label="Job Status"
        )
        refresh_btn = gr.Button("Refresh")
        refresh_btn.click(fn=get_status, inputs=None, outputs=status_table)

demo.launch()