owl-idm-v0-tiny / inference.py
shahbuland's picture
Upload folder using huggingface_hub
8942ffb verified
"""
Inference pipeline for Owl IDM models.
Example usage (local):
pipeline = InferencePipeline.from_pretrained(
config_path="configs/simple.yml",
checkpoint_path="checkpoints/simple/ema/step_50000.pt"
)
Example usage (HF Hub):
pipeline = InferencePipeline.from_pretrained(
"username/owl-idm-simple-v0"
)
# video: [b, n, c, h, w] tensor in range [-1, 1]
wasd_preds, mouse_preds = pipeline(video)
"""
import torch
import os
from pathlib import Path
from tqdm import tqdm
from owl_idms.configs import load_config
from owl_idms.models import get_model_cls
class InferencePipeline:
"""
Inference pipeline for IDM models with sliding window prediction.
"""
def __init__(self, model, config, device='cuda', compile_model=True):
"""
Initialize the inference pipeline.
Args:
model: The IDM model
config: Full config object (with train and model sections)
device: Device to run inference on (default: 'cuda')
compile_model: Whether to compile the model with torch.compile (default: True)
"""
self.config = config
self.device = device
self.window_length = config.train.window_length
self.use_log1p_scaling = getattr(config.train, 'use_log1p_scaling', True)
# Move model to device, convert to bfloat16, and set to eval mode
self.model = model.to(device=device, dtype=torch.bfloat16)
self.model.eval()
# Compile for faster inference
if compile_model:
print("Compiling model for inference...")
self.model = torch.compile(self.model, mode='max-autotune')
print("Model compiled!")
@classmethod
def from_pretrained(cls, model_id_or_path, checkpoint_path=None, device='cuda', compile_model=True, token=None):
"""
Load a pretrained model from local files or Hugging Face Hub.
Args:
model_id_or_path: Either:
- HF Hub repo ID (e.g., "username/owl-idm-simple-v0")
- Local path to config YAML file
checkpoint_path: Path to checkpoint .pt file (only for local loading)
device: Device to run inference on (default: 'cuda')
compile_model: Whether to compile the model (default: True)
token: HF API token (optional, for private repos)
Returns:
InferencePipeline instance ready for inference
Examples:
# Load from HF Hub
pipeline = InferencePipeline.from_pretrained("username/owl-idm-simple-v0")
# Load from local files
pipeline = InferencePipeline.from_pretrained(
"configs/simple.yml",
checkpoint_path="checkpoints/simple/ema/step_17100.pt"
)
"""
# Check if loading from HF Hub or local
is_local = os.path.exists(model_id_or_path) or model_id_or_path.endswith('.yml')
if is_local:
# Local loading
if checkpoint_path is None:
raise ValueError("checkpoint_path required when loading from local files")
config_path = model_id_or_path
print(f"Loading from local files...")
print(f" Config: {config_path}")
print(f" Checkpoint: {checkpoint_path}")
else:
# HF Hub loading
try:
from huggingface_hub import hf_hub_download
except ImportError:
raise ImportError(
"huggingface_hub is required to load from HF Hub. "
"Install with: pip install huggingface_hub"
)
print(f"Loading from Hugging Face Hub: {model_id_or_path}")
# Download config and checkpoint
config_path = hf_hub_download(
repo_id=model_id_or_path,
filename="config.yml",
token=token
)
checkpoint_path = hf_hub_download(
repo_id=model_id_or_path,
filename="model.pt",
token=token
)
print(f"✓ Downloaded files from HF Hub")
# Load config
config = load_config(config_path)
# Initialize model from config
model_cls = get_model_cls(config.model.model_id)
model = model_cls(config.model)
# Load checkpoint weights
print(f"Loading checkpoint...")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint)
print(f"✓ Checkpoint loaded successfully!")
# Create and return pipeline
return cls(model, config, device=device, compile_model=compile_model)
@torch.no_grad()
def __call__(self, videos, window_size=None, show_progress=True):
"""
Run inference on videos using sliding window.
Args:
videos: Input videos of shape [b, n, c, h, w] in range [-1, 1]
window_size: Override window size (default: use config.train.window_length)
show_progress: Whether to show progress bar (default: True)
Returns:
Tuple of (wasd_predictions, mouse_predictions) where:
- wasd_predictions: shape [b, n, 4] with boolean values
- mouse_predictions: shape [b, n, 2] with float values (raw scale)
"""
if window_size is None:
window_size = self.window_length
b, n, c, h, w = videos.shape
# Move to device and convert to bfloat16
videos = videos.to(device=self.device, dtype=torch.bfloat16)
# Calculate middle index for predictions
middle_idx = (window_size - 1) // 2
# Calculate padding needed
pad_start = middle_idx
pad_end = window_size - 1 - middle_idx
# Pad videos by duplicating first and last frames
first_frame = videos[:, 0:1].expand(-1, pad_start, -1, -1, -1)
last_frame = videos[:, -1:].expand(-1, pad_end, -1, -1, -1)
padded_videos = torch.cat([first_frame, videos, last_frame], dim=1)
# Sliding window inference
wasd_preds = []
mouse_preds = []
iterator = range(n)
if show_progress:
iterator = tqdm(iterator, desc="Running inference")
for i in iterator:
# Extract window
window = padded_videos[:, i:i+window_size]
# Model inference
wasd_logits, mouse_pred = self.model(window)
# Clone tensors to avoid CUDA graph memory reuse issues
wasd_preds.append(wasd_logits.clone())
mouse_preds.append(mouse_pred.clone())
# Stack predictions: [n, b, ...] -> [b, n, ...]
wasd_preds = torch.stack(wasd_preds, dim=1)
mouse_preds = torch.stack(mouse_preds, dim=1)
# Convert WASD logits to boolean predictions
wasd_preds = torch.sigmoid(wasd_preds) > 0.5
# Convert mouse predictions from log1p space if needed
if self.use_log1p_scaling:
mouse_preds = torch.sign(mouse_preds) * torch.expm1(torch.abs(mouse_preds))
return wasd_preds, mouse_preds
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run inference with Owl IDM")
parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint .pt file")
parser.add_argument("--video", type=str, help="Path to video file (optional, for testing)")
parser.add_argument("--device", type=str, default="cuda", help="Device to use")
parser.add_argument("--no-compile", action="store_true", help="Disable model compilation")
args = parser.parse_args()
# Load pipeline
pipeline = InferencePipeline.from_pretrained(
args.config,
args.checkpoint,
device=args.device,
compile_model=not args.no_compile
)
print(f"\nPipeline ready!")
print(f" Window length: {pipeline.window_length}")
print(f" Log1p scaling: {pipeline.use_log1p_scaling}")
print(f" Device: {pipeline.device}")
if args.video:
print(f"\nLoading video from {args.video}...")
# TODO: Add video loading code
print("Video inference not yet implemented in main()")
else:
print("\nNo video provided. Use --video to run inference on a video file.")
print("Example: python inference.py --config configs/simple.yml --checkpoint checkpoints/simple/ema/step_50000.pt")