File size: 8,633 Bytes
8942ffb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 | """
Inference pipeline for Owl IDM models.
Example usage (local):
pipeline = InferencePipeline.from_pretrained(
config_path="configs/simple.yml",
checkpoint_path="checkpoints/simple/ema/step_50000.pt"
)
Example usage (HF Hub):
pipeline = InferencePipeline.from_pretrained(
"username/owl-idm-simple-v0"
)
# video: [b, n, c, h, w] tensor in range [-1, 1]
wasd_preds, mouse_preds = pipeline(video)
"""
import torch
import os
from pathlib import Path
from tqdm import tqdm
from owl_idms.configs import load_config
from owl_idms.models import get_model_cls
class InferencePipeline:
"""
Inference pipeline for IDM models with sliding window prediction.
"""
def __init__(self, model, config, device='cuda', compile_model=True):
"""
Initialize the inference pipeline.
Args:
model: The IDM model
config: Full config object (with train and model sections)
device: Device to run inference on (default: 'cuda')
compile_model: Whether to compile the model with torch.compile (default: True)
"""
self.config = config
self.device = device
self.window_length = config.train.window_length
self.use_log1p_scaling = getattr(config.train, 'use_log1p_scaling', True)
# Move model to device, convert to bfloat16, and set to eval mode
self.model = model.to(device=device, dtype=torch.bfloat16)
self.model.eval()
# Compile for faster inference
if compile_model:
print("Compiling model for inference...")
self.model = torch.compile(self.model, mode='max-autotune')
print("Model compiled!")
@classmethod
def from_pretrained(cls, model_id_or_path, checkpoint_path=None, device='cuda', compile_model=True, token=None):
"""
Load a pretrained model from local files or Hugging Face Hub.
Args:
model_id_or_path: Either:
- HF Hub repo ID (e.g., "username/owl-idm-simple-v0")
- Local path to config YAML file
checkpoint_path: Path to checkpoint .pt file (only for local loading)
device: Device to run inference on (default: 'cuda')
compile_model: Whether to compile the model (default: True)
token: HF API token (optional, for private repos)
Returns:
InferencePipeline instance ready for inference
Examples:
# Load from HF Hub
pipeline = InferencePipeline.from_pretrained("username/owl-idm-simple-v0")
# Load from local files
pipeline = InferencePipeline.from_pretrained(
"configs/simple.yml",
checkpoint_path="checkpoints/simple/ema/step_17100.pt"
)
"""
# Check if loading from HF Hub or local
is_local = os.path.exists(model_id_or_path) or model_id_or_path.endswith('.yml')
if is_local:
# Local loading
if checkpoint_path is None:
raise ValueError("checkpoint_path required when loading from local files")
config_path = model_id_or_path
print(f"Loading from local files...")
print(f" Config: {config_path}")
print(f" Checkpoint: {checkpoint_path}")
else:
# HF Hub loading
try:
from huggingface_hub import hf_hub_download
except ImportError:
raise ImportError(
"huggingface_hub is required to load from HF Hub. "
"Install with: pip install huggingface_hub"
)
print(f"Loading from Hugging Face Hub: {model_id_or_path}")
# Download config and checkpoint
config_path = hf_hub_download(
repo_id=model_id_or_path,
filename="config.yml",
token=token
)
checkpoint_path = hf_hub_download(
repo_id=model_id_or_path,
filename="model.pt",
token=token
)
print(f"✓ Downloaded files from HF Hub")
# Load config
config = load_config(config_path)
# Initialize model from config
model_cls = get_model_cls(config.model.model_id)
model = model_cls(config.model)
# Load checkpoint weights
print(f"Loading checkpoint...")
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
model.load_state_dict(checkpoint)
print(f"✓ Checkpoint loaded successfully!")
# Create and return pipeline
return cls(model, config, device=device, compile_model=compile_model)
@torch.no_grad()
def __call__(self, videos, window_size=None, show_progress=True):
"""
Run inference on videos using sliding window.
Args:
videos: Input videos of shape [b, n, c, h, w] in range [-1, 1]
window_size: Override window size (default: use config.train.window_length)
show_progress: Whether to show progress bar (default: True)
Returns:
Tuple of (wasd_predictions, mouse_predictions) where:
- wasd_predictions: shape [b, n, 4] with boolean values
- mouse_predictions: shape [b, n, 2] with float values (raw scale)
"""
if window_size is None:
window_size = self.window_length
b, n, c, h, w = videos.shape
# Move to device and convert to bfloat16
videos = videos.to(device=self.device, dtype=torch.bfloat16)
# Calculate middle index for predictions
middle_idx = (window_size - 1) // 2
# Calculate padding needed
pad_start = middle_idx
pad_end = window_size - 1 - middle_idx
# Pad videos by duplicating first and last frames
first_frame = videos[:, 0:1].expand(-1, pad_start, -1, -1, -1)
last_frame = videos[:, -1:].expand(-1, pad_end, -1, -1, -1)
padded_videos = torch.cat([first_frame, videos, last_frame], dim=1)
# Sliding window inference
wasd_preds = []
mouse_preds = []
iterator = range(n)
if show_progress:
iterator = tqdm(iterator, desc="Running inference")
for i in iterator:
# Extract window
window = padded_videos[:, i:i+window_size]
# Model inference
wasd_logits, mouse_pred = self.model(window)
# Clone tensors to avoid CUDA graph memory reuse issues
wasd_preds.append(wasd_logits.clone())
mouse_preds.append(mouse_pred.clone())
# Stack predictions: [n, b, ...] -> [b, n, ...]
wasd_preds = torch.stack(wasd_preds, dim=1)
mouse_preds = torch.stack(mouse_preds, dim=1)
# Convert WASD logits to boolean predictions
wasd_preds = torch.sigmoid(wasd_preds) > 0.5
# Convert mouse predictions from log1p space if needed
if self.use_log1p_scaling:
mouse_preds = torch.sign(mouse_preds) * torch.expm1(torch.abs(mouse_preds))
return wasd_preds, mouse_preds
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run inference with Owl IDM")
parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint .pt file")
parser.add_argument("--video", type=str, help="Path to video file (optional, for testing)")
parser.add_argument("--device", type=str, default="cuda", help="Device to use")
parser.add_argument("--no-compile", action="store_true", help="Disable model compilation")
args = parser.parse_args()
# Load pipeline
pipeline = InferencePipeline.from_pretrained(
args.config,
args.checkpoint,
device=args.device,
compile_model=not args.no_compile
)
print(f"\nPipeline ready!")
print(f" Window length: {pipeline.window_length}")
print(f" Log1p scaling: {pipeline.use_log1p_scaling}")
print(f" Device: {pipeline.device}")
if args.video:
print(f"\nLoading video from {args.video}...")
# TODO: Add video loading code
print("Video inference not yet implemented in main()")
else:
print("\nNo video provided. Use --video to run inference on a video file.")
print("Example: python inference.py --config configs/simple.yml --checkpoint checkpoints/simple/ema/step_50000.pt")
|