File size: 3,772 Bytes
24a4bd7
 
 
 
 
 
 
 
9549eae
 
 
 
24a4bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import sqlite3
import threading
import time
import requests
import os
import shutil


# Ensure /data/ directory exists before connecting to the database
os.makedirs('/data', exist_ok=True)

# Use persistent storage
conn = sqlite3.connect('/data/jobs.db', check_same_thread=False)
c = conn.cursor()
c.execute('''CREATE TABLE IF NOT EXISTS jobs
             (id INTEGER PRIMARY KEY, image_path TEXT, job_id TEXT, status TEXT, output_path TEXT)''')
conn.commit()

# API functions (adjust endpoints per TRELLIS API docs)
def upload_image(image_path):
    with open(image_path, 'rb') as f:
        response = requests.post(
            'https://huggingface.co/spaces/jkorstad/TRELLIS/api/upload',
            files={'file': f}
        )
        response.raise_for_status()
    return response.json()['job_id']

def check_status(job_id):
    response = requests.get(f'https://huggingface.co/spaces/jkorstad/TRELLIS/api/status/{job_id}')
    response.raise_for_status()
    return response.json()['status']

def get_result(job_id):
    response = requests.get(f'https://huggingface.co/spaces/jkorstad/TRELLIS/api/result/{job_id}')
    response.raise_for_status()
    output_path = f'/data/outputs/result_{job_id}.glb'
    os.makedirs('/data/outputs', exist_ok=True)
    with open(output_path, 'wb') as f:
        f.write(response.content)
    return output_path

# Processing logic (same as before)
def process_job(job_id):
    try:
        c.execute("SELECT image_path FROM jobs WHERE id=?", (job_id,))
        image_path = c.fetchone()[0]
        api_job_id = upload_image(image_path)
        c.execute("UPDATE jobs SET job_id=?, status='processing' WHERE id=?", (api_job_id, job_id))
        conn.commit()
        while True:
            status = check_status(api_job_id)
            if status == 'completed':
                output_path = get_result(api_job_id)
                c.execute("UPDATE jobs SET status='completed', output_path=? WHERE id=?", (output_path, job_id))
                conn.commit()
                break
            elif status == 'failed':
                c.execute("UPDATE jobs SET status='failed' WHERE id=?", (job_id,))
                conn.commit()
                break
            time.sleep(5)
    except Exception as e:
        c.execute("UPDATE jobs SET status='failed' WHERE id=?", (job_id,))
        conn.commit()

# Gradio interface (same as before)
def submit_images(files):
    if not files:
        return "No files uploaded."
    for file in files:
        c.execute("INSERT INTO jobs (status) VALUES ('submitted')")
        job_id = c.lastrowid
        conn.commit()
        image_path = f'/data/inputs/input_{job_id}.jpg'
        os.makedirs('/data/inputs', exist_ok=True)
        shutil.copy(file.name, image_path)
        c.execute("UPDATE jobs SET image_path=? WHERE id=?", (image_path, job_id))
        conn.commit()
        threading.Thread(target=process_job, args=(job_id,), daemon=True).start()
    return "Jobs submitted. Check the status tab."

def get_status():
    c.execute("SELECT id, image_path, status, output_path FROM jobs")
    return c.fetchall()

with gr.Blocks(title="TRELLIS 3D Generator") as demo:
    with gr.Tab("Upload"):
        files_input = gr.File(file_count="multiple", label="Upload Images (JPG/PNG)")
        submit_btn = gr.Button("Submit")
        output_msg = gr.Textbox(label="Message")
        submit_btn.click(fn=submit_images, inputs=files_input, outputs=output_msg)
    with gr.Tab("Status"):
        status_table = gr.DataFrame(
            headers=["ID", "Image Path", "Status", "Output Path"],
            label="Job Status"
        )
        refresh_btn = gr.Button("Refresh")
        refresh_btn.click(fn=get_status, inputs=None, outputs=status_table)

demo.launch()