Spaces:
Sleeping
Sleeping
| """ | |
| app.py โ HeightAdaptor Hugging Face Spaces App | |
| Backbone : sd-research/stable-diffusion-2-1-base | |
| Adaptor : UEXdo/HeightAdaptor-weight | |
| Outputs : Height Map (2D) | Semantic Map | 3D Height Surface | 3D Height + RGB Texture | |
| """ | |
| # โโ ZeroGPU compatibility๏ผๆ spaces ๅบๆถ่ชๅจ้็บง๏ผโโโโโโโโโโโโโโโโโโโโโ | |
| try: | |
| import spaces | |
| except ImportError: | |
| class spaces: | |
| def GPU(duration=120): | |
| return lambda fn: fn | |
| import os, io, traceback | |
| import torch | |
| import numpy as np | |
| import matplotlib; matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from mpl_toolkits.mplot3d import Axes3D # noqa: F401 โ ๆณจๅ '3d' projection | |
| from PIL import Image | |
| from torch.nn import functional as F | |
| from diffusers import StableDiffusionPipeline | |
| from huggingface_hub import snapshot_download | |
| from peft import PeftModel | |
| import gradio as gr | |
| import safetensors.torch | |
| import warnings | |
| warnings.filterwarnings("ignore", category=ResourceWarning) | |
| from networks.semantic_head import SemanticHead | |
| from networks.height_head import HeightHead | |
| from networks.decoder import Decoder | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๅทฅๅ ทๅฝๆฐ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def fix_lora_state_dict(state_dict: dict) -> dict: | |
| """ๆๆง็ Linear proj_in/proj_out ็ 2D LoRA ๆ้ๅ็ปดๅฐ Conv2d ๆ้็ 4D""" | |
| fixed = {} | |
| for k, v in state_dict.items(): | |
| if ("proj_in" in k or "proj_out" in k) and v.ndim == 2: | |
| v = v.unsqueeze(-1).unsqueeze(-1) | |
| fixed[k] = v | |
| return fixed | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๅธธ้ & ้ ็ฝฎ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| RGB_LATENT_SCALE = 0.18215 | |
| ADAPTOR_REPO = os.environ.get("ADAPTOR_MODEL_ID", "UEXdo/HeightAdaptor-weight") | |
| DATASET_NAME = "OpenDC" | |
| H_TYPE = "ER" | |
| DATASET_CFG = { | |
| "OpenDC": {"classes_num": 8}, | |
| } | |
| LABEL_COLORS = { | |
| "OpenDC": { | |
| 0: (50,125,0), 1: (255,0,0), 2: (0,255,0), 3: (255,0,0), | |
| 4: (255,255,0), 5: (255,255,255), 6: (0,255,255), 7: (0,0,0), | |
| }, | |
| } | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ไธ่ฝฝ Adaptor ๆ้ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| print(f"๐ฆ Downloading adaptor weights from {ADAPTOR_REPO} ...") | |
| ADAPTOR_DIR = snapshot_download(repo_id=ADAPTOR_REPO) | |
| print(f"โ Weights cached at: {ADAPTOR_DIR}") | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๆจกๅ็ฎก็ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| _model = None | |
| def build_model(): | |
| classes_num = DATASET_CFG[DATASET_NAME]["classes_num"] | |
| print(f"๐ง Building model โ dataset={DATASET_NAME}, h_type={H_TYPE}") | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| os.path.join(ADAPTOR_DIR, "stable-diffusion-v2"), | |
| torch_dtype=torch.float32, | |
| safety_checker=None, | |
| requires_safety_checker=False, | |
| ) | |
| lora_path = os.path.join(ADAPTOR_DIR, "lora") | |
| ckpt_file = os.path.join(lora_path, "adapter_model.safetensors") | |
| if os.path.exists(ckpt_file): | |
| from safetensors.torch import load_file | |
| raw_sd = load_file(ckpt_file) | |
| else: | |
| raw_sd = torch.load( | |
| os.path.join(lora_path, "adapter_model.bin"), | |
| map_location="cpu" | |
| ) | |
| fixed_sd = fix_lora_state_dict(raw_sd) # noqa: F841 | |
| pipe.unet = PeftModel.from_pretrained(pipe.unet, lora_path) | |
| pipe.decoder = Decoder(in_channel=320) | |
| pipe.decoder.load_state_dict( | |
| torch.load(os.path.join(ADAPTOR_DIR, "decoder.pth"), map_location="cpu")) | |
| pipe.decoder.eval() | |
| pipe.height_head = HeightHead(in_channels=192, h_type=H_TYPE) | |
| pipe.height_head.load_state_dict( | |
| torch.load(os.path.join(ADAPTOR_DIR, "height_head.pth"), map_location="cpu")) | |
| pipe.height_head.eval() | |
| pipe.semantic_head = SemanticHead(in_channels=192, num_classes=classes_num) | |
| pipe.semantic_head.load_state_dict( | |
| torch.load(os.path.join(ADAPTOR_DIR, "semantic_head.pth"), map_location="cpu")) | |
| pipe.semantic_head.eval() | |
| print("โ Model ready (on CPU).") | |
| return pipe | |
| def move_pipe_to(pipe, device: str): | |
| """ | |
| pipe.to() ๅช็งปๅจ Pipeline ๅ ๅปบ็ปไปถ๏ผ | |
| decoder / height_head / semantic_head ๆฏไบๅๆไธๅป็่ชๅฎไนๅฑๆง๏ผๅฟ ้กปๆๅจ็งปๅจใ | |
| """ | |
| pipe.to(device) | |
| pipe.decoder.to(device) | |
| pipe.height_head.to(device) | |
| pipe.semantic_head.to(device) | |
| # ๅฏๅจๆถ้ขๅ ่ฝฝๆจกๅ๏ผOpenDC / ER๏ผ | |
| _model = build_model() | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # VAE / UNet forward | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def _vae_encode(pipe, x: torch.Tensor): | |
| enc = pipe.vae.encoder | |
| x = enc.conv_in(x) | |
| feats = [] | |
| for blk in enc.down_blocks: | |
| x = blk(x) | |
| feats.append(x) | |
| x = enc.mid_block(x) | |
| x = enc.conv_norm_out(x) | |
| x = enc.conv_act(x) | |
| x = enc.conv_out(x) | |
| return x, feats[:-1] | |
| def _unet_forward(unet, sample, timestep, enc_hs): | |
| t_emb = unet.get_time_embed(sample=sample, timestep=timestep) | |
| emb = unet.time_embedding(t_emb) | |
| enc_hs = unet.process_encoder_hidden_states( | |
| encoder_hidden_states=enc_hs, added_cond_kwargs=None) | |
| x = unet.conv_in(sample) | |
| skips = (x,) | |
| for blk in unet.down_blocks: | |
| x, res = blk(hidden_states=x, temb=emb, encoder_hidden_states=enc_hs) | |
| skips += res | |
| x = unet.mid_block(x, emb, encoder_hidden_states=enc_hs) | |
| for blk in unet.up_blocks: | |
| res = skips[-len(blk.resnets):] | |
| skips = skips[:-len(blk.resnets)] | |
| x = blk(hidden_states=x, temb=emb, | |
| res_hidden_states_tuple=res, encoder_hidden_states=enc_hs) | |
| return x | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # 3D ๆฒ้ขๆธฒๆ่พ ๅฉๅฝๆฐ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def _render_3d_surface( | |
| height_np: np.ndarray, | |
| rgb_img: Image.Image = None, | |
| title: str = "3D Height", | |
| grid_size: int = 128, | |
| elev: int = 35, | |
| azim: int = -30, | |
| ) -> Image.Image: | |
| """ | |
| ๅฐๅฝไธๅ้ซๅบฆๅพ (H, W)๏ผๅผๅ [0, 1]๏ผๆธฒๆไธบ 3D ๆฒ้ขๅพใ | |
| ่ฅๆไพ rgb_img๏ผPIL Image๏ผ๏ผๅๅฐๅ ถ่ดดๅฐๆฒ้ขไฝไธบ้ข่ฒ็บน็ใ | |
| """ | |
| h_pil = Image.fromarray((height_np * 255).astype(np.uint8)) | |
| h_pil = h_pil.resize((grid_size, grid_size), Image.BILINEAR) | |
| Z = np.array(h_pil, dtype=np.float32) / 255.0 | |
| x = np.linspace(0, 1, grid_size) | |
| y = np.linspace(0, 1, grid_size) | |
| X, Y = np.meshgrid(x, y) | |
| fig = plt.figure(figsize=(6, 5)) | |
| ax = fig.add_subplot(111, projection="3d") | |
| if rgb_img is not None: | |
| rgb_small = rgb_img.resize((grid_size, grid_size), Image.BILINEAR) | |
| rgb_arr = np.array(rgb_small, dtype=np.float32) / 255.0 | |
| ax.plot_surface( | |
| X, Y, Z, | |
| facecolors=rgb_arr, | |
| rstride=1, cstride=1, | |
| shade=False, antialiased=False, | |
| ) | |
| else: | |
| surf = ax.plot_surface( | |
| X, Y, Z, | |
| cmap="plasma", | |
| rstride=1, cstride=1, | |
| antialiased=False, | |
| ) | |
| ax.set_xlabel("X") | |
| ax.set_ylabel("Y") | |
| ax.set_zlabel("Height") | |
| ax.set_title(title) | |
| ax.set_zlim(0.0, np.max(height_np) * 5) | |
| ax.set_axis_off() | |
| ax.view_init(elev=elev, azim=azim) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=150) | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf).copy() | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๆ ธๅฟๆจ็้ป่พ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def _run_inference_core(pipe, device, image): | |
| """ | |
| ๅๆถ่ฟ่ก height_head ๅ semantic_head๏ผ็ๆ 4 ๅผ ่พๅบๅพใ | |
| Returns | |
| ------- | |
| height_img : PIL Image 2D ้ซๅบฆๅพ๏ผplasma ไผชๅฝฉ่ฒ + colorbar๏ผ | |
| semantic_img : PIL Image ่ฏญไนๅๅฒๅพ๏ผ็ฑปๅซ้ข่ฒ็ผ็ ๏ผ | |
| d3_height_img : PIL Image 3D ้ซๅบฆๆฒ้ขๅพ๏ผplasma ็่ฒ๏ผ | |
| d3_rgb_img : PIL Image 3D ้ซๅบฆๆฒ้ข + RGB ็บน็่ดดๅพ | |
| info : str ๆฐๅผ็ป่ฎก่ฏดๆ | |
| """ | |
| # โโ 1. ๆๆฌ็ผ็ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| tokens = pipe.tokenizer( | |
| "", padding="max_length", truncation=True, | |
| max_length=pipe.tokenizer.model_max_length, return_tensors="pt") | |
| text_emb = pipe.text_encoder(tokens.input_ids.to(device))[0].float() | |
| # โโ 2. ๅพๅ้ขๅค็ โ [1, 3, 512, 512] โ [-1, 1] โโโโโโโโโ | |
| img = image.convert("RGB").resize((512, 512), Image.BILINEAR) | |
| arr = np.array(img, dtype=np.float32).transpose(2, 0, 1) | |
| norm = (torch.from_numpy(arr) / 255.0 * 2.0 - 1.0).unsqueeze(0).to(device) | |
| # โโ 3. VAE ็ผ็ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| h, h_list = _vae_encode(pipe, norm) | |
| moments = pipe.vae.quant_conv(h) | |
| mean, lv = torch.chunk(moments, 2, dim=1) | |
| latents = (mean + torch.exp(0.5 * lv) * torch.randn_like(mean)) * RGB_LATENT_SCALE | |
| # โโ 4. UNet + ่ชๅฎไน Decoder โโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| ts = torch.ones([latents.shape[0]], device=device) * 999 | |
| unet_o = _unet_forward(pipe.unet, latents, ts, text_emb) | |
| dec_o = pipe.decoder(unet_o, res_list=h_list[::-1]) | |
| # โโ 5. ไธคไธช Head ๅๆถๆจ็ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| h_out = pipe.height_head(dec_o) | |
| s_out = pipe.semantic_head(dec_o) | |
| # โโ 6a. ้ซๅบฆๅพ๏ผ2D๏ผplasma ไผชๅฝฉ่ฒ๏ผโโโโโโโโโโโโโโโโโโโโโโ | |
| height_pred = F.interpolate( | |
| h_out[0].cpu(), (512, 512), mode="bilinear", align_corners=False) | |
| height_pred = ((height_pred + 1.0) / 2.0).clamp(0, 1).squeeze().numpy() | |
| fig, ax = plt.subplots(figsize=(6, 5), tight_layout=True) | |
| im = ax.imshow(height_pred, cmap="plasma") | |
| fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="Norm. height") | |
| ax.set_title("Predicted Height Map") | |
| ax.axis("off") | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=150) | |
| plt.close(fig); buf.seek(0) | |
| height_img = Image.open(buf).copy() | |
| # โโ 6b. ่ฏญไนๅๅฒๅพ๏ผ2D๏ผ็ฑปๅซ้ข่ฒ็ผ็ ๏ผโโโโโโโโโโโโโโโโโโ | |
| sem_pred = F.interpolate(s_out, (512, 512), mode="bilinear", align_corners=False) | |
| argmax = torch.argmax(sem_pred, dim=1).squeeze().cpu().numpy() | |
| canvas = np.zeros((512, 512, 3), dtype=np.uint8) | |
| for lbl, col in LABEL_COLORS[DATASET_NAME].items(): | |
| canvas[argmax == lbl] = col | |
| semantic_img = Image.fromarray(canvas) | |
| # โโ 6c. 3D ้ซๅบฆๆฒ้ข๏ผplasma ็่ฒ๏ผโโโโโโโโโโโโโโโโโโโโโโโ | |
| d3_height_img = _render_3d_surface( | |
| height_pred, | |
| rgb_img=None, | |
| title="3D Height Surface", | |
| grid_size=256, | |
| ) | |
| # โโ 6d. 3D ้ซๅบฆๆฒ้ข + RGB ็บน็่ดดๅพ โโโโโโโโโโโโโโโโโโโโโโ | |
| d3_rgb_img = _render_3d_surface( | |
| height_pred, | |
| rgb_img=img, | |
| title="3D Height + RGB Texture", | |
| grid_size=256, | |
| ) | |
| info = ( | |
| f"Height normalized range : [{height_pred.min():.4f}, {height_pred.max():.4f}]" | |
| f" (0 โ 0 m, 1 โ 50 m)\n" | |
| f"Semantic class indices : {np.unique(argmax).tolist()}" | |
| ) | |
| return height_img, semantic_img, d3_height_img, d3_rgb_img, info | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # GPU ๆจ็ๅ ฅๅฃ๏ผGradio ๆ้ฎ่งฆๅ๏ผ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def run_inference(image): | |
| _EMPTY = None | |
| if image is None: | |
| return _EMPTY, _EMPTY, _EMPTY, _EMPTY, "โ ๏ธ Please upload an image first." | |
| if _model is None: | |
| return _EMPTY, _EMPTY, _EMPTY, _EMPTY, "โ ๏ธ Model not loaded." | |
| device = "cuda" | |
| pipe = _model | |
| move_pipe_to(pipe, device) | |
| try: | |
| return _run_inference_core(pipe, device, image) | |
| except Exception as e: | |
| traceback.print_exc() | |
| return _EMPTY, _EMPTY, _EMPTY, _EMPTY, f"โ Inference error: {e}" | |
| finally: | |
| pipe.to("cpu") | |
| torch.cuda.empty_cache() | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ๅฏๅจๆต่ฏ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def _startup_gpu_test(): | |
| _DEMO_IMG_PATH = "Demo1.png" | |
| print(f"\n{'='*60}") | |
| print(f"๐งช Startup inference test โ {_DEMO_IMG_PATH} (device=cuda)") | |
| print(f"{'='*60}") | |
| try: | |
| if not os.path.exists(_DEMO_IMG_PATH): | |
| print(f"โ ๏ธ {_DEMO_IMG_PATH} not found, skipping test.") | |
| return | |
| _test_img = Image.open(_DEMO_IMG_PATH) | |
| print(f" Image size : {_test_img.size}, mode: {_test_img.mode}") | |
| move_pipe_to(_model, "cuda") | |
| height_img, semantic_img, d3_height_img, d3_rgb_img, info = \ | |
| _run_inference_core(_model, "cuda", _test_img) | |
| height_img.save("Demo1_height.png") | |
| semantic_img.save("Demo1_semantic.png") | |
| d3_height_img.save("Demo1_3d_height.png") | |
| d3_rgb_img.save("Demo1_3d_rgb.png") | |
| print(f"โ Test PASSED") | |
| print(f" Info : {info}") | |
| except Exception: | |
| print("โ Test FAILED โ full traceback below:") | |
| traceback.print_exc() | |
| finally: | |
| move_pipe_to(_model, "cpu") | |
| torch.cuda.empty_cache() | |
| print(f"{'='*60}\n") | |
| _startup_gpu_test() | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Gradio UI | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| with gr.Blocks(title="HeightAdaptor") as demo: | |
| gr.Markdown(""" | |
| # ๐๏ธ HeightAdaptor | |
| **Remote Sensing Image โ Height Map ยท Semantic Segmentation ยท 3D Reconstruction** | |
| """) | |
| with gr.Row(): | |
| # โโ ๅทฆๅ๏ผ่พๅ ฅ โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| with gr.Column(scale=1): | |
| inp_img = gr.Image(type="pil", label="๐ท Input RGB Image") | |
| run_btn = gr.Button("๐ Run Inference", variant="primary", size="lg") | |
| out_info = gr.Textbox(label="โน๏ธ Info", interactive=False, lines=3) | |
| # โโ ๅณๅ๏ผ4 ไธช่พๅบ็ชๅฃ๏ผ2ร2 ็ฝๆ ผ๏ผโโโโโโโโโโโโโโโโโโโโ | |
| with gr.Column(scale=2): | |
| gr.Markdown("#### ๐ Results") | |
| with gr.Row(): | |
| out_height = gr.Image(type="pil", label="๐บ๏ธ Height Map") | |
| out_semantic = gr.Image(type="pil", label="๐จ Semantic Map") | |
| with gr.Row(): | |
| out_3d_height = gr.Image(type="pil", label="๐๏ธ 3D Height Surface") | |
| out_3d_rgb = gr.Image(type="pil", label="๐ 3D Height + RGB Texture") | |
| # โโ ็คบไพๅพ็ๅบ๏ผๅบ้จ๏ผ4 ๅผ ๅฏ็นๅปๅค้๏ผโโโโโโโโโโโโโโโโโโโโ | |
| gr.Markdown("---\n#### ๐ผ๏ธ Example Images โ Click any image to load it, then click **Run Inference**") | |
| gr.Examples( | |
| examples=[ | |
| ["demo/Demo1.png"], | |
| ["demo/Demo2.png"], | |
| ["demo/Demo3.png"], | |
| ["demo/Demo4.png"], | |
| ["demo/Demo5.png"], | |
| ["demo/Demo6.png"], | |
| ["demo/Demo7.png"], | |
| ], | |
| inputs=[inp_img], | |
| label="Demo Samples", | |
| examples_per_page=7, | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| > ๅพๅไผ่ชๅจ็ผฉๆพ่ณ 512 ร 512๏ผGPU ๆจ็็บฆ้ 15โ45 ็ง๏ผๅซ 3D ๆธฒๆ๏ผใ | |
| """) | |
| run_btn.click( | |
| fn=run_inference, | |
| inputs=[inp_img], | |
| outputs=[out_height, out_semantic, out_3d_height, out_3d_rgb, out_info], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |