""" Inference pipeline for Owl IDM models. Example usage (local): pipeline = InferencePipeline.from_pretrained( config_path="configs/simple.yml", checkpoint_path="checkpoints/simple/ema/step_50000.pt" ) Example usage (HF Hub): pipeline = InferencePipeline.from_pretrained( "username/owl-idm-simple-v0" ) # video: [b, n, c, h, w] tensor in range [-1, 1] wasd_preds, mouse_preds = pipeline(video) """ import torch import os from pathlib import Path from tqdm import tqdm from owl_idms.configs import load_config from owl_idms.models import get_model_cls class InferencePipeline: """ Inference pipeline for IDM models with sliding window prediction. """ def __init__(self, model, config, device='cuda', compile_model=True): """ Initialize the inference pipeline. Args: model: The IDM model config: Full config object (with train and model sections) device: Device to run inference on (default: 'cuda') compile_model: Whether to compile the model with torch.compile (default: True) """ self.config = config self.device = device self.window_length = config.train.window_length self.use_log1p_scaling = getattr(config.train, 'use_log1p_scaling', True) # Move model to device, convert to bfloat16, and set to eval mode self.model = model.to(device=device, dtype=torch.bfloat16) self.model.eval() # Compile for faster inference if compile_model: print("Compiling model for inference...") self.model = torch.compile(self.model, mode='max-autotune') print("Model compiled!") @classmethod def from_pretrained(cls, model_id_or_path, checkpoint_path=None, device='cuda', compile_model=True, token=None): """ Load a pretrained model from local files or Hugging Face Hub. Args: model_id_or_path: Either: - HF Hub repo ID (e.g., "username/owl-idm-simple-v0") - Local path to config YAML file checkpoint_path: Path to checkpoint .pt file (only for local loading) device: Device to run inference on (default: 'cuda') compile_model: Whether to compile the model (default: True) token: HF API token (optional, for private repos) Returns: InferencePipeline instance ready for inference Examples: # Load from HF Hub pipeline = InferencePipeline.from_pretrained("username/owl-idm-simple-v0") # Load from local files pipeline = InferencePipeline.from_pretrained( "configs/simple.yml", checkpoint_path="checkpoints/simple/ema/step_17100.pt" ) """ # Check if loading from HF Hub or local is_local = os.path.exists(model_id_or_path) or model_id_or_path.endswith('.yml') if is_local: # Local loading if checkpoint_path is None: raise ValueError("checkpoint_path required when loading from local files") config_path = model_id_or_path print(f"Loading from local files...") print(f" Config: {config_path}") print(f" Checkpoint: {checkpoint_path}") else: # HF Hub loading try: from huggingface_hub import hf_hub_download except ImportError: raise ImportError( "huggingface_hub is required to load from HF Hub. " "Install with: pip install huggingface_hub" ) print(f"Loading from Hugging Face Hub: {model_id_or_path}") # Download config and checkpoint config_path = hf_hub_download( repo_id=model_id_or_path, filename="config.yml", token=token ) checkpoint_path = hf_hub_download( repo_id=model_id_or_path, filename="model.pt", token=token ) print(f"✓ Downloaded files from HF Hub") # Load config config = load_config(config_path) # Initialize model from config model_cls = get_model_cls(config.model.model_id) model = model_cls(config.model) # Load checkpoint weights print(f"Loading checkpoint...") checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) model.load_state_dict(checkpoint) print(f"✓ Checkpoint loaded successfully!") # Create and return pipeline return cls(model, config, device=device, compile_model=compile_model) @torch.no_grad() def __call__(self, videos, window_size=None, show_progress=True): """ Run inference on videos using sliding window. Args: videos: Input videos of shape [b, n, c, h, w] in range [-1, 1] window_size: Override window size (default: use config.train.window_length) show_progress: Whether to show progress bar (default: True) Returns: Tuple of (wasd_predictions, mouse_predictions) where: - wasd_predictions: shape [b, n, 4] with boolean values - mouse_predictions: shape [b, n, 2] with float values (raw scale) """ if window_size is None: window_size = self.window_length b, n, c, h, w = videos.shape # Move to device and convert to bfloat16 videos = videos.to(device=self.device, dtype=torch.bfloat16) # Calculate middle index for predictions middle_idx = (window_size - 1) // 2 # Calculate padding needed pad_start = middle_idx pad_end = window_size - 1 - middle_idx # Pad videos by duplicating first and last frames first_frame = videos[:, 0:1].expand(-1, pad_start, -1, -1, -1) last_frame = videos[:, -1:].expand(-1, pad_end, -1, -1, -1) padded_videos = torch.cat([first_frame, videos, last_frame], dim=1) # Sliding window inference wasd_preds = [] mouse_preds = [] iterator = range(n) if show_progress: iterator = tqdm(iterator, desc="Running inference") for i in iterator: # Extract window window = padded_videos[:, i:i+window_size] # Model inference wasd_logits, mouse_pred = self.model(window) # Clone tensors to avoid CUDA graph memory reuse issues wasd_preds.append(wasd_logits.clone()) mouse_preds.append(mouse_pred.clone()) # Stack predictions: [n, b, ...] -> [b, n, ...] wasd_preds = torch.stack(wasd_preds, dim=1) mouse_preds = torch.stack(mouse_preds, dim=1) # Convert WASD logits to boolean predictions wasd_preds = torch.sigmoid(wasd_preds) > 0.5 # Convert mouse predictions from log1p space if needed if self.use_log1p_scaling: mouse_preds = torch.sign(mouse_preds) * torch.expm1(torch.abs(mouse_preds)) return wasd_preds, mouse_preds if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Run inference with Owl IDM") parser.add_argument("--config", type=str, required=True, help="Path to config YAML") parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint .pt file") parser.add_argument("--video", type=str, help="Path to video file (optional, for testing)") parser.add_argument("--device", type=str, default="cuda", help="Device to use") parser.add_argument("--no-compile", action="store_true", help="Disable model compilation") args = parser.parse_args() # Load pipeline pipeline = InferencePipeline.from_pretrained( args.config, args.checkpoint, device=args.device, compile_model=not args.no_compile ) print(f"\nPipeline ready!") print(f" Window length: {pipeline.window_length}") print(f" Log1p scaling: {pipeline.use_log1p_scaling}") print(f" Device: {pipeline.device}") if args.video: print(f"\nLoading video from {args.video}...") # TODO: Add video loading code print("Video inference not yet implemented in main()") else: print("\nNo video provided. Use --video to run inference on a video file.") print("Example: python inference.py --config configs/simple.yml --checkpoint checkpoints/simple/ema/step_50000.pt")