""" SAP RPT-1 Wrapper ================= Sklearn-compatible wrapper for SAP RPT-1-OSS. SAP RPT-1 uses in-context learning with pretrained transformers. Requires Python 3.11 and Hugging Face model access. Author: UW MSIM Team Date: November 2025 """ import time import logging from typing import Optional, Union import numpy as np import pandas as pd from .base_wrapper import BaseModelWrapper logger = logging.getLogger(__name__) class SAPRPT1Wrapper(BaseModelWrapper): """ SAP RPT-1 (Retrieval Pretrained Transformer) wrapper. Parameters ---------- task_type : str, default='classification' Task type: 'classification' or 'regression' context_size : int, default=4096 Maximum context window size in tokens bagging_factor : int, default=4 Number of bagging iterations for prediction stability model_size : str, default='small' Model size: 'small' or 'large' device : str, default='auto' Device to use: 'cpu', 'cuda', or 'auto' random_state : int, default=42 Random seed for reproducibility """ def __init__( self, task_type: str = 'classification', context_size: int = 4096, bagging_factor: int = 4, model_size: str = 'small', device: str = 'auto', random_state: int = 42 ): super().__init__(task_type=task_type, random_state=random_state) self.context_size = context_size self.bagging_factor = bagging_factor self.model_size = model_size self.device = device def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray]) -> 'SAPRPT1Wrapper': """ Train SAP RPT-1 model. Note: SAP RPT-1 uses in-context learning, so "training" is primarily about storing the training data for retrieval during inference. Parameters ---------- X : pd.DataFrame or np.ndarray, shape (n_samples, n_features) Training features y : pd.Series or np.ndarray, shape (n_samples,) Training target Returns ------- self : SAPRPT1Wrapper Fitted model """ self._validate_input(X, y) logger.info(f"Fitting SAP RPT-1 ({self.model_size}) on {X.shape[0]} samples...") start_time = time.time() try: # Import here to avoid import errors in environments without SAP RPT-1 from sap_rpt_1_oss import SAP_RPT_OSS_Classifier, SAP_RPT_OSS_Regressor # Initialize appropriate model if self.task_type == 'classification': self.model = SAP_RPT_OSS_Classifier( context_size=self.context_size, bagging_factor=self.bagging_factor, model_size=self.model_size, device=self.device ) else: self.model = SAP_RPT_OSS_Regressor( context_size=self.context_size, bagging_factor=self.bagging_factor, model_size=self.model_size, device=self.device ) # Fit model (stores training data for in-context learning) self.model.fit(X, y) self.is_fitted = True self.fit_time = time.time() - start_time logger.info(f"SAP RPT-1 fitted in {self.fit_time:.2f} seconds") except ImportError as e: logger.error(f"SAP RPT-1 not installed: {e}") raise ImportError( "SAP RPT-1 not found. Install with: " "pip install git+https://github.com/SAP-samples/sap-rpt-1-oss.git" ) except Exception as e: logger.error(f"Error fitting SAP RPT-1: {e}") raise return self def predict(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray: """ Make predictions with SAP RPT-1. Parameters ---------- X : pd.DataFrame or np.ndarray, shape (n_samples, n_features) Test features Returns ------- predictions : np.ndarray, shape (n_samples,) Predicted values or class labels """ if not self.is_fitted: raise ValueError("Model not fitted. Call fit() first.") self._validate_input(X) logger.info(f"Predicting on {X.shape[0]} samples with SAP RPT-1...") start_time = time.time() try: predictions = self.model.predict(X) self.predict_time = time.time() - start_time logger.info(f"Predictions complete in {self.predict_time:.2f} seconds") return predictions except Exception as e: logger.error(f"Error during prediction: {e}") raise def _predict_proba_impl(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray: """ Implementation of predict_proba for SAP RPT-1. Parameters ---------- X : pd.DataFrame or np.ndarray, shape (n_samples, n_features) Test features Returns ------- probabilities : np.ndarray, shape (n_samples, n_classes) Class probabilities """ if self.task_type != 'classification': raise ValueError("predict_proba only available for classification") try: return self.model.predict_proba(X) except AttributeError: # Fallback if predict_proba not available logger.warning("predict_proba not available, using one-hot encoding of predictions") predictions = self.model.predict(X) n_samples = len(predictions) n_classes = len(np.unique(predictions)) proba = np.zeros((n_samples, n_classes)) proba[np.arange(n_samples), predictions] = 1.0 return proba def get_params(self, deep: bool = True) -> dict: """Get parameters for this estimator.""" params = super().get_params(deep) params.update({ 'context_size': self.context_size, 'bagging_factor': self.bagging_factor, 'model_size': self.model_size, 'device': self.device }) return params