|
|
""" |
|
|
Inference pipeline for Owl IDM models. |
|
|
|
|
|
Example usage (local): |
|
|
pipeline = InferencePipeline.from_pretrained( |
|
|
config_path="configs/simple.yml", |
|
|
checkpoint_path="checkpoints/simple/ema/step_50000.pt" |
|
|
) |
|
|
|
|
|
Example usage (HF Hub): |
|
|
pipeline = InferencePipeline.from_pretrained( |
|
|
"username/owl-idm-simple-v0" |
|
|
) |
|
|
|
|
|
# video: [b, n, c, h, w] tensor in range [-1, 1] |
|
|
wasd_preds, mouse_preds = pipeline(video) |
|
|
""" |
|
|
import torch |
|
|
import os |
|
|
from pathlib import Path |
|
|
from tqdm import tqdm |
|
|
|
|
|
from owl_idms.configs import load_config |
|
|
from owl_idms.models import get_model_cls |
|
|
|
|
|
|
|
|
class InferencePipeline: |
|
|
""" |
|
|
Inference pipeline for IDM models with sliding window prediction. |
|
|
""" |
|
|
|
|
|
def __init__(self, model, config, device='cuda', compile_model=True): |
|
|
""" |
|
|
Initialize the inference pipeline. |
|
|
|
|
|
Args: |
|
|
model: The IDM model |
|
|
config: Full config object (with train and model sections) |
|
|
device: Device to run inference on (default: 'cuda') |
|
|
compile_model: Whether to compile the model with torch.compile (default: True) |
|
|
""" |
|
|
self.config = config |
|
|
self.device = device |
|
|
self.window_length = config.train.window_length |
|
|
self.use_log1p_scaling = getattr(config.train, 'use_log1p_scaling', True) |
|
|
|
|
|
|
|
|
self.model = model.to(device=device, dtype=torch.bfloat16) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
if compile_model: |
|
|
print("Compiling model for inference...") |
|
|
self.model = torch.compile(self.model, mode='max-autotune') |
|
|
print("Model compiled!") |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, model_id_or_path, checkpoint_path=None, device='cuda', compile_model=True, token=None): |
|
|
""" |
|
|
Load a pretrained model from local files or Hugging Face Hub. |
|
|
|
|
|
Args: |
|
|
model_id_or_path: Either: |
|
|
- HF Hub repo ID (e.g., "username/owl-idm-simple-v0") |
|
|
- Local path to config YAML file |
|
|
checkpoint_path: Path to checkpoint .pt file (only for local loading) |
|
|
device: Device to run inference on (default: 'cuda') |
|
|
compile_model: Whether to compile the model (default: True) |
|
|
token: HF API token (optional, for private repos) |
|
|
|
|
|
Returns: |
|
|
InferencePipeline instance ready for inference |
|
|
|
|
|
Examples: |
|
|
# Load from HF Hub |
|
|
pipeline = InferencePipeline.from_pretrained("username/owl-idm-simple-v0") |
|
|
|
|
|
# Load from local files |
|
|
pipeline = InferencePipeline.from_pretrained( |
|
|
"configs/simple.yml", |
|
|
checkpoint_path="checkpoints/simple/ema/step_17100.pt" |
|
|
) |
|
|
""" |
|
|
|
|
|
is_local = os.path.exists(model_id_or_path) or model_id_or_path.endswith('.yml') |
|
|
|
|
|
if is_local: |
|
|
|
|
|
if checkpoint_path is None: |
|
|
raise ValueError("checkpoint_path required when loading from local files") |
|
|
|
|
|
config_path = model_id_or_path |
|
|
print(f"Loading from local files...") |
|
|
print(f" Config: {config_path}") |
|
|
print(f" Checkpoint: {checkpoint_path}") |
|
|
|
|
|
else: |
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"huggingface_hub is required to load from HF Hub. " |
|
|
"Install with: pip install huggingface_hub" |
|
|
) |
|
|
|
|
|
print(f"Loading from Hugging Face Hub: {model_id_or_path}") |
|
|
|
|
|
|
|
|
config_path = hf_hub_download( |
|
|
repo_id=model_id_or_path, |
|
|
filename="config.yml", |
|
|
token=token |
|
|
) |
|
|
checkpoint_path = hf_hub_download( |
|
|
repo_id=model_id_or_path, |
|
|
filename="model.pt", |
|
|
token=token |
|
|
) |
|
|
print(f"✓ Downloaded files from HF Hub") |
|
|
|
|
|
|
|
|
config = load_config(config_path) |
|
|
|
|
|
|
|
|
model_cls = get_model_cls(config.model.model_id) |
|
|
model = model_cls(config.model) |
|
|
|
|
|
|
|
|
print(f"Loading checkpoint...") |
|
|
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True) |
|
|
model.load_state_dict(checkpoint) |
|
|
print(f"✓ Checkpoint loaded successfully!") |
|
|
|
|
|
|
|
|
return cls(model, config, device=device, compile_model=compile_model) |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, videos, window_size=None, show_progress=True): |
|
|
""" |
|
|
Run inference on videos using sliding window. |
|
|
|
|
|
Args: |
|
|
videos: Input videos of shape [b, n, c, h, w] in range [-1, 1] |
|
|
window_size: Override window size (default: use config.train.window_length) |
|
|
show_progress: Whether to show progress bar (default: True) |
|
|
|
|
|
Returns: |
|
|
Tuple of (wasd_predictions, mouse_predictions) where: |
|
|
- wasd_predictions: shape [b, n, 4] with boolean values |
|
|
- mouse_predictions: shape [b, n, 2] with float values (raw scale) |
|
|
""" |
|
|
if window_size is None: |
|
|
window_size = self.window_length |
|
|
|
|
|
b, n, c, h, w = videos.shape |
|
|
|
|
|
|
|
|
videos = videos.to(device=self.device, dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
middle_idx = (window_size - 1) // 2 |
|
|
|
|
|
|
|
|
pad_start = middle_idx |
|
|
pad_end = window_size - 1 - middle_idx |
|
|
|
|
|
|
|
|
first_frame = videos[:, 0:1].expand(-1, pad_start, -1, -1, -1) |
|
|
last_frame = videos[:, -1:].expand(-1, pad_end, -1, -1, -1) |
|
|
padded_videos = torch.cat([first_frame, videos, last_frame], dim=1) |
|
|
|
|
|
|
|
|
wasd_preds = [] |
|
|
mouse_preds = [] |
|
|
|
|
|
iterator = range(n) |
|
|
if show_progress: |
|
|
iterator = tqdm(iterator, desc="Running inference") |
|
|
|
|
|
for i in iterator: |
|
|
|
|
|
window = padded_videos[:, i:i+window_size] |
|
|
|
|
|
|
|
|
wasd_logits, mouse_pred = self.model(window) |
|
|
|
|
|
|
|
|
wasd_preds.append(wasd_logits.clone()) |
|
|
mouse_preds.append(mouse_pred.clone()) |
|
|
|
|
|
|
|
|
wasd_preds = torch.stack(wasd_preds, dim=1) |
|
|
mouse_preds = torch.stack(mouse_preds, dim=1) |
|
|
|
|
|
|
|
|
wasd_preds = torch.sigmoid(wasd_preds) > 0.5 |
|
|
|
|
|
|
|
|
if self.use_log1p_scaling: |
|
|
mouse_preds = torch.sign(mouse_preds) * torch.expm1(torch.abs(mouse_preds)) |
|
|
|
|
|
return wasd_preds, mouse_preds |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Run inference with Owl IDM") |
|
|
parser.add_argument("--config", type=str, required=True, help="Path to config YAML") |
|
|
parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint .pt file") |
|
|
parser.add_argument("--video", type=str, help="Path to video file (optional, for testing)") |
|
|
parser.add_argument("--device", type=str, default="cuda", help="Device to use") |
|
|
parser.add_argument("--no-compile", action="store_true", help="Disable model compilation") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
pipeline = InferencePipeline.from_pretrained( |
|
|
args.config, |
|
|
args.checkpoint, |
|
|
device=args.device, |
|
|
compile_model=not args.no_compile |
|
|
) |
|
|
|
|
|
print(f"\nPipeline ready!") |
|
|
print(f" Window length: {pipeline.window_length}") |
|
|
print(f" Log1p scaling: {pipeline.use_log1p_scaling}") |
|
|
print(f" Device: {pipeline.device}") |
|
|
|
|
|
if args.video: |
|
|
print(f"\nLoading video from {args.video}...") |
|
|
|
|
|
print("Video inference not yet implemented in main()") |
|
|
else: |
|
|
print("\nNo video provided. Use --video to run inference on a video file.") |
|
|
print("Example: python inference.py --config configs/simple.yml --checkpoint checkpoints/simple/ema/step_50000.pt") |
|
|
|