NautilusAI / models /pytorch_wrapper.py
gionuibk's picture
Upload models/pytorch_wrapper.py with huggingface_hub
794839a verified
import logging
import torch
import numpy as np
from typing import List, Dict, Any, Union
from .deeplob import DeepLOB
logger = logging.getLogger("PYTORCH_INFERENCE")
class PyTorchModel:
"""
Wrapper for running PyTorch models (.pt) in NautilusAI.
Matches the interface of ONNXModel for seamless switching.
"""
def __init__(self, model_path: str, device: str = None, model_class: Any = None, **model_args):
"""
Initialize PyTorch model.
Args:
model_path: Path to the .pt file.
device: 'cpu' or 'cuda' (if available). Auto-detect if None.
model_class: The class of the model architecture (e.g., DeepLOB, TRM).
**model_args: Arguments to pass to the model constructor.
"""
self.model_path = model_path
if device:
self.device = torch.device(device)
else:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = None
try:
# Instantiate model architecture
if model_class:
self.model = model_class(**model_args)
else:
# Fallback implementation (Legacy)
# logger.warning(f"No model_class provided for {model_path}. Defaulting to DeepLOB.")
from .deeplob import DeepLOB
self.model = DeepLOB()
# Load state dict
state_dict = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.to(self.device)
self.model.eval()
logger.info(f"Loaded PyTorch Model: {model_path} [{self.device}]")
except Exception as e:
logger.error(f"Failed to load PyTorch model {model_path}: {e}")
raise
def predict(self, input_data: np.ndarray) -> np.ndarray:
"""
Run inference on numpy input.
Args:
input_data: Numpy array matching model input shape.
Returns:
Numpy array of predictions.
"""
if self.model is None:
raise RuntimeError("Model not initialized.")
try:
# Handle dictionary input (ONNX style)
if isinstance(input_data, dict):
# Take the first value (assuming single input) or look for "input"
if "input" in input_data:
input_data = input_data["input"]
else:
input_data = list(input_data.values())[0]
# Convert to Tensor
tensor_in = torch.from_numpy(input_data).float().to(self.device)
# Run inference
with torch.no_grad():
output = self.model(tensor_in)
# Convert back to Numpy
return output.cpu().numpy()[0]
except Exception as e:
logger.error(f"Inference Error: {e}")
return np.zeros(3) # Return safe zeros (Hold)
def warmup(self, input_shape: tuple = (1, 2, 100, 40)):
"""
Run a dummy inference to warm up.
"""
dummy_input = np.random.randn(*input_shape).astype(np.float32)
self.predict(dummy_input)
logger.info("PyTorch Model Warmup Complete.")