Add files using upload-large-folder tool
Browse files
demo_images/GeoSynth-Canny/output.jpeg
CHANGED
|
|
demo_images/GeoSynth-OSM/output.jpeg
CHANGED
|
|
demo_images/GeoSynth-SAM/output.jpeg
CHANGED
|
|
run_demo_inference.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Run ControlNet inference for OSM, Canny, SAM and save outputs to demo_images/."""
|
| 3 |
+
import os
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
|
| 9 |
+
|
| 10 |
+
REPO = Path(__file__).resolve().parent
|
| 11 |
+
PROMPT = "Satellite image features a city neighborhood"
|
| 12 |
+
SEED = 42
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
for control_type in ["OSM", "Canny", "SAM"]:
|
| 17 |
+
subfolder = f"controlnet/GeoSynth-{control_type}"
|
| 18 |
+
in_dir = REPO / "demo_images" / f"GeoSynth-{control_type}"
|
| 19 |
+
in_path = in_dir / "input.jpeg"
|
| 20 |
+
out_path = in_dir / "output.jpeg"
|
| 21 |
+
if not in_path.exists():
|
| 22 |
+
print(f"Skipping {control_type}: {in_path} not found")
|
| 23 |
+
continue
|
| 24 |
+
print(f"Loading GeoSynth-{control_type}...")
|
| 25 |
+
controlnet = ControlNetModel.from_pretrained(
|
| 26 |
+
str(REPO), subfolder=subfolder, torch_dtype=torch.float16
|
| 27 |
+
)
|
| 28 |
+
pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
| 29 |
+
str(REPO), controlnet=controlnet, torch_dtype=torch.float16
|
| 30 |
+
)
|
| 31 |
+
pipe = pipe.to("cuda")
|
| 32 |
+
img = Image.open(in_path).convert("RGB").resize((512, 512))
|
| 33 |
+
gen = torch.manual_seed(SEED)
|
| 34 |
+
out = pipe(
|
| 35 |
+
PROMPT, image=img, generator=gen, num_inference_steps=20
|
| 36 |
+
).images[0]
|
| 37 |
+
out.save(out_path)
|
| 38 |
+
print(f"Saved {out_path}")
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
main()
|