Other
English
Shanci's picture
Upload folder using huggingface_hub
28a0087 verified
import os
import laspy
import torch
import logging
import numpy as np
from tqdm import tqdm
@torch.no_grad()
def run_test_inference(model, test_loader, config, save_dir):
model.eval()
device = config["training"]["device"]
current_zone_name = None
coords_buffer = []
preds_buffer = []
offset_buffer = None
for batch in tqdm(test_loader, desc="Running inference"):
x, coords, offset, zone_name = batch
x = x[0].to(device)
coords = coords[0].numpy().astype(np.float64)
offset_buffer = offset[0].numpy().astype(np.float64)
zone_name = zone_name[0]
if current_zone_name is not None and zone_name != current_zone_name:
write_las(current_zone_name, coords_buffer, preds_buffer, offset_buffer, save_dir)
coords_buffer.clear()
preds_buffer.clear()
current_zone_name = zone_name
logits = model(x)
preds = logits.argmax(dim=1).cpu().numpy()
coords = coords + offset_buffer
coords_buffer.append(coords)
preds_buffer.append(preds)
# last batch
if current_zone_name is not None and coords_buffer:
write_las(current_zone_name, coords_buffer, preds_buffer, offset_buffer, save_dir)
logging.info("All las files written.")
def write_las(zone_name, coords_list, preds_list, offset, save_dir):
coords = np.vstack(coords_list)
preds = np.concatenate(preds_list)
header = laspy.LasHeader(point_format=3, version="1.2")
header.offsets = offset
header.scales = np.array([0.001, 0.001, 0.001]) # 1 mm
las = laspy.LasData(header)
las.x, las.y, las.z = coords[:, 0], coords[:, 1], coords[:, 2]
las.add_extra_dim(laspy.ExtraBytesParams(name="classif", type=np.uint8, description="Predicted class"))
las.classif = preds.astype(np.uint8)
out_path = os.path.join(save_dir, f"{zone_name}_prediction.las")
las.write(out_path)
logging.info(f"Saved {zone_name} ({len(coords)} points) in {out_path}")