File size: 2,619 Bytes
05f3422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0643d49
05f3422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0643d49
 
 
 
05f3422
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/env python3
"""
Full diffusers-style inference demo for ControlEarth.
Uses only this repo's checkpoints (base + controlnet), no external downloads.
"""
from pathlib import Path

import torch
from diffusers import (
    ControlNetModel,
    StableDiffusionControlNetPipeline,
    UniPCMultistepScheduler,
)
from PIL import Image


def main():
    repo = Path(__file__).resolve().parent
    controlnet_path = repo / "controlnet"
    demo_images = repo / "demo_images"
    out_dir = repo / "outputs"
    demo_out = repo / "demo_images"
    out_dir.mkdir(exist_ok=True)

    print("Loading ControlNet...")
    controlnet = ControlNetModel.from_pretrained(
        str(controlnet_path), torch_dtype=torch.float16
    )

    print("Loading pipeline (base + controlnet)...")
    pipe = StableDiffusionControlNetPipeline.from_pretrained(
        str(repo),
        controlnet=controlnet,
        torch_dtype=torch.float16,
        safety_checker=None,
        requires_safety_checker=False,
    )
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.enable_model_cpu_offload()

    prompt = "convert this openstreetmap into its satellite view"
    num_inference_steps = 50

    # Find OSM images in demo_images
    image_exts = {".png", ".jpg", ".jpeg", ".webp"}
    image_paths = [
        p for p in demo_images.iterdir()
        if p.is_file() and p.suffix.lower() in image_exts
    ]

    if not image_paths:
        print(f"No images found in {demo_images}. Creating a placeholder run.")
        # Create minimal 512x512 RGB placeholder for demo
        placeholder = Image.new("RGB", (512, 512), color=(200, 200, 200))
        image_paths = [None]
        control_images = [placeholder]
    else:
        control_images = [
            Image.open(p).convert("RGB") for p in sorted(image_paths)
        ]

    for idx, control_image in enumerate(control_images):
        name = image_paths[idx].stem if image_paths[idx] else "placeholder"
        print(f"Generating for {name}...")
        for i in range(3):
            result = pipe(
                prompt,
                num_inference_steps=num_inference_steps,
                image=control_image,
            ).images[0]
            out_path = out_dir / f"{name}-{i}.png"
            result.save(out_path)
            print(f"  Saved {out_path}")
            if i == 0:
                demo_out.mkdir(exist_ok=True)
                result.save(demo_out / "output.jpeg")
                print(f"  Saved demo {demo_out / 'output.jpeg'}")

    print(f"Done. Outputs in {out_dir}")


if __name__ == "__main__":
    main()