Spaces:
Running
on
A10G
Running
on
A10G
Refactor code for improved performance and readability
Browse files- .gitignore +2 -0
- TripoSR +1 -0
- app.py +13 -0
- imgGen.py +28 -0
- requirements.txt +12 -0
- worker.py +86 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
output.png
|
TripoSR
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 8e51fec8095c9eae20e6ea7c9aef6368c5631a21
|
app.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from worker import worker # import the worker function
|
| 3 |
+
|
| 4 |
+
def greet(name):
|
| 5 |
+
return "Hello " + name + "!!"
|
| 6 |
+
|
| 7 |
+
def kickoff_worker():
|
| 8 |
+
worker() # call the worker function
|
| 9 |
+
|
| 10 |
+
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
| 11 |
+
iface.launch()
|
| 12 |
+
print("Launching worker...")
|
| 13 |
+
kickoff_worker() # kickoff the worker after launching the interface
|
imgGen.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 8 |
+
repo = "ByteDance/SDXL-Lightning"
|
| 9 |
+
ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Load model.
|
| 13 |
+
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
|
| 14 |
+
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
|
| 15 |
+
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
|
| 16 |
+
|
| 17 |
+
# Ensure sampler uses "trailing" timesteps.
|
| 18 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
| 19 |
+
|
| 20 |
+
def generateTransparentImage(text):
|
| 21 |
+
# Ensure using the same inference steps as the loaded model and CFG set to 0.
|
| 22 |
+
image = pipe(text+', full body, transparent background', num_inference_steps=4, guidance_scale=0).images[0]
|
| 23 |
+
return image
|
| 24 |
+
|
| 25 |
+
if __name__ == "__main__":
|
| 26 |
+
text = "a cat"
|
| 27 |
+
img = generateTransparentImage(text)
|
| 28 |
+
img.save("output.png")
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
omegaconf==2.3.0
|
| 2 |
+
Pillow==10.1.0
|
| 3 |
+
einops==0.7.0
|
| 4 |
+
git+https://github.com/tatsy/torchmcubes.git
|
| 5 |
+
diffusers["torch"]
|
| 6 |
+
transformers==4.35.0
|
| 7 |
+
trimesh==4.0.5
|
| 8 |
+
rembg
|
| 9 |
+
huggingface-hub
|
| 10 |
+
imageio[ffmpeg]
|
| 11 |
+
setuptools --upgrade
|
| 12 |
+
torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
worker.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
from supabase import create_client, Client
|
| 3 |
+
from imgGen import generateTransparentImage
|
| 4 |
+
import sys
|
| 5 |
+
sys.path.append('./TripoSR')
|
| 6 |
+
|
| 7 |
+
import TripoSR.obj_gen as obj_gen
|
| 8 |
+
import os
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
import time
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
url: str = os.environ.get("SUPABASE_URL")
|
| 15 |
+
key: str = os.environ.get("SUPABASE_KEY")
|
| 16 |
+
supabase: Client = create_client(url, key)
|
| 17 |
+
|
| 18 |
+
def check_queue():
|
| 19 |
+
try:
|
| 20 |
+
tasks = supabase.table("Tasks").select("*").eq("status", "pending").execute()
|
| 21 |
+
assert len(tasks.data) > 0
|
| 22 |
+
if len(tasks.data) > 0:
|
| 23 |
+
return tasks.data[0]
|
| 24 |
+
else:
|
| 25 |
+
return None
|
| 26 |
+
except Exception as e:
|
| 27 |
+
print(f"Error checking queue: {e}")
|
| 28 |
+
return None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def generate_image(text):
|
| 32 |
+
try:
|
| 33 |
+
img = generateTransparentImage(text)
|
| 34 |
+
return img
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"Error generating image: {e}")
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def create_obj_file(img, task_id):
|
| 41 |
+
try:
|
| 42 |
+
obj_gen.generate_obj_from_image(img, 'task_'+str(task_id)+'.obj')
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f"Error creating obj file: {e}")
|
| 45 |
+
supabase.table("Tasks").update({"status": "error"}).eq("id", task_id).execute()
|
| 46 |
+
|
| 47 |
+
def send_back_to_supabase(task_id):
|
| 48 |
+
# check that a file was created
|
| 49 |
+
if os.path.exists('task_'+str(task_id)+'.obj'):
|
| 50 |
+
try:
|
| 51 |
+
with open('task_'+str(task_id)+'.obj', 'rb') as file:
|
| 52 |
+
data = file.read()
|
| 53 |
+
supabase.storage.from_('Results').upload('task_'+str(task_id)+'.obj', data)
|
| 54 |
+
public_url = supabase.storage.from_('Results').get_public_url('task_'+str(task_id)+'.obj')
|
| 55 |
+
supabase.table("Tasks").update({"status": "complete","result":public_url}).eq("id", task_id).execute()
|
| 56 |
+
os.remove('task_'+str(task_id)+'.obj')
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"Error sending file back to Supabase: {e}")
|
| 59 |
+
supabase.table("Tasks").update({"status": "error"}).eq("id", task_id).execute()
|
| 60 |
+
|
| 61 |
+
else:
|
| 62 |
+
print(f"Error: No file was created for task {task_id}")
|
| 63 |
+
|
| 64 |
+
def worker():
|
| 65 |
+
while True:
|
| 66 |
+
task = check_queue()
|
| 67 |
+
if task:
|
| 68 |
+
supabase.table("Tasks").update({"status": "processing"}).eq("id", task['id']).execute()
|
| 69 |
+
print(f"Processing task {task['id']}")
|
| 70 |
+
img = generate_image(task["text"])
|
| 71 |
+
if img:
|
| 72 |
+
print(f"Image generated for task {task['id']}")
|
| 73 |
+
create_obj_file(img,task["id"])
|
| 74 |
+
send_back_to_supabase(task["id"])
|
| 75 |
+
print(f"Task {task['id']} completed")
|
| 76 |
+
else:
|
| 77 |
+
print(f"Error generating image for task {task['id']}")
|
| 78 |
+
supabase.table("Tasks").update({"status": "error"}).eq("id", task['id']).execute()
|
| 79 |
+
|
| 80 |
+
else:
|
| 81 |
+
print("No pending tasks in the queue")
|
| 82 |
+
|
| 83 |
+
time.sleep(2) # Add a 2 second delay between checks
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
worker()
|