jkorstad commited on
Commit
3b02325
·
verified ·
1 Parent(s): 9a270a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -99
app.py CHANGED
@@ -1,113 +1,86 @@
1
  import gradio as gr
2
- import sqlite3
3
- import threading
4
- import time
5
  import os
6
  import shutil
7
  from gradio_client import Client, handle_file
8
- # import spaces
9
 
10
- # Database setup
11
- conn = sqlite3.connect('/tmp/jobs.db', check_same_thread=False)
12
- c = conn.cursor()
13
- c.execute('''CREATE TABLE IF NOT EXISTS jobs
14
- (id INTEGER PRIMARY KEY, image_path TEXT, job_id TEXT, status TEXT, output_path TEXT)''')
15
- conn.commit()
16
 
17
- # TRELLIS API client
18
- trellis_client = Client("jkorstad/TRELLIS")
 
 
 
 
 
 
 
 
19
 
 
 
 
 
 
 
20
 
21
- # @spaces.GPU
22
- # Processing logic with three-step TRELLIS workflow
23
- def process_job(job_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  try:
25
- # Get the uploaded image path
26
- c.execute("SELECT image_path FROM jobs WHERE id=?", (job_id,))
27
- image_path = c.fetchone()[0]
28
-
29
- # Step 1: Preprocess the image
30
- c.execute("UPDATE jobs SET status='preprocessing' WHERE id=?", (job_id,))
31
- conn.commit()
32
- preprocessed_image = trellis_client.predict(
33
- image=handle_file(image_path),
34
- api_name="/preprocess_image"
35
- )
36
-
37
- # Step 2: Generate 3D asset
38
- c.execute("UPDATE jobs SET status='generating' WHERE id=?", (job_id,))
39
- conn.commit()
40
- time.sleep(30) # Wait between steps; adjust based on observed timing
41
- result_3d = trellis_client.predict(
42
- image=handle_file(preprocessed_image), # Use preprocessed image
43
- multiimages=[],
44
- seed=0, # Default; could make configurable
45
- ss_guidance_strength=7.5,
46
- ss_sampling_steps=12,
47
- slat_guidance_strength=3,
48
- slat_sampling_steps=12,
49
- multiimage_algo="stochastic",
50
- api_name="/image_to_3d"
51
- )
52
- video_path = result_3d['video'] # Extract video filepath from dict
53
-
54
- # Step 3: Extract GLB
55
- c.execute("UPDATE jobs SET status='extracting' WHERE id=?", (job_id,))
56
- conn.commit()
57
- time.sleep(65) # Wait for 3D processing; adjust as needed
58
- glb_result = trellis_client.predict(
59
- mesh_simplify=0.95,
60
- texture_size=1024,
61
- api_name="/extract_glb"
62
- )
63
- glb_path = glb_result[0] # First element is the GLB filepath
64
-
65
- # Move GLB to persistent storage
66
- output_path = f'/tmp/outputs/result_{job_id}.glb'
67
- os.makedirs('/tmp/outputs', exist_ok=True)
68
- shutil.move(glb_path, output_path)
69
-
70
- # Update job status
71
- c.execute("UPDATE jobs SET status='completed', output_path=? WHERE id=?", (output_path, job_id))
72
- conn.commit()
73
-
74
  except Exception as e:
75
- c.execute("UPDATE jobs SET status='failed' WHERE id=?", (job_id,))
76
- conn.commit()
77
- print(f"Error processing job {job_id}: {e}")
78
 
79
- # Gradio interface
80
- def submit_images(files):
81
- if not files:
82
- return "No files uploaded."
83
- for file in files:
84
- c.execute("INSERT INTO jobs (status) VALUES ('submitted')")
85
- job_id = c.lastrowid
86
- conn.commit()
87
- image_path = f'/tmp/inputs/input_{job_id}.jpg'
88
- os.makedirs('/tmp/inputs', exist_ok=True)
89
- shutil.copy(file.name, image_path)
90
- c.execute("UPDATE jobs SET image_path=? WHERE id=?", (image_path, job_id))
91
- conn.commit()
92
- threading.Thread(target=process_job, args=(job_id,), daemon=True).start()
93
- return "Jobs submitted. Check the status tab."
94
 
95
- def get_status():
96
- c.execute("SELECT id, image_path, status, output_path FROM jobs")
97
- return c.fetchall()
 
98
 
99
- with gr.Blocks(title="TRELLIS 3D Generator") as demo:
100
- with gr.Tab("Upload"):
101
- files_input = gr.File(file_count="multiple", label="Upload Images (JPG/PNG)")
102
- submit_btn = gr.Button("Submit")
103
- output_msg = gr.Textbox(label="Message")
104
- submit_btn.click(fn=submit_images, inputs=files_input, outputs=output_msg)
105
- with gr.Tab("Status"):
106
- status_table = gr.DataFrame(
107
- headers=["ID", "Image Path", "Status", "Output Path"],
108
- label="Job Status"
109
- )
110
- refresh_btn = gr.Button("Refresh")
111
- refresh_btn.click(fn=get_status, inputs=None, outputs=status_table)
112
 
113
- demo.launch()
 
 
1
  import gradio as gr
 
 
 
2
  import os
3
  import shutil
4
  from gradio_client import Client, handle_file
5
+ from smolagents import Tool, CodeAgent, HfApiModel
6
 
7
+ # import spaces - if using ZeroGPU
 
 
 
 
 
8
 
9
+ # Define tools from Spaces
10
+ spaces = [
11
+ {"repo_id": "black-forest-labs/FLUX.1-schnell",
12
+ "name": "image_generator",
13
+ "description": "Generate an image from a prompt"},
14
+
15
+ {"repo_id": "jamesliu1217/EasyControl_Ghibli",
16
+ "name": "Ghibli_style_Image_control",
17
+ "description": "Create Ghibli style image"},
18
+ ]
19
 
20
+ tools = []
21
+ for space in spaces:
22
+ # Access repo_id, name, and description
23
+ repo_id = space['repo_id']
24
+ name = space.get('name', repo_id) # Use repo_id as name if not specified
25
+ description = space.get('description', '') # Use empty string if not specified
26
 
27
+ # Create Tool instance
28
+ tool = Tool.from_space(repo_id, name=name, description=description)
29
+ tools.append(tool)
30
+
31
+ # Define a custom tool
32
+ class CustomTool(Tool):
33
+ name = "custom_tool"
34
+ description = "A custom tool that processes input text"
35
+ inputs = {"input": {"type": "string", "description": "Some input text to process"}}
36
+ output_type = "string"
37
+ def forward(self, input: str):
38
+ return f"Processed: {input}"
39
+
40
+ tools.append(CustomTool())
41
+
42
+
43
+ # Initialize the model
44
+ model = HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct")
45
+
46
+ # Create the agent
47
+ agent = CodeAgent(tools=tools, model=model)
48
+
49
+ # Function to run the agent and return the image path
50
+ def generate_and_transform(prompt):
51
+ result = agent.run(prompt)
52
+
53
+ if isinstance(result, str): # Assuming result is a file path
54
+ # Copy the temporary file to a permanent location
55
+ permanent_path = "ghibli_output.webp"
56
+ shutil.copy(result, permanent_path)
57
+ return permanent_path
58
+ else:
59
+ raise ValueError("Unexpected result type from agent")
60
+
61
+ # Gradio interface function
62
+ def gradio_interface(prompt):
63
  try:
64
+ image_path = generate_and_transform(prompt)
65
+ return image_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  except Exception as e:
67
+ return str(e)
 
 
68
 
69
+ # Create the Gradio app
70
+ with gr.Blocks() as app:
71
+ gr.Markdown("### Smolagent Image Generator with Ghibli Style")
72
+ with gr.Row():
73
+ prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., Generate an image of a dog and then make a Ghibli style version of that image")
74
+ submit_button = gr.Button("Generate")
75
+ output_image = gr.Image(label="Generated Image")
76
+ download_button = gr.File(label="Download Image")
 
 
 
 
 
 
 
77
 
78
+ # Connect the button to the function
79
+ def on_submit(prompt):
80
+ image_path = gradio_interface(prompt)
81
+ return image_path, image_path
82
 
83
+ submit_button.click(on_submit, inputs=prompt_input, outputs=[output_image, download_button])
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ # Launch the app
86
+ app.launch()