| |
| """ |
| Full diffusers-style inference demo for ControlEarth. |
| Uses only this repo's checkpoints (base + controlnet), no external downloads. |
| """ |
| from pathlib import Path |
|
|
| import torch |
| from diffusers import ( |
| ControlNetModel, |
| StableDiffusionControlNetPipeline, |
| UniPCMultistepScheduler, |
| ) |
| from PIL import Image |
|
|
|
|
| def main(): |
| repo = Path(__file__).resolve().parent |
| controlnet_path = repo / "controlnet" |
| demo_images = repo / "demo_images" |
| out_dir = repo / "outputs" |
| demo_out = repo / "demo_images" |
| out_dir.mkdir(exist_ok=True) |
|
|
| print("Loading ControlNet...") |
| controlnet = ControlNetModel.from_pretrained( |
| str(controlnet_path), torch_dtype=torch.float16 |
| ) |
|
|
| print("Loading pipeline (base + controlnet)...") |
| pipe = StableDiffusionControlNetPipeline.from_pretrained( |
| str(repo), |
| controlnet=controlnet, |
| torch_dtype=torch.float16, |
| safety_checker=None, |
| requires_safety_checker=False, |
| ) |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
| pipe.enable_model_cpu_offload() |
|
|
| prompt = "convert this openstreetmap into its satellite view" |
| num_inference_steps = 50 |
|
|
| |
| image_exts = {".png", ".jpg", ".jpeg", ".webp"} |
| image_paths = [ |
| p for p in demo_images.iterdir() |
| if p.is_file() and p.suffix.lower() in image_exts |
| ] |
|
|
| if not image_paths: |
| print(f"No images found in {demo_images}. Creating a placeholder run.") |
| |
| placeholder = Image.new("RGB", (512, 512), color=(200, 200, 200)) |
| image_paths = [None] |
| control_images = [placeholder] |
| else: |
| control_images = [ |
| Image.open(p).convert("RGB") for p in sorted(image_paths) |
| ] |
|
|
| for idx, control_image in enumerate(control_images): |
| name = image_paths[idx].stem if image_paths[idx] else "placeholder" |
| print(f"Generating for {name}...") |
| for i in range(3): |
| result = pipe( |
| prompt, |
| num_inference_steps=num_inference_steps, |
| image=control_image, |
| ).images[0] |
| out_path = out_dir / f"{name}-{i}.png" |
| result.save(out_path) |
| print(f" Saved {out_path}") |
| if i == 0: |
| demo_out.mkdir(exist_ok=True) |
| result.save(demo_out / "output.jpeg") |
| print(f" Saved demo {demo_out / 'output.jpeg'}") |
|
|
| print(f"Done. Outputs in {out_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|