File size: 2,619 Bytes
05f3422 0643d49 05f3422 0643d49 05f3422 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | #!/usr/bin/env python3
"""
Full diffusers-style inference demo for ControlEarth.
Uses only this repo's checkpoints (base + controlnet), no external downloads.
"""
from pathlib import Path
import torch
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler,
)
from PIL import Image
def main():
repo = Path(__file__).resolve().parent
controlnet_path = repo / "controlnet"
demo_images = repo / "demo_images"
out_dir = repo / "outputs"
demo_out = repo / "demo_images"
out_dir.mkdir(exist_ok=True)
print("Loading ControlNet...")
controlnet = ControlNetModel.from_pretrained(
str(controlnet_path), torch_dtype=torch.float16
)
print("Loading pipeline (base + controlnet)...")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
str(repo),
controlnet=controlnet,
torch_dtype=torch.float16,
safety_checker=None,
requires_safety_checker=False,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
prompt = "convert this openstreetmap into its satellite view"
num_inference_steps = 50
# Find OSM images in demo_images
image_exts = {".png", ".jpg", ".jpeg", ".webp"}
image_paths = [
p for p in demo_images.iterdir()
if p.is_file() and p.suffix.lower() in image_exts
]
if not image_paths:
print(f"No images found in {demo_images}. Creating a placeholder run.")
# Create minimal 512x512 RGB placeholder for demo
placeholder = Image.new("RGB", (512, 512), color=(200, 200, 200))
image_paths = [None]
control_images = [placeholder]
else:
control_images = [
Image.open(p).convert("RGB") for p in sorted(image_paths)
]
for idx, control_image in enumerate(control_images):
name = image_paths[idx].stem if image_paths[idx] else "placeholder"
print(f"Generating for {name}...")
for i in range(3):
result = pipe(
prompt,
num_inference_steps=num_inference_steps,
image=control_image,
).images[0]
out_path = out_dir / f"{name}-{i}.png"
result.save(out_path)
print(f" Saved {out_path}")
if i == 0:
demo_out.mkdir(exist_ok=True)
result.save(demo_out / "output.jpeg")
print(f" Saved demo {demo_out / 'output.jpeg'}")
print(f"Done. Outputs in {out_dir}")
if __name__ == "__main__":
main()
|