import argparse import os import torch import numpy as np import random import logging from unish.utils.inference_utils import * def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True def setup_logging(output_dir): os.makedirs(output_dir, exist_ok=True) # Create logger logger = logging.getLogger() logger.setLevel(logging.INFO) # Create handlers c_handler = logging.StreamHandler() f_handler = logging.FileHandler(os.path.join(output_dir, 'inference.log'), mode='w') # Create formatters and add it to handlers c_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') f_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') c_handler.setFormatter(c_format) f_handler.setFormatter(f_format) # Add handlers to the logger logger.addHandler(c_handler) logger.addHandler(f_handler) return logger def main(): parser = argparse.ArgumentParser(description="Video Inference Script") parser.add_argument("--video_path", type=str, required=True, help="Path to the input video file or directory containing images") parser.add_argument("--fps", type=float, default=6.0, help="Target FPS for frame extraction (default: 6.0)") parser.add_argument("--original_fps", type=float, default=30.0, help="Original FPS of the image sequence (default: 30.0, used only for directory input)") parser.add_argument("--target_size", type=int, default=518, help="Target size for frame processing (default: 518)") parser.add_argument("--checkpoint", type=str, default="checkpoints/unish_release.safetensors", help="Path to the model checkpoint") parser.add_argument("--output_dir", type=str, default="inference_results_video", help="Output directory for results") parser.add_argument("--body_models_path", type=str, default="body_models/", help="Path to SMPL body models") parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on") parser.add_argument("--save_results", action="store_true", default=True, help="Save additional results including smpl_points_for_camera (default: True)") parser.add_argument("--chunk_size", type=int, default=30, help="Number of frames to process in each chunk during inference (default: 30)") parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use for inference (default: 0)") parser.add_argument("--camera_mode", type=str, default="fixed", choices=["predicted", "fixed"], help="Camera mode: 'predicted' uses model-predicted camera parameters, " "'fixed' uses a fixed camera angle (default: predicted)") parser.add_argument("--human_idx", type=int, default=0, help="Human index to process (default: 0)") parser.add_argument("--start_idx", type=int, default=None, help="Start frame index for processing (default: None, process from beginning)") parser.add_argument("--end_idx", type=int, default=None, help="End frame index for processing (default: None, process to end)") parser.add_argument("--bbox_scale", type=float, default=1.0, help="Scale factor for bounding box size (default: 1.0)") parser.add_argument("--conf_thres", type=float, default=0.1, help="Confidence threshold for point cloud generation (default: 0.1)") # New arguments parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") parser.add_argument("--yolo_ckpt", type=str, default="ckpts/yolo11n.pt", help="Path to YOLO checkpoint") parser.add_argument("--sam2_model", type=str, default="facebook/sam2-hiera-large", help="SAM2 model name or path") args = parser.parse_args() # Setup seed setup_seed(args.seed) # Setup logging logger = setup_logging(args.output_dir) # Setup device if torch.cuda.is_available(): if args.device == "cuda": # Use specified GPU ID device = torch.device(f"cuda:{args.gpu_id}") # Set the current CUDA device torch.cuda.set_device(args.gpu_id) logger.info( f"Using GPU {args.gpu_id}: {torch.cuda.get_device_name(args.gpu_id)}") else: device = torch.device(args.device) else: device = torch.device("cpu") logger.info("CUDA not available, using CPU") logger.info(f"Using device: {device}") # Load model logger.info("Loading model...") model = load_model(args.checkpoint) model = model.to(device) model.eval() # Process video logger.info(f"Processing video: {args.video_path}") data_dict = process_video( args.video_path, args.fps, args.human_idx, args.target_size, bbox_scale=args.bbox_scale, start_idx=args.start_idx, end_idx=args.end_idx, original_fps=args.original_fps, yolo_ckpt=args.yolo_ckpt, sam2_model=args.sam2_model ) # Run inference results = run_inference(model, data_dict, device, args.chunk_size) # Create output directory os.makedirs(args.output_dir, exist_ok=True) viz_scene_point_clouds, viz_smpl_meshes, viz_scene_only_point_clouds, smpl_points_for_camera = generate_mixed_geometries_in_memory( results, args.body_models_path, fps=args.fps, conf_thres=args.conf_thres ) # Determine camera mode based on arguments use_predicted_camera = (args.camera_mode == "predicted") logger.info(f"Using {args.camera_mode} camera mode") original_rgb_images = results['rgb_images'] if original_rgb_images is not None: if hasattr(original_rgb_images, 'permute'): # It's a torch tensor original_rgb_images = original_rgb_images.permute( 0, 2, 3, 1).cpu().numpy() # [S, H, W, 3] elif not isinstance(original_rgb_images, np.ndarray): original_rgb_images = np.array(original_rgb_images) # Ensure proper data type and range if original_rgb_images.max() <= 1.0: original_rgb_images = ( original_rgb_images * 255).astype(np.uint8) original_human_boxes = data_dict['human_boxes'] run_visualization(viz_scene_point_clouds, viz_smpl_meshes, smpl_points_for_camera, args.output_dir, results['seq_name'], fps=args.fps, # Use original fps rgb_images=original_rgb_images, human_boxes=original_human_boxes, chunk_size=args.chunk_size, # Use original chunk size results=results, use_predicted_camera=use_predicted_camera, scene_only_point_clouds=viz_scene_only_point_clouds, conf_thres=args.conf_thres) if args.save_results: logger.info("Creating SMPL meshes per frame...") save_smpl_meshes_per_frame( results, args.output_dir, args.body_models_path) logger.info("Saving scene point clouds (without human)...") save_scene_only_point_clouds( viz_scene_only_point_clouds, args.output_dir, results['seq_name']) logger.info("Saving human point clouds...") save_human_point_clouds(viz_scene_point_clouds, viz_scene_only_point_clouds, args.output_dir, results['seq_name'], results) logger.info("Saving camera parameters per frame...") save_camera_parameters_per_frame( results, args.output_dir, results['seq_name']) logger.info(f"Inference completed! Results saved to {args.output_dir}") if __name__ == "__main__": main()