import logging import torch import numpy as np from typing import List, Dict, Any, Union from .deeplob import DeepLOB logger = logging.getLogger("PYTORCH_INFERENCE") class PyTorchModel: """ Wrapper for running PyTorch models (.pt) in NautilusAI. Matches the interface of ONNXModel for seamless switching. """ def __init__(self, model_path: str, device: str = None, model_class: Any = None, **model_args): """ Initialize PyTorch model. Args: model_path: Path to the .pt file. device: 'cpu' or 'cuda' (if available). Auto-detect if None. model_class: The class of the model architecture (e.g., DeepLOB, TRM). **model_args: Arguments to pass to the model constructor. """ self.model_path = model_path if device: self.device = torch.device(device) else: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = None try: # Instantiate model architecture if model_class: self.model = model_class(**model_args) else: # Fallback implementation (Legacy) # logger.warning(f"No model_class provided for {model_path}. Defaulting to DeepLOB.") from .deeplob import DeepLOB self.model = DeepLOB() # Load state dict state_dict = torch.load(model_path, map_location=self.device) self.model.load_state_dict(state_dict) self.model.to(self.device) self.model.eval() logger.info(f"Loaded PyTorch Model: {model_path} [{self.device}]") except Exception as e: logger.error(f"Failed to load PyTorch model {model_path}: {e}") raise def predict(self, input_data: np.ndarray) -> np.ndarray: """ Run inference on numpy input. Args: input_data: Numpy array matching model input shape. Returns: Numpy array of predictions. """ if self.model is None: raise RuntimeError("Model not initialized.") try: # Handle dictionary input (ONNX style) if isinstance(input_data, dict): # Take the first value (assuming single input) or look for "input" if "input" in input_data: input_data = input_data["input"] else: input_data = list(input_data.values())[0] # Convert to Tensor tensor_in = torch.from_numpy(input_data).float().to(self.device) # Run inference with torch.no_grad(): output = self.model(tensor_in) # Convert back to Numpy return output.cpu().numpy()[0] except Exception as e: logger.error(f"Inference Error: {e}") return np.zeros(3) # Return safe zeros (Hold) def warmup(self, input_shape: tuple = (1, 2, 100, 40)): """ Run a dummy inference to warm up. """ dummy_input = np.random.randn(*input_shape).astype(np.float32) self.predict(dummy_input) logger.info("PyTorch Model Warmup Complete.")