Spaces:
Sleeping
Sleeping
| import logging | |
| import torch | |
| import numpy as np | |
| from typing import List, Dict, Any, Union | |
| from .deeplob import DeepLOB | |
| logger = logging.getLogger("PYTORCH_INFERENCE") | |
| class PyTorchModel: | |
| """ | |
| Wrapper for running PyTorch models (.pt) in NautilusAI. | |
| Matches the interface of ONNXModel for seamless switching. | |
| """ | |
| def __init__(self, model_path: str, device: str = None, model_class: Any = None, **model_args): | |
| """ | |
| Initialize PyTorch model. | |
| Args: | |
| model_path: Path to the .pt file. | |
| device: 'cpu' or 'cuda' (if available). Auto-detect if None. | |
| model_class: The class of the model architecture (e.g., DeepLOB, TRM). | |
| **model_args: Arguments to pass to the model constructor. | |
| """ | |
| self.model_path = model_path | |
| if device: | |
| self.device = torch.device(device) | |
| else: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.model = None | |
| try: | |
| # Instantiate model architecture | |
| if model_class: | |
| self.model = model_class(**model_args) | |
| else: | |
| # Fallback implementation (Legacy) | |
| # logger.warning(f"No model_class provided for {model_path}. Defaulting to DeepLOB.") | |
| from .deeplob import DeepLOB | |
| self.model = DeepLOB() | |
| # Load state dict | |
| state_dict = torch.load(model_path, map_location=self.device) | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| logger.info(f"Loaded PyTorch Model: {model_path} [{self.device}]") | |
| except Exception as e: | |
| logger.error(f"Failed to load PyTorch model {model_path}: {e}") | |
| raise | |
| def predict(self, input_data: np.ndarray) -> np.ndarray: | |
| """ | |
| Run inference on numpy input. | |
| Args: | |
| input_data: Numpy array matching model input shape. | |
| Returns: | |
| Numpy array of predictions. | |
| """ | |
| if self.model is None: | |
| raise RuntimeError("Model not initialized.") | |
| try: | |
| # Handle dictionary input (ONNX style) | |
| if isinstance(input_data, dict): | |
| # Take the first value (assuming single input) or look for "input" | |
| if "input" in input_data: | |
| input_data = input_data["input"] | |
| else: | |
| input_data = list(input_data.values())[0] | |
| # Convert to Tensor | |
| tensor_in = torch.from_numpy(input_data).float().to(self.device) | |
| # Run inference | |
| with torch.no_grad(): | |
| output = self.model(tensor_in) | |
| # Convert back to Numpy | |
| return output.cpu().numpy()[0] | |
| except Exception as e: | |
| logger.error(f"Inference Error: {e}") | |
| return np.zeros(3) # Return safe zeros (Hold) | |
| def warmup(self, input_shape: tuple = (1, 2, 100, 40)): | |
| """ | |
| Run a dummy inference to warm up. | |
| """ | |
| dummy_input = np.random.randn(*input_shape).astype(np.float32) | |
| self.predict(dummy_input) | |
| logger.info("PyTorch Model Warmup Complete.") | |