Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import os | |
| import torch | |
| import numpy as np | |
| import random | |
| import logging | |
| from unish.utils.inference_utils import * | |
| def setup_seed(seed): | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| torch.backends.cudnn.deterministic = True | |
| def setup_logging(output_dir): | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Create logger | |
| logger = logging.getLogger() | |
| logger.setLevel(logging.INFO) | |
| # Create handlers | |
| c_handler = logging.StreamHandler() | |
| f_handler = logging.FileHandler(os.path.join(output_dir, 'inference.log'), mode='w') | |
| # Create formatters and add it to handlers | |
| c_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
| f_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
| c_handler.setFormatter(c_format) | |
| f_handler.setFormatter(f_format) | |
| # Add handlers to the logger | |
| logger.addHandler(c_handler) | |
| logger.addHandler(f_handler) | |
| return logger | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Video Inference Script") | |
| parser.add_argument("--video_path", type=str, required=True, | |
| help="Path to the input video file or directory containing images") | |
| parser.add_argument("--fps", type=float, default=6.0, | |
| help="Target FPS for frame extraction (default: 6.0)") | |
| parser.add_argument("--original_fps", type=float, default=30.0, | |
| help="Original FPS of the image sequence (default: 30.0, used only for directory input)") | |
| parser.add_argument("--target_size", type=int, default=518, | |
| help="Target size for frame processing (default: 518)") | |
| parser.add_argument("--checkpoint", type=str, default="checkpoints/unish_release.safetensors", | |
| help="Path to the model checkpoint") | |
| parser.add_argument("--output_dir", type=str, default="inference_results_video", | |
| help="Output directory for results") | |
| parser.add_argument("--body_models_path", type=str, default="body_models/", | |
| help="Path to SMPL body models") | |
| parser.add_argument("--device", type=str, default="cuda", | |
| help="Device to run inference on") | |
| parser.add_argument("--save_results", action="store_true", default=True, | |
| help="Save additional results including smpl_points_for_camera (default: True)") | |
| parser.add_argument("--chunk_size", type=int, default=30, | |
| help="Number of frames to process in each chunk during inference (default: 30)") | |
| parser.add_argument("--gpu_id", type=int, default=0, | |
| help="GPU ID to use for inference (default: 0)") | |
| parser.add_argument("--camera_mode", type=str, default="fixed", | |
| choices=["predicted", "fixed"], | |
| help="Camera mode: 'predicted' uses model-predicted camera parameters, " | |
| "'fixed' uses a fixed camera angle (default: predicted)") | |
| parser.add_argument("--human_idx", type=int, default=0, | |
| help="Human index to process (default: 0)") | |
| parser.add_argument("--start_idx", type=int, default=None, | |
| help="Start frame index for processing (default: None, process from beginning)") | |
| parser.add_argument("--end_idx", type=int, default=None, | |
| help="End frame index for processing (default: None, process to end)") | |
| parser.add_argument("--bbox_scale", type=float, default=1.0, | |
| help="Scale factor for bounding box size (default: 1.0)") | |
| parser.add_argument("--conf_thres", type=float, default=0.1, | |
| help="Confidence threshold for point cloud generation (default: 0.1)") | |
| # New arguments | |
| parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility") | |
| parser.add_argument("--yolo_ckpt", type=str, default="ckpts/yolo11n.pt", help="Path to YOLO checkpoint") | |
| parser.add_argument("--sam2_model", type=str, default="facebook/sam2-hiera-large", help="SAM2 model name or path") | |
| args = parser.parse_args() | |
| # Setup seed | |
| setup_seed(args.seed) | |
| # Setup logging | |
| logger = setup_logging(args.output_dir) | |
| # Setup device | |
| if torch.cuda.is_available(): | |
| if args.device == "cuda": | |
| # Use specified GPU ID | |
| device = torch.device(f"cuda:{args.gpu_id}") | |
| # Set the current CUDA device | |
| torch.cuda.set_device(args.gpu_id) | |
| logger.info( | |
| f"Using GPU {args.gpu_id}: {torch.cuda.get_device_name(args.gpu_id)}") | |
| else: | |
| device = torch.device(args.device) | |
| else: | |
| device = torch.device("cpu") | |
| logger.info("CUDA not available, using CPU") | |
| logger.info(f"Using device: {device}") | |
| # Load model | |
| logger.info("Loading model...") | |
| model = load_model(args.checkpoint) | |
| model = model.to(device) | |
| model.eval() | |
| # Process video | |
| logger.info(f"Processing video: {args.video_path}") | |
| data_dict = process_video( | |
| args.video_path, args.fps, args.human_idx, args.target_size, | |
| bbox_scale=args.bbox_scale, start_idx=args.start_idx, end_idx=args.end_idx, | |
| original_fps=args.original_fps, | |
| yolo_ckpt=args.yolo_ckpt, sam2_model=args.sam2_model | |
| ) | |
| # Run inference | |
| results = run_inference(model, data_dict, device, args.chunk_size) | |
| # Create output directory | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| viz_scene_point_clouds, viz_smpl_meshes, viz_scene_only_point_clouds, smpl_points_for_camera = generate_mixed_geometries_in_memory( | |
| results, args.body_models_path, fps=args.fps, conf_thres=args.conf_thres | |
| ) | |
| # Determine camera mode based on arguments | |
| use_predicted_camera = (args.camera_mode == "predicted") | |
| logger.info(f"Using {args.camera_mode} camera mode") | |
| original_rgb_images = results['rgb_images'] | |
| if original_rgb_images is not None: | |
| if hasattr(original_rgb_images, 'permute'): # It's a torch tensor | |
| original_rgb_images = original_rgb_images.permute( | |
| 0, 2, 3, 1).cpu().numpy() # [S, H, W, 3] | |
| elif not isinstance(original_rgb_images, np.ndarray): | |
| original_rgb_images = np.array(original_rgb_images) | |
| # Ensure proper data type and range | |
| if original_rgb_images.max() <= 1.0: | |
| original_rgb_images = ( | |
| original_rgb_images * 255).astype(np.uint8) | |
| original_human_boxes = data_dict['human_boxes'] | |
| run_visualization(viz_scene_point_clouds, viz_smpl_meshes, smpl_points_for_camera, | |
| args.output_dir, results['seq_name'], | |
| fps=args.fps, # Use original fps | |
| rgb_images=original_rgb_images, | |
| human_boxes=original_human_boxes, | |
| chunk_size=args.chunk_size, # Use original chunk size | |
| results=results, | |
| use_predicted_camera=use_predicted_camera, | |
| scene_only_point_clouds=viz_scene_only_point_clouds, | |
| conf_thres=args.conf_thres) | |
| if args.save_results: | |
| logger.info("Creating SMPL meshes per frame...") | |
| save_smpl_meshes_per_frame( | |
| results, args.output_dir, args.body_models_path) | |
| logger.info("Saving scene point clouds (without human)...") | |
| save_scene_only_point_clouds( | |
| viz_scene_only_point_clouds, args.output_dir, results['seq_name']) | |
| logger.info("Saving human point clouds...") | |
| save_human_point_clouds(viz_scene_point_clouds, | |
| viz_scene_only_point_clouds, args.output_dir, results['seq_name'], results) | |
| logger.info("Saving camera parameters per frame...") | |
| save_camera_parameters_per_frame( | |
| results, args.output_dir, results['seq_name']) | |
| logger.info(f"Inference completed! Results saved to {args.output_dir}") | |
| if __name__ == "__main__": | |
| main() | |