BiliSakura commited on
Commit
aa2ba63
·
verified ·
1 Parent(s): 9e7e880

Add files using upload-large-folder tool

Browse files
Files changed (38) hide show
  1. .gitattributes +1 -0
  2. README.md +115 -0
  3. __pycache__/pipeline.cpython-312.pyc +0 -0
  4. controlnet/GeoSynth-Location-Canny/config.json +56 -0
  5. controlnet/GeoSynth-Location-Canny/diffusion_pytorch_model.safetensors +3 -0
  6. controlnet/GeoSynth-Location-OSM/config.json +56 -0
  7. controlnet/GeoSynth-Location-OSM/diffusion_pytorch_model.safetensors +3 -0
  8. controlnet/GeoSynth-Location-SAM/config.json +56 -0
  9. convert_ckpts.py +50 -0
  10. coordnet/config.json +8 -0
  11. coordnet/diffusion_pytorch_model.safetensors +3 -0
  12. demo_images/GeoSynth-Location-Canny/input.jpeg +3 -0
  13. demo_images/GeoSynth-Location-Canny/output.jpeg +0 -0
  14. demo_images/GeoSynth-Location-OSM/input.jpeg +0 -0
  15. demo_images/GeoSynth-Location-OSM/output.jpeg +0 -0
  16. demo_images/GeoSynth-Location-SAM/input.jpeg +0 -0
  17. demo_outputs/output_input_lon-90.2_lat38.6.jpg +0 -0
  18. feature_extractor/preprocessor_config.json +27 -0
  19. geosynth_pipeline/__init__.py +13 -0
  20. geosynth_pipeline/controlnet_location_wrapper.py +51 -0
  21. geosynth_pipeline/coordnet.py +87 -0
  22. geosynth_pipeline/pipeline.py +106 -0
  23. geosynth_pipeline/satclip_location_encoder.py +116 -0
  24. geosynth_pipeline/spherical_harmonics.py +75 -0
  25. inference_demo.py +77 -0
  26. model_index.json +37 -0
  27. satclip_location_encoder/config.json +8 -0
  28. satclip_location_encoder/diffusion_pytorch_model.safetensors +3 -0
  29. scheduler/scheduler_config.json +36 -0
  30. test_output.jpeg +0 -0
  31. text_encoder/config.json +24 -0
  32. text_encoder/model.safetensors +3 -0
  33. tokenizer/tokenizer.json +0 -0
  34. tokenizer/tokenizer_config.json +14 -0
  35. unet/config.json +72 -0
  36. unet/diffusion_pytorch_model.safetensors +3 -0
  37. vae/config.json +38 -0
  38. vae/diffusion_pytorch_model.safetensors +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ demo_images/GeoSynth-Location-Canny/input.jpeg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: diffusers
4
+ pipeline_tag: image-to-image
5
+ tags:
6
+ - controlnet
7
+ - remote-sensing
8
+ - arxiv:2404.06637
9
+ widget:
10
+ # GeoSynth-Location-OSM: OSM tile -> satellite image (default lon=-90.2, lat=38.6)
11
+ - src: demo_images/GeoSynth-Location-OSM/input.jpeg
12
+ prompt: Satellite image features a city neighborhood
13
+ output:
14
+ url: demo_images/GeoSynth-Location-OSM/output.jpeg
15
+ # GeoSynth-Location-Canny: Canny edges -> satellite image (default lon=-90.2, lat=38.6)
16
+ - src: demo_images/GeoSynth-Location-Canny/input.jpeg
17
+ prompt: Satellite image features a city neighborhood
18
+ output:
19
+ url: demo_images/GeoSynth-Location-Canny/output.jpeg
20
+ # GeoSynth-Location-SAM: SAM segmentation -> satellite image
21
+ - src: demo_images/GeoSynth-Location-SAM/input.jpeg
22
+ prompt: Satellite image features a city neighborhood
23
+ output:
24
+ url: demo_images/GeoSynth-Location-SAM/output.jpeg
25
+ ---
26
+
27
+ # GeoSynth-ControlNets-Location
28
+
29
+ Repository for location-conditioned GeoSynth ControlNets.
30
+
31
+ **Location (lon/lat) conditioning** is the primary workflow for geo-aware synthesis. Default: St. Louis, MO (`lon=-90.2`, `lat=38.6`). Use SatCLIP + CoordNet for full diffusers-style location conditioning.
32
+
33
+ We maintain **two repositories**—one per base checkpoint—each with its compatible ControlNets:
34
+
35
+ | Repo | Base Model | ControlNets |
36
+ |------|------------|-------------|
37
+ | **[GeoSynth-ControlNets](https://huggingface.co/BiliSakura/GeoSynth-ControlNets)** | GeoSynth (text encoder & UNet same as SD 2.1) | GeoSynth-OSM, GeoSynth-Canny, GeoSynth-SAM |
38
+ | **This repo** | GeoSynth-Location (adds CoordNet branch) | GeoSynth-Location-OSM, GeoSynth-Location-SAM*, GeoSynth-Location-Canny |
39
+
40
+ *[GeoSynth-Location-SAM](https://huggingface.co/MVRL/GeoSynth-Location-SAM) controlnet ckpt is missing from source.*
41
+
42
+ ### This repository
43
+
44
+ 1. **GeoSynth-Location base** — Converted from `geosynth_sd_loc-v3.ckpt` to diffusers format. Text encoder and UNet are the same as SD 2.1 (not fine-tuned). The original checkpoint also includes a CoordNet branch for `[lon, lat]` conditioning (see Architecture).
45
+ 2. **ControlNet models** — GeoSynth-Location-OSM, GeoSynth-Location-Canny (converted from SD-style checkpoints under `MVRL/GeoSynth-Location-OSM` and `MVRL/GeoSynth-Location-Canny`), and GeoSynth-Location-SAM, located under [`controlnet/`](controlnet/).
46
+
47
+ ### Architecture
48
+
49
+ The full location pipeline adds a **CoordNet** branch to the base LDM:
50
+
51
+ - **Input**: `[lon, lat]` → **SatCLIP** location encoder → **CoordNet** (13 stacked cross-attention blocks, inner dim 256, 4 heads) → conditioning injected into UNet
52
+ - ControlNet and CoordNet jointly condition the UNet (see [GeoSynth paper](https://huggingface.co/papers/2404.06637) Figure 3)
53
+
54
+ ### ControlNet variants (this repo)
55
+
56
+ | Control | Subfolder | Status |
57
+ |---------|-----------|--------|
58
+ | OSM | `controlnet/GeoSynth-Location-OSM` | ✅ ready |
59
+ | Canny | `controlnet/GeoSynth-Location-Canny` | ✅ ready |
60
+ | SAM | `controlnet/GeoSynth-Location-SAM` | ⏳ ckpt pending |
61
+
62
+ ### Model Sources
63
+
64
+ - **Source:** [GeoSynth](https://github.com/mvrl/GeoSynth)
65
+ - **Paper:** [GeoSynth: Contextually-Aware High-Resolution Satellite Image Synthesis](https://huggingface.co/papers/2404.06637)
66
+ - **Base model:** [Stable Diffusion 2.1](https://huggingface.co/sd2-community/stable-diffusion-2-1-base)
67
+ - **Related:** [GeoSynth-ControlNets](https://huggingface.co/BiliSakura/GeoSynth-ControlNets) (non-location models)
68
+
69
+ ## Usage
70
+
71
+ **CLI:**
72
+ ```bash
73
+ python inference_demo.py --control demo_images/GeoSynth-Location-OSM/input.jpeg --control_type OSM --lon -90.2 --lat 38.6
74
+ ```
75
+
76
+ **Python:**
77
+ ```python
78
+ import sys, os
79
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
80
+
81
+ import torch
82
+ from PIL import Image
83
+ from geosynth_pipeline import load_geosynth_pipeline_with_location, run_with_location
84
+
85
+ pipe = load_geosynth_pipeline_with_location(".", controlnet_subfolder="controlnet/GeoSynth-Location-OSM", local_files_only=True)
86
+ pipe = pipe.to("cuda")
87
+
88
+ img = Image.open("demo_images/GeoSynth-Location-OSM/input.jpeg").convert("RGB").resize((512, 512))
89
+ output = run_with_location(pipe, "Satellite image features a city neighborhood", image=img, lon=-90.2, lat=38.6)
90
+ output.images[0].save("generated_city.jpg")
91
+ ```
92
+
93
+ ## Citation
94
+
95
+ If you use this model, please cite the GeoSynth paper. For location-conditioned variants, also cite SatCLIP.
96
+
97
+ ```bibtex
98
+ @inproceedings{sastry2024geosynth,
99
+ title={GeoSynth: Contextually-Aware High-Resolution Satellite Image Synthesis},
100
+ author={Sastry, Srikumar and Khanal, Subash and Dhakal, Aayush and Jacobs, Nathan},
101
+ booktitle={IEEE/ISPRS Workshop: Large Scale Computer Vision for Remote Sensing (EARTHVISION)},
102
+ year={2024}
103
+ }
104
+
105
+ @article{klemmer2025satclip,
106
+ title={{SatCLIP}: {Global}, General-Purpose Location Embeddings with Satellite Imagery},
107
+ author={Klemmer, Konstantin and Rolf, Esther and Robinson, Caleb and Mackey, Lester and Ru{\ss}wurm, Marc},
108
+ journal={Proceedings of the AAAI Conference on Artificial Intelligence},
109
+ volume={39},
110
+ number={4},
111
+ pages={4347--4355},
112
+ year={2025},
113
+ doi={10.1609/aaai.v39i4.32457}
114
+ }
115
+ ```
__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (5.9 kB). View file
 
controlnet/GeoSynth-Location-Canny/config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.37.0",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": null,
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": null,
8
+ "attention_head_dim": [
9
+ 5,
10
+ 10,
11
+ 20,
12
+ 20
13
+ ],
14
+ "block_out_channels": [
15
+ 320,
16
+ 640,
17
+ 1280,
18
+ 1280
19
+ ],
20
+ "class_embed_type": null,
21
+ "conditioning_channels": 3,
22
+ "conditioning_embedding_out_channels": [
23
+ 16,
24
+ 32,
25
+ 96,
26
+ 256
27
+ ],
28
+ "controlnet_conditioning_channel_order": "rgb",
29
+ "cross_attention_dim": 1024,
30
+ "down_block_types": [
31
+ "CrossAttnDownBlock2D",
32
+ "CrossAttnDownBlock2D",
33
+ "CrossAttnDownBlock2D",
34
+ "DownBlock2D"
35
+ ],
36
+ "downsample_padding": 1,
37
+ "encoder_hid_dim": null,
38
+ "encoder_hid_dim_type": null,
39
+ "flip_sin_to_cos": true,
40
+ "freq_shift": 0,
41
+ "global_pool_conditions": false,
42
+ "in_channels": 4,
43
+ "layers_per_block": 2,
44
+ "mid_block_scale_factor": 1,
45
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
46
+ "norm_eps": 1e-05,
47
+ "norm_num_groups": 32,
48
+ "num_attention_heads": null,
49
+ "num_class_embeds": null,
50
+ "only_cross_attention": false,
51
+ "projection_class_embeddings_input_dim": null,
52
+ "resnet_time_scale_shift": "default",
53
+ "transformer_layers_per_block": 1,
54
+ "upcast_attention": true,
55
+ "use_linear_projection": true
56
+ }
controlnet/GeoSynth-Location-Canny/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22f7bab41c6a6da224c1987e2b8d69e5b8e5f45218f78ee43ec3eb55bbd20508
3
+ size 1456953560
controlnet/GeoSynth-Location-OSM/config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.37.0",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": null,
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": null,
8
+ "attention_head_dim": [
9
+ 5,
10
+ 10,
11
+ 20,
12
+ 20
13
+ ],
14
+ "block_out_channels": [
15
+ 320,
16
+ 640,
17
+ 1280,
18
+ 1280
19
+ ],
20
+ "class_embed_type": null,
21
+ "conditioning_channels": 3,
22
+ "conditioning_embedding_out_channels": [
23
+ 16,
24
+ 32,
25
+ 96,
26
+ 256
27
+ ],
28
+ "controlnet_conditioning_channel_order": "rgb",
29
+ "cross_attention_dim": 1024,
30
+ "down_block_types": [
31
+ "CrossAttnDownBlock2D",
32
+ "CrossAttnDownBlock2D",
33
+ "CrossAttnDownBlock2D",
34
+ "DownBlock2D"
35
+ ],
36
+ "downsample_padding": 1,
37
+ "encoder_hid_dim": null,
38
+ "encoder_hid_dim_type": null,
39
+ "flip_sin_to_cos": true,
40
+ "freq_shift": 0,
41
+ "global_pool_conditions": false,
42
+ "in_channels": 4,
43
+ "layers_per_block": 2,
44
+ "mid_block_scale_factor": 1,
45
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
46
+ "norm_eps": 1e-05,
47
+ "norm_num_groups": 32,
48
+ "num_attention_heads": null,
49
+ "num_class_embeds": null,
50
+ "only_cross_attention": false,
51
+ "projection_class_embeddings_input_dim": null,
52
+ "resnet_time_scale_shift": "default",
53
+ "transformer_layers_per_block": 1,
54
+ "upcast_attention": true,
55
+ "use_linear_projection": true
56
+ }
controlnet/GeoSynth-Location-OSM/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5119803f858a8038a127b3251a582e8bd7c3e9f20cdfb8ae6f2bdc413ae880f4
3
+ size 1456953560
controlnet/GeoSynth-Location-SAM/config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ControlNetModel",
3
+ "_diffusers_version": "0.37.0",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": null,
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": null,
8
+ "attention_head_dim": [
9
+ 5,
10
+ 10,
11
+ 20,
12
+ 20
13
+ ],
14
+ "block_out_channels": [
15
+ 320,
16
+ 640,
17
+ 1280,
18
+ 1280
19
+ ],
20
+ "class_embed_type": null,
21
+ "conditioning_channels": 3,
22
+ "conditioning_embedding_out_channels": [
23
+ 16,
24
+ 32,
25
+ 96,
26
+ 256
27
+ ],
28
+ "controlnet_conditioning_channel_order": "rgb",
29
+ "cross_attention_dim": 1024,
30
+ "down_block_types": [
31
+ "CrossAttnDownBlock2D",
32
+ "CrossAttnDownBlock2D",
33
+ "CrossAttnDownBlock2D",
34
+ "DownBlock2D"
35
+ ],
36
+ "downsample_padding": 1,
37
+ "encoder_hid_dim": null,
38
+ "encoder_hid_dim_type": null,
39
+ "flip_sin_to_cos": true,
40
+ "freq_shift": 0,
41
+ "global_pool_conditions": false,
42
+ "in_channels": 4,
43
+ "layers_per_block": 2,
44
+ "mid_block_scale_factor": 1,
45
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
46
+ "norm_eps": 1e-05,
47
+ "norm_num_groups": 32,
48
+ "num_attention_heads": null,
49
+ "num_class_embeds": null,
50
+ "only_cross_attention": false,
51
+ "projection_class_embeddings_input_dim": null,
52
+ "resnet_time_scale_shift": "default",
53
+ "transformer_layers_per_block": 1,
54
+ "upcast_attention": true,
55
+ "use_linear_projection": true
56
+ }
convert_ckpts.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Convert SatCLIP and CoordNet from source .ckpt to diffusers-style format.
4
+ Saves into this repo (satclip_location_encoder/, coordnet/) for self-contained inference.
5
+
6
+ Usage:
7
+ 1. Download SatCLIP: huggingface.co/microsoft/SatCLIP-ViT16-L10 -> satclip-vit16-l10.ckpt
8
+ 2. Download GeoSynth-Location-OSM: huggingface.co/MVRL/GeoSynth-Location-OSM -> geosynth_loc_osm.ckpt
9
+ 3. Run: python convert_ckpts.py --satclip_ckpt /path/to/satclip-vit16-l10.ckpt --coordnet_ckpt /path/to/geosynth_loc_osm.ckpt
10
+ """
11
+ import argparse
12
+ import os
13
+ import sys
14
+
15
+ REPO_DIR = os.path.dirname(os.path.abspath(__file__))
16
+ sys.path.insert(0, REPO_DIR)
17
+
18
+ from geosynth_pipeline import GeoSynthCoordNetModel, SatCLIPLocationEncoder
19
+
20
+
21
+ def main():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--satclip_ckpt", required=True, help="Path to satclip-vit16-l10.ckpt")
24
+ parser.add_argument("--coordnet_ckpt", required=True, help="Path to geosynth_loc_osm.ckpt")
25
+ args = parser.parse_args()
26
+
27
+ satclip_out = os.path.join(REPO_DIR, "satclip_location_encoder")
28
+ coordnet_out = os.path.join(REPO_DIR, "coordnet")
29
+
30
+ if not os.path.isfile(args.satclip_ckpt):
31
+ sys.exit(f"SatCLIP checkpoint not found: {args.satclip_ckpt}")
32
+ if not os.path.isfile(args.coordnet_ckpt):
33
+ sys.exit(f"CoordNet checkpoint not found: {args.coordnet_ckpt}")
34
+
35
+ print("Loading SatCLIP location encoder...")
36
+ loc_enc = SatCLIPLocationEncoder.from_satclip_ckpt(args.satclip_ckpt)
37
+ os.makedirs(satclip_out, exist_ok=True)
38
+ loc_enc.save_pretrained(satclip_out)
39
+ print(f"Saved to {satclip_out}")
40
+
41
+ print("Loading CoordNet...")
42
+ coordnet = GeoSynthCoordNetModel.from_geosynth_ckpt(args.coordnet_ckpt)
43
+ os.makedirs(coordnet_out, exist_ok=True)
44
+ coordnet.save_pretrained(coordnet_out)
45
+ print(f"Saved to {coordnet_out}")
46
+ print("Done. Use inference_demo.py with local paths.")
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
coordnet/config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "GeoSynthCoordNetModel",
3
+ "_diffusers_version": "0.37.0",
4
+ "block_out_channels": null,
5
+ "embed_dim": 256,
6
+ "num_heads": 4,
7
+ "time_embed_dim": 1280
8
+ }
coordnet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:757124ffa2752db3822a70e40684da275648b9f9a4e1efa8eda407f452273576
3
+ size 65790520
demo_images/GeoSynth-Location-Canny/input.jpeg ADDED

Git LFS Details

  • SHA256: dba59aca702f3a6a6878e74a69a8fc70298e92933d676067852af168cd098294
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB
demo_images/GeoSynth-Location-Canny/output.jpeg ADDED
demo_images/GeoSynth-Location-OSM/input.jpeg ADDED
demo_images/GeoSynth-Location-OSM/output.jpeg ADDED
demo_images/GeoSynth-Location-SAM/input.jpeg ADDED
demo_outputs/output_input_lon-90.2_lat38.6.jpg ADDED
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "image_mean": [
12
+ 0.48145466,
13
+ 0.4578275,
14
+ 0.40821073
15
+ ],
16
+ "image_processor_type": "CLIPImageProcessor",
17
+ "image_std": [
18
+ 0.26862954,
19
+ 0.26130258,
20
+ 0.27577711
21
+ ],
22
+ "resample": 3,
23
+ "rescale_factor": 0.00392156862745098,
24
+ "size": {
25
+ "shortest_edge": 224
26
+ }
27
+ }
geosynth_pipeline/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GeoSynth custom pipeline: SatCLIP + CoordNet location conditioning
2
+ from .satclip_location_encoder import SatCLIPLocationEncoder
3
+ from .coordnet import GeoSynthCoordNetModel
4
+ from .controlnet_location_wrapper import GeoSynthControlNetWithLocation
5
+ from .pipeline import load_geosynth_pipeline_with_location, run_with_location
6
+
7
+ __all__ = [
8
+ "SatCLIPLocationEncoder",
9
+ "GeoSynthCoordNetModel",
10
+ "GeoSynthControlNetWithLocation",
11
+ "load_geosynth_pipeline_with_location",
12
+ "run_with_location",
13
+ ]
geosynth_pipeline/controlnet_location_wrapper.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wrapper combining ControlNet + CoordNet for location-conditioned generation.
2
+ Subclasses ControlNetModel so pipeline's isinstance checks pass."""
3
+ from typing import Any, Optional
4
+
5
+ import torch
6
+ from diffusers.models import ControlNetModel
7
+
8
+ from .coordnet import GeoSynthCoordNetModel
9
+ from .satclip_location_encoder import SatCLIPLocationEncoder
10
+
11
+
12
+ class GeoSynthControlNetWithLocation(ControlNetModel):
13
+ """ControlNetModel + CoordNet. Passes isinstance(ControlNetModel) for pipeline compatibility."""
14
+
15
+ def __init__(self, controlnet: ControlNetModel, coordnet: GeoSynthCoordNetModel,
16
+ location_encoder: SatCLIPLocationEncoder, unet: torch.nn.Module):
17
+ cfg = dict(controlnet.config)
18
+ super().__init__(**cfg)
19
+ self.load_state_dict(controlnet.state_dict(), strict=True)
20
+ self._coordnet = coordnet
21
+ self._location_encoder = location_encoder
22
+ self._unet = unet
23
+ self._location_embedding: Optional[torch.Tensor] = None
24
+
25
+ def set_location_embedding(self, embedding: torch.Tensor):
26
+ self._location_embedding = embedding
27
+
28
+ def forward(self, sample: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: torch.Tensor,
29
+ controlnet_cond: torch.Tensor, conditioning_scale: float = 1.0, guess_mode: bool = False,
30
+ return_dict: bool = True, **kwargs) -> Any:
31
+ down_block_res_samples, mid_block_res_sample = super().forward(
32
+ sample, timestep, encoder_hidden_states=encoder_hidden_states, controlnet_cond=controlnet_cond,
33
+ conditioning_scale=conditioning_scale, guess_mode=guess_mode, return_dict=False, **kwargs
34
+ )
35
+ if self._location_embedding is not None and self._coordnet is not None:
36
+ coord_dtype = next(self._coordnet.parameters()).dtype
37
+ loc_emb = self._location_embedding.to(coord_dtype)
38
+ t_emb = self._unet.get_time_embed(sample=sample, timestep=timestep)
39
+ time_emb = self._unet.time_embedding(t_emb, None).to(coord_dtype)
40
+ down_shapes = [tuple(d.shape[2:]) for d in down_block_res_samples]
41
+ mid_shape = tuple(mid_block_res_sample.shape[2:])
42
+ loc_down, loc_mid = self._coordnet(
43
+ loc_emb, time_emb,
44
+ down_block_sample_shapes=down_shapes, mid_block_sample_shape=mid_shape, return_dict=False
45
+ )
46
+ down_block_res_samples = [c + l for c, l in zip(down_block_res_samples, loc_down)]
47
+ mid_block_res_sample = mid_block_res_sample + loc_mid
48
+ if return_dict:
49
+ from diffusers.models.controlnets.controlnet import ControlNetOutput
50
+ return ControlNetOutput(down_block_res_samples=tuple(down_block_res_samples), mid_block_res_sample=mid_block_res_sample)
51
+ return down_block_res_samples, mid_block_res_sample
geosynth_pipeline/coordnet.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GeoSynth CoordNet - diffusers-style model.
3
+ 13 cross-attention blocks producing location conditioning for UNet.
4
+ """
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from torch import nn
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+
12
+
13
+ class CoordNetLocationBlock(nn.Module):
14
+ def __init__(self, embed_dim: int = 256, out_channels: int = 320, num_heads: int = 4, time_embed_dim: int = 1280):
15
+ super().__init__()
16
+ self.query_embed = nn.Linear(embed_dim, embed_dim, bias=False)
17
+ self.key_embed = nn.Linear(time_embed_dim, embed_dim, bias=False)
18
+ self.value_embed = nn.Linear(time_embed_dim, embed_dim, bias=False)
19
+ self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
20
+ self.norm1 = nn.LayerNorm(embed_dim)
21
+ self.norm2 = nn.LayerNorm(embed_dim)
22
+ self.ff = nn.Linear(embed_dim, embed_dim, bias=False)
23
+ self.ff_zero = nn.Linear(embed_dim, out_channels, bias=False)
24
+ nn.init.zeros_(self.ff_zero.weight)
25
+
26
+ def forward(self, loc_emb: torch.Tensor, time_emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
27
+ q = self.query_embed(loc_emb)
28
+ k = self.key_embed(time_emb).unsqueeze(1)
29
+ v = self.value_embed(time_emb).unsqueeze(1)
30
+ x, _ = self.cross_attn(q, k, v)
31
+ x = self.norm1(x) + loc_emb
32
+ x = self.norm2(self.ff(x)) + x
33
+ return x, self.ff_zero(x)
34
+
35
+
36
+ class GeoSynthCoordNetModel(ModelMixin, ConfigMixin):
37
+ # 12 down blocks + 1 middle (separate loc_middle_block)
38
+ BLOCK_OUT_CHANNELS = [320, 320, 320, 320, 640, 640, 640, 1280, 1280, 1280, 1280, 1280]
39
+
40
+ @register_to_config
41
+ def __init__(self, embed_dim: int = 256, time_embed_dim: int = 1280, num_heads: int = 4,
42
+ block_out_channels: Optional[List[int]] = None):
43
+ super().__init__()
44
+ chs = block_out_channels or self.BLOCK_OUT_CHANNELS
45
+ self.loc_blocks = nn.ModuleList([
46
+ CoordNetLocationBlock(embed_dim=embed_dim, out_channels=ch, num_heads=num_heads, time_embed_dim=time_embed_dim)
47
+ for ch in chs
48
+ ])
49
+ self.loc_middle_block = CoordNetLocationBlock(embed_dim=embed_dim, out_channels=1280, num_heads=num_heads, time_embed_dim=time_embed_dim)
50
+
51
+ def forward(self, location_embedding: torch.Tensor, time_emb: torch.Tensor,
52
+ down_block_sample_shapes: Optional[List[tuple]] = None, mid_block_sample_shape: Optional[tuple] = None,
53
+ return_dict: bool = True):
54
+ if down_block_sample_shapes is None:
55
+ # SD 2.1 ControlNet down block spatial sizes
56
+ down_block_sample_shapes = [(64, 64), (64, 64), (64, 64), (32, 32), (32, 32), (32, 32),
57
+ (16, 16), (16, 16), (16, 16), (8, 8), (8, 8), (8, 8)]
58
+ if mid_block_sample_shape is None:
59
+ mid_block_sample_shape = (8, 8)
60
+ loc_input = location_embedding.unsqueeze(1)
61
+ down_residuals = []
62
+ for block, (h, w) in zip(self.loc_blocks, down_block_sample_shapes):
63
+ loc_input, loc_zero = block(loc_input, time_emb)
64
+ loc_spatial = loc_zero.squeeze(1).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, h, w)
65
+ down_residuals.append(loc_spatial)
66
+ _, mid_zero = self.loc_middle_block(loc_input, time_emb)
67
+ h, w = mid_block_sample_shape
68
+ mid_spatial = mid_zero.squeeze(1).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, h, w)
69
+ if return_dict:
70
+ return {"down_block_residuals": down_residuals, "mid_block_residual": mid_spatial}
71
+ return down_residuals, mid_spatial
72
+
73
+ @classmethod
74
+ def from_geosynth_ckpt(cls, ckpt_path: str, device: Optional[Union[str, torch.device]] = None) -> "GeoSynthCoordNetModel":
75
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
76
+ state_dict = ckpt.get("state_dict", ckpt)
77
+ model = cls()
78
+ sd = {}
79
+ for k, v in state_dict.items():
80
+ if "control_model.loc_blocks" in k:
81
+ sd[k.replace("control_model.loc_blocks", "loc_blocks")] = v
82
+ elif "control_model.loc_middle_block" in k:
83
+ sd[k.replace("control_model.loc_middle_block", "loc_middle_block")] = v
84
+ model.load_state_dict(sd, strict=True)
85
+ if device is not None:
86
+ model = model.to(device)
87
+ return model
geosynth_pipeline/pipeline.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GeoSynth pipeline with full location conditioning.
3
+ Load from local diffusers-style paths (satclip_location_encoder/, coordnet/) — no external models.
4
+ """
5
+ from typing import Any, Optional, Union
6
+
7
+ import os
8
+ import torch
9
+ from diffusers import StableDiffusionControlNetPipeline
10
+ from diffusers.models import ControlNetModel
11
+
12
+ from .controlnet_location_wrapper import GeoSynthControlNetWithLocation
13
+ from .coordnet import GeoSynthCoordNetModel
14
+ from .satclip_location_encoder import SatCLIPLocationEncoder
15
+
16
+
17
+ def _is_diffusers_dir(path: str) -> bool:
18
+ return os.path.isdir(path) and os.path.isfile(os.path.join(path, "config.json"))
19
+
20
+
21
+ def load_geosynth_pipeline_with_location(
22
+ model_path: str,
23
+ controlnet_subfolder: str = "controlnet/GeoSynth-Location-OSM",
24
+ satclip_path: Optional[str] = None,
25
+ coordnet_path: Optional[str] = None,
26
+ torch_dtype: Optional[torch.dtype] = None,
27
+ device: Optional[str] = None,
28
+ local_files_only: bool = False,
29
+ ) -> StableDiffusionControlNetPipeline:
30
+ """
31
+ Load GeoSynth pipeline with full location conditioning.
32
+ satclip_path and coordnet_path: paths relative to this repo (or abs).
33
+ Use diffusers-style dirs (satclip_location_encoder/, coordnet/) — run convert_ckpts.py first.
34
+ """
35
+ repo_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
36
+ satclip_path = satclip_path or os.path.join(repo_dir, "satclip_location_encoder")
37
+ coordnet_path = coordnet_path or os.path.join(repo_dir, "coordnet")
38
+
39
+ # Resolve to absolute
40
+ if not os.path.isabs(satclip_path):
41
+ satclip_path = os.path.join(repo_dir, satclip_path)
42
+ if not os.path.isabs(coordnet_path):
43
+ coordnet_path = os.path.join(repo_dir, coordnet_path)
44
+
45
+ device = device or ("cuda" if torch.cuda.is_available() else "cpu")
46
+ dtype = torch_dtype or (torch.float16 if torch.cuda.is_available() else torch.float32)
47
+
48
+ controlnet = ControlNetModel.from_pretrained(
49
+ model_path, subfolder=controlnet_subfolder, torch_dtype=dtype, local_files_only=local_files_only
50
+ )
51
+ pipe = StableDiffusionControlNetPipeline.from_pretrained(
52
+ model_path, controlnet=controlnet, torch_dtype=dtype, local_files_only=local_files_only
53
+ )
54
+
55
+ # Load SatCLIP: from_pretrained (diffusers) or from_satclip_ckpt (.ckpt)
56
+ if _is_diffusers_dir(satclip_path):
57
+ location_encoder = SatCLIPLocationEncoder.from_pretrained(satclip_path)
58
+ else:
59
+ location_encoder = SatCLIPLocationEncoder.from_satclip_ckpt(satclip_path)
60
+ location_encoder = location_encoder.to(device).to(dtype)
61
+
62
+ # Load CoordNet: from_pretrained (diffusers) or from_geosynth_ckpt (.ckpt)
63
+ if _is_diffusers_dir(coordnet_path):
64
+ coordnet = GeoSynthCoordNetModel.from_pretrained(coordnet_path)
65
+ else:
66
+ coordnet = GeoSynthCoordNetModel.from_geosynth_ckpt(coordnet_path)
67
+ coordnet = coordnet.to(device).to(dtype)
68
+
69
+ wrapper = GeoSynthControlNetWithLocation(
70
+ controlnet=controlnet, coordnet=coordnet, location_encoder=location_encoder, unet=pipe.unet
71
+ )
72
+ wrapper = wrapper.to(device).to(dtype)
73
+ pipe.controlnet = wrapper
74
+ return pipe
75
+
76
+
77
+ def run_with_location(
78
+ pipe: StableDiffusionControlNetPipeline,
79
+ prompt: Union[str, list],
80
+ image: Any,
81
+ lon: float = -90.2,
82
+ lat: float = 38.6,
83
+ num_inference_steps: int = 20,
84
+ guidance_scale: float = 7.5,
85
+ generator: Optional[torch.Generator] = None,
86
+ controlnet_conditioning_scale: float = 1.0,
87
+ num_images_per_prompt: int = 1,
88
+ **kwargs,
89
+ ) -> Any:
90
+ """Run inference with location conditioning."""
91
+ if not hasattr(pipe.controlnet, "set_location_embedding"):
92
+ return pipe(prompt, image=image, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
93
+ generator=generator, controlnet_conditioning_scale=controlnet_conditioning_scale,
94
+ num_images_per_prompt=num_images_per_prompt, **kwargs)
95
+ device = getattr(pipe, "_execution_device", None) or next(pipe.unet.parameters()).device
96
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
97
+ n = batch_size * num_images_per_prompt
98
+ if guidance_scale > 1:
99
+ n *= 2
100
+ lonlat = torch.tensor([[lon, lat]], device=device, dtype=torch.float32).expand(n, 2)
101
+ with torch.no_grad():
102
+ loc_emb = pipe.controlnet._location_encoder(lonlat)
103
+ pipe.controlnet.set_location_embedding(loc_emb)
104
+ return pipe(prompt, image=image, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale,
105
+ generator=generator, controlnet_conditioning_scale=controlnet_conditioning_scale,
106
+ num_images_per_prompt=num_images_per_prompt, **kwargs)
geosynth_pipeline/satclip_location_encoder.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SatCLIP Location Encoder - diffusers-style model.
3
+ Encodes (lon, lat) -> 256-dim embedding via spherical harmonics + SIREN MLP.
4
+ """
5
+ import math
6
+ from typing import Optional, Union
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+
12
+ from .spherical_harmonics import SphericalHarmonicsEncoder
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+
16
+
17
+ def _exists(val):
18
+ return val is not None
19
+
20
+
21
+ class Sine(nn.Module):
22
+ def __init__(self, w0: float = 1.0):
23
+ super().__init__()
24
+ self.w0 = w0
25
+
26
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
27
+ return torch.sin(self.w0 * x)
28
+
29
+
30
+ class SirenLayer(nn.Module):
31
+ def __init__(self, dim_in: int, dim_out: int, w0: float = 1.0, c: float = 6.0, is_first: bool = False,
32
+ use_bias: bool = True, activation: Optional[nn.Module] = None, dropout: bool = False):
33
+ super().__init__()
34
+ self.dropout = dropout
35
+ weight = torch.zeros(dim_out, dim_in)
36
+ bias = torch.zeros(dim_out) if use_bias else None
37
+ w_std = (1 / dim_in) if is_first else (math.sqrt(c / dim_in) / w0)
38
+ weight.uniform_(-w_std, w_std)
39
+ if _exists(bias):
40
+ bias.uniform_(-w_std, w_std)
41
+ self.weight = nn.Parameter(weight)
42
+ self.bias = nn.Parameter(bias) if use_bias else None
43
+ self.activation = Sine(w0) if activation is None else activation
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ out = F.linear(x, self.weight, self.bias)
47
+ if self.dropout:
48
+ out = F.dropout(out, training=self.training)
49
+ return self.activation(out)
50
+
51
+
52
+ class SirenNet(nn.Module):
53
+ def __init__(self, dim_in: int, dim_hidden: int, dim_out: int, num_layers: int, w0: float = 1.0,
54
+ w0_initial: float = 30.0, use_bias: bool = True, dropout: bool = True):
55
+ super().__init__()
56
+ self.layers = nn.ModuleList()
57
+ for i in range(num_layers):
58
+ is_first = i == 0
59
+ layer_w0 = w0_initial if is_first else w0
60
+ layer_dim_in = dim_in if is_first else dim_hidden
61
+ self.layers.append(SirenLayer(layer_dim_in, dim_hidden, w0=layer_w0, use_bias=use_bias,
62
+ is_first=is_first, dropout=dropout))
63
+ self.last_layer = SirenLayer(dim_hidden, dim_out, w0=w0, use_bias=use_bias,
64
+ activation=nn.Identity(), dropout=False)
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+ for layer in self.layers:
68
+ x = layer(x)
69
+ return self.last_layer(x)
70
+
71
+
72
+ class SatCLIPLocationEncoder(ModelMixin, ConfigMixin):
73
+ """[lon, lat] -> 256-dim embedding. Load via from_pretrained() or from_satclip_ckpt()."""
74
+
75
+ @register_to_config
76
+ def __init__(self, legendre_polys: int = 10, embed_dim: int = 256, dim_hidden: int = 512, num_hidden_layers: int = 2):
77
+ super().__init__()
78
+ self.posenc = SphericalHarmonicsEncoder(legendre_polys=legendre_polys)
79
+ self.nnet = SirenNet(
80
+ dim_in=self.posenc.embedding_dim,
81
+ dim_hidden=dim_hidden,
82
+ dim_out=embed_dim,
83
+ num_layers=num_hidden_layers,
84
+ dropout=True,
85
+ )
86
+
87
+ def forward(self, lonlat: torch.Tensor) -> torch.Tensor:
88
+ if lonlat.dtype != torch.float64:
89
+ lonlat = lonlat.double()
90
+ x = self.posenc(lonlat).float()
91
+ x = x.to(next(self.nnet.parameters()).dtype)
92
+ return self.nnet(x)
93
+
94
+ @classmethod
95
+ def from_satclip_ckpt(cls, ckpt_path: str, device: Optional[Union[str, torch.device]] = None,
96
+ dtype: Optional[torch.dtype] = None) -> "SatCLIPLocationEncoder":
97
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
98
+ hp = ckpt.get("hyper_parameters", {})
99
+ state_dict = ckpt.get("state_dict", ckpt)
100
+ model = cls(
101
+ legendre_polys=hp.get("legendre_polys", 10),
102
+ embed_dim=hp.get("embed_dim", 256),
103
+ dim_hidden=hp.get("capacity", 512),
104
+ num_hidden_layers=hp.get("num_hidden_layers", 2),
105
+ )
106
+ nnet_sd = {}
107
+ for k, v in state_dict.items():
108
+ if "location" in k and "nnet" in k:
109
+ new_k = k[k.index("nnet") + 5:]
110
+ nnet_sd[new_k] = v
111
+ model.nnet.load_state_dict(nnet_sd, strict=True)
112
+ if device is not None:
113
+ model = model.to(device)
114
+ if dtype is not None:
115
+ model = model.to(dtype)
116
+ return model
geosynth_pipeline/spherical_harmonics.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Spherical harmonics encoding for (lon, lat) -> embedding.
3
+ Adapted from Microsoft SatCLIP: https://github.com/microsoft/satclip
4
+ """
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ def associated_legendre_polynomial(l: int, m: int, x: torch.Tensor) -> torch.Tensor:
11
+ """Associated Legendre polynomial P_l^m(x)."""
12
+ pmm = torch.ones_like(x)
13
+ if m > 0:
14
+ somx2 = torch.sqrt((1 - x) * (1 + x))
15
+ fact = 1.0
16
+ for i in range(1, m + 1):
17
+ pmm = pmm * (-fact) * somx2
18
+ fact += 2.0
19
+ if l == m:
20
+ return pmm
21
+ pmmp1 = x * (2.0 * m + 1.0) * pmm
22
+ if l == m + 1:
23
+ return pmmp1
24
+ pll = torch.zeros_like(x)
25
+ for ll in range(m + 2, l + 1):
26
+ pll = ((2.0 * ll - 1.0) * x * pmmp1 - (ll + m - 1.0) * pmm) / (ll - m)
27
+ pmm = pmmp1
28
+ pmmp1 = pll
29
+ return pll
30
+
31
+
32
+ def sh_renormalization(l: int, m: int) -> float:
33
+ return math.sqrt((2.0 * l + 1.0) * math.factorial(l - m) / (4 * math.pi * math.factorial(l + m)))
34
+
35
+
36
+ def sh_basis(m: int, l: int, phi: torch.Tensor, theta: torch.Tensor) -> torch.Tensor:
37
+ """Spherical harmonic Y_l^m(theta, phi). phi=lon, theta=lat (in rad)."""
38
+ if m == 0:
39
+ return sh_renormalization(l, m) * associated_legendre_polynomial(l, m, torch.cos(theta))
40
+ elif m > 0:
41
+ return (
42
+ math.sqrt(2.0)
43
+ * sh_renormalization(l, m)
44
+ * torch.cos(m * phi)
45
+ * associated_legendre_polynomial(l, m, torch.cos(theta))
46
+ )
47
+ else:
48
+ return (
49
+ math.sqrt(2.0)
50
+ * sh_renormalization(l, -m)
51
+ * torch.sin(-m * phi)
52
+ * associated_legendre_polynomial(l, -m, torch.cos(theta))
53
+ )
54
+
55
+
56
+ class SphericalHarmonicsEncoder(nn.Module):
57
+ """Encode (lon, lat) in degrees to spherical harmonics. Output dim = L * M with L=M=legendre_polys."""
58
+
59
+ def __init__(self, legendre_polys: int = 10):
60
+ super().__init__()
61
+ self.L = self.M = int(legendre_polys)
62
+ self.embedding_dim = self.L * self.M
63
+
64
+ def forward(self, lonlat: torch.Tensor) -> torch.Tensor:
65
+ lon, lat = lonlat[:, 0], lonlat[:, 1]
66
+ phi = torch.deg2rad(lon + 180)
67
+ theta = torch.deg2rad(lat + 90)
68
+ Y = []
69
+ for l in range(self.L):
70
+ for m in range(-l, l + 1):
71
+ y = sh_basis(m, l, phi, theta)
72
+ if isinstance(y, float):
73
+ y = y * torch.ones_like(phi)
74
+ Y.append(y)
75
+ return torch.stack(Y, dim=-1)
inference_demo.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ GeoSynth inference demo with full location conditioning.
4
+ Uses only paths relative to this repo — no external code or models.
5
+
6
+ Requirements:
7
+ - Run convert_ckpts.py first to populate satclip_location_encoder/ and coordnet/
8
+ - diffusers, torch, transformers, pillow
9
+ """
10
+ import argparse
11
+ import os
12
+ import sys
13
+
14
+ REPO_DIR = os.path.dirname(os.path.abspath(__file__))
15
+ sys.path.insert(0, REPO_DIR)
16
+
17
+ import torch
18
+ from PIL import Image
19
+ from geosynth_pipeline import load_geosynth_pipeline_with_location, run_with_location
20
+
21
+ DEFAULT_LON = -90.2
22
+ DEFAULT_LAT = 38.6
23
+
24
+
25
+ def main():
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument("--control", required=True, help="Path to control image (OSM or Canny)")
28
+ parser.add_argument("--model", default=REPO_DIR, help="Model path (default: this repo)")
29
+ parser.add_argument("--control_type", choices=["OSM", "Canny"], default="OSM")
30
+ parser.add_argument("--prompt", default="Satellite image features a city neighborhood")
31
+ parser.add_argument("--lon", type=float, default=DEFAULT_LON)
32
+ parser.add_argument("--lat", type=float, default=DEFAULT_LAT)
33
+ parser.add_argument("--seed", type=int, default=42)
34
+ parser.add_argument("--steps", type=int, default=20)
35
+ parser.add_argument("--out", default=None, help="Output path")
36
+ args = parser.parse_args()
37
+
38
+ # All paths relative to this repo
39
+ model_path = args.model if os.path.isabs(args.model) else REPO_DIR
40
+ satclip_path = os.path.join(REPO_DIR, "satclip_location_encoder")
41
+ coordnet_path = os.path.join(REPO_DIR, "coordnet")
42
+
43
+ if not os.path.isdir(satclip_path) or not os.path.isfile(os.path.join(satclip_path, "config.json")):
44
+ sys.exit("Run convert_ckpts.py first to create satclip_location_encoder/")
45
+ if not os.path.isdir(coordnet_path) or not os.path.isfile(os.path.join(coordnet_path, "config.json")):
46
+ sys.exit("Run convert_ckpts.py first to create coordnet/")
47
+
48
+ print("Loading pipeline (local paths only)...")
49
+ pipe = load_geosynth_pipeline_with_location(
50
+ model_path,
51
+ controlnet_subfolder=f"controlnet/GeoSynth-Location-{args.control_type}",
52
+ satclip_path=satclip_path,
53
+ coordnet_path=coordnet_path,
54
+ local_files_only=True,
55
+ )
56
+ pipe = pipe.to("cuda")
57
+
58
+ img = Image.open(args.control).convert("RGB").resize((512, 512))
59
+ generator = torch.manual_seed(args.seed)
60
+
61
+ print(f"Generating (lon={args.lon}, lat={args.lat})...")
62
+ output = run_with_location(
63
+ pipe, args.prompt, image=img, lon=args.lon, lat=args.lat,
64
+ num_inference_steps=args.steps, generator=generator,
65
+ )
66
+
67
+ out_path = args.out
68
+ if not out_path:
69
+ os.makedirs(os.path.join(REPO_DIR, "demo_outputs"), exist_ok=True)
70
+ base = os.path.splitext(os.path.basename(args.control))[0]
71
+ out_path = os.path.join(REPO_DIR, "demo_outputs", f"output_{base}_lon{args.lon}_lat{args.lat}.jpg")
72
+ output.images[0].save(out_path)
73
+ print(f"Saved: {out_path}")
74
+
75
+
76
+ if __name__ == "__main__":
77
+ main()
model_index.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "StableDiffusionPipeline",
3
+ "_diffusers_version": "0.37.0",
4
+ "feature_extractor": [
5
+ "transformers",
6
+ "CLIPImageProcessor"
7
+ ],
8
+ "image_encoder": [
9
+ null,
10
+ null
11
+ ],
12
+ "requires_safety_checker": false,
13
+ "safety_checker": [
14
+ null,
15
+ null
16
+ ],
17
+ "scheduler": [
18
+ "diffusers",
19
+ "DPMSolverMultistepScheduler"
20
+ ],
21
+ "text_encoder": [
22
+ "transformers",
23
+ "CLIPTextModel"
24
+ ],
25
+ "tokenizer": [
26
+ "transformers",
27
+ "CLIPTokenizer"
28
+ ],
29
+ "unet": [
30
+ "diffusers",
31
+ "UNet2DConditionModel"
32
+ ],
33
+ "vae": [
34
+ "diffusers",
35
+ "AutoencoderKL"
36
+ ]
37
+ }
satclip_location_encoder/config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "SatCLIPLocationEncoder",
3
+ "_diffusers_version": "0.37.0",
4
+ "dim_hidden": 512,
5
+ "embed_dim": 256,
6
+ "legendre_polys": 10,
7
+ "num_hidden_layers": 2
8
+ }
satclip_location_encoder/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a04fc24fd0edf58754130c1cfcd75a74274476a010c509b5ee23191c3b9bcb5
3
+ size 1783336
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DPMSolverMultistepScheduler",
3
+ "_diffusers_version": "0.37.0",
4
+ "algorithm_type": "dpmsolver++",
5
+ "beta_end": 0.085,
6
+ "beta_schedule": "scaled_linear",
7
+ "beta_start": 0.02,
8
+ "clip_sample": false,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "euler_at_final": false,
11
+ "final_sigmas_type": "zero",
12
+ "flow_shift": 1.0,
13
+ "interpolation_type": "linear",
14
+ "lambda_min_clipped": -Infinity,
15
+ "lower_order_final": true,
16
+ "num_train_timesteps": 1000,
17
+ "prediction_type": "v_prediction",
18
+ "rescale_betas_zero_snr": false,
19
+ "sample_max_value": 1.0,
20
+ "set_alpha_to_one": false,
21
+ "skip_prk_steps": true,
22
+ "solver_order": 2,
23
+ "solver_type": "midpoint",
24
+ "steps_offset": 1,
25
+ "thresholding": false,
26
+ "time_shift_type": "exponential",
27
+ "timestep_spacing": "leading",
28
+ "trained_betas": null,
29
+ "use_beta_sigmas": false,
30
+ "use_dynamic_shifting": false,
31
+ "use_exponential_sigmas": false,
32
+ "use_flow_sigmas": false,
33
+ "use_karras_sigmas": false,
34
+ "use_lu_lambdas": false,
35
+ "variance_type": null
36
+ }
test_output.jpeg ADDED
text_encoder/config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "CLIPTextModel"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 0,
7
+ "dropout": 0.0,
8
+ "dtype": "float32",
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_size": 1024,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 23,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 512,
22
+ "transformers_version": "5.3.0",
23
+ "vocab_size": 49408
24
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67e013543d4fac905c882e2993d86a2d454ee69dc9e8f37c0c23d33a48959d15
3
+ size 1361596304
tokenizer/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "backend": "tokenizers",
4
+ "bos_token": "<|startoftext|>",
5
+ "clean_up_tokenization_spaces": true,
6
+ "do_lower_case": true,
7
+ "eos_token": "<|endoftext|>",
8
+ "errors": "replace",
9
+ "is_local": true,
10
+ "model_max_length": 77,
11
+ "pad_token": "!",
12
+ "tokenizer_class": "CLIPTokenizer",
13
+ "unk_token": "<|endoftext|>"
14
+ }
unet/config.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNet2DConditionModel",
3
+ "_diffusers_version": "0.37.0",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": null,
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": null,
8
+ "attention_head_dim": [
9
+ 5,
10
+ 10,
11
+ 20,
12
+ 20
13
+ ],
14
+ "attention_type": "default",
15
+ "block_out_channels": [
16
+ 320,
17
+ 640,
18
+ 1280,
19
+ 1280
20
+ ],
21
+ "center_input_sample": false,
22
+ "class_embed_type": null,
23
+ "class_embeddings_concat": false,
24
+ "conv_in_kernel": 3,
25
+ "conv_out_kernel": 3,
26
+ "cross_attention_dim": 1024,
27
+ "cross_attention_norm": null,
28
+ "down_block_types": [
29
+ "CrossAttnDownBlock2D",
30
+ "CrossAttnDownBlock2D",
31
+ "CrossAttnDownBlock2D",
32
+ "DownBlock2D"
33
+ ],
34
+ "downsample_padding": 1,
35
+ "dropout": 0.0,
36
+ "dual_cross_attention": false,
37
+ "encoder_hid_dim": null,
38
+ "encoder_hid_dim_type": null,
39
+ "flip_sin_to_cos": true,
40
+ "freq_shift": 0,
41
+ "in_channels": 4,
42
+ "layers_per_block": 2,
43
+ "mid_block_only_cross_attention": null,
44
+ "mid_block_scale_factor": 1,
45
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
46
+ "norm_eps": 1e-05,
47
+ "norm_num_groups": 32,
48
+ "num_attention_heads": null,
49
+ "num_class_embeds": null,
50
+ "only_cross_attention": false,
51
+ "out_channels": 4,
52
+ "projection_class_embeddings_input_dim": null,
53
+ "resnet_out_scale_factor": 1.0,
54
+ "resnet_skip_time_act": false,
55
+ "resnet_time_scale_shift": "default",
56
+ "reverse_transformer_layers_per_block": null,
57
+ "sample_size": 96,
58
+ "time_cond_proj_dim": null,
59
+ "time_embedding_act_fn": null,
60
+ "time_embedding_dim": null,
61
+ "time_embedding_type": "positional",
62
+ "timestep_post_act": null,
63
+ "transformer_layers_per_block": 1,
64
+ "up_block_types": [
65
+ "UpBlock2D",
66
+ "CrossAttnUpBlock2D",
67
+ "CrossAttnUpBlock2D",
68
+ "CrossAttnUpBlock2D"
69
+ ],
70
+ "upcast_attention": true,
71
+ "use_linear_projection": true
72
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f349f136057c7ddefbf454ba56b1f9a84ae15436b9b2e677a6590aa966f1db64
3
+ size 3463726504
vae/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.37.0",
4
+ "_name_or_path": "/home/s.sastry/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1-base/snapshots/5ede9e4bf3e3fd1cb0ef2f7a3fff13ee514fdf06/vae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 256,
9
+ 512,
10
+ 512
11
+ ],
12
+ "down_block_types": [
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D"
17
+ ],
18
+ "force_upcast": true,
19
+ "in_channels": 3,
20
+ "latent_channels": 4,
21
+ "latents_mean": null,
22
+ "latents_std": null,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "sample_size": 768,
28
+ "scaling_factor": 0.18215,
29
+ "shift_factor": null,
30
+ "up_block_types": [
31
+ "UpDecoderBlock2D",
32
+ "UpDecoderBlock2D",
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D"
35
+ ],
36
+ "use_post_quant_conv": true,
37
+ "use_quant_conv": true
38
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2aa1f43011b553a4cba7f37456465cdbd48aab7b54b9348b890e8058ea7683ec
3
+ size 334643268