Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,268 Bytes
ae166e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import argparse
import os
import torch
import numpy as np
import random
import logging
from unish.utils.inference_utils import *
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def setup_logging(output_dir):
os.makedirs(output_dir, exist_ok=True)
# Create logger
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# Create handlers
c_handler = logging.StreamHandler()
f_handler = logging.FileHandler(os.path.join(output_dir, 'inference.log'), mode='w')
# Create formatters and add it to handlers
c_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
f_format = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
c_handler.setFormatter(c_format)
f_handler.setFormatter(f_format)
# Add handlers to the logger
logger.addHandler(c_handler)
logger.addHandler(f_handler)
return logger
def main():
parser = argparse.ArgumentParser(description="Video Inference Script")
parser.add_argument("--video_path", type=str, required=True,
help="Path to the input video file or directory containing images")
parser.add_argument("--fps", type=float, default=6.0,
help="Target FPS for frame extraction (default: 6.0)")
parser.add_argument("--original_fps", type=float, default=30.0,
help="Original FPS of the image sequence (default: 30.0, used only for directory input)")
parser.add_argument("--target_size", type=int, default=518,
help="Target size for frame processing (default: 518)")
parser.add_argument("--checkpoint", type=str, default="checkpoints/unish_release.safetensors",
help="Path to the model checkpoint")
parser.add_argument("--output_dir", type=str, default="inference_results_video",
help="Output directory for results")
parser.add_argument("--body_models_path", type=str, default="body_models/",
help="Path to SMPL body models")
parser.add_argument("--device", type=str, default="cuda",
help="Device to run inference on")
parser.add_argument("--save_results", action="store_true", default=True,
help="Save additional results including smpl_points_for_camera (default: True)")
parser.add_argument("--chunk_size", type=int, default=30,
help="Number of frames to process in each chunk during inference (default: 30)")
parser.add_argument("--gpu_id", type=int, default=0,
help="GPU ID to use for inference (default: 0)")
parser.add_argument("--camera_mode", type=str, default="fixed",
choices=["predicted", "fixed"],
help="Camera mode: 'predicted' uses model-predicted camera parameters, "
"'fixed' uses a fixed camera angle (default: predicted)")
parser.add_argument("--human_idx", type=int, default=0,
help="Human index to process (default: 0)")
parser.add_argument("--start_idx", type=int, default=None,
help="Start frame index for processing (default: None, process from beginning)")
parser.add_argument("--end_idx", type=int, default=None,
help="End frame index for processing (default: None, process to end)")
parser.add_argument("--bbox_scale", type=float, default=1.0,
help="Scale factor for bounding box size (default: 1.0)")
parser.add_argument("--conf_thres", type=float, default=0.1,
help="Confidence threshold for point cloud generation (default: 0.1)")
# New arguments
parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility")
parser.add_argument("--yolo_ckpt", type=str, default="ckpts/yolo11n.pt", help="Path to YOLO checkpoint")
parser.add_argument("--sam2_model", type=str, default="facebook/sam2-hiera-large", help="SAM2 model name or path")
args = parser.parse_args()
# Setup seed
setup_seed(args.seed)
# Setup logging
logger = setup_logging(args.output_dir)
# Setup device
if torch.cuda.is_available():
if args.device == "cuda":
# Use specified GPU ID
device = torch.device(f"cuda:{args.gpu_id}")
# Set the current CUDA device
torch.cuda.set_device(args.gpu_id)
logger.info(
f"Using GPU {args.gpu_id}: {torch.cuda.get_device_name(args.gpu_id)}")
else:
device = torch.device(args.device)
else:
device = torch.device("cpu")
logger.info("CUDA not available, using CPU")
logger.info(f"Using device: {device}")
# Load model
logger.info("Loading model...")
model = load_model(args.checkpoint)
model = model.to(device)
model.eval()
# Process video
logger.info(f"Processing video: {args.video_path}")
data_dict = process_video(
args.video_path, args.fps, args.human_idx, args.target_size,
bbox_scale=args.bbox_scale, start_idx=args.start_idx, end_idx=args.end_idx,
original_fps=args.original_fps,
yolo_ckpt=args.yolo_ckpt, sam2_model=args.sam2_model
)
# Run inference
results = run_inference(model, data_dict, device, args.chunk_size)
# Create output directory
os.makedirs(args.output_dir, exist_ok=True)
viz_scene_point_clouds, viz_smpl_meshes, viz_scene_only_point_clouds, smpl_points_for_camera = generate_mixed_geometries_in_memory(
results, args.body_models_path, fps=args.fps, conf_thres=args.conf_thres
)
# Determine camera mode based on arguments
use_predicted_camera = (args.camera_mode == "predicted")
logger.info(f"Using {args.camera_mode} camera mode")
original_rgb_images = results['rgb_images']
if original_rgb_images is not None:
if hasattr(original_rgb_images, 'permute'): # It's a torch tensor
original_rgb_images = original_rgb_images.permute(
0, 2, 3, 1).cpu().numpy() # [S, H, W, 3]
elif not isinstance(original_rgb_images, np.ndarray):
original_rgb_images = np.array(original_rgb_images)
# Ensure proper data type and range
if original_rgb_images.max() <= 1.0:
original_rgb_images = (
original_rgb_images * 255).astype(np.uint8)
original_human_boxes = data_dict['human_boxes']
run_visualization(viz_scene_point_clouds, viz_smpl_meshes, smpl_points_for_camera,
args.output_dir, results['seq_name'],
fps=args.fps, # Use original fps
rgb_images=original_rgb_images,
human_boxes=original_human_boxes,
chunk_size=args.chunk_size, # Use original chunk size
results=results,
use_predicted_camera=use_predicted_camera,
scene_only_point_clouds=viz_scene_only_point_clouds,
conf_thres=args.conf_thres)
if args.save_results:
logger.info("Creating SMPL meshes per frame...")
save_smpl_meshes_per_frame(
results, args.output_dir, args.body_models_path)
logger.info("Saving scene point clouds (without human)...")
save_scene_only_point_clouds(
viz_scene_only_point_clouds, args.output_dir, results['seq_name'])
logger.info("Saving human point clouds...")
save_human_point_clouds(viz_scene_point_clouds,
viz_scene_only_point_clouds, args.output_dir, results['seq_name'], results)
logger.info("Saving camera parameters per frame...")
save_camera_parameters_per_frame(
results, args.output_dir, results['seq_name'])
logger.info(f"Inference completed! Results saved to {args.output_dir}")
if __name__ == "__main__":
main()
|