PubAccount commited on
Commit
12a114a
ยท
verified ยท
1 Parent(s): 47ab827

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -98
app.py CHANGED
@@ -12,7 +12,7 @@ except ImportError:
12
  def GPU(duration=120):
13
  return lambda fn: fn
14
 
15
- import os, io
16
  import torch
17
  import numpy as np
18
  import matplotlib; matplotlib.use("Agg")
@@ -25,30 +25,29 @@ from peft import PeftModel
25
  import gradio as gr
26
  import safetensors.torch
27
  import warnings
28
- import asyncio
29
 
30
- # ๅฟฝ็•ฅ asyncio ไบ‹ไปถๅพช็Žฏๆžๆž„ๆ—ถ็š„ ResourceWarning
31
  warnings.filterwarnings("ignore", category=ResourceWarning)
32
 
33
  from networks.semantic_head import SemanticHead
34
  from networks.height_head import HeightHead
35
  from networks.decoder import Decoder
36
 
 
37
  def fix_lora_state_dict(state_dict: dict) -> dict:
38
  """ๆŠŠๆ—ง็‰ˆ Linear proj_in/proj_out ็š„ 2D LoRA ๆƒ้‡ๅ‡็ปดๅˆฐ Conv2d ๆ‰€้œ€็š„ 4D"""
39
  fixed = {}
40
  for k, v in state_dict.items():
41
  if ("proj_in" in k or "proj_out" in k) and v.ndim == 2:
42
- v = v.unsqueeze(-1).unsqueeze(-1) # (out, in) โ†’ (out, in, 1, 1)
43
  fixed[k] = v
44
  return fixed
45
 
 
46
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
47
  # ๅธธ้‡ & ้…็ฝฎ
48
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
49
  RGB_LATENT_SCALE = 0.18215
50
 
51
- # ้€š่ฟ‡็Žฏๅขƒๅ˜้‡ๅฏ่ฆ†็›–๏ผŒๅฆๅˆ™ไฝฟ็”จ้ป˜่ฎค HF Repo ID
52
  SD_MODEL_ID = os.environ.get("SD_MODEL_ID", "sd-research/stable-diffusion-2-1-base")
53
  ADAPTOR_REPO = os.environ.get("ADAPTOR_MODEL_ID", "UEXdo/HeightAdaptor-weight")
54
 
@@ -68,26 +67,26 @@ LABEL_COLORS = {
68
  },
69
  }
70
 
 
71
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
72
- # ๅฏๅŠจๆ—ถไธ‹่ฝฝ Adaptor ๆƒ้‡๏ผˆ็ผ“ๅญ˜ๅˆฐๆœฌๅœฐ๏ผŒๅŽ็ปญๆ— ้œ€้‡ๅคไธ‹่ฝฝ๏ผ‰
73
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
74
  print(f"๐Ÿ“ฆ Downloading adaptor weights from {ADAPTOR_REPO} ...")
75
  ADAPTOR_DIR = snapshot_download(repo_id=ADAPTOR_REPO)
76
  print(f"โœ… Weights cached at: {ADAPTOR_DIR}")
77
 
 
78
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
79
- # ๆจกๅž‹็ฎก็†๏ผˆไธป่ฟ›็จ‹็ปดๆŠค CPU ๆจกๅž‹๏ผŒGPU ๅญ่ฟ›็จ‹ copy-on-use๏ผ‰
80
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
81
  _model = None
82
- _model_key = None # (dataset_name, h_type)
83
 
84
 
85
- def build_model(dataset_name: str, h_type: str) -> StableDiffusionPipeline:
86
- """ไปŽ HF Hub ๆ‹‰ๅ–ๅŸบ็ก€ๆจกๅž‹๏ผŒๅ ๅŠ  LoRA + ไธ‰ไธช่‡ชๅฎšไน‰ Head๏ผŒ่ฟ”ๅ›ž CPU ๆจกๅž‹ใ€‚"""
87
  classes_num = DATASET_CFG[dataset_name]["classes_num"]
88
  print(f"๐Ÿ”ง Building model โ€” dataset={dataset_name}, h_type={h_type}")
89
 
90
- # 1. ๅŠ ่ฝฝ SD v1.5 ๅŸบ็ก€ Pipeline
91
  pipe = StableDiffusionPipeline.from_pretrained(
92
  SD_MODEL_ID,
93
  torch_dtype=torch.float32,
@@ -95,8 +94,6 @@ def build_model(dataset_name: str, h_type: str) -> StableDiffusionPipeline:
95
  requires_safety_checker=False,
96
  )
97
 
98
- # 2. ็”จ PEFT ๆŠŠ LoRA ๆƒ้‡ๆณจๅ…ฅ UNet
99
- # ๅฐ่ฏ•ๅŠ ่ฝฝ safetensors ๆˆ– pytorch bin
100
  lora_path = os.path.join(ADAPTOR_DIR, "lora")
101
  ckpt_file = os.path.join(lora_path, "adapter_model.safetensors")
102
  if os.path.exists(ckpt_file):
@@ -107,23 +104,20 @@ def build_model(dataset_name: str, h_type: str) -> StableDiffusionPipeline:
107
  os.path.join(lora_path, "adapter_model.bin"),
108
  map_location="cpu"
109
  )
110
-
111
- fixed_sd = fix_lora_state_dict(raw_sd)
112
  pipe.unet = PeftModel.from_pretrained(pipe.unet, lora_path)
113
 
114
- # 3. ๅŠ ่ฝฝ Decoder
115
  pipe.decoder = Decoder(in_channel=320)
116
  pipe.decoder.load_state_dict(
117
  torch.load(os.path.join(ADAPTOR_DIR, "decoder.pth"), map_location="cpu"))
118
  pipe.decoder.eval()
119
 
120
- # 4. ๅŠ ่ฝฝ HeightHead
121
  pipe.height_head = HeightHead(in_channels=192, h_type=h_type)
122
  pipe.height_head.load_state_dict(
123
  torch.load(os.path.join(ADAPTOR_DIR, "height_head.pth"), map_location="cpu"))
124
  pipe.height_head.eval()
125
 
126
- # 5. ๅŠ ่ฝฝ SemanticHead๏ผˆ็ฑปๅˆซๆ•ฐ็”ฑ dataset ๅ†ณๅฎš๏ผ‰
127
  pipe.semantic_head = SemanticHead(in_channels=192, num_classes=classes_num)
128
  pipe.semantic_head.load_state_dict(
129
  torch.load(os.path.join(ADAPTOR_DIR, "semantic_head.pth"), map_location="cpu"))
@@ -134,11 +128,6 @@ def build_model(dataset_name: str, h_type: str) -> StableDiffusionPipeline:
134
 
135
 
136
  def reload_model(dataset_name: str, h_type: str) -> str:
137
- """
138
- ๅœจไธป่ฟ›็จ‹ไธญ้‡ๅปบๆจกๅž‹๏ผŒไพ› Gradio ๆŒ‰้’ฎ่ฐƒ็”จใ€‚
139
- ๆณจๆ„๏ผšๆญคๅ‡ฝๆ•ฐ **ไธๅŠ ** @spaces.GPU๏ผŒ็›ดๆŽฅ่ฟ่กŒๅœจไธป่ฟ›็จ‹๏ผŒ
140
- ๅ…จๅฑ€ _model ๆ›ดๆ–ฐๅŽ๏ผŒไธ‹ไธ€ๆฌก @spaces.GPU ่ฐƒ็”จไผš fork ๅˆฐๆ–ฐๆจกๅž‹ใ€‚
141
- """
142
  global _model, _model_key
143
  key = (dataset_name, h_type)
144
  if _model is not None and _model_key == key:
@@ -153,11 +142,9 @@ reload_model("OpenDC", "ER")
153
 
154
 
155
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
156
- # VAE / UNet forward๏ผˆ็งป้™คไบ† DistributedDataParallel ๅˆ†ๆ”ฏ๏ผŒ
157
- # Spaces ๅ•ๅกๅœบๆ™ฏไธ้œ€่ฆ๏ผ‰
158
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
159
  def _vae_encode(pipe, x: torch.Tensor):
160
- """้€š่ฟ‡ VAE Encoder ๅ‰ๅ‘๏ผŒ่ฟ”ๅ›ž (ๆœ€็ปˆ็‰นๅพ, ไธญ้—ด็‰นๅพๅˆ—่กจ)ใ€‚"""
161
  enc = pipe.vae.encoder
162
  x = enc.conv_in(x)
163
  feats = []
@@ -168,7 +155,7 @@ def _vae_encode(pipe, x: torch.Tensor):
168
  x = enc.conv_norm_out(x)
169
  x = enc.conv_act(x)
170
  x = enc.conv_out(x)
171
- return x, feats[:-1] # ไธŽๅŽŸๅง‹ไปฃ็ ไธ€่‡ด๏ผŒไธขๅผƒๆœ€ๅŽไธ€ๅฑ‚็‰นๅพ
172
 
173
 
174
  def _unet_forward(unet, sample, timestep, enc_hs):
@@ -194,17 +181,74 @@ def _unet_forward(unet, sample, timestep, enc_hs):
194
 
195
 
196
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
197
- # GPU ๆŽจ็†๏ผˆ็”จ @spaces.GPU ่ฃ…้ฅฐ๏ผŒ็”ณ่ฏทๆœ€ๅคš 120s GPU๏ผ‰
198
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
199
- @spaces.GPU(duration=120)
200
  @torch.no_grad()
201
- def run_inference(
202
- image: Image.Image,
203
- task: str,
204
- dataset_name: str,
205
- h_type: str,
206
- mode_type: str,
207
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  if image is None:
209
  return None, "โš ๏ธ Please upload an image first."
210
  if _model is None:
@@ -212,71 +256,57 @@ def run_inference(
212
 
213
  device = "cuda"
214
  pipe = _model
215
- pipe.to(device) # ZeroGPU ๅญ่ฟ›็จ‹ๆ‹ฟๅˆฐ CPU ๅ‰ฏๆœฌๅŽ็งปๅˆฐ GPU
216
 
217
  try:
218
- # โ”€โ”€ 1. ๆ–‡ๆœฌ็ผ–็  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
219
- tokens = pipe.tokenizer(
220
- "", padding="max_length", truncation=True,
221
- max_length=pipe.tokenizer.model_max_length, return_tensors="pt")
222
- text_emb = pipe.text_encoder(tokens.input_ids.to(device))[0].float()
223
- # text_emb: [1, 77, 768] (SD v1.5 ็š„ text dim ไธบ 768)
224
-
225
- # โ”€โ”€ 2. ๅ›พๅƒ้ข„ๅค„็† โ†’ [1, 3, 512, 512] โˆˆ [-1, 1] โ”€โ”€โ”€โ”€โ”€โ”€
226
- img = image.convert("RGB").resize((512, 512), Image.BILINEAR)
227
- arr = np.array(img, dtype=np.float32).transpose(2, 0, 1)
228
- norm = (torch.from_numpy(arr) / 255.0 * 2.0 - 1.0).unsqueeze(0).to(device)
229
-
230
- # โ”€โ”€ 3. VAE ็ผ–็  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
231
- h, h_list = _vae_encode(pipe, norm)
232
- moments = pipe.vae.quant_conv(h)
233
- mean, lv = torch.chunk(moments, 2, dim=1)
234
- latents = (mean + torch.exp(0.5 * lv) * torch.randn_like(mean)) * RGB_LATENT_SCALE
235
-
236
- # โ”€โ”€ 4. UNet + ่‡ชๅฎšไน‰ Decoder โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
237
- ts = torch.ones([latents.shape[0]], device=device) * 999
238
- unet_o = _unet_forward(pipe.unet, latents, ts, text_emb)
239
- dec_o = pipe.decoder(unet_o, res_list=h_list[::-1])
240
-
241
- # โ”€โ”€ 5. ไปปๅŠก Head โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
242
- h_out = pipe.height_head(dec_o)
243
- s_out = pipe.semantic_head(dec_o)
244
-
245
- # โ”€โ”€ 6. ๅŽๅค„็† & ๅฏ่ง†ๅŒ– โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
246
- if mode_type == "Height Map":
247
- pred = F.interpolate(h_out[0].cpu(), (512, 512),
248
- mode="bilinear", align_corners=False)
249
- pred = ((pred + 1.0) / 2.0).clamp(0, 1).squeeze().numpy()
250
-
251
- fig, ax = plt.subplots(figsize=(6, 5), tight_layout=True)
252
- im = ax.imshow(pred, cmap="plasma")
253
- fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
254
- ax.set_title("Predicted Height Map"); ax.axis("off")
255
- buf = io.BytesIO()
256
- fig.savefig(buf, format="png", dpi=150)
257
- plt.close(fig); buf.seek(0)
258
- out_img = Image.open(buf).copy()
259
- info = (f"Normalized range: [{pred.min():.4f}, {pred.max():.4f}]\n"
260
- "(0 โ‰ˆ 0 m, 1 โ‰ˆ 50 m before denormalization)")
261
-
262
- else: # Semantic Map
263
- pred = F.interpolate(s_out, (512, 512), mode="bilinear", align_corners=False)
264
- argmax = torch.argmax(pred, dim=1).squeeze().cpu().numpy()
265
- canvas = np.zeros((512, 512, 3), dtype=np.uint8)
266
- for lbl, col in LABEL_COLORS[dataset_name].items():
267
- canvas[argmax == lbl] = col
268
- out_img = Image.fromarray(canvas)
269
- info = f"Detected class indices: {np.unique(argmax).tolist()}"
270
-
271
- return out_img, info
272
-
273
  finally:
274
- # ZeroGPU ๅญ่ฟ›็จ‹็ป“ๆŸๅŽ GPU ๅ†…ๅญ˜่‡ชๅŠจ้‡Šๆ”พ๏ผŒ
275
- # ่ฟ™้‡Œๆ˜พๅผ็งปๅ›ž CPU ๅชๆ˜ฏ้ขๅค–ไฟ้™ฉ
276
  pipe.to("cpu")
277
  torch.cuda.empty_cache()
278
 
279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
281
  # Gradio UI
282
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
@@ -284,12 +314,10 @@ with gr.Blocks(title="HeightAdaptor") as demo:
284
  gr.Markdown("""
285
  # ๐Ÿ™๏ธ HeightAdaptor
286
  **Remote Sensing Image โ†’ Height Map / Semantic Segmentation**
287
-
288
  Backbone: `stable-diffusion-v1-5` + LoRA adaptor (`UEXdo/HeightAdaptor-weight`) + ่‡ชๅฎšไน‰ Task Heads
289
  """)
290
 
291
  with gr.Row():
292
- # โ”€โ”€ ๅทฆๆ ๏ผš่พ“ๅ…ฅ & ้…็ฝฎ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
293
  with gr.Column(scale=1):
294
  inp_img = gr.Image(type="pil", label="๐Ÿ“ท Input RGB Image")
295
 
@@ -313,7 +341,6 @@ with gr.Blocks(title="HeightAdaptor") as demo:
313
 
314
  run_btn = gr.Button("๐Ÿš€ Run Inference", variant="primary", size="lg")
315
 
316
- # โ”€โ”€ ๅณๆ ๏ผš่พ“ๅ‡บ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
317
  with gr.Column(scale=1):
318
  out_img = gr.Image(type="pil", label="๐Ÿ“Š Output")
319
  out_info = gr.Textbox(label="โ„น๏ธ Info", interactive=False, lines=3)
@@ -324,7 +351,6 @@ with gr.Blocks(title="HeightAdaptor") as demo:
324
  > ๅ›พๅƒไผš่‡ชๅŠจ็ผฉๆ”พ่‡ณ 512 ร— 512๏ผŒGPU ๆŽจ็†็บฆ้œ€ 10โ€“30 ็ง’ใ€‚
325
  """)
326
 
327
- # โ”€โ”€ ไบ‹ไปถ็ป‘ๅฎš โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
328
  load_btn.click(
329
  fn=reload_model,
330
  inputs=[dataset_radio, h_type_radio],
 
12
  def GPU(duration=120):
13
  return lambda fn: fn
14
 
15
+ import os, io, traceback # โ† ๆ–ฐๅขž traceback
16
  import torch
17
  import numpy as np
18
  import matplotlib; matplotlib.use("Agg")
 
25
  import gradio as gr
26
  import safetensors.torch
27
  import warnings
 
28
 
 
29
  warnings.filterwarnings("ignore", category=ResourceWarning)
30
 
31
  from networks.semantic_head import SemanticHead
32
  from networks.height_head import HeightHead
33
  from networks.decoder import Decoder
34
 
35
+
36
  def fix_lora_state_dict(state_dict: dict) -> dict:
37
  """ๆŠŠๆ—ง็‰ˆ Linear proj_in/proj_out ็š„ 2D LoRA ๆƒ้‡ๅ‡็ปดๅˆฐ Conv2d ๆ‰€้œ€็š„ 4D"""
38
  fixed = {}
39
  for k, v in state_dict.items():
40
  if ("proj_in" in k or "proj_out" in k) and v.ndim == 2:
41
+ v = v.unsqueeze(-1).unsqueeze(-1)
42
  fixed[k] = v
43
  return fixed
44
 
45
+
46
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
47
  # ๅธธ้‡ & ้…็ฝฎ
48
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
49
  RGB_LATENT_SCALE = 0.18215
50
 
 
51
  SD_MODEL_ID = os.environ.get("SD_MODEL_ID", "sd-research/stable-diffusion-2-1-base")
52
  ADAPTOR_REPO = os.environ.get("ADAPTOR_MODEL_ID", "UEXdo/HeightAdaptor-weight")
53
 
 
67
  },
68
  }
69
 
70
+
71
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
72
+ # ไธ‹่ฝฝ Adaptor ๆƒ้‡
73
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
74
  print(f"๐Ÿ“ฆ Downloading adaptor weights from {ADAPTOR_REPO} ...")
75
  ADAPTOR_DIR = snapshot_download(repo_id=ADAPTOR_REPO)
76
  print(f"โœ… Weights cached at: {ADAPTOR_DIR}")
77
 
78
+
79
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
80
+ # ๆจกๅž‹็ฎก็†
81
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
82
  _model = None
83
+ _model_key = None
84
 
85
 
86
+ def build_model(dataset_name: str, h_type: str):
 
87
  classes_num = DATASET_CFG[dataset_name]["classes_num"]
88
  print(f"๐Ÿ”ง Building model โ€” dataset={dataset_name}, h_type={h_type}")
89
 
 
90
  pipe = StableDiffusionPipeline.from_pretrained(
91
  SD_MODEL_ID,
92
  torch_dtype=torch.float32,
 
94
  requires_safety_checker=False,
95
  )
96
 
 
 
97
  lora_path = os.path.join(ADAPTOR_DIR, "lora")
98
  ckpt_file = os.path.join(lora_path, "adapter_model.safetensors")
99
  if os.path.exists(ckpt_file):
 
104
  os.path.join(lora_path, "adapter_model.bin"),
105
  map_location="cpu"
106
  )
107
+
108
+ fixed_sd = fix_lora_state_dict(raw_sd) # noqa: F841๏ผˆไฟฎๅคๅŽๆš‚ๅญ˜๏ผŒPeftModel ไผš่ฏปๆ–‡ไปถ๏ผ‰
109
  pipe.unet = PeftModel.from_pretrained(pipe.unet, lora_path)
110
 
 
111
  pipe.decoder = Decoder(in_channel=320)
112
  pipe.decoder.load_state_dict(
113
  torch.load(os.path.join(ADAPTOR_DIR, "decoder.pth"), map_location="cpu"))
114
  pipe.decoder.eval()
115
 
 
116
  pipe.height_head = HeightHead(in_channels=192, h_type=h_type)
117
  pipe.height_head.load_state_dict(
118
  torch.load(os.path.join(ADAPTOR_DIR, "height_head.pth"), map_location="cpu"))
119
  pipe.height_head.eval()
120
 
 
121
  pipe.semantic_head = SemanticHead(in_channels=192, num_classes=classes_num)
122
  pipe.semantic_head.load_state_dict(
123
  torch.load(os.path.join(ADAPTOR_DIR, "semantic_head.pth"), map_location="cpu"))
 
128
 
129
 
130
  def reload_model(dataset_name: str, h_type: str) -> str:
 
 
 
 
 
131
  global _model, _model_key
132
  key = (dataset_name, h_type)
133
  if _model is not None and _model_key == key:
 
142
 
143
 
144
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
145
+ # VAE / UNet forward
 
146
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
147
  def _vae_encode(pipe, x: torch.Tensor):
 
148
  enc = pipe.vae.encoder
149
  x = enc.conv_in(x)
150
  feats = []
 
155
  x = enc.conv_norm_out(x)
156
  x = enc.conv_act(x)
157
  x = enc.conv_out(x)
158
+ return x, feats[:-1]
159
 
160
 
161
  def _unet_forward(unet, sample, timestep, enc_hs):
 
181
 
182
 
183
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
184
+ # ๆ ธๅฟƒๆŽจ็†้€ป่พ‘๏ผˆไธŽ @spaces.GPU ่งฃ่€ฆ๏ผŒๅฏ็‹ฌ็ซ‹็”จ CPU ๆต‹่ฏ•๏ผ‰
185
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
 
186
  @torch.no_grad()
187
+ def _run_inference_core(pipe, device, image, task, dataset_name, h_type, mode_type):
188
+ """
189
+ ็บฏๆŽจ็†้€ป่พ‘๏ผŒไธไพ่ต– @spaces.GPUใ€‚
190
+ pipe ๅ’Œๆ‰€ๆœ‰ tensor ๅฟ…้กปๅทฒ็ปๅœจๅŒไธ€ไธช device ไธŠใ€‚
191
+ """
192
+ # โ”€โ”€ 1. ๆ–‡ๆœฌ็ผ–็  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
193
+ tokens = pipe.tokenizer(
194
+ "", padding="max_length", truncation=True,
195
+ max_length=pipe.tokenizer.model_max_length, return_tensors="pt")
196
+ text_emb = pipe.text_encoder(tokens.input_ids.to(device))[0].float()
197
+
198
+ # โ”€โ”€ 2. ๅ›พๅƒ้ข„ๅค„็† โ†’ [1, 3, 512, 512] โˆˆ [-1, 1] โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
199
+ img = image.convert("RGB").resize((512, 512), Image.BILINEAR)
200
+ arr = np.array(img, dtype=np.float32).transpose(2, 0, 1)
201
+ norm = (torch.from_numpy(arr) / 255.0 * 2.0 - 1.0).unsqueeze(0).to(device)
202
+
203
+ # โ”€โ”€ 3. VAE ็ผ–็  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
204
+ h, h_list = _vae_encode(pipe, norm)
205
+ moments = pipe.vae.quant_conv(h)
206
+ mean, lv = torch.chunk(moments, 2, dim=1)
207
+ latents = (mean + torch.exp(0.5 * lv) * torch.randn_like(mean)) * RGB_LATENT_SCALE
208
+
209
+ # โ”€โ”€ 4. UNet + ่‡ชๅฎšไน‰ Decoder โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
210
+ ts = torch.ones([latents.shape[0]], device=device) * 999
211
+ unet_o = _unet_forward(pipe.unet, latents, ts, text_emb)
212
+ dec_o = pipe.decoder(unet_o, res_list=h_list[::-1])
213
+
214
+ # โ”€โ”€ 5. ไปปๅŠก Head โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
215
+ h_out = pipe.height_head(dec_o)
216
+ s_out = pipe.semantic_head(dec_o)
217
+
218
+ # โ”€โ”€ 6. ๅŽๅค„็† & ๅฏ่ง†ๅŒ– โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
219
+ if mode_type == "Height Map":
220
+ pred = F.interpolate(h_out[0].cpu(), (512, 512),
221
+ mode="bilinear", align_corners=False)
222
+ pred = ((pred + 1.0) / 2.0).clamp(0, 1).squeeze().numpy()
223
+
224
+ fig, ax = plt.subplots(figsize=(6, 5), tight_layout=True)
225
+ im = ax.imshow(pred, cmap="plasma")
226
+ fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
227
+ ax.set_title("Predicted Height Map"); ax.axis("off")
228
+ buf = io.BytesIO()
229
+ fig.savefig(buf, format="png", dpi=150)
230
+ plt.close(fig); buf.seek(0)
231
+ out_img = Image.open(buf).copy()
232
+ info = (f"Normalized range: [{pred.min():.4f}, {pred.max():.4f}]\n"
233
+ "(0 โ‰ˆ 0 m, 1 โ‰ˆ 50 m before denormalization)")
234
+
235
+ else: # Semantic Map
236
+ pred = F.interpolate(s_out, (512, 512), mode="bilinear", align_corners=False)
237
+ argmax = torch.argmax(pred, dim=1).squeeze().cpu().numpy()
238
+ canvas = np.zeros((512, 512, 3), dtype=np.uint8)
239
+ for lbl, col in LABEL_COLORS[dataset_name].items():
240
+ canvas[argmax == lbl] = col
241
+ out_img = Image.fromarray(canvas)
242
+ info = f"Detected class indices: {np.unique(argmax).tolist()}"
243
+
244
+ return out_img, info
245
+
246
+
247
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
248
+ # GPU ๆŽจ็†ๅ…ฅๅฃ๏ผˆGradio ๆŒ‰้’ฎ่งฆๅ‘๏ผ‰
249
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
250
+ @spaces.GPU(duration=120)
251
+ def run_inference(image, task, dataset_name, h_type, mode_type):
252
  if image is None:
253
  return None, "โš ๏ธ Please upload an image first."
254
  if _model is None:
 
256
 
257
  device = "cuda"
258
  pipe = _model
259
+ pipe.to(device)
260
 
261
  try:
262
+ return _run_inference_core(pipe, device, image, task, dataset_name, h_type, mode_type)
263
+ except Exception as e:
264
+ traceback.print_exc() # โ† ็ปˆ็ซฏๆ‰“ๅฎŒๆ•ด stack trace
265
+ return None, f"โŒ Inference error: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  finally:
 
 
267
  pipe.to("cpu")
268
  torch.cuda.empty_cache()
269
 
270
 
271
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
272
+ # โ˜… ๅฏๅŠจๆต‹่ฏ•๏ผš็”จ Demo1.png ๅœจ CPU ไธŠ่ท‘ไธ€ๆฌกๅฎŒๆ•ดๆŽจ็†
273
+ # ๆˆๅŠŸ โ†’ ๆ‰“ๅฐ็ป“ๆžœ่Œƒๅ›ด๏ผŒๅนถๆŠŠ่พ“ๅ‡บๅ›พๅญ˜ๅˆฐ Demo1_result.png
274
+ # ๅคฑ่ดฅ โ†’ ๆ‰“ๅฐๅฎŒๆ•ด traceback๏ผŒๆ–นไพฟๅฎšไฝ้”™่ฏฏ
275
+ # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
276
+ _DEMO_IMG_PATH = "Demo1.png"
277
+ print(f"\n{'='*60}")
278
+ print(f"๐Ÿงช Startup inference test โ€” {_DEMO_IMG_PATH} (device=cpu)")
279
+ print(f"{'='*60}")
280
+ try:
281
+ if not os.path.exists(_DEMO_IMG_PATH):
282
+ print(f"โš ๏ธ {_DEMO_IMG_PATH} not found, skipping test.")
283
+ else:
284
+ _test_img = Image.open(_DEMO_IMG_PATH)
285
+ print(f" Image size : {_test_img.size}, mode: {_test_img.mode}")
286
+
287
+ # ๆŠŠๆจกๅž‹็ป„ไปถ็งปๅˆฐ CPU๏ผˆๆญคๆ—ถๆœฌๆฅๅฐฑๅœจ CPU๏ผŒไป…ๅšๆ˜พๅผ็กฎ่ฎค๏ผ‰
288
+ _model.to("cuda")
289
+
290
+ _out_img, _info = _run_inference_core(
291
+ _model, "cuda",
292
+ _test_img,
293
+ "Height Estimation", # task
294
+ "OpenDC", # dataset_name
295
+ "ER", # h_type
296
+ "Height Map", # mode_type
297
+ )
298
+ _out_img.save("Demo1_result.png")
299
+ print(f"โœ… Test PASSED")
300
+ print(f" Info : {_info}")
301
+ print(f" Saved to : Demo1_result.png")
302
+
303
+ except Exception:
304
+ print("โŒ Test FAILED โ€” full traceback below:")
305
+ traceback.print_exc()
306
+
307
+ print(f"{'='*60}\n")
308
+
309
+
310
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
311
  # Gradio UI
312
  # โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
 
314
  gr.Markdown("""
315
  # ๐Ÿ™๏ธ HeightAdaptor
316
  **Remote Sensing Image โ†’ Height Map / Semantic Segmentation**
 
317
  Backbone: `stable-diffusion-v1-5` + LoRA adaptor (`UEXdo/HeightAdaptor-weight`) + ่‡ชๅฎšไน‰ Task Heads
318
  """)
319
 
320
  with gr.Row():
 
321
  with gr.Column(scale=1):
322
  inp_img = gr.Image(type="pil", label="๐Ÿ“ท Input RGB Image")
323
 
 
341
 
342
  run_btn = gr.Button("๐Ÿš€ Run Inference", variant="primary", size="lg")
343
 
 
344
  with gr.Column(scale=1):
345
  out_img = gr.Image(type="pil", label="๐Ÿ“Š Output")
346
  out_info = gr.Textbox(label="โ„น๏ธ Info", interactive=False, lines=3)
 
351
  > ๅ›พๅƒไผš่‡ชๅŠจ็ผฉๆ”พ่‡ณ 512 ร— 512๏ผŒGPU ๆŽจ็†็บฆ้œ€ 10โ€“30 ็ง’ใ€‚
352
  """)
353
 
 
354
  load_btn.click(
355
  fn=reload_model,
356
  inputs=[dataset_radio, h_type_radio],