Long Nguyen commited on
Commit
a702328
·
verified ·
1 Parent(s): ee22b18

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +34 -7
inference.py CHANGED
@@ -5,24 +5,50 @@ Simple inference script for TFv6 NavSim model.
5
  import torch
6
  import numpy as np
7
  import cv2
8
- from stand_alone_model import load_tf
 
 
 
 
 
 
9
 
10
 
11
  class TFv6NavSimInference:
12
  """Easy-to-use inference wrapper for TFv6 NavSim model."""
13
 
14
- def __init__(self, model_path="model_0060.pth", device=None):
15
  """
16
  Initialize the model.
17
 
18
  Args:
19
- model_path: Path to the model checkpoint (.pth file)
20
  device: torch.device or None (will auto-detect CUDA)
 
21
  """
22
  if device is None:
23
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
  else:
25
  self.device = device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  print(f"Loading model on {self.device}...")
28
  self.model = load_tf(model_path, self.device)
@@ -120,9 +146,10 @@ class TFv6NavSimInference:
120
 
121
 
122
  def main():
123
- """Example usage."""
124
- # Initialize model
125
- model = TFv6NavSimInference()
 
126
 
127
  # Example: Create dummy input
128
  rgb = np.random.randint(0, 255, (900, 1600, 3), dtype=np.uint8)
@@ -133,7 +160,7 @@ def main():
133
  # Run prediction
134
  result = model.predict(rgb, command, speed, acceleration)
135
 
136
- print("Prediction results:")
137
  if result['waypoints'] is not None:
138
  print(f" Waypoints shape: {result['waypoints'].shape}")
139
  print(f" First 3 waypoints:\n{result['waypoints'][:3]}")
 
5
  import torch
6
  import numpy as np
7
  import cv2
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # Lazy import to avoid errors if not downloaded yet
11
+ try:
12
+ from stand_alone_model import load_tf
13
+ except ImportError:
14
+ load_tf = None
15
 
16
 
17
  class TFv6NavSimInference:
18
  """Easy-to-use inference wrapper for TFv6 NavSim model."""
19
 
20
+ def __init__(self, model_path=None, device=None, repo_id="longpollehn/tfv6_navsim"):
21
  """
22
  Initialize the model.
23
 
24
  Args:
25
+ model_path: Path to the model checkpoint (.pth file). If None, downloads from HF.
26
  device: torch.device or None (will auto-detect CUDA)
27
+ repo_id: Hugging Face repo to download from if model_path is None
28
  """
29
  if device is None:
30
  self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
31
  else:
32
  self.device = device
33
+
34
+ # Auto-download from Hugging Face if no path provided
35
+ if model_path is None:
36
+ print(f"Downloading model from {repo_id}...")
37
+ model_path = hf_hub_download(repo_id=repo_id, filename="model_0060.pth")
38
+ print(f"Downloaded to {model_path}")
39
+
40
+ # Import stand_alone_model here if not already imported
41
+ if load_tf is None:
42
+ import sys
43
+ import os
44
+ # Download stand_alone_model.py if needed
45
+ model_code_path = hf_hub_download(repo_id=repo_id, filename="stand_alone_model.py")
46
+ config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
47
+ # Add to path and import
48
+ sys.path.insert(0, os.path.dirname(model_code_path))
49
+ from stand_alone_model import load_tf as _load_tf
50
+ global load_tf
51
+ load_tf = _load_tf
52
 
53
  print(f"Loading model on {self.device}...")
54
  self.model = load_tf(model_path, self.device)
 
146
 
147
 
148
  def main():
149
+ """Example usage - automatically downloads from HuggingFace."""
150
+ # Initialize model (will auto-download if not present)
151
+ print("Example: Using model directly from HuggingFace")
152
+ model = TFv6NavSimInference() # No arguments needed!
153
 
154
  # Example: Create dummy input
155
  rgb = np.random.randint(0, 255, (900, 1600, 3), dtype=np.uint8)
 
160
  # Run prediction
161
  result = model.predict(rgb, command, speed, acceleration)
162
 
163
+ print("\nPrediction results:")
164
  if result['waypoints'] is not None:
165
  print(f" Waypoints shape: {result['waypoints'].shape}")
166
  print(f" First 3 waypoints:\n{result['waypoints'][:3]}")