#!/usr/bin/env python3 """ Full diffusers-style inference demo for ControlEarth. Uses only this repo's checkpoints (base + controlnet), no external downloads. """ from pathlib import Path import torch from diffusers import ( ControlNetModel, StableDiffusionControlNetPipeline, UniPCMultistepScheduler, ) from PIL import Image def main(): repo = Path(__file__).resolve().parent controlnet_path = repo / "controlnet" demo_images = repo / "demo_images" out_dir = repo / "outputs" demo_out = repo / "demo_images" out_dir.mkdir(exist_ok=True) print("Loading ControlNet...") controlnet = ControlNetModel.from_pretrained( str(controlnet_path), torch_dtype=torch.float16 ) print("Loading pipeline (base + controlnet)...") pipe = StableDiffusionControlNetPipeline.from_pretrained( str(repo), controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False, ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() prompt = "convert this openstreetmap into its satellite view" num_inference_steps = 50 # Find OSM images in demo_images image_exts = {".png", ".jpg", ".jpeg", ".webp"} image_paths = [ p for p in demo_images.iterdir() if p.is_file() and p.suffix.lower() in image_exts ] if not image_paths: print(f"No images found in {demo_images}. Creating a placeholder run.") # Create minimal 512x512 RGB placeholder for demo placeholder = Image.new("RGB", (512, 512), color=(200, 200, 200)) image_paths = [None] control_images = [placeholder] else: control_images = [ Image.open(p).convert("RGB") for p in sorted(image_paths) ] for idx, control_image in enumerate(control_images): name = image_paths[idx].stem if image_paths[idx] else "placeholder" print(f"Generating for {name}...") for i in range(3): result = pipe( prompt, num_inference_steps=num_inference_steps, image=control_image, ).images[0] out_path = out_dir / f"{name}-{i}.png" result.save(out_path) print(f" Saved {out_path}") if i == 0: demo_out.mkdir(exist_ok=True) result.save(demo_out / "output.jpeg") print(f" Saved demo {demo_out / 'output.jpeg'}") print(f"Done. Outputs in {out_dir}") if __name__ == "__main__": main()