Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| # /// script | |
| # dependencies = [ | |
| # "gradio>=5.0.0", | |
| # "pillow>=10.0.0", | |
| # "torch>=2.0.0", | |
| # "numpy>=1.24.0", | |
| # "spaces>=0.28.0", | |
| # "teamwork[diffusers]", | |
| # "hf-transfer>=0.1.9", | |
| # ] | |
| # [tool.uv.sources] | |
| # teamwork = { git = "https://github.com/samsartor/teamwork" } | |
| # /// | |
| import gradio as gr | |
| import spaces | |
| from teamwork.pipelines import TeamworkPipeline | |
| from PIL import Image | |
| import torch | |
| import numpy as np | |
| pipe = TeamworkPipeline.from_checkpoint( | |
| "samsartor/teamwork-release", | |
| "decomposition_heterogeneous_sd3.safetensors", | |
| ).to("cuda") | |
| # The outputs you asked for, in display order | |
| OUTPUT_KEYS = [ | |
| "diffuse", | |
| "specular", | |
| "roughness", | |
| "normals", | |
| "depth", | |
| "albedo", | |
| "inverseshading", | |
| "diffuseshading", | |
| "residual", | |
| ] | |
| def _to_pil(x): | |
| """Convert pipeline output (PIL / numpy / torch) into a PIL.Image.""" | |
| if x is None: | |
| return None | |
| if isinstance(x, Image.Image): | |
| return x | |
| if torch.is_tensor(x): | |
| x = x.detach().float().cpu().numpy() | |
| if isinstance(x, np.ndarray): | |
| arr = x | |
| # handle CHW -> HWC | |
| if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4): | |
| arr = np.transpose(arr, (1, 2, 0)) | |
| # If float in [0..1] (common), scale to [0..255] | |
| if arr.dtype != np.uint8: | |
| arr = np.clip(arr, 0.0, 1.0) | |
| arr = (arr * 255.0).round().astype(np.uint8) | |
| # If single channel, make it 'L' | |
| if arr.ndim == 2: | |
| return Image.fromarray(arr, mode="L") | |
| # If HWC with 1 channel, squeeze to L | |
| if arr.ndim == 3 and arr.shape[2] == 1: | |
| return Image.fromarray(arr[:, :, 0], mode="L") | |
| # Otherwise RGB/RGBA | |
| return Image.fromarray(arr) | |
| # Fallback: try to coerce | |
| try: | |
| return Image.fromarray(x) | |
| except Exception: | |
| return None | |
| def process_image(input_image, num_inference_steps=28, guidance_scale=1): | |
| if input_image is None: | |
| return [], "Please upload an image." | |
| if not isinstance(input_image, Image.Image): | |
| input_image = Image.fromarray(input_image) | |
| with torch.no_grad(): | |
| result = pipe( | |
| {"image": input_image}, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| ) | |
| # Build a labeled gallery: [(image, "label"), ...] | |
| gallery_items = [] | |
| missing = [] | |
| for k in OUTPUT_KEYS: | |
| if k in result and result[k] is not None: | |
| img = _to_pil(result[k]) | |
| if img is not None: | |
| gallery_items.append((img, k)) | |
| else: | |
| missing.append(k) | |
| else: | |
| missing.append(k) | |
| status = "Successfully processed!" | |
| if missing: | |
| status += f" Missing/unreadable keys: {', '.join(missing)}" | |
| return gallery_items, status | |
| with gr.Blocks(title="Teamwork: Collaborative Diffusion") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎨 Teamwork: Collaborative Diffusion | |
| Upload an image to process it using the Teamwork model for intrinsic decomposition. | |
| [Paper](https://bin.samsartor.com/teamwork.pdf) | [Code](https://github.com/samsartor/teamwork) | [Models](https://huggingface.co/samsartor/teamwork-release) | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input Image", type="pil") | |
| num_steps = gr.Slider(1, 50, value=28, step=1, label="Inference Steps") | |
| guidance = gr.Slider(1.0, 10.0, value=1.0, step=0.5, label="Guidance Scale") | |
| submit_btn = gr.Button("Process Image", variant="primary") | |
| with gr.Column(): | |
| outputs_gallery = gr.Gallery( | |
| label="Outputs", | |
| columns=3, | |
| height="auto", | |
| preview=True, | |
| show_label=True, | |
| ) | |
| status_text = gr.Textbox(label="Status", interactive=False) | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[input_image, num_steps, guidance], | |
| outputs=[outputs_gallery, status_text], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["https://raw.githubusercontent.com/samsartor/teamwork/refs/heads/main/demo/red_glass_sphere.png"], | |
| ["https://raw.githubusercontent.com/samsartor/teamwork/refs/heads/main/demo/gift_card.png"], | |
| ["https://raw.githubusercontent.com/samsartor/teamwork/refs/heads/main/demo/pink_input.png"], | |
| ], | |
| inputs=input_image, | |
| label="Example Images", | |
| ) | |
| demo.launch() | |