DAP-weights / test /infer.py
Insta360-Research's picture
Upload 372 files
f4d2177 verified
from __future__ import absolute_import, division, print_function
import os
import sys
import cv2
import torch
import yaml
import argparse
import numpy as np
import torch.nn as nn
from tqdm import tqdm
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(PROJECT_ROOT)
from networks.models import make # 建议用 make,而不是 import *
import matplotlib
def colorize_depth_fixed(depth_u8: np.ndarray, cmap: str = "Spectral") -> np.ndarray:
"""
depth_u8: uint8, 0~255
return: RGB uint8
"""
disp = depth_u8.astype(np.float32) / 255.0
colored = matplotlib.colormaps[cmap](disp)[..., :3]
colored = (colored * 255).astype(np.uint8)
return np.ascontiguousarray(colored)
def ensure_dir_for_file(path: str):
os.makedirs(os.path.dirname(path), exist_ok=True)
def load_model(config):
model_path = os.path.join(config["load_weights_dir"], "model.pth")
print(f"🔹 Loading model weights from: {model_path}")
device = "cuda" if torch.cuda.is_available() else "cpu"
state = torch.load(model_path, map_location=device)
m = make(config["model"])
if any(k.startswith("module") for k in state.keys()):
m = nn.DataParallel(m)
m = m.to(device)
m_state = m.state_dict()
m.load_state_dict({k: v for k, v in state.items() if k in m_state}, strict=False)
m.eval()
print("✅ Model loaded successfully.\n")
return m, device
def infer_raw(model, device, img_rgb_u8: np.ndarray) -> np.ndarray:
"""
img_rgb_u8: HWC uint8 RGB
return: pred float32 (H,W)
"""
img = img_rgb_u8.astype(np.float32) / 255.0
tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(device)
with torch.inference_mode():
outputs = model(tensor)
if isinstance(outputs, dict) and "pred_depth" in outputs:
if "pred_mask" in outputs:
mask = 1 - outputs["pred_mask"]
mask = mask > 0.5
outputs["pred_depth"][~mask] = 1
pred = outputs["pred_depth"][0].detach().cpu().squeeze().numpy()
else:
pred = outputs[0].detach().cpu().squeeze().numpy()
return pred.astype(np.float32)
def pred_to_vis(pred: np.ndarray, vis_range: str = "100m", cmap: str = "Spectral"):
"""
return:
depth_gray_u8: (H,W) uint8
depth_color_rgb: (H,W,3) uint8 RGB
"""
if vis_range == "100m":
pred_clip = np.clip(pred, 0.0, 1.0)
depth_gray = (pred_clip * 255).astype(np.uint8)
elif vis_range == "10m":
pred_clip = np.clip(pred, 0.0, 0.1)
depth_gray = (pred_clip * 10.0 * 255).astype(np.uint8)
else:
raise ValueError(f"Unknown vis_range: {vis_range} (use '100m' or '10m')")
depth_color = colorize_depth_fixed(depth_gray, cmap=cmap)
return depth_gray, depth_color
def infer_and_save(model, device, img_path, out_root, idx, vis_range="100m", cmap="Spectral"):
img_bgr = cv2.imread(img_path)
if img_bgr is None:
print(f"⚠️ Cannot read image: {img_path}")
return
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
pred = infer_raw(model, device, img_rgb)
depth_gray, depth_color_rgb = pred_to_vis(pred, vis_range=vis_range, cmap=cmap)
filename = f"{idx:06d}"
pred_npy_path = os.path.join(out_root, "depth_npy", filename + ".npy")
gray_png_path = os.path.join(out_root, f"depth_vis_gray_{vis_range}", filename + ".png")
color_png_path = os.path.join(out_root, f"depth_vis_color_{vis_range}", filename + ".png")
ensure_dir_for_file(pred_npy_path)
ensure_dir_for_file(gray_png_path)
ensure_dir_for_file(color_png_path)
np.save(pred_npy_path, pred)
cv2.imwrite(gray_png_path, depth_gray)
cv2.imwrite(color_png_path, cv2.cvtColor(depth_color_rgb, cv2.COLOR_RGB2BGR))
def main(config_path, txt_path, out_root, vis_range="100m", cmap="Spectral"):
with open(config_path, "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
print("✅ Config loaded.")
model, device = load_model(config)
with open(txt_path, "r") as f:
img_list = [l.strip() for l in f.readlines() if l.strip()]
print(f"🔹 Total images to infer: {len(img_list)}")
print(f"🔹 Visualization: {vis_range}, colormap: {cmap}\n")
for idx, img_path in enumerate(tqdm(img_list, desc="Inferencing"), start=1):
infer_and_save(model, device, img_path, out_root, idx, vis_range=vis_range, cmap=cmap)
print("\n🎯 推理完成!")
print(f" depth npy: {out_root}/depth_npy")
print(f" depth gray png: {out_root}/depth_vis_gray_{vis_range}")
print(f" depth color png: {out_root}/depth_vis_color_{vis_range}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", default="config/infer.yaml")
parser.add_argument("--txt", default="datasets/test.txt")
parser.add_argument("--output", default="test_output")
parser.add_argument("--gpu", default="0", help="使用的GPU编号")
parser.add_argument("--vis", default="100m", choices=["100m", "10m"], help="可视化范围(只影响png,不影响npy)")
parser.add_argument("--cmap", default="Spectral", help="matplotlib colormap name, e.g. Spectral, Turbo, Viridis")
args = parser.parse_args()
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
main(args.config, args.txt, args.output, vis_range=args.vis, cmap=args.cmap)