import torch import torch.nn as nn import numpy as np from pathlib import Path from typing import Optional class BetaRegressor(nn.Module): def __init__(self, input_dim: int = 9, output_dim: int = 10, hidden_dims: list = [64, 32]): super(BetaRegressor, self).__init__() layers = [] prev_dim = input_dim for hidden_dim in hidden_dims: layers.append(nn.Linear(prev_dim, hidden_dim)) layers.append(nn.ReLU()) layers.append(nn.Dropout(0.1)) prev_dim = hidden_dim layers.append(nn.Linear(prev_dim, output_dim)) layers.append(nn.Tanh()) self.network = nn.Sequential(*layers) self._initialize_weights() def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight, gain=0.1) nn.init.zeros_(m.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.network(x) class MeasurementToBetaPredictor: def __init__(self, model_path: Optional[str] = None, device: str = "cpu"): self.device = torch.device(device) self.model = BetaRegressor().to(self.device) self.model.eval() if model_path and Path(model_path).exists(): self.load_model(model_path) else: print("Warning: Using untrained model. Results may not be optimal.") print("Consider training the model or loading pretrained weights.") def load_model(self, model_path: str): checkpoint = torch.load(model_path, map_location=self.device) if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: self.model.load_state_dict(checkpoint['model_state_dict']) else: self.model.load_state_dict(checkpoint) print(f"Loaded model from {model_path}") def predict(self, normalized_measurements: np.ndarray) -> np.ndarray: with torch.no_grad(): measurements_tensor = torch.FloatTensor(normalized_measurements).unsqueeze(0).to(self.device) betas_tensor = self.model(measurements_tensor) betas = betas_tensor.squeeze(0).cpu().numpy() betas = betas * 2.0 return betas _predictor_instance = None def get_predictor(model_path: Optional[str] = None, device: str = "cpu") -> MeasurementToBetaPredictor: global _predictor_instance if _predictor_instance is None: _predictor_instance = MeasurementToBetaPredictor(model_path=model_path, device=device) return _predictor_instance def predict_betas(normalized_measurements: np.ndarray, model_path: Optional[str] = None) -> np.ndarray: predictor = get_predictor(model_path=model_path) return predictor.predict(normalized_measurements)