""" app.py — HeightAdaptor Hugging Face Spaces App Backbone : sd-research/stable-diffusion-2-1-base Adaptor : UEXdo/HeightAdaptor-weight Outputs : Height Map (2D) | Semantic Map | 3D Height Surface | 3D Height + RGB Texture """ # ── ZeroGPU compatibility(无 spaces 库时自动降级)───────────────────── try: import spaces except ImportError: class spaces: @staticmethod def GPU(duration=120): return lambda fn: fn import os, io, traceback import torch import numpy as np import matplotlib; matplotlib.use("Agg") import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # noqa: F401 — 注册 '3d' projection from PIL import Image from torch.nn import functional as F from diffusers import StableDiffusionPipeline from huggingface_hub import snapshot_download from peft import PeftModel import gradio as gr import safetensors.torch import warnings warnings.filterwarnings("ignore", category=ResourceWarning) from networks.semantic_head import SemanticHead from networks.height_head import HeightHead from networks.decoder import Decoder # ══════════════════════════════════════════════════════════════ # 工具函数 # ══════════════════════════════════════════════════════════════ def fix_lora_state_dict(state_dict: dict) -> dict: """把旧版 Linear proj_in/proj_out 的 2D LoRA 权重升维到 Conv2d 所需的 4D""" fixed = {} for k, v in state_dict.items(): if ("proj_in" in k or "proj_out" in k) and v.ndim == 2: v = v.unsqueeze(-1).unsqueeze(-1) fixed[k] = v return fixed # ══════════════════════════════════════════════════════════════ # 常量 & 配置 # ══════════════════════════════════════════════════════════════ RGB_LATENT_SCALE = 0.18215 ADAPTOR_REPO = os.environ.get("ADAPTOR_MODEL_ID", "UEXdo/HeightAdaptor-weight") DATASET_NAME = "OpenDC" H_TYPE = "ER" DATASET_CFG = { "OpenDC": {"classes_num": 8}, } LABEL_COLORS = { "OpenDC": { 0: (50,125,0), 1: (255,0,0), 2: (0,255,0), 3: (255,0,0), 4: (255,255,0), 5: (255,255,255), 6: (0,255,255), 7: (0,0,0), }, } # ══════════════════════════════════════════════════════════════ # 下载 Adaptor 权重 # ══════════════════════════════════════════════════════════════ print(f"📦 Downloading adaptor weights from {ADAPTOR_REPO} ...") ADAPTOR_DIR = snapshot_download(repo_id=ADAPTOR_REPO) print(f"✅ Weights cached at: {ADAPTOR_DIR}") # ══════════════════════════════════════════════════════════════ # 模型管理 # ══════════════════════════════════════════════════════════════ _model = None def build_model(): classes_num = DATASET_CFG[DATASET_NAME]["classes_num"] print(f"🔧 Building model — dataset={DATASET_NAME}, h_type={H_TYPE}") pipe = StableDiffusionPipeline.from_pretrained( os.path.join(ADAPTOR_DIR, "stable-diffusion-v2"), torch_dtype=torch.float32, safety_checker=None, requires_safety_checker=False, ) lora_path = os.path.join(ADAPTOR_DIR, "lora") ckpt_file = os.path.join(lora_path, "adapter_model.safetensors") if os.path.exists(ckpt_file): from safetensors.torch import load_file raw_sd = load_file(ckpt_file) else: raw_sd = torch.load( os.path.join(lora_path, "adapter_model.bin"), map_location="cpu" ) fixed_sd = fix_lora_state_dict(raw_sd) # noqa: F841 pipe.unet = PeftModel.from_pretrained(pipe.unet, lora_path) pipe.decoder = Decoder(in_channel=320) pipe.decoder.load_state_dict( torch.load(os.path.join(ADAPTOR_DIR, "decoder.pth"), map_location="cpu")) pipe.decoder.eval() pipe.height_head = HeightHead(in_channels=192, h_type=H_TYPE) pipe.height_head.load_state_dict( torch.load(os.path.join(ADAPTOR_DIR, "height_head.pth"), map_location="cpu")) pipe.height_head.eval() pipe.semantic_head = SemanticHead(in_channels=192, num_classes=classes_num) pipe.semantic_head.load_state_dict( torch.load(os.path.join(ADAPTOR_DIR, "semantic_head.pth"), map_location="cpu")) pipe.semantic_head.eval() print("✅ Model ready (on CPU).") return pipe def move_pipe_to(pipe, device: str): """ pipe.to() 只移动 Pipeline 内建组件; decoder / height_head / semantic_head 是事后挂上去的自定义属性,必须手动移动。 """ pipe.to(device) pipe.decoder.to(device) pipe.height_head.to(device) pipe.semantic_head.to(device) # 启动时预加载模型(OpenDC / ER) _model = build_model() # ══════════════════════════════════════════════════════════════ # VAE / UNet forward # ══════════════════════════════════════════════════════════════ def _vae_encode(pipe, x: torch.Tensor): enc = pipe.vae.encoder x = enc.conv_in(x) feats = [] for blk in enc.down_blocks: x = blk(x) feats.append(x) x = enc.mid_block(x) x = enc.conv_norm_out(x) x = enc.conv_act(x) x = enc.conv_out(x) return x, feats[:-1] def _unet_forward(unet, sample, timestep, enc_hs): t_emb = unet.get_time_embed(sample=sample, timestep=timestep) emb = unet.time_embedding(t_emb) enc_hs = unet.process_encoder_hidden_states( encoder_hidden_states=enc_hs, added_cond_kwargs=None) x = unet.conv_in(sample) skips = (x,) for blk in unet.down_blocks: x, res = blk(hidden_states=x, temb=emb, encoder_hidden_states=enc_hs) skips += res x = unet.mid_block(x, emb, encoder_hidden_states=enc_hs) for blk in unet.up_blocks: res = skips[-len(blk.resnets):] skips = skips[:-len(blk.resnets)] x = blk(hidden_states=x, temb=emb, res_hidden_states_tuple=res, encoder_hidden_states=enc_hs) return x # ══════════════════════════════════════════════════════════════ # 3D 曲面渲染辅助函数 # ══════════════════════════════════════════════════════════════ def _render_3d_surface( height_np: np.ndarray, rgb_img: Image.Image = None, title: str = "3D Height", grid_size: int = 128, elev: int = 35, azim: int = -30, ) -> Image.Image: """ 将归一化高度图 (H, W),值域 [0, 1],渲染为 3D 曲面图。 若提供 rgb_img(PIL Image),则将其贴到曲面作为颜色纹理。 """ h_pil = Image.fromarray((height_np * 255).astype(np.uint8)) h_pil = h_pil.resize((grid_size, grid_size), Image.BILINEAR) Z = np.array(h_pil, dtype=np.float32) / 255.0 x = np.linspace(0, 1, grid_size) y = np.linspace(0, 1, grid_size) X, Y = np.meshgrid(x, y) fig = plt.figure(figsize=(6, 5)) ax = fig.add_subplot(111, projection="3d") if rgb_img is not None: rgb_small = rgb_img.resize((grid_size, grid_size), Image.BILINEAR) rgb_arr = np.array(rgb_small, dtype=np.float32) / 255.0 ax.plot_surface( X, Y, Z, facecolors=rgb_arr, rstride=1, cstride=1, shade=False, antialiased=False, ) else: surf = ax.plot_surface( X, Y, Z, cmap="plasma", rstride=1, cstride=1, antialiased=False, ) ax.set_xlabel("X") ax.set_ylabel("Y") ax.set_zlabel("Height") ax.set_title(title) ax.set_zlim(0.0, np.max(height_np) * 5) ax.set_axis_off() ax.view_init(elev=elev, azim=azim) plt.tight_layout() buf = io.BytesIO() fig.savefig(buf, format="png", dpi=150) plt.close(fig) buf.seek(0) return Image.open(buf).copy() # ══════════════════════════════════════════════════════════════ # 核心推理逻辑 # ══════════════════════════════════════════════════════════════ @torch.no_grad() def _run_inference_core(pipe, device, image): """ 同时运行 height_head 和 semantic_head,生成 4 张输出图。 Returns ------- height_img : PIL Image 2D 高度图(plasma 伪彩色 + colorbar) semantic_img : PIL Image 语义分割图(类别颜色编码) d3_height_img : PIL Image 3D 高度曲面图(plasma 着色) d3_rgb_img : PIL Image 3D 高度曲面 + RGB 纹理贴图 info : str 数值统计说明 """ # ── 1. 文本编码 ────────────────────────────────────────── tokens = pipe.tokenizer( "", padding="max_length", truncation=True, max_length=pipe.tokenizer.model_max_length, return_tensors="pt") text_emb = pipe.text_encoder(tokens.input_ids.to(device))[0].float() # ── 2. 图像预处理 → [1, 3, 512, 512] ∈ [-1, 1] ───────── img = image.convert("RGB").resize((512, 512), Image.BILINEAR) arr = np.array(img, dtype=np.float32).transpose(2, 0, 1) norm = (torch.from_numpy(arr) / 255.0 * 2.0 - 1.0).unsqueeze(0).to(device) # ── 3. VAE 编码 ─────────────────────────────────────────── h, h_list = _vae_encode(pipe, norm) moments = pipe.vae.quant_conv(h) mean, lv = torch.chunk(moments, 2, dim=1) latents = (mean + torch.exp(0.5 * lv) * torch.randn_like(mean)) * RGB_LATENT_SCALE # ── 4. UNet + 自定义 Decoder ───────────────────────────── ts = torch.ones([latents.shape[0]], device=device) * 999 unet_o = _unet_forward(pipe.unet, latents, ts, text_emb) dec_o = pipe.decoder(unet_o, res_list=h_list[::-1]) # ── 5. 两个 Head 同时推理 ───────────────────────────────── h_out = pipe.height_head(dec_o) s_out = pipe.semantic_head(dec_o) # ── 6a. 高度图(2D,plasma 伪彩色)────────────────────── height_pred = F.interpolate( h_out[0].cpu(), (512, 512), mode="bilinear", align_corners=False) height_pred = ((height_pred + 1.0) / 2.0).clamp(0, 1).squeeze().numpy() fig, ax = plt.subplots(figsize=(6, 5), tight_layout=True) im = ax.imshow(height_pred, cmap="plasma") fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="Norm. height") ax.set_title("Predicted Height Map") ax.axis("off") buf = io.BytesIO() fig.savefig(buf, format="png", dpi=150) plt.close(fig); buf.seek(0) height_img = Image.open(buf).copy() # ── 6b. 语义分割图(2D,类别颜色编码)────────────────── sem_pred = F.interpolate(s_out, (512, 512), mode="bilinear", align_corners=False) argmax = torch.argmax(sem_pred, dim=1).squeeze().cpu().numpy() canvas = np.zeros((512, 512, 3), dtype=np.uint8) for lbl, col in LABEL_COLORS[DATASET_NAME].items(): canvas[argmax == lbl] = col semantic_img = Image.fromarray(canvas) # ── 6c. 3D 高度曲面(plasma 着色)─────────────────────── d3_height_img = _render_3d_surface( height_pred, rgb_img=None, title="3D Height Surface", grid_size=256, ) # ── 6d. 3D 高度曲面 + RGB 纹理贴图 ────────────────────── d3_rgb_img = _render_3d_surface( height_pred, rgb_img=img, title="3D Height + RGB Texture", grid_size=256, ) info = ( f"Height normalized range : [{height_pred.min():.4f}, {height_pred.max():.4f}]" f" (0 ≈ 0 m, 1 ≈ 50 m)\n" f"Semantic class indices : {np.unique(argmax).tolist()}" ) return height_img, semantic_img, d3_height_img, d3_rgb_img, info # ══════════════════════════════════════════════════════════════ # GPU 推理入口(Gradio 按钮触发) # ══════════════════════════════════════════════════════════════ @spaces.GPU(duration=120) def run_inference(image): _EMPTY = None if image is None: return _EMPTY, _EMPTY, _EMPTY, _EMPTY, "⚠️ Please upload an image first." if _model is None: return _EMPTY, _EMPTY, _EMPTY, _EMPTY, "⚠️ Model not loaded." device = "cuda" pipe = _model move_pipe_to(pipe, device) try: return _run_inference_core(pipe, device, image) except Exception as e: traceback.print_exc() return _EMPTY, _EMPTY, _EMPTY, _EMPTY, f"❌ Inference error: {e}" finally: pipe.to("cpu") torch.cuda.empty_cache() # ══════════════════════════════════════════════════════════════ # 启动测试 # ══════════════════════════════════════════════════════════════ @spaces.GPU(duration=120) def _startup_gpu_test(): _DEMO_IMG_PATH = "Demo1.png" print(f"\n{'='*60}") print(f"🧪 Startup inference test — {_DEMO_IMG_PATH} (device=cuda)") print(f"{'='*60}") try: if not os.path.exists(_DEMO_IMG_PATH): print(f"⚠️ {_DEMO_IMG_PATH} not found, skipping test.") return _test_img = Image.open(_DEMO_IMG_PATH) print(f" Image size : {_test_img.size}, mode: {_test_img.mode}") move_pipe_to(_model, "cuda") height_img, semantic_img, d3_height_img, d3_rgb_img, info = \ _run_inference_core(_model, "cuda", _test_img) height_img.save("Demo1_height.png") semantic_img.save("Demo1_semantic.png") d3_height_img.save("Demo1_3d_height.png") d3_rgb_img.save("Demo1_3d_rgb.png") print(f"✅ Test PASSED") print(f" Info : {info}") except Exception: print("❌ Test FAILED — full traceback below:") traceback.print_exc() finally: move_pipe_to(_model, "cpu") torch.cuda.empty_cache() print(f"{'='*60}\n") _startup_gpu_test() # ══════════════════════════════════════════════════════════════ # Gradio UI # ══════════════════════════════════════════════════════════════ with gr.Blocks(title="HeightAdaptor") as demo: gr.Markdown(""" # 🏙️ HeightAdaptor **Remote Sensing Image → Height Map · Semantic Segmentation · 3D Reconstruction** """) with gr.Row(): # ── 左列:输入 ──────────────────────────────────────── with gr.Column(scale=1): inp_img = gr.Image(type="pil", label="📷 Input RGB Image") run_btn = gr.Button("🚀 Run Inference", variant="primary", size="lg") out_info = gr.Textbox(label="ℹ️ Info", interactive=False, lines=3) # ── 右列:4 个输出窗口(2×2 网格)──────────────────── with gr.Column(scale=2): gr.Markdown("#### 📊 Results") with gr.Row(): out_height = gr.Image(type="pil", label="🗺️ Height Map") out_semantic = gr.Image(type="pil", label="🎨 Semantic Map") with gr.Row(): out_3d_height = gr.Image(type="pil", label="🏔️ 3D Height Surface") out_3d_rgb = gr.Image(type="pil", label="🌍 3D Height + RGB Texture") # ── 示例图片区(底部,4 张可点击备选)──────────────────── gr.Markdown("---\n#### 🖼️ Example Images — Click any image to load it, then click **Run Inference**") gr.Examples( examples=[ ["demo/Demo1.png"], ["demo/Demo2.png"], ["demo/Demo3.png"], ["demo/Demo4.png"], ["demo/Demo5.png"], ["demo/Demo6.png"], ["demo/Demo7.png"], ], inputs=[inp_img], label="Demo Samples", examples_per_page=7, ) gr.Markdown(""" --- > 图像会自动缩放至 512 × 512,GPU 推理约需 15–45 秒(含 3D 渲染)。 """) run_btn.click( fn=run_inference, inputs=[inp_img], outputs=[out_height, out_semantic, out_3d_height, out_3d_rgb, out_info], ) if __name__ == "__main__": demo.launch()