Long Nguyen
commited on
Upload inference.py with huggingface_hub
Browse files- inference.py +34 -7
inference.py
CHANGED
|
@@ -5,24 +5,50 @@ Simple inference script for TFv6 NavSim model.
|
|
| 5 |
import torch
|
| 6 |
import numpy as np
|
| 7 |
import cv2
|
| 8 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class TFv6NavSimInference:
|
| 12 |
"""Easy-to-use inference wrapper for TFv6 NavSim model."""
|
| 13 |
|
| 14 |
-
def __init__(self, model_path=
|
| 15 |
"""
|
| 16 |
Initialize the model.
|
| 17 |
|
| 18 |
Args:
|
| 19 |
-
model_path: Path to the model checkpoint (.pth file)
|
| 20 |
device: torch.device or None (will auto-detect CUDA)
|
|
|
|
| 21 |
"""
|
| 22 |
if device is None:
|
| 23 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 24 |
else:
|
| 25 |
self.device = device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
print(f"Loading model on {self.device}...")
|
| 28 |
self.model = load_tf(model_path, self.device)
|
|
@@ -120,9 +146,10 @@ class TFv6NavSimInference:
|
|
| 120 |
|
| 121 |
|
| 122 |
def main():
|
| 123 |
-
"""Example usage."""
|
| 124 |
-
# Initialize model
|
| 125 |
-
model
|
|
|
|
| 126 |
|
| 127 |
# Example: Create dummy input
|
| 128 |
rgb = np.random.randint(0, 255, (900, 1600, 3), dtype=np.uint8)
|
|
@@ -133,7 +160,7 @@ def main():
|
|
| 133 |
# Run prediction
|
| 134 |
result = model.predict(rgb, command, speed, acceleration)
|
| 135 |
|
| 136 |
-
print("
|
| 137 |
if result['waypoints'] is not None:
|
| 138 |
print(f" Waypoints shape: {result['waypoints'].shape}")
|
| 139 |
print(f" First 3 waypoints:\n{result['waypoints'][:3]}")
|
|
|
|
| 5 |
import torch
|
| 6 |
import numpy as np
|
| 7 |
import cv2
|
| 8 |
+
from huggingface_hub import hf_hub_download
|
| 9 |
+
|
| 10 |
+
# Lazy import to avoid errors if not downloaded yet
|
| 11 |
+
try:
|
| 12 |
+
from stand_alone_model import load_tf
|
| 13 |
+
except ImportError:
|
| 14 |
+
load_tf = None
|
| 15 |
|
| 16 |
|
| 17 |
class TFv6NavSimInference:
|
| 18 |
"""Easy-to-use inference wrapper for TFv6 NavSim model."""
|
| 19 |
|
| 20 |
+
def __init__(self, model_path=None, device=None, repo_id="longpollehn/tfv6_navsim"):
|
| 21 |
"""
|
| 22 |
Initialize the model.
|
| 23 |
|
| 24 |
Args:
|
| 25 |
+
model_path: Path to the model checkpoint (.pth file). If None, downloads from HF.
|
| 26 |
device: torch.device or None (will auto-detect CUDA)
|
| 27 |
+
repo_id: Hugging Face repo to download from if model_path is None
|
| 28 |
"""
|
| 29 |
if device is None:
|
| 30 |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 31 |
else:
|
| 32 |
self.device = device
|
| 33 |
+
|
| 34 |
+
# Auto-download from Hugging Face if no path provided
|
| 35 |
+
if model_path is None:
|
| 36 |
+
print(f"Downloading model from {repo_id}...")
|
| 37 |
+
model_path = hf_hub_download(repo_id=repo_id, filename="model_0060.pth")
|
| 38 |
+
print(f"Downloaded to {model_path}")
|
| 39 |
+
|
| 40 |
+
# Import stand_alone_model here if not already imported
|
| 41 |
+
if load_tf is None:
|
| 42 |
+
import sys
|
| 43 |
+
import os
|
| 44 |
+
# Download stand_alone_model.py if needed
|
| 45 |
+
model_code_path = hf_hub_download(repo_id=repo_id, filename="stand_alone_model.py")
|
| 46 |
+
config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
|
| 47 |
+
# Add to path and import
|
| 48 |
+
sys.path.insert(0, os.path.dirname(model_code_path))
|
| 49 |
+
from stand_alone_model import load_tf as _load_tf
|
| 50 |
+
global load_tf
|
| 51 |
+
load_tf = _load_tf
|
| 52 |
|
| 53 |
print(f"Loading model on {self.device}...")
|
| 54 |
self.model = load_tf(model_path, self.device)
|
|
|
|
| 146 |
|
| 147 |
|
| 148 |
def main():
|
| 149 |
+
"""Example usage - automatically downloads from HuggingFace."""
|
| 150 |
+
# Initialize model (will auto-download if not present)
|
| 151 |
+
print("Example: Using model directly from HuggingFace")
|
| 152 |
+
model = TFv6NavSimInference() # No arguments needed!
|
| 153 |
|
| 154 |
# Example: Create dummy input
|
| 155 |
rgb = np.random.randint(0, 255, (900, 1600, 3), dtype=np.uint8)
|
|
|
|
| 160 |
# Run prediction
|
| 161 |
result = model.predict(rgb, command, speed, acceleration)
|
| 162 |
|
| 163 |
+
print("\nPrediction results:")
|
| 164 |
if result['waypoints'] is not None:
|
| 165 |
print(f" Waypoints shape: {result['waypoints'].shape}")
|
| 166 |
print(f" First 3 waypoints:\n{result['waypoints'][:3]}")
|