HeightAdaptor / app.py
PubAccount's picture
Update app.py
05e03fc verified
"""
app.py โ€” HeightAdaptor Hugging Face Spaces App
Backbone : sd-research/stable-diffusion-2-1-base
Adaptor : UEXdo/HeightAdaptor-weight
Outputs : Height Map (2D) | Semantic Map | 3D Height Surface | 3D Height + RGB Texture
"""
# โ”€โ”€ ZeroGPU compatibility๏ผˆๆ—  spaces ๅบ“ๆ—ถ่‡ชๅŠจ้™็บง๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
try:
import spaces
except ImportError:
class spaces:
@staticmethod
def GPU(duration=120):
return lambda fn: fn
import os, io, traceback
import torch
import numpy as np
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D # noqa: F401 โ€” ๆณจๅ†Œ '3d' projection
from PIL import Image
from torch.nn import functional as F
from diffusers import StableDiffusionPipeline
from huggingface_hub import snapshot_download
from peft import PeftModel
import gradio as gr
import safetensors.torch
import warnings
warnings.filterwarnings("ignore", category=ResourceWarning)
from networks.semantic_head import SemanticHead
from networks.height_head import HeightHead
from networks.decoder import Decoder
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# ๅทฅๅ…ทๅ‡ฝๆ•ฐ
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def fix_lora_state_dict(state_dict: dict) -> dict:
"""ๆŠŠๆ—ง็‰ˆ Linear proj_in/proj_out ็š„ 2D LoRA ๆƒ้‡ๅ‡็ปดๅˆฐ Conv2d ๆ‰€้œ€็š„ 4D"""
fixed = {}
for k, v in state_dict.items():
if ("proj_in" in k or "proj_out" in k) and v.ndim == 2:
v = v.unsqueeze(-1).unsqueeze(-1)
fixed[k] = v
return fixed
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# ๅธธ้‡ & ้…็ฝฎ
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
RGB_LATENT_SCALE = 0.18215
ADAPTOR_REPO = os.environ.get("ADAPTOR_MODEL_ID", "UEXdo/HeightAdaptor-weight")
DATASET_NAME = "OpenDC"
H_TYPE = "ER"
DATASET_CFG = {
"OpenDC": {"classes_num": 8},
}
LABEL_COLORS = {
"OpenDC": {
0: (50,125,0), 1: (255,0,0), 2: (0,255,0), 3: (255,0,0),
4: (255,255,0), 5: (255,255,255), 6: (0,255,255), 7: (0,0,0),
},
}
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# ไธ‹่ฝฝ Adaptor ๆƒ้‡
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
print(f"๐Ÿ“ฆ Downloading adaptor weights from {ADAPTOR_REPO} ...")
ADAPTOR_DIR = snapshot_download(repo_id=ADAPTOR_REPO)
print(f"โœ… Weights cached at: {ADAPTOR_DIR}")
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# ๆจกๅž‹็ฎก็†
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
_model = None
def build_model():
classes_num = DATASET_CFG[DATASET_NAME]["classes_num"]
print(f"๐Ÿ”ง Building model โ€” dataset={DATASET_NAME}, h_type={H_TYPE}")
pipe = StableDiffusionPipeline.from_pretrained(
os.path.join(ADAPTOR_DIR, "stable-diffusion-v2"),
torch_dtype=torch.float32,
safety_checker=None,
requires_safety_checker=False,
)
lora_path = os.path.join(ADAPTOR_DIR, "lora")
ckpt_file = os.path.join(lora_path, "adapter_model.safetensors")
if os.path.exists(ckpt_file):
from safetensors.torch import load_file
raw_sd = load_file(ckpt_file)
else:
raw_sd = torch.load(
os.path.join(lora_path, "adapter_model.bin"),
map_location="cpu"
)
fixed_sd = fix_lora_state_dict(raw_sd) # noqa: F841
pipe.unet = PeftModel.from_pretrained(pipe.unet, lora_path)
pipe.decoder = Decoder(in_channel=320)
pipe.decoder.load_state_dict(
torch.load(os.path.join(ADAPTOR_DIR, "decoder.pth"), map_location="cpu"))
pipe.decoder.eval()
pipe.height_head = HeightHead(in_channels=192, h_type=H_TYPE)
pipe.height_head.load_state_dict(
torch.load(os.path.join(ADAPTOR_DIR, "height_head.pth"), map_location="cpu"))
pipe.height_head.eval()
pipe.semantic_head = SemanticHead(in_channels=192, num_classes=classes_num)
pipe.semantic_head.load_state_dict(
torch.load(os.path.join(ADAPTOR_DIR, "semantic_head.pth"), map_location="cpu"))
pipe.semantic_head.eval()
print("โœ… Model ready (on CPU).")
return pipe
def move_pipe_to(pipe, device: str):
"""
pipe.to() ๅช็งปๅŠจ Pipeline ๅ†…ๅปบ็ป„ไปถ๏ผ›
decoder / height_head / semantic_head ๆ˜ฏไบ‹ๅŽๆŒ‚ไธŠๅŽป็š„่‡ชๅฎšไน‰ๅฑžๆ€ง๏ผŒๅฟ…้กปๆ‰‹ๅŠจ็งปๅŠจใ€‚
"""
pipe.to(device)
pipe.decoder.to(device)
pipe.height_head.to(device)
pipe.semantic_head.to(device)
# ๅฏๅŠจๆ—ถ้ข„ๅŠ ่ฝฝๆจกๅž‹๏ผˆOpenDC / ER๏ผ‰
_model = build_model()
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# VAE / UNet forward
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def _vae_encode(pipe, x: torch.Tensor):
enc = pipe.vae.encoder
x = enc.conv_in(x)
feats = []
for blk in enc.down_blocks:
x = blk(x)
feats.append(x)
x = enc.mid_block(x)
x = enc.conv_norm_out(x)
x = enc.conv_act(x)
x = enc.conv_out(x)
return x, feats[:-1]
def _unet_forward(unet, sample, timestep, enc_hs):
t_emb = unet.get_time_embed(sample=sample, timestep=timestep)
emb = unet.time_embedding(t_emb)
enc_hs = unet.process_encoder_hidden_states(
encoder_hidden_states=enc_hs, added_cond_kwargs=None)
x = unet.conv_in(sample)
skips = (x,)
for blk in unet.down_blocks:
x, res = blk(hidden_states=x, temb=emb, encoder_hidden_states=enc_hs)
skips += res
x = unet.mid_block(x, emb, encoder_hidden_states=enc_hs)
for blk in unet.up_blocks:
res = skips[-len(blk.resnets):]
skips = skips[:-len(blk.resnets)]
x = blk(hidden_states=x, temb=emb,
res_hidden_states_tuple=res, encoder_hidden_states=enc_hs)
return x
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# 3D ๆ›ฒ้ขๆธฒๆŸ“่พ…ๅŠฉๅ‡ฝๆ•ฐ
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
def _render_3d_surface(
height_np: np.ndarray,
rgb_img: Image.Image = None,
title: str = "3D Height",
grid_size: int = 128,
elev: int = 35,
azim: int = -30,
) -> Image.Image:
"""
ๅฐ†ๅฝ’ไธ€ๅŒ–้ซ˜ๅบฆๅ›พ (H, W)๏ผŒๅ€ผๅŸŸ [0, 1]๏ผŒๆธฒๆŸ“ไธบ 3D ๆ›ฒ้ขๅ›พใ€‚
่‹ฅๆไพ› rgb_img๏ผˆPIL Image๏ผ‰๏ผŒๅˆ™ๅฐ†ๅ…ถ่ดดๅˆฐๆ›ฒ้ขไฝœไธบ้ขœ่‰ฒ็บน็†ใ€‚
"""
h_pil = Image.fromarray((height_np * 255).astype(np.uint8))
h_pil = h_pil.resize((grid_size, grid_size), Image.BILINEAR)
Z = np.array(h_pil, dtype=np.float32) / 255.0
x = np.linspace(0, 1, grid_size)
y = np.linspace(0, 1, grid_size)
X, Y = np.meshgrid(x, y)
fig = plt.figure(figsize=(6, 5))
ax = fig.add_subplot(111, projection="3d")
if rgb_img is not None:
rgb_small = rgb_img.resize((grid_size, grid_size), Image.BILINEAR)
rgb_arr = np.array(rgb_small, dtype=np.float32) / 255.0
ax.plot_surface(
X, Y, Z,
facecolors=rgb_arr,
rstride=1, cstride=1,
shade=False, antialiased=False,
)
else:
surf = ax.plot_surface(
X, Y, Z,
cmap="plasma",
rstride=1, cstride=1,
antialiased=False,
)
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Height")
ax.set_title(title)
ax.set_zlim(0.0, np.max(height_np) * 5)
ax.set_axis_off()
ax.view_init(elev=elev, azim=azim)
plt.tight_layout()
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=150)
plt.close(fig)
buf.seek(0)
return Image.open(buf).copy()
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# ๆ ธๅฟƒๆŽจ็†้€ป่พ‘
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
@torch.no_grad()
def _run_inference_core(pipe, device, image):
"""
ๅŒๆ—ถ่ฟ่กŒ height_head ๅ’Œ semantic_head๏ผŒ็”Ÿๆˆ 4 ๅผ ่พ“ๅ‡บๅ›พใ€‚
Returns
-------
height_img : PIL Image 2D ้ซ˜ๅบฆๅ›พ๏ผˆplasma ไผชๅฝฉ่‰ฒ + colorbar๏ผ‰
semantic_img : PIL Image ่ฏญไน‰ๅˆ†ๅ‰ฒๅ›พ๏ผˆ็ฑปๅˆซ้ขœ่‰ฒ็ผ–็ ๏ผ‰
d3_height_img : PIL Image 3D ้ซ˜ๅบฆๆ›ฒ้ขๅ›พ๏ผˆplasma ็€่‰ฒ๏ผ‰
d3_rgb_img : PIL Image 3D ้ซ˜ๅบฆๆ›ฒ้ข + RGB ็บน็†่ดดๅ›พ
info : str ๆ•ฐๅ€ผ็ปŸ่ฎก่ฏดๆ˜Ž
"""
# โ”€โ”€ 1. ๆ–‡ๆœฌ็ผ–็  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
tokens = pipe.tokenizer(
"", padding="max_length", truncation=True,
max_length=pipe.tokenizer.model_max_length, return_tensors="pt")
text_emb = pipe.text_encoder(tokens.input_ids.to(device))[0].float()
# โ”€โ”€ 2. ๅ›พๅƒ้ข„ๅค„็† โ†’ [1, 3, 512, 512] โˆˆ [-1, 1] โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
img = image.convert("RGB").resize((512, 512), Image.BILINEAR)
arr = np.array(img, dtype=np.float32).transpose(2, 0, 1)
norm = (torch.from_numpy(arr) / 255.0 * 2.0 - 1.0).unsqueeze(0).to(device)
# โ”€โ”€ 3. VAE ็ผ–็  โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
h, h_list = _vae_encode(pipe, norm)
moments = pipe.vae.quant_conv(h)
mean, lv = torch.chunk(moments, 2, dim=1)
latents = (mean + torch.exp(0.5 * lv) * torch.randn_like(mean)) * RGB_LATENT_SCALE
# โ”€โ”€ 4. UNet + ่‡ชๅฎšไน‰ Decoder โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
ts = torch.ones([latents.shape[0]], device=device) * 999
unet_o = _unet_forward(pipe.unet, latents, ts, text_emb)
dec_o = pipe.decoder(unet_o, res_list=h_list[::-1])
# โ”€โ”€ 5. ไธคไธช Head ๅŒๆ—ถๆŽจ็† โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
h_out = pipe.height_head(dec_o)
s_out = pipe.semantic_head(dec_o)
# โ”€โ”€ 6a. ้ซ˜ๅบฆๅ›พ๏ผˆ2D๏ผŒplasma ไผชๅฝฉ่‰ฒ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
height_pred = F.interpolate(
h_out[0].cpu(), (512, 512), mode="bilinear", align_corners=False)
height_pred = ((height_pred + 1.0) / 2.0).clamp(0, 1).squeeze().numpy()
fig, ax = plt.subplots(figsize=(6, 5), tight_layout=True)
im = ax.imshow(height_pred, cmap="plasma")
fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04, label="Norm. height")
ax.set_title("Predicted Height Map")
ax.axis("off")
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=150)
plt.close(fig); buf.seek(0)
height_img = Image.open(buf).copy()
# โ”€โ”€ 6b. ่ฏญไน‰ๅˆ†ๅ‰ฒๅ›พ๏ผˆ2D๏ผŒ็ฑปๅˆซ้ขœ่‰ฒ็ผ–็ ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
sem_pred = F.interpolate(s_out, (512, 512), mode="bilinear", align_corners=False)
argmax = torch.argmax(sem_pred, dim=1).squeeze().cpu().numpy()
canvas = np.zeros((512, 512, 3), dtype=np.uint8)
for lbl, col in LABEL_COLORS[DATASET_NAME].items():
canvas[argmax == lbl] = col
semantic_img = Image.fromarray(canvas)
# โ”€โ”€ 6c. 3D ้ซ˜ๅบฆๆ›ฒ้ข๏ผˆplasma ็€่‰ฒ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
d3_height_img = _render_3d_surface(
height_pred,
rgb_img=None,
title="3D Height Surface",
grid_size=256,
)
# โ”€โ”€ 6d. 3D ้ซ˜ๅบฆๆ›ฒ้ข + RGB ็บน็†่ดดๅ›พ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
d3_rgb_img = _render_3d_surface(
height_pred,
rgb_img=img,
title="3D Height + RGB Texture",
grid_size=256,
)
info = (
f"Height normalized range : [{height_pred.min():.4f}, {height_pred.max():.4f}]"
f" (0 โ‰ˆ 0 m, 1 โ‰ˆ 50 m)\n"
f"Semantic class indices : {np.unique(argmax).tolist()}"
)
return height_img, semantic_img, d3_height_img, d3_rgb_img, info
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# GPU ๆŽจ็†ๅ…ฅๅฃ๏ผˆGradio ๆŒ‰้’ฎ่งฆๅ‘๏ผ‰
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
@spaces.GPU(duration=120)
def run_inference(image):
_EMPTY = None
if image is None:
return _EMPTY, _EMPTY, _EMPTY, _EMPTY, "โš ๏ธ Please upload an image first."
if _model is None:
return _EMPTY, _EMPTY, _EMPTY, _EMPTY, "โš ๏ธ Model not loaded."
device = "cuda"
pipe = _model
move_pipe_to(pipe, device)
try:
return _run_inference_core(pipe, device, image)
except Exception as e:
traceback.print_exc()
return _EMPTY, _EMPTY, _EMPTY, _EMPTY, f"โŒ Inference error: {e}"
finally:
pipe.to("cpu")
torch.cuda.empty_cache()
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# ๅฏๅŠจๆต‹่ฏ•
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
@spaces.GPU(duration=120)
def _startup_gpu_test():
_DEMO_IMG_PATH = "Demo1.png"
print(f"\n{'='*60}")
print(f"๐Ÿงช Startup inference test โ€” {_DEMO_IMG_PATH} (device=cuda)")
print(f"{'='*60}")
try:
if not os.path.exists(_DEMO_IMG_PATH):
print(f"โš ๏ธ {_DEMO_IMG_PATH} not found, skipping test.")
return
_test_img = Image.open(_DEMO_IMG_PATH)
print(f" Image size : {_test_img.size}, mode: {_test_img.mode}")
move_pipe_to(_model, "cuda")
height_img, semantic_img, d3_height_img, d3_rgb_img, info = \
_run_inference_core(_model, "cuda", _test_img)
height_img.save("Demo1_height.png")
semantic_img.save("Demo1_semantic.png")
d3_height_img.save("Demo1_3d_height.png")
d3_rgb_img.save("Demo1_3d_rgb.png")
print(f"โœ… Test PASSED")
print(f" Info : {info}")
except Exception:
print("โŒ Test FAILED โ€” full traceback below:")
traceback.print_exc()
finally:
move_pipe_to(_model, "cpu")
torch.cuda.empty_cache()
print(f"{'='*60}\n")
_startup_gpu_test()
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
# Gradio UI
# โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•
with gr.Blocks(title="HeightAdaptor") as demo:
gr.Markdown("""
# ๐Ÿ™๏ธ HeightAdaptor
**Remote Sensing Image โ†’ Height Map ยท Semantic Segmentation ยท 3D Reconstruction**
""")
with gr.Row():
# โ”€โ”€ ๅทฆๅˆ—๏ผš่พ“ๅ…ฅ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
with gr.Column(scale=1):
inp_img = gr.Image(type="pil", label="๐Ÿ“ท Input RGB Image")
run_btn = gr.Button("๐Ÿš€ Run Inference", variant="primary", size="lg")
out_info = gr.Textbox(label="โ„น๏ธ Info", interactive=False, lines=3)
# โ”€โ”€ ๅณๅˆ—๏ผš4 ไธช่พ“ๅ‡บ็ช—ๅฃ๏ผˆ2ร—2 ็ฝ‘ๆ ผ๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
with gr.Column(scale=2):
gr.Markdown("#### ๐Ÿ“Š Results")
with gr.Row():
out_height = gr.Image(type="pil", label="๐Ÿ—บ๏ธ Height Map")
out_semantic = gr.Image(type="pil", label="๐ŸŽจ Semantic Map")
with gr.Row():
out_3d_height = gr.Image(type="pil", label="๐Ÿ”๏ธ 3D Height Surface")
out_3d_rgb = gr.Image(type="pil", label="๐ŸŒ 3D Height + RGB Texture")
# โ”€โ”€ ็คบไพ‹ๅ›พ็‰‡ๅŒบ๏ผˆๅบ•้ƒจ๏ผŒ4 ๅผ ๅฏ็‚นๅ‡ปๅค‡้€‰๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
gr.Markdown("---\n#### ๐Ÿ–ผ๏ธ Example Images โ€” Click any image to load it, then click **Run Inference**")
gr.Examples(
examples=[
["demo/Demo1.png"],
["demo/Demo2.png"],
["demo/Demo3.png"],
["demo/Demo4.png"],
["demo/Demo5.png"],
["demo/Demo6.png"],
["demo/Demo7.png"],
],
inputs=[inp_img],
label="Demo Samples",
examples_per_page=7,
)
gr.Markdown("""
---
> ๅ›พๅƒไผš่‡ชๅŠจ็ผฉๆ”พ่‡ณ 512 ร— 512๏ผŒGPU ๆŽจ็†็บฆ้œ€ 15โ€“45 ็ง’๏ผˆๅซ 3D ๆธฒๆŸ“๏ผ‰ใ€‚
""")
run_btn.click(
fn=run_inference,
inputs=[inp_img],
outputs=[out_height, out_semantic, out_3d_height, out_3d_rgb, out_info],
)
if __name__ == "__main__":
demo.launch()