| import os |
| import gc |
| import copy |
| from io import BytesIO |
|
|
| import cv2 |
| import numpy as np |
| import rasterio |
| import matplotlib.pyplot as plt |
| import streamlit as st |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from huggingface_hub import hf_hub_download |
| from torchvision.transforms.functional import normalize |
|
|
| |
| |
| |
| st.set_page_config(layout="wide", page_title="Prior2DSM | LoRA") |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| torch.backends.cudnn.benchmark = True |
|
|
| PATCH_SIZE = 16 |
| STRIDE = 4 |
|
|
| |
| IMAGENET_MEAN = (0.430, 0.411, 0.296) |
| IMAGENET_STD = (0.213, 0.156, 0.143) |
|
|
| |
| EXAMPLE_RGB_FILENAME = "examples/example_rgb.tif" |
| EXAMPLE_PRIOR_FILENAME = "examples/example_prior.tif" |
|
|
|
|
| |
| |
| |
| def normalize_01(arr, valid_mask=None): |
| a = np.asarray(arr, dtype=np.float32) |
| if valid_mask is None: |
| valid_mask = np.isfinite(a) |
| else: |
| valid_mask = np.asarray(valid_mask, dtype=bool) & np.isfinite(a) |
|
|
| out = np.zeros_like(a, dtype=np.float32) |
| if not valid_mask.any(): |
| return out |
|
|
| vmin = float(np.nanmin(a[valid_mask])) |
| vmax = float(np.nanmax(a[valid_mask])) |
| denom = max(1e-8, (vmax - vmin)) |
| out[valid_mask] = (a[valid_mask] - vmin) / denom |
| return np.clip(out, 0.0, 1.0) |
|
|
|
|
| def preview_rgb(rgb_raw): |
| rgb = rgb_raw.transpose(1, 2, 0).astype(np.float32) |
| if rgb.max() > 1.5: |
| rgb = rgb / (np.percentile(rgb, 98) + 1e-6) |
| return np.clip(rgb, 0, 1) |
|
|
|
|
| def draw_roi_preview(viz_rgb, x1, y1, x2, y2): |
| preview = (np.clip(viz_rgb, 0, 1) * 255).astype(np.uint8).copy() |
| cv2.rectangle(preview, (x1, y1), (x2, y2), (255, 0, 0), 2) |
| return preview |
|
|
|
|
| @st.cache_data(show_spinner=False) |
| def load_tiff_from_hf(repo_id, filename, repo_type="space"): |
| return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type) |
|
|
|
|
| def read_rgb_tiff(path_or_bytes): |
| if isinstance(path_or_bytes, (str, os.PathLike)): |
| with rasterio.open(path_or_bytes) as src: |
| rgb_raw = src.read([1, 2, 3]) |
| h_f, w_f = src.height, src.width |
| meta = src.meta.copy() |
| else: |
| with rasterio.open(BytesIO(path_or_bytes)) as src: |
| rgb_raw = src.read([1, 2, 3]) |
| h_f, w_f = src.height, src.width |
| meta = src.meta.copy() |
| return rgb_raw, h_f, w_f, meta |
|
|
|
|
| def read_prior_tiff(path_or_bytes): |
| if isinstance(path_or_bytes, (str, os.PathLike)): |
| with rasterio.open(path_or_bytes) as src: |
| prior_raw = src.read(1).astype(np.float32) |
| meta = src.meta.copy() |
| else: |
| with rasterio.open(BytesIO(path_or_bytes)) as src: |
| prior_raw = src.read(1).astype(np.float32) |
| meta = src.meta.copy() |
| return prior_raw, meta |
|
|
|
|
| def init_roi_state(h_f, w_f): |
| if "x_center" not in st.session_state: |
| st.session_state["x_center"] = w_f // 2 |
| if "y_center" not in st.session_state: |
| st.session_state["y_center"] = h_f // 2 |
| if "bbox_size" not in st.session_state: |
| st.session_state["bbox_size"] = min(200, min(h_f, w_f)) |
| if "use_normalized_rel" not in st.session_state: |
| st.session_state["use_normalized_rel"] = True |
| if "loaded_shape" not in st.session_state: |
| st.session_state["loaded_shape"] = (h_f, w_f) |
|
|
| prev_shape = st.session_state.get("loaded_shape", None) |
| if prev_shape != (h_f, w_f): |
| st.session_state["x_center"] = w_f // 2 |
| st.session_state["y_center"] = h_f // 2 |
| st.session_state["bbox_size"] = min(200, min(h_f, w_f)) |
| st.session_state["use_normalized_rel"] = True |
| st.session_state["loaded_shape"] = (h_f, w_f) |
|
|
|
|
| |
| |
| |
| class MLPDecoder(nn.Module): |
| def __init__(self, in_dim=1024): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(in_dim, 256), |
| nn.LayerNorm(256), |
| nn.GELU(), |
| nn.Linear(256, 128), |
| nn.GELU(), |
| nn.Linear(128, 2) |
| ) |
| nn.init.zeros_(self.net[-1].weight) |
| self.net[-1].bias.data = torch.tensor([1.0, 0.0]) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
|
|
| class LoRALinear(nn.Module): |
| def __init__(self, base_linear, r=8, alpha=16.0): |
| super().__init__() |
| self.base = base_linear |
|
|
| |
| self.base.weight.requires_grad_(False) |
| if getattr(self.base, "bias", None) is not None: |
| self.base.bias.requires_grad_(False) |
|
|
| self.r = r |
| self.alpha = alpha |
| self.scaling = alpha / r if r > 0 else 1.0 |
|
|
| self.A = nn.Linear(base_linear.in_features, r, bias=False) |
| self.B = nn.Linear(r, base_linear.out_features, bias=False) |
|
|
| nn.init.kaiming_uniform_(self.A.weight, a=np.sqrt(5)) |
| nn.init.zeros_(self.B.weight) |
|
|
| @property |
| def in_features(self): |
| return self.base.in_features |
|
|
| @property |
| def out_features(self): |
| return self.base.out_features |
|
|
| @property |
| def weight(self): |
| return self.base.weight |
|
|
| @property |
| def bias(self): |
| return self.base.bias |
|
|
| def forward(self, x): |
| return self.base(x) + self.scaling * self.B(self.A(x)) |
|
|
|
|
| def inject_lora(model, r=8, alpha=16.0): |
| for blk in model.modules(): |
| if hasattr(blk, "attn"): |
| if hasattr(blk.attn, "qkv") and not isinstance(blk.attn.qkv, LoRALinear): |
| blk.attn.qkv = LoRALinear(blk.attn.qkv, r, alpha) |
|
|
| if hasattr(blk.attn, "proj") and not isinstance(blk.attn.proj, LoRALinear): |
| blk.attn.proj = LoRALinear(blk.attn.proj, r, alpha) |
| return model |
|
|
|
|
| def get_lora_params(model): |
| params = [] |
| for module in model.modules(): |
| if isinstance(module, LoRALinear): |
| params.extend(list(module.A.parameters())) |
| params.extend(list(module.B.parameters())) |
| return params |
|
|
|
|
| |
| |
| |
| @st.cache_resource |
| def load_models(repo_id, dav_file, dino_file): |
| |
| dav_path = hf_hub_download(repo_id=repo_id, filename=dav_file) |
| from depth_anything_v2.dpt import DepthAnythingV2 |
|
|
| dav_model = DepthAnythingV2( |
| encoder="vitl", |
| features=256, |
| out_channels=[256, 512, 1024, 1024] |
| ) |
| dav_model.load_state_dict(torch.load(dav_path, map_location="cpu", weights_only=True)) |
| dav_model = dav_model.to(DEVICE).eval() |
|
|
| |
| if hasattr(torch, "_dynamo") and hasattr(torch._dynamo, "config"): |
| orig_config = torch._dynamo.config |
|
|
| class ConfigWrapper: |
| def __getattr__(self, name): |
| return getattr(orig_config, name) |
|
|
| def __setattr__(self, name, value): |
| if name == "accumulated_cache_size_limit": |
| return |
| setattr(orig_config, name, value) |
|
|
| torch._dynamo.config = ConfigWrapper() |
|
|
| |
| dino_path = hf_hub_download(repo_id=repo_id, filename=dino_file) |
| from dinov3.models.vision_transformer import DinoVisionTransformer |
|
|
| dino_model = DinoVisionTransformer( |
| img_size=1024, |
| patch_size=16, |
| embed_dim=1024, |
| depth=24, |
| num_heads=16, |
| qkv_bias=True |
| ).to(DEVICE).eval() |
|
|
| ckpt = torch.load(dino_path, map_location="cpu") |
| if "state_dict" in ckpt: |
| ckpt = ckpt["state_dict"] |
|
|
| clean_ckpt = { |
| k.replace("module.", "").replace("backbone.", "").replace("teacher.backbone.", ""): v |
| for k, v in ckpt.items() |
| } |
| dino_model.load_state_dict(clean_ckpt, strict=False) |
|
|
| return dav_model, dino_model |
|
|
|
|
| |
| |
| |
| @st.cache_data(show_spinner=False) |
| def run_dav_inference(_dav, rgb_raw, h_f, w_f): |
| img_448 = cv2.resize(rgb_raw.transpose(1, 2, 0), (448, 448)) |
|
|
| dav_in = torch.tensor(img_448, device=DEVICE).permute(2, 0, 1).unsqueeze(0).float() / 255.0 |
|
|
| with torch.no_grad(): |
| raw_depth = _dav(dav_in) |
| if isinstance(raw_depth, (list, tuple)): |
| raw_depth = raw_depth[-1] |
|
|
| raw_depth = F.interpolate( |
| raw_depth.unsqueeze(1), |
| size=(h_f, w_f), |
| mode="bilinear", |
| align_corners=False |
| ).squeeze(1) |
|
|
| raw_depth_map = raw_depth[0].detach().float().cpu().numpy() |
|
|
| valid = np.isfinite(raw_depth_map) |
| raw_depth_01 = normalize_01(raw_depth_map, valid) |
| raw_depth_01[~valid] = np.nan |
|
|
| return raw_depth_map, raw_depth_01 |
|
|
|
|
| |
| |
| |
| def run_lora_pipeline( |
| rgb_raw, |
| prior_raw, |
| rel_map, |
| bbox_mask, |
| dino_base, |
| lora_r, |
| lora_alpha, |
| tto_steps, |
| tto_lr |
| ): |
| rgb_cpu = torch.tensor(rgb_raw.astype(np.float32) / 255.0) |
| prior_raw_t = torch.tensor(prior_raw.astype(np.float32)) |
| rel_cpu = torch.tensor(rel_map.astype(np.float32), device=DEVICE) |
|
|
| H, W = prior_raw.shape |
|
|
| |
| anchor_mask_cpu = (~torch.tensor(bbox_mask)) & torch.isfinite(prior_raw_t) & (prior_raw_t != 0) |
| anchor_mask = anchor_mask_cpu.to(DEVICE) |
| prior_gpu = prior_raw_t.to(DEVICE) |
|
|
| dino = copy.deepcopy(dino_base) |
| dino = inject_lora(dino, r=lora_r, alpha=lora_alpha).to(DEVICE).train() |
|
|
| mlp_head = MLPDecoder(in_dim=1024).to(DEVICE).train() |
|
|
| for p in dino.parameters(): |
| p.requires_grad_(False) |
| for p in get_lora_params(dino): |
| p.requires_grad_(True) |
| for p in mlp_head.parameters(): |
| p.requires_grad_(True) |
|
|
| params = list(mlp_head.parameters()) + get_lora_params(dino) |
| opt = torch.optim.AdamW(params, lr=tto_lr) |
|
|
| rgb_tto = normalize(rgb_cpu.unsqueeze(0), IMAGENET_MEAN, IMAGENET_STD).to(DEVICE) |
|
|
| Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE |
| prior_p = F.interpolate(prior_gpu.view(1, 1, H, W), size=(Hp, Wp), mode="bilinear").flatten() |
| rel_p = F.interpolate(rel_cpu.view(1, 1, H, W), size=(Hp, Wp), mode="bilinear").flatten() |
| mask_p = F.interpolate(anchor_mask.float().view(1, 1, H, W), size=(Hp, Wp), mode="area").flatten() > 0.5 |
|
|
| loss_hist = [] |
| prog = st.progress(0, text="Running LoRA TTO...") |
|
|
| for step in range(tto_steps): |
| opt.zero_grad(set_to_none=True) |
|
|
| tokens = dino.forward_features(rgb_tto)["x_norm_patchtokens"].squeeze(0) |
| sb = mlp_head(tokens) |
| s, b = sb[:, 0], sb[:, 1] |
|
|
| pred_p = s * rel_p + b |
| loss = F.huber_loss(pred_p[mask_p], prior_p[mask_p], delta=1.0) |
|
|
| loss.backward() |
| opt.step() |
|
|
| loss_hist.append(float(loss.item())) |
| prog.progress((step + 1) / tto_steps, text=f"Running LoRA TTO... {step + 1}/{tto_steps}") |
|
|
| prog.empty() |
|
|
| dino.eval() |
| mlp_head.eval() |
|
|
| with torch.no_grad(): |
| p, stride = PATCH_SIZE, STRIDE |
| rgb_pad = F.pad(rgb_cpu.unsqueeze(0), (p, p, p, p), mode="reflect") |
| Hp_pad, Wp_pad = rgb_pad.shape[-2:] |
|
|
| sb_acc = torch.zeros((2, Hp_pad // stride, Wp_pad // stride), device=DEVICE) |
| cnt_acc = torch.zeros((1, Hp_pad // stride, Wp_pad // stride), device=DEVICE) |
|
|
| rgb_norm = normalize(rgb_pad, IMAGENET_MEAN, IMAGENET_STD).to(DEVICE) |
|
|
| for dy in range(0, p, stride): |
| for dx in range(0, p, stride): |
| hc = ((Hp_pad - dy) // p) * p |
| wc = ((Wp_pad - dx) // p) * p |
| if hc <= 0 or wc <= 0: |
| continue |
|
|
| patch = rgb_norm[:, :, dy:dy + hc, dx:dx + wc] |
| t = dino.forward_features(patch)["x_norm_patchtokens"].squeeze(0) |
|
|
| sb_local = mlp_head(t).t().reshape(2, hc // p, wc // p) |
|
|
| sb_acc[:, dy // stride:dy // stride + (hc // p) * (p // stride):p // stride, |
| dx // stride:dx // stride + (wc // p) * (p // stride):p // stride] += sb_local |
|
|
| cnt_acc[:, dy // stride:dy // stride + (hc // p) * (p // stride):p // stride, |
| dx // stride:dx // stride + (wc // p) * (p // stride):p // stride] += 1 |
|
|
| sb_dense = sb_acc / (cnt_acc + 1e-8) |
| offset = (p - (p // 2)) // stride + 1 |
| sb_final = sb_dense[:, offset:offset + (H // stride), offset:offset + (W // stride)] |
|
|
| sb_hr = F.interpolate( |
| sb_final.unsqueeze(0), |
| size=(H, W), |
| mode="bilinear", |
| align_corners=False |
| ).squeeze(0) |
|
|
| s_hr, b_hr = sb_hr[0], sb_hr[1] |
| final_dsm = (s_hr * rel_cpu + b_hr).detach().cpu().numpy() |
|
|
| del dino, mlp_head, opt |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return final_dsm, loss_hist, anchor_mask_cpu.cpu().numpy() |
|
|
|
|
| |
| |
| |
| st.title("Prior2DSM | LoRA") |
|
|
| st.markdown( |
| f""" |
| **Example TIFFs** |
| - [Download example RGB TIFF](https://huggingface.co/spaces/osherr/Prior2DSM/resolve/main/{EXAMPLE_RGB_FILENAME}) |
| - [Download example Prior TIFF](https://huggingface.co/spaces/osherr/Prior2DSM/resolve/main/{EXAMPLE_PRIOR_FILENAME}) |
| """ |
| ) |
|
|
| with st.sidebar: |
| st.header("📂 Data") |
|
|
| data_mode = st.radio( |
| "Data source", |
| ["Upload TIFFs", "Use example TIFFs"], |
| index=0 |
| ) |
|
|
| rgb_file = None |
| prior_file = None |
| rgb_example_path = None |
| prior_example_path = None |
|
|
| if data_mode == "Upload TIFFs": |
| rgb_file = st.file_uploader("RGB Image", type=["tif", "tiff"]) |
| prior_file = st.file_uploader("LiDAR Prior", type=["tif", "tiff"]) |
| else: |
| st.caption("Load demo RGB/Prior TIFFs from the Hugging Face Space.") |
| if st.button("Load example TIFFs"): |
| st.session_state["use_examples"] = True |
|
|
| if st.session_state.get("use_examples", False): |
| rgb_example_path = load_tiff_from_hf( |
| repo_id="osherr/Prior2DSM", |
| filename=EXAMPLE_RGB_FILENAME, |
| repo_type="space" |
| ) |
| prior_example_path = load_tiff_from_hf( |
| repo_id="osherr/Prior2DSM", |
| filename=EXAMPLE_PRIOR_FILENAME, |
| repo_type="space" |
| ) |
| st.success("Example TIFFs loaded.") |
|
|
| st.divider() |
| st.write("#### LoRA / TTO") |
| lora_r = st.slider("LoRA rank", 2, 32, 8, step=2) |
| lora_alpha = st.slider("LoRA alpha", 4.0, 64.0, 16.0, step=4.0) |
| tto_steps = st.slider("TTO steps", 10, 300, 100, step=10) |
| tto_lr = st.select_slider("TTO LR", options=[1e-4, 3e-4, 1e-3, 3e-3], value=1e-3) |
|
|
| has_uploaded = (rgb_file is not None and prior_file is not None) |
| has_examples = ( |
| data_mode == "Use example TIFFs" |
| and st.session_state.get("use_examples", False) |
| and rgb_example_path is not None |
| and prior_example_path is not None |
| ) |
|
|
| if has_uploaded or has_examples: |
| dav_m, dino_base = load_models( |
| repo_id="osherr/Prior2DSM", |
| dav_file="depth_anything_v2_vitl.pth", |
| dino_file="dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth" |
| ) |
|
|
| if has_uploaded: |
| rgb_raw, h_f, w_f, _ = read_rgb_tiff(rgb_file.read()) |
| prior_raw, prior_meta = read_prior_tiff(prior_file.read()) |
| else: |
| rgb_raw, h_f, w_f, _ = read_rgb_tiff(rgb_example_path) |
| prior_raw, prior_meta = read_prior_tiff(prior_example_path) |
|
|
| init_roi_state(h_f, w_f) |
|
|
| with st.spinner("Generating relative depth with Depth Anything V2..."): |
| rel_depth_map, rel_depth_01 = run_dav_inference(dav_m, rgb_raw, h_f, w_f) |
|
|
| st.subheader("1. ROI Selection") |
|
|
| viz_rgb = preview_rgb(rgb_raw) |
| col_img, col_ctrl = st.columns([1.2, 0.8]) |
|
|
| with col_ctrl: |
| with st.form("roi_form", clear_on_submit=False): |
| x_center_form = st.slider( |
| "X center", |
| 0, w_f - 1, |
| int(st.session_state["x_center"]) |
| ) |
| y_center_form = st.slider( |
| "Y center", |
| 0, h_f - 1, |
| int(st.session_state["y_center"]) |
| ) |
| bbox_size_form = st.slider( |
| "BBox Size (px)", |
| 50, min(400, min(h_f, w_f)), |
| int(st.session_state["bbox_size"]) |
| ) |
| use_normalized_rel_form = st.checkbox( |
| "Use normalized relative depth for LoRA", |
| value=bool(st.session_state["use_normalized_rel"]) |
| ) |
|
|
| c1, c2 = st.columns(2) |
| with c1: |
| update_roi = st.form_submit_button("Update ROI") |
| with c2: |
| run_btn = st.form_submit_button("🚀 Run LoRA Pipeline", type="primary") |
|
|
| if update_roi or run_btn: |
| st.session_state["x_center"] = x_center_form |
| st.session_state["y_center"] = y_center_form |
| st.session_state["bbox_size"] = bbox_size_form |
| st.session_state["use_normalized_rel"] = use_normalized_rel_form |
| x_center = int(st.session_state["x_center"]) |
| y_center = int(st.session_state["y_center"]) |
| bbox_size = int(st.session_state["bbox_size"]) |
| use_normalized_rel = bool(st.session_state["use_normalized_rel"]) |
|
|
| half_s = bbox_size // 2 |
| x1, x2 = max(0, x_center - half_s), min(w_f, x_center + half_s) |
| y1, y2 = max(0, y_center - half_s), min(h_f, y_center + half_s) |
|
|
| bbox_mask = np.zeros((h_f, w_f), dtype=bool) |
| bbox_mask[y1:y2, x1:x2] = True |
|
|
| with col_img: |
| roi_preview = draw_roi_preview(viz_rgb, x1, y1, x2, y2) |
| st.image(roi_preview, caption="ROI Preview", use_container_width=True) |
|
|
| if run_btn: |
| rel_for_lora = rel_depth_01 if use_normalized_rel else rel_depth_map |
|
|
| with st.spinner("Running LoRA adaptation..."): |
| final_dsm, loss_hist, anchor_mask_np = run_lora_pipeline( |
| rgb_raw=rgb_raw, |
| prior_raw=prior_raw, |
| rel_map=rel_for_lora, |
| bbox_mask=bbox_mask, |
| dino_base=dino_base, |
| lora_r=lora_r, |
| lora_alpha=lora_alpha, |
| tto_steps=tto_steps, |
| tto_lr=tto_lr |
| ) |
|
|
| st.subheader("Results") |
| tab1, tab2, tab3, tab4 = st.tabs( |
| ["Final Result", "Relative Depth", "Loss", "Masks"] |
| ) |
|
|
| with tab1: |
| fig, ax = plt.subplots(1, 3, figsize=(18, 6)) |
|
|
| masked_prior = prior_raw.copy() |
| masked_prior[bbox_mask] = np.nan |
|
|
| ax[0].imshow(viz_rgb) |
| ax[0].add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="red", lw=2)) |
| ax[0].set_title("Input RGB") |
| ax[0].axis("off") |
|
|
| ax[1].set_facecolor("black") |
| im1 = ax[1].imshow(masked_prior, cmap="terrain") |
| ax[1].add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="red", lw=2)) |
| ax[1].set_title("Input LiDAR (BBox Masked)") |
| ax[1].axis("off") |
| plt.colorbar(im1, ax=ax[1], fraction=0.046) |
|
|
| im2 = ax[2].imshow(final_dsm, cmap="terrain") |
| ax[2].add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="red", lw=2)) |
| ax[2].set_title("LoRA Refined DSM") |
| ax[2].axis("off") |
| plt.colorbar(im2, ax=ax[2], fraction=0.046) |
|
|
| st.pyplot(fig) |
|
|
| with tab2: |
| fig_rel, ax_rel = plt.subplots(1, 2, figsize=(12, 5)) |
|
|
| im0 = ax_rel[0].imshow(rel_depth_map, cmap="viridis") |
| ax_rel[0].set_title("Depth Anything Raw Relative Depth") |
| ax_rel[0].axis("off") |
| plt.colorbar(im0, ax=ax_rel[0], fraction=0.046) |
|
|
| im1 = ax_rel[1].imshow(rel_depth_01, cmap="viridis") |
| ax_rel[1].set_title("Normalized Relative Depth") |
| ax_rel[1].axis("off") |
| plt.colorbar(im1, ax=ax_rel[1], fraction=0.046) |
|
|
| st.pyplot(fig_rel) |
|
|
| with tab3: |
| fig_loss, ax_loss = plt.subplots(figsize=(8, 3)) |
| ax_loss.plot(loss_hist) |
| ax_loss.set_title("TTO Huber Loss") |
| ax_loss.set_yscale("log") |
| ax_loss.grid(True, alpha=0.3) |
| st.pyplot(fig_loss) |
|
|
| with tab4: |
| fig_mask, axm = plt.subplots(1, 2, figsize=(10, 4)) |
|
|
| axm[0].imshow(bbox_mask, cmap="gray") |
| axm[0].set_title("Target BBox Mask") |
| axm[0].axis("off") |
|
|
| axm[1].imshow(anchor_mask_np, cmap="gray") |
| axm[1].set_title("Anchor Mask") |
| axm[1].axis("off") |
|
|
| st.pyplot(fig_mask) |
|
|
| out_buf = BytesIO() |
| prior_meta.update({ |
| "driver": "GTiff", |
| "height": h_f, |
| "width": w_f, |
| "dtype": "float32", |
| "count": 1 |
| }) |
|
|
| with rasterio.open(out_buf, "w", **prior_meta) as dst: |
| dst.write(final_dsm.astype(np.float32), 1) |
|
|
| st.download_button( |
| "Download Georeferenced DSM", |
| out_buf.getvalue(), |
| file_name="lora_refined_dsm_georef.tif" |
| ) |
| else: |
| st.info("Upload RGB and Prior TIFFs, or switch to example TIFFs in the sidebar.") |