""" PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation Official implementation of the paper: "PRIMA: Boosting Animal Mesh Recovery with Biological Priors and Test-Time Adaptation" by Xiaohang Yu, Ti Wang, and Mackenzie Weygandt Mathis Licensed under a modified MIT license """ """ Split acinoset multiview_mapping.json into train and test sets (7:3 ratio). Usage: python split_acinoset.py \ --input_json /path/to/multiview_mapping.json \ --output_dir /path/to/output \ --train_ratio 0.7 \ --seed 42 """ import argparse import json import random from pathlib import Path from collections import defaultdict # ------------------------------------------------------------------ # EDIT THIS to point to your dataset root (see examples above). # All paths below are relative to this directory. # ------------------------------------------------------------------ BASE_DIR = Path("datasets") def split_multiview_data(input_json, output_dir, train_ratio=0.7, seed=42): """ Split multiview mapping data into train and test sets. Args: input_json: Path to multiview_mapping.json output_dir: Directory to save train.json and test.json train_ratio: Ratio of training data (default 0.7 for 70%%) train_ratio: Ratio of training data (default 0.7 for 70%%) seed: Random seed for reproducibility """ # Set random seed random.seed(seed) # Load data print(f"Loading data from {input_json}...") with open(input_json, 'r') as f: data = json.load(f) # Initialize train and test splits train_data = defaultdict(dict) test_data = defaultdict(dict) # Process each behavior for behavior, frames in data.items(): print(f"\nProcessing behavior: {behavior}") # Get all frame indices frame_indices = list(frames.keys()) total_frames = len(frame_indices) # Shuffle frame indices random.shuffle(frame_indices) # Calculate split point train_size = int(total_frames * train_ratio) # Split frames train_frames = frame_indices[:train_size] test_frames = frame_indices[train_size:] print(f" Total frames: {total_frames}") print(f" Train frames: {len(train_frames)}") print(f" Test frames: {len(test_frames)}") # Assign to train and test for frame_idx in train_frames: train_data[behavior][frame_idx] = frames[frame_idx] for frame_idx in test_frames: test_data[behavior][frame_idx] = frames[frame_idx] # Save train and test splits output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) train_json = output_dir / "train.json" test_json = output_dir / "test.json" print(f"\nSaving train data to {train_json}...") with open(train_json, 'w') as f: json.dump(dict(train_data), f, indent=4) print(f"Saving test data to {test_json}...") with open(test_json, 'w') as f: json.dump(dict(test_data), f, indent=4) # Print summary print("\n" + "="*50) print("Summary:") print("="*50) total_train_frames = sum(len(frames) for frames in train_data.values()) total_test_frames = sum(len(frames) for frames in test_data.values()) total_frames = total_train_frames + total_test_frames print(f"Total frames: {total_frames}") print(f"Train frames: {total_train_frames} ({total_train_frames/total_frames*100:.1f}%%)") print(f"Test frames: {total_test_frames} ({total_test_frames/total_frames*100:.1f}%%)") print("\nPer behavior:") for behavior in train_data.keys(): train_count = len(train_data[behavior]) test_count = len(test_data[behavior]) total_count = train_count + test_count print(f" {behavior}: train={train_count}, test={test_count}, total={total_count}") print("\nDone!") if __name__ == "__main__": parser = argparse.ArgumentParser( description="Split multiview_mapping.json into train/test sets (default 7:3)." ) parser.add_argument( "--input_json", type=str, default="datasets/acinoset/multiview_mapping.json", help="Path to multiview_mapping.json (default: datasets/acinoset/multiview_mapping.json)." ) parser.add_argument( "--output_dir", type=str, default="datasets/acinoset", help="Directory to save train.json and test.json (default: datasets/acinoset)." ) parser.add_argument( "--train_ratio", type=float, default=0.7, help="Fraction of data for training (default: 0.7)." ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for reproducibility (default: 42)." ) args = parser.parse_args() split_multiview_data( input_json=args.input_json, output_dir=args.output_dir, train_ratio=args.train_ratio, seed=args.seed, )