File size: 8,633 Bytes
8942ffb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
"""
Inference pipeline for Owl IDM models.

Example usage (local):
    pipeline = InferencePipeline.from_pretrained(
        config_path="configs/simple.yml",
        checkpoint_path="checkpoints/simple/ema/step_50000.pt"
    )

Example usage (HF Hub):
    pipeline = InferencePipeline.from_pretrained(
        "username/owl-idm-simple-v0"
    )

    # video: [b, n, c, h, w] tensor in range [-1, 1]
    wasd_preds, mouse_preds = pipeline(video)
"""
import torch
import os
from pathlib import Path
from tqdm import tqdm

from owl_idms.configs import load_config
from owl_idms.models import get_model_cls


class InferencePipeline:
    """
    Inference pipeline for IDM models with sliding window prediction.
    """

    def __init__(self, model, config, device='cuda', compile_model=True):
        """
        Initialize the inference pipeline.

        Args:
            model: The IDM model
            config: Full config object (with train and model sections)
            device: Device to run inference on (default: 'cuda')
            compile_model: Whether to compile the model with torch.compile (default: True)
        """
        self.config = config
        self.device = device
        self.window_length = config.train.window_length
        self.use_log1p_scaling = getattr(config.train, 'use_log1p_scaling', True)

        # Move model to device, convert to bfloat16, and set to eval mode
        self.model = model.to(device=device, dtype=torch.bfloat16)
        self.model.eval()

        # Compile for faster inference
        if compile_model:
            print("Compiling model for inference...")
            self.model = torch.compile(self.model, mode='max-autotune')
            print("Model compiled!")

    @classmethod
    def from_pretrained(cls, model_id_or_path, checkpoint_path=None, device='cuda', compile_model=True, token=None):
        """
        Load a pretrained model from local files or Hugging Face Hub.

        Args:
            model_id_or_path: Either:
                - HF Hub repo ID (e.g., "username/owl-idm-simple-v0")
                - Local path to config YAML file
            checkpoint_path: Path to checkpoint .pt file (only for local loading)
            device: Device to run inference on (default: 'cuda')
            compile_model: Whether to compile the model (default: True)
            token: HF API token (optional, for private repos)

        Returns:
            InferencePipeline instance ready for inference

        Examples:
            # Load from HF Hub
            pipeline = InferencePipeline.from_pretrained("username/owl-idm-simple-v0")

            # Load from local files
            pipeline = InferencePipeline.from_pretrained(
                "configs/simple.yml",
                checkpoint_path="checkpoints/simple/ema/step_17100.pt"
            )
        """
        # Check if loading from HF Hub or local
        is_local = os.path.exists(model_id_or_path) or model_id_or_path.endswith('.yml')

        if is_local:
            # Local loading
            if checkpoint_path is None:
                raise ValueError("checkpoint_path required when loading from local files")

            config_path = model_id_or_path
            print(f"Loading from local files...")
            print(f"  Config: {config_path}")
            print(f"  Checkpoint: {checkpoint_path}")

        else:
            # HF Hub loading
            try:
                from huggingface_hub import hf_hub_download
            except ImportError:
                raise ImportError(
                    "huggingface_hub is required to load from HF Hub. "
                    "Install with: pip install huggingface_hub"
                )

            print(f"Loading from Hugging Face Hub: {model_id_or_path}")

            # Download config and checkpoint
            config_path = hf_hub_download(
                repo_id=model_id_or_path,
                filename="config.yml",
                token=token
            )
            checkpoint_path = hf_hub_download(
                repo_id=model_id_or_path,
                filename="model.pt",
                token=token
            )
            print(f"✓ Downloaded files from HF Hub")

        # Load config
        config = load_config(config_path)

        # Initialize model from config
        model_cls = get_model_cls(config.model.model_id)
        model = model_cls(config.model)

        # Load checkpoint weights
        print(f"Loading checkpoint...")
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
        model.load_state_dict(checkpoint)
        print(f"✓ Checkpoint loaded successfully!")

        # Create and return pipeline
        return cls(model, config, device=device, compile_model=compile_model)

    @torch.no_grad()
    def __call__(self, videos, window_size=None, show_progress=True):
        """
        Run inference on videos using sliding window.

        Args:
            videos: Input videos of shape [b, n, c, h, w] in range [-1, 1]
            window_size: Override window size (default: use config.train.window_length)
            show_progress: Whether to show progress bar (default: True)

        Returns:
            Tuple of (wasd_predictions, mouse_predictions) where:
                - wasd_predictions: shape [b, n, 4] with boolean values
                - mouse_predictions: shape [b, n, 2] with float values (raw scale)
        """
        if window_size is None:
            window_size = self.window_length

        b, n, c, h, w = videos.shape

        # Move to device and convert to bfloat16
        videos = videos.to(device=self.device, dtype=torch.bfloat16)

        # Calculate middle index for predictions
        middle_idx = (window_size - 1) // 2

        # Calculate padding needed
        pad_start = middle_idx
        pad_end = window_size - 1 - middle_idx

        # Pad videos by duplicating first and last frames
        first_frame = videos[:, 0:1].expand(-1, pad_start, -1, -1, -1)
        last_frame = videos[:, -1:].expand(-1, pad_end, -1, -1, -1)
        padded_videos = torch.cat([first_frame, videos, last_frame], dim=1)

        # Sliding window inference
        wasd_preds = []
        mouse_preds = []

        iterator = range(n)
        if show_progress:
            iterator = tqdm(iterator, desc="Running inference")

        for i in iterator:
            # Extract window
            window = padded_videos[:, i:i+window_size]

            # Model inference
            wasd_logits, mouse_pred = self.model(window)

            # Clone tensors to avoid CUDA graph memory reuse issues
            wasd_preds.append(wasd_logits.clone())
            mouse_preds.append(mouse_pred.clone())

        # Stack predictions: [n, b, ...] -> [b, n, ...]
        wasd_preds = torch.stack(wasd_preds, dim=1)
        mouse_preds = torch.stack(mouse_preds, dim=1)

        # Convert WASD logits to boolean predictions
        wasd_preds = torch.sigmoid(wasd_preds) > 0.5

        # Convert mouse predictions from log1p space if needed
        if self.use_log1p_scaling:
            mouse_preds = torch.sign(mouse_preds) * torch.expm1(torch.abs(mouse_preds))

        return wasd_preds, mouse_preds


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Run inference with Owl IDM")
    parser.add_argument("--config", type=str, required=True, help="Path to config YAML")
    parser.add_argument("--checkpoint", type=str, required=True, help="Path to checkpoint .pt file")
    parser.add_argument("--video", type=str, help="Path to video file (optional, for testing)")
    parser.add_argument("--device", type=str, default="cuda", help="Device to use")
    parser.add_argument("--no-compile", action="store_true", help="Disable model compilation")

    args = parser.parse_args()

    # Load pipeline
    pipeline = InferencePipeline.from_pretrained(
        args.config,
        args.checkpoint,
        device=args.device,
        compile_model=not args.no_compile
    )

    print(f"\nPipeline ready!")
    print(f"  Window length: {pipeline.window_length}")
    print(f"  Log1p scaling: {pipeline.use_log1p_scaling}")
    print(f"  Device: {pipeline.device}")

    if args.video:
        print(f"\nLoading video from {args.video}...")
        # TODO: Add video loading code
        print("Video inference not yet implemented in main()")
    else:
        print("\nNo video provided. Use --video to run inference on a video file.")
        print("Example: python inference.py --config configs/simple.yml --checkpoint checkpoints/simple/ema/step_50000.pt")