| | from __future__ import annotations |
| |
|
| | import argparse |
| | import json |
| |
|
| | import pandas as pd |
| | import torch |
| |
|
| | from .model import PointNetRegressor |
| | from .preprocess import prepare_points |
| | from .utils import get_device |
| |
|
| |
|
| | def load_points_from_csv( |
| | csv_path: str, |
| | x_col: str = "x", |
| | y_col: str = "y", |
| | z_col: str = "z", |
| | ) -> torch.Tensor: |
| | df = pd.read_csv(csv_path, usecols=[x_col, y_col, z_col]) |
| | for c in [x_col, y_col, z_col]: |
| | df[c] = pd.to_numeric(df[c], errors="coerce") |
| | df = df.dropna(subset=[x_col, y_col, z_col]) |
| | points = df[[x_col, y_col, z_col]].to_numpy(dtype="float32") |
| | return points |
| |
|
| |
|
| | def load_model(checkpoint_path: str, device: str | None = None): |
| | device = device or get_device() |
| | checkpoint = torch.load(checkpoint_path, map_location=device) |
| |
|
| | input_dim = int(checkpoint.get("input_dim", 3)) |
| | model = PointNetRegressor(input_dim=input_dim) |
| | model.load_state_dict(checkpoint["model_state"]) |
| | model.to(device) |
| | model.eval() |
| | return model, checkpoint, device |
| |
|
| |
|
| | @torch.no_grad() |
| | def predict( |
| | csv_path: str, |
| | checkpoint_path: str, |
| | num_points: int | None = None, |
| | device: str | None = None, |
| | ) -> dict: |
| | model, checkpoint, device = load_model(checkpoint_path, device=device) |
| |
|
| | if num_points is None: |
| | num_points = int(checkpoint.get("num_points", 2048)) |
| |
|
| | points = load_points_from_csv(csv_path) |
| | points = prepare_points(points, num_points=num_points) |
| |
|
| | x = torch.from_numpy(points).transpose(0, 1).unsqueeze(0).to(device) |
| | pred = model(x).item() |
| |
|
| | return { |
| | "free_space_score": float(pred), |
| | "num_points_used": int(num_points), |
| | "device": device, |
| | } |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Inference for TerrainFreeSpaceNet") |
| | parser.add_argument("--input_csv", type=str, required=True, help="Input CSV path") |
| | parser.add_argument("--checkpoint", type=str, required=True, help="Checkpoint path") |
| | parser.add_argument("--num_points", type=int, default=None, help="Override checkpoint num_points") |
| | args = parser.parse_args() |
| |
|
| | result = predict( |
| | csv_path=args.input_csv, |
| | checkpoint_path=args.checkpoint, |
| | num_points=args.num_points, |
| | ) |
| | print(json.dumps(result, indent=2)) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |