controlearth / inference_demo.py
BiliSakura's picture
Add files using upload-large-folder tool
0643d49 verified
#!/usr/bin/env python3
"""
Full diffusers-style inference demo for ControlEarth.
Uses only this repo's checkpoints (base + controlnet), no external downloads.
"""
from pathlib import Path
import torch
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler,
)
from PIL import Image
def main():
repo = Path(__file__).resolve().parent
controlnet_path = repo / "controlnet"
demo_images = repo / "demo_images"
out_dir = repo / "outputs"
demo_out = repo / "demo_images"
out_dir.mkdir(exist_ok=True)
print("Loading ControlNet...")
controlnet = ControlNetModel.from_pretrained(
str(controlnet_path), torch_dtype=torch.float16
)
print("Loading pipeline (base + controlnet)...")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
str(repo),
controlnet=controlnet,
torch_dtype=torch.float16,
safety_checker=None,
requires_safety_checker=False,
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
prompt = "convert this openstreetmap into its satellite view"
num_inference_steps = 50
# Find OSM images in demo_images
image_exts = {".png", ".jpg", ".jpeg", ".webp"}
image_paths = [
p for p in demo_images.iterdir()
if p.is_file() and p.suffix.lower() in image_exts
]
if not image_paths:
print(f"No images found in {demo_images}. Creating a placeholder run.")
# Create minimal 512x512 RGB placeholder for demo
placeholder = Image.new("RGB", (512, 512), color=(200, 200, 200))
image_paths = [None]
control_images = [placeholder]
else:
control_images = [
Image.open(p).convert("RGB") for p in sorted(image_paths)
]
for idx, control_image in enumerate(control_images):
name = image_paths[idx].stem if image_paths[idx] else "placeholder"
print(f"Generating for {name}...")
for i in range(3):
result = pipe(
prompt,
num_inference_steps=num_inference_steps,
image=control_image,
).images[0]
out_path = out_dir / f"{name}-{i}.png"
result.save(out_path)
print(f" Saved {out_path}")
if i == 0:
demo_out.mkdir(exist_ok=True)
result.save(demo_out / "output.jpeg")
print(f" Saved demo {demo_out / 'output.jpeg'}")
print(f"Done. Outputs in {out_dir}")
if __name__ == "__main__":
main()