teamwork-demo / app.py
blanchon's picture
Update app.py
a696b33 verified
#!/usr/bin/env python3
# /// script
# dependencies = [
# "gradio>=5.0.0",
# "pillow>=10.0.0",
# "torch>=2.0.0",
# "numpy>=1.24.0",
# "spaces>=0.28.0",
# "teamwork[diffusers]",
# "hf-transfer>=0.1.9",
# ]
# [tool.uv.sources]
# teamwork = { git = "https://github.com/samsartor/teamwork" }
# ///
import gradio as gr
import spaces
from teamwork.pipelines import TeamworkPipeline
from PIL import Image
import torch
import numpy as np
pipe = TeamworkPipeline.from_checkpoint(
"samsartor/teamwork-release",
"decomposition_heterogeneous_sd3.safetensors",
).to("cuda")
# The outputs you asked for, in display order
OUTPUT_KEYS = [
"diffuse",
"specular",
"roughness",
"normals",
"depth",
"albedo",
"inverseshading",
"diffuseshading",
"residual",
]
def _to_pil(x):
"""Convert pipeline output (PIL / numpy / torch) into a PIL.Image."""
if x is None:
return None
if isinstance(x, Image.Image):
return x
if torch.is_tensor(x):
x = x.detach().float().cpu().numpy()
if isinstance(x, np.ndarray):
arr = x
# handle CHW -> HWC
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
# If float in [0..1] (common), scale to [0..255]
if arr.dtype != np.uint8:
arr = np.clip(arr, 0.0, 1.0)
arr = (arr * 255.0).round().astype(np.uint8)
# If single channel, make it 'L'
if arr.ndim == 2:
return Image.fromarray(arr, mode="L")
# If HWC with 1 channel, squeeze to L
if arr.ndim == 3 and arr.shape[2] == 1:
return Image.fromarray(arr[:, :, 0], mode="L")
# Otherwise RGB/RGBA
return Image.fromarray(arr)
# Fallback: try to coerce
try:
return Image.fromarray(x)
except Exception:
return None
@spaces.GPU(duration=120)
def process_image(input_image, num_inference_steps=28, guidance_scale=1):
if input_image is None:
return [], "Please upload an image."
if not isinstance(input_image, Image.Image):
input_image = Image.fromarray(input_image)
with torch.no_grad():
result = pipe(
{"image": input_image},
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
)
# Build a labeled gallery: [(image, "label"), ...]
gallery_items = []
missing = []
for k in OUTPUT_KEYS:
if k in result and result[k] is not None:
img = _to_pil(result[k])
if img is not None:
gallery_items.append((img, k))
else:
missing.append(k)
else:
missing.append(k)
status = "Successfully processed!"
if missing:
status += f" Missing/unreadable keys: {', '.join(missing)}"
return gallery_items, status
with gr.Blocks(title="Teamwork: Collaborative Diffusion") as demo:
gr.Markdown(
"""
# 🎨 Teamwork: Collaborative Diffusion
Upload an image to process it using the Teamwork model for intrinsic decomposition.
[Paper](https://bin.samsartor.com/teamwork.pdf) | [Code](https://github.com/samsartor/teamwork) | [Models](https://huggingface.co/samsartor/teamwork-release)
"""
)
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="pil")
num_steps = gr.Slider(1, 50, value=28, step=1, label="Inference Steps")
guidance = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="Guidance Scale")
submit_btn = gr.Button("Process Image", variant="primary")
with gr.Column():
outputs_gallery = gr.Gallery(
label="Outputs",
columns=3,
height="auto",
preview=True,
show_label=True,
)
status_text = gr.Textbox(label="Status", interactive=False)
submit_btn.click(
fn=process_image,
inputs=[input_image, num_steps, guidance],
outputs=[outputs_gallery, status_text],
)
gr.Examples(
examples=[
["https://raw.githubusercontent.com/samsartor/teamwork/refs/heads/main/demo/red_glass_sphere.png"],
["https://raw.githubusercontent.com/samsartor/teamwork/refs/heads/main/demo/gift_card.png"],
["https://raw.githubusercontent.com/samsartor/teamwork/refs/heads/main/demo/pink_input.png"],
],
inputs=input_image,
label="Example Images",
)
demo.launch()