| | import os |
| | import sys |
| | import json |
| | import gzip |
| | import argparse |
| | import numpy as np |
| | from PIL import Image |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torchvision |
| | from einops import rearrange |
| | from lpips import LPIPS |
| |
|
| | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| |
|
| | from src.model.model.anysplat import AnySplat |
| | from src.model.encoder.vggt.utils.pose_enc import pose_encoding_to_extri_intri |
| | from src.model.encoder.vggt.utils.load_fn import load_and_preprocess_images |
| | from src.utils.pose import align_to_first_camera, calculate_auc_np, convert_pt3d_RT_to_opencv, se3_to_relative_pose_error |
| | from src.misc.cam_utils import camera_normalization, pose_auc, rotation_6d_to_matrix, update_pose, get_pnp_pose |
| |
|
| | def setup_args(): |
| | """Set up command-line arguments for the CO3D evaluation script.""" |
| | parser = argparse.ArgumentParser(description='Test AnySplat on CO3D dataset') |
| | parser.add_argument('--debug', action='store_true', help='Enable debug mode (only test on specific category)') |
| | parser.add_argument('--use_ba', action='store_true', default=False, help='Enable bundle adjustment') |
| | parser.add_argument('--fast_eval', action='store_true', default=False, help='Only evaluate 10 sequences per category') |
| | parser.add_argument('--min_num_images', type=int, default=50, help='Minimum number of images for a sequence') |
| | parser.add_argument('--num_frames', type=int, default=10, help='Number of frames to use for testing') |
| | parser.add_argument('--co3d_dir', type=str, required=True, help='Path to CO3D dataset') |
| | parser.add_argument('--co3d_anno_dir', type=str, required=True, help='Path to CO3D annotations') |
| | parser.add_argument('--seed', type=int, default=0, help='Random seed for reproducibility') |
| | return parser.parse_args() |
| |
|
| | lpips = LPIPS(net="vgg") |
| |
|
| | def rendering_loss(pred_image, image): |
| | lpips_loss = lpips.forward(rearrange(pred_image, "b v c h w -> (b v) c h w"), rearrange(image, "b v c h w -> (b v) c h w"), normalize=True) |
| | delta = pred_image - (image + 1) / 2 |
| | mse_loss = (delta**2).mean() |
| | return mse_loss + 0.05 * lpips_loss.mean() |
| |
|
| | def process_sequence(model, seq_name, seq_data, category, co3d_dir, min_num_images, num_frames, use_ba, device, dtype): |
| | """ |
| | Process a single sequence and compute pose errors. |
| | |
| | Args: |
| | model: AnySplat model |
| | seq_name: Sequence name |
| | seq_data: Sequence data |
| | category: Category name |
| | co3d_dir: CO3D dataset directory |
| | min_num_images: Minimum number of images required |
| | num_frames: Number of frames to sample |
| | use_ba: Whether to use bundle adjustment |
| | device: Device to run on |
| | dtype: Data type for model inference |
| | |
| | Returns: |
| | rError: Rotation errors |
| | tError: Translation errors |
| | """ |
| | if len(seq_data) < min_num_images: |
| | return None, None |
| | |
| | metadata = [] |
| | for data in seq_data: |
| | |
| | if data["T"][0] + data["T"][1] + data["T"][2] > 1e5: |
| | return None, None |
| |
|
| | extri_opencv = convert_pt3d_RT_to_opencv(data["R"], data["T"]) |
| | metadata.append({ |
| | "filepath": data["filepath"], |
| | "extri": extri_opencv, |
| | }) |
| |
|
| | ids = np.random.choice(len(metadata), num_frames, replace=False) |
| | image_names = [os.path.join(co3d_dir, metadata[i]["filepath"]) for i in ids] |
| | gt_extri = [np.array(metadata[i]["extri"]) for i in ids] |
| | gt_extri = np.stack(gt_extri, axis=0) |
| | |
| | max_size = max(Image.open(image_names[0]).size) |
| | if max_size < 448: |
| | return None, None |
| | images = load_and_preprocess_images(image_names)[None].to(device) |
| | |
| | batch = { |
| | "context": { |
| | "image": images*2.0-1, |
| | "image_names": image_names, |
| | "index": ids, |
| | }, |
| | "scene": "co3d" |
| | } |
| | |
| | if use_ba: |
| | try: |
| | encoder_output = model.encoder( |
| | batch, |
| | global_step=0, |
| | visualization_dump={}, |
| | ) |
| | gaussians, pred_context_pose = encoder_output.gaussians, encoder_output.pred_context_pose |
| | pred_extrinsic = pred_context_pose['extrinsic'] |
| | pred_intrinsic = pred_context_pose['intrinsic'] |
| | |
| | b, v, _, h, w = images.shape |
| | with torch.set_grad_enabled(True), torch.cuda.amp.autocast(enabled=False, dtype=torch.float32): |
| | cam_rot_delta = nn.Parameter(torch.zeros([b, v, 6], requires_grad=True, device=pred_extrinsic.device, dtype=torch.float32)) |
| | cam_trans_delta = nn.Parameter(torch.zeros([b, v, 3], requires_grad=True, device=pred_extrinsic.device, dtype=torch.float32)) |
| | opt_params = [] |
| | model.register_buffer("identity", torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], dtype=torch.float32).to(pred_extrinsic.device)) |
| | opt_params.append( |
| | { |
| | "params": [cam_rot_delta], |
| | "lr": 0.005, |
| | } |
| | ) |
| | opt_params.append( |
| | { |
| | "params": [cam_trans_delta], |
| | "lr": 0.005, |
| | } |
| | ) |
| | pose_optimizer = torch.optim.Adam(opt_params) |
| | extrinsics = pred_extrinsic.clone().float() |
| |
|
| | for i in range(100): |
| | pose_optimizer.zero_grad() |
| | dx, drot = cam_trans_delta, cam_rot_delta |
| | rot = rotation_6d_to_matrix( |
| | drot + model.identity.expand(b, v, -1) |
| | ) |
| |
|
| | transform = torch.eye(4, device=extrinsics.device).repeat((b, v, 1, 1)) |
| | transform[..., :3, :3] = rot |
| | transform[..., :3, 3] = dx |
| |
|
| | new_extrinsics = torch.matmul(extrinsics, transform) |
| | |
| | output = model.decoder.forward( |
| | gaussians, |
| | new_extrinsics, |
| | pred_intrinsic.float(), |
| | 0.1, |
| | 100.0, |
| | (h, w), |
| | |
| | |
| | ) |
| | |
| | rendering_loss = rendering_loss(output.color, images*2.0-1) |
| | torchvision.utils.save_image(output.color[0], f"outputs/vis/output_co3d_{i}.png") |
| | print(f"Rendering loss: {rendering_loss.item()}") |
| | |
| |
|
| | rendering_loss.backward() |
| | pose_optimizer.step() |
| | torchvision.utils.save_image(images[0], f"outputs/vis/gt_co3d.png") |
| | pred_extrinsic = new_extrinsics.inverse()[0][:,:-1,:] |
| |
|
| | except Exception as e: |
| | print(f"BA failed with error: {e}. Falling back to standard VGGT inference.") |
| | with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): |
| | aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(images, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx) |
| | with torch.cuda.amp.autocast(dtype=torch.float32): |
| | fp32_tokens = [token.float() for token in aggregated_tokens_list] |
| | pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1] |
| | pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, images.shape[-2:]) |
| | pred_extrinsic = pred_all_extrinsic[0] |
| | else: |
| | with torch.no_grad(), torch.cuda.amp.autocast(dtype=dtype): |
| | aggregated_tokens_list, patch_start_idx = model.encoder.aggregator(images, intermediate_layer_idx=model.encoder.cfg.intermediate_layer_idx) |
| | with torch.cuda.amp.autocast(dtype=torch.float32): |
| | fp32_tokens = [token.float() for token in aggregated_tokens_list] |
| | pred_all_pose_enc = model.encoder.camera_head(fp32_tokens)[-1] |
| | pred_all_extrinsic, pred_all_intrinsic = pose_encoding_to_extri_intri(pred_all_pose_enc, images.shape[-2:]) |
| | pred_extrinsic = pred_all_extrinsic[0] |
| |
|
| | with torch.cuda.amp.autocast(dtype=torch.float32): |
| | gt_extrinsic = torch.from_numpy(gt_extri).to(device) |
| | add_row = torch.tensor([0, 0, 0, 1], device=device).expand(pred_extrinsic.size(0), 1, 4) |
| |
|
| | pred_se3 = torch.cat((pred_extrinsic, add_row), dim=1) |
| | gt_se3 = torch.cat((gt_extrinsic, add_row), dim=1) |
| |
|
| | |
| | |
| | |
| | gt_se3 = align_to_first_camera(gt_se3) |
| |
|
| | rel_rangle_deg, rel_tangle_deg = se3_to_relative_pose_error(pred_se3, gt_se3, num_frames) |
| | print(f"{category} sequence {seq_name} Rot Error: {rel_rangle_deg.mean().item():.4f}") |
| | print(f"{category} sequence {seq_name} Trans Error: {rel_tangle_deg.mean().item():.4f}") |
| |
|
| | return rel_rangle_deg.cpu().numpy(), rel_tangle_deg.cpu().numpy() |
| |
|
| | def evaluate(args: argparse.Namespace): |
| | model = AnySplat.from_pretrained("lhjiang/anysplat") |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | model.to(device) |
| | model.eval() |
| | for param in model.parameters(): |
| | param.requires_grad = False |
| |
|
| | |
| | SEEN_CATEGORIES = [ |
| | "apple", "backpack", "banana", "baseballbat", "baseballglove", |
| | "bench", "bicycle", "bottle", "bowl", "broccoli", |
| | "cake", "car", "carrot", "cellphone", "chair", |
| | "cup", "donut", "hairdryer", "handbag", "hydrant", |
| | "keyboard", "laptop", "microwave", "motorcycle", "mouse", |
| | "orange", "parkingmeter", "pizza", "plant", "stopsign", |
| | "teddybear", "toaster", "toilet", "toybus", "toyplane", |
| | "toytrain", "toytruck", "tv", "umbrella", "vase", "wineglass", |
| | ] |
| | |
| | if args.debug: |
| | SEEN_CATEGORIES = ["apple"] |
| | |
| | per_category_results = {} |
| |
|
| | for category in SEEN_CATEGORIES: |
| | print(f"Loading annotation for {category} test set") |
| | annotation_file = os.path.join(args.co3d_anno_dir, f"{category}_test.jgz") |
| | |
| | try: |
| | with gzip.open(annotation_file, "r") as fin: |
| | annotation = json.loads(fin.read()) |
| | except FileNotFoundError: |
| | print(f"Annotation file not found for {category}, skipping") |
| | continue |
| | |
| | rError = [] |
| | tError = [] |
| |
|
| | for seq_name, seq_data in annotation.items(): |
| | print("-" * 50) |
| | |
| | print(f"Processing {seq_name} for {category} test set") |
| | if args.debug and not os.path.exists(os.path.join(args.co3d_dir, category, seq_name)): |
| | print(f"Skipping {seq_name} (not found)") |
| | continue |
| | |
| | seq_rError, seq_tError = process_sequence( |
| | model, seq_name, seq_data, category, args.co3d_dir, |
| | args.min_num_images, args.num_frames, args.use_ba, device, torch.bfloat16 |
| | ) |
| | |
| | print("-" * 50) |
| | |
| | if seq_rError is not None and seq_tError is not None: |
| | rError.extend(seq_rError) |
| | tError.extend(seq_tError) |
| |
|
| | if not rError: |
| | print(f"No valid sequences found for {category}, skipping") |
| | continue |
| |
|
| | rError = np.array(rError) |
| | tError = np.array(tError) |
| | |
| | thresholds = [5, 10, 20, 30] |
| | Aucs = {} |
| | |
| | for threshold in thresholds: |
| | Auc, _ = calculate_auc_np(rError, tError, max_threshold=threshold) |
| | Aucs[threshold] = Auc |
| | |
| | print("="*80) |
| | print(f"AUC of {category} test set: {Aucs[30]:.4f}") |
| | print("="*80) |
| | |
| | per_category_results[category] = { |
| | "rError": rError, |
| | "tError": tError, |
| | "Auc_5": Aucs[5], |
| | "Auc_10": Aucs[10], |
| | "Auc_20": Aucs[20], |
| | "Auc_30": Aucs[30], |
| | } |
| |
|
| | |
| | print("\nSummary of AUC results:") |
| | print("-"*50) |
| | for category in sorted(per_category_results.keys()): |
| | print(f"{category:<15} AUC_5: {per_category_results[category]['Auc_5']:.4f}") |
| | print(f"{category:<15} AUC_30: {per_category_results[category]['Auc_30']:.4f}") |
| | print(f"{category:<15} AUC_20: {per_category_results[category]['Auc_20']:.4f}") |
| | print(f"{category:<15} AUC_10: {per_category_results[category]['Auc_10']:.4f}") |
| |
|
| | if per_category_results: |
| | mean_AUC_30 = np.mean([per_category_results[category]["Auc_30"] for category in per_category_results]) |
| | mean_AUC_20 = np.mean([per_category_results[category]["Auc_20"] for category in per_category_results]) |
| | mean_AUC_10 = np.mean([per_category_results[category]["Auc_10"] for category in per_category_results]) |
| | mean_AUC_5 = np.mean([per_category_results[category]["Auc_5"] for category in per_category_results]) |
| | print("-"*50) |
| | print(f"Mean AUC_5: {mean_AUC_5:.4f}") |
| | print(f"Mean AUC_30: {mean_AUC_30:.4f}") |
| | print(f"Mean AUC_20: {mean_AUC_20:.4f}") |
| | print(f"Mean AUC_10: {mean_AUC_10:.4f}") |
| | |
| | |
| | |
| | |
| | import datetime |
| | timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
| | random_index = timestamp |
| | results_file = f"co3d_results_{random_index}.txt" |
| |
|
| | with open(results_file, "w") as f: |
| | f.write("CO3D Evaluation Results\n") |
| | f.write("=" * 50 + "\n\n") |
| | |
| | f.write("Per-category results:\n") |
| | f.write("-" * 50 + "\n") |
| | for category in sorted(per_category_results.keys()): |
| | f.write(f"{category:<15} AUC_30: {per_category_results[category]['Auc_30']:.4f}\n") |
| | f.write(f"{category:<15} AUC_20: {per_category_results[category]['Auc_20']:.4f}\n") |
| | f.write(f"{category:<15} AUC_10: {per_category_results[category]['Auc_10']:.4f}\n") |
| | f.write(f"{category:<15} AUC_5: {per_category_results[category]['Auc_5']:.4f}\n") |
| | f.write("\n") |
| | |
| | if per_category_results: |
| | f.write("-" * 50 + "\n") |
| | f.write(f"Mean AUC_30: {mean_AUC_30:.4f}\n") |
| | f.write(f"Mean AUC_20: {mean_AUC_20:.4f}\n") |
| | f.write(f"Mean AUC_10: {mean_AUC_10:.4f}\n") |
| | f.write(f"Mean AUC_5: {mean_AUC_5:.4f}\n") |
| | f.write("\n" + "=" * 50 + "\n") |
| | |
| | print(f"Results saved to {results_file}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = setup_args() |
| | evaluate(args) |
| |
|