#!/usr/bin/env python3 # /// script # dependencies = [ # "gradio>=5.0.0", # "pillow>=10.0.0", # "torch>=2.0.0", # "numpy>=1.24.0", # "spaces>=0.28.0", # "teamwork[diffusers]", # "hf-transfer>=0.1.9", # ] # [tool.uv.sources] # teamwork = { git = "https://github.com/samsartor/teamwork" } # /// import gradio as gr import spaces from teamwork.pipelines import TeamworkPipeline from PIL import Image import torch import numpy as np pipe = TeamworkPipeline.from_checkpoint( "samsartor/teamwork-release", "decomposition_heterogeneous_sd3.safetensors", ).to("cuda") # The outputs you asked for, in display order OUTPUT_KEYS = [ "diffuse", "specular", "roughness", "normals", "depth", "albedo", "inverseshading", "diffuseshading", "residual", ] def _to_pil(x): """Convert pipeline output (PIL / numpy / torch) into a PIL.Image.""" if x is None: return None if isinstance(x, Image.Image): return x if torch.is_tensor(x): x = x.detach().float().cpu().numpy() if isinstance(x, np.ndarray): arr = x # handle CHW -> HWC if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): arr = np.transpose(arr, (1, 2, 0)) # If float in [0..1] (common), scale to [0..255] if arr.dtype != np.uint8: arr = np.clip(arr, 0.0, 1.0) arr = (arr * 255.0).round().astype(np.uint8) # If single channel, make it 'L' if arr.ndim == 2: return Image.fromarray(arr, mode="L") # If HWC with 1 channel, squeeze to L if arr.ndim == 3 and arr.shape[2] == 1: return Image.fromarray(arr[:, :, 0], mode="L") # Otherwise RGB/RGBA return Image.fromarray(arr) # Fallback: try to coerce try: return Image.fromarray(x) except Exception: return None @spaces.GPU(duration=120) def process_image(input_image, num_inference_steps=28, guidance_scale=1): if input_image is None: return [], "Please upload an image." if not isinstance(input_image, Image.Image): input_image = Image.fromarray(input_image) with torch.no_grad(): result = pipe( {"image": input_image}, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, ) # Build a labeled gallery: [(image, "label"), ...] gallery_items = [] missing = [] for k in OUTPUT_KEYS: if k in result and result[k] is not None: img = _to_pil(result[k]) if img is not None: gallery_items.append((img, k)) else: missing.append(k) else: missing.append(k) status = "Successfully processed!" if missing: status += f" Missing/unreadable keys: {', '.join(missing)}" return gallery_items, status with gr.Blocks(title="Teamwork: Collaborative Diffusion") as demo: gr.Markdown( """ # 🎨 Teamwork: Collaborative Diffusion Upload an image to process it using the Teamwork model for intrinsic decomposition. [Paper](https://bin.samsartor.com/teamwork.pdf) | [Code](https://github.com/samsartor/teamwork) | [Models](https://huggingface.co/samsartor/teamwork-release) """ ) with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type="pil") num_steps = gr.Slider(1, 50, value=28, step=1, label="Inference Steps") guidance = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="Guidance Scale") submit_btn = gr.Button("Process Image", variant="primary") with gr.Column(): outputs_gallery = gr.Gallery( label="Outputs", columns=3, height="auto", preview=True, show_label=True, ) status_text = gr.Textbox(label="Status", interactive=False) submit_btn.click( fn=process_image, inputs=[input_image, num_steps, guidance], outputs=[outputs_gallery, status_text], ) gr.Examples( examples=[ ["https://raw.githubusercontent.com/samsartor/teamwork/refs/heads/main/demo/red_glass_sphere.png"], ["https://raw.githubusercontent.com/samsartor/teamwork/refs/heads/main/demo/gift_card.png"], ["https://raw.githubusercontent.com/samsartor/teamwork/refs/heads/main/demo/pink_input.png"], ], inputs=input_image, label="Example Images", ) demo.launch()