ModelMatrix / matrix /code /models /sap_rpt1_wrapper.py
Akshay4506's picture
Fix deployment entry point and merge requirements
c4ff02d
"""
SAP RPT-1 Wrapper
=================
Sklearn-compatible wrapper for SAP RPT-1-OSS.
SAP RPT-1 uses in-context learning with pretrained transformers.
Requires Python 3.11 and Hugging Face model access.
Author: UW MSIM Team
Date: November 2025
"""
import time
import logging
from typing import Optional, Union
import numpy as np
import pandas as pd
from .base_wrapper import BaseModelWrapper
logger = logging.getLogger(__name__)
class SAPRPT1Wrapper(BaseModelWrapper):
"""
SAP RPT-1 (Retrieval Pretrained Transformer) wrapper.
Parameters
----------
task_type : str, default='classification'
Task type: 'classification' or 'regression'
context_size : int, default=4096
Maximum context window size in tokens
bagging_factor : int, default=4
Number of bagging iterations for prediction stability
model_size : str, default='small'
Model size: 'small' or 'large'
device : str, default='auto'
Device to use: 'cpu', 'cuda', or 'auto'
random_state : int, default=42
Random seed for reproducibility
"""
def __init__(
self,
task_type: str = 'classification',
context_size: int = 4096,
bagging_factor: int = 4,
model_size: str = 'small',
device: str = 'auto',
random_state: int = 42
):
super().__init__(task_type=task_type, random_state=random_state)
self.context_size = context_size
self.bagging_factor = bagging_factor
self.model_size = model_size
self.device = device
def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray]) -> 'SAPRPT1Wrapper':
"""
Train SAP RPT-1 model.
Note: SAP RPT-1 uses in-context learning, so "training" is primarily
about storing the training data for retrieval during inference.
Parameters
----------
X : pd.DataFrame or np.ndarray, shape (n_samples, n_features)
Training features
y : pd.Series or np.ndarray, shape (n_samples,)
Training target
Returns
-------
self : SAPRPT1Wrapper
Fitted model
"""
self._validate_input(X, y)
logger.info(f"Fitting SAP RPT-1 ({self.model_size}) on {X.shape[0]} samples...")
start_time = time.time()
try:
# Import here to avoid import errors in environments without SAP RPT-1
from sap_rpt_1_oss import SAP_RPT_OSS_Classifier, SAP_RPT_OSS_Regressor
# Initialize appropriate model
if self.task_type == 'classification':
self.model = SAP_RPT_OSS_Classifier(
context_size=self.context_size,
bagging_factor=self.bagging_factor,
model_size=self.model_size,
device=self.device
)
else:
self.model = SAP_RPT_OSS_Regressor(
context_size=self.context_size,
bagging_factor=self.bagging_factor,
model_size=self.model_size,
device=self.device
)
# Fit model (stores training data for in-context learning)
self.model.fit(X, y)
self.is_fitted = True
self.fit_time = time.time() - start_time
logger.info(f"SAP RPT-1 fitted in {self.fit_time:.2f} seconds")
except ImportError as e:
logger.error(f"SAP RPT-1 not installed: {e}")
raise ImportError(
"SAP RPT-1 not found. Install with: "
"pip install git+https://github.com/SAP-samples/sap-rpt-1-oss.git"
)
except Exception as e:
logger.error(f"Error fitting SAP RPT-1: {e}")
raise
return self
def predict(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
"""
Make predictions with SAP RPT-1.
Parameters
----------
X : pd.DataFrame or np.ndarray, shape (n_samples, n_features)
Test features
Returns
-------
predictions : np.ndarray, shape (n_samples,)
Predicted values or class labels
"""
if not self.is_fitted:
raise ValueError("Model not fitted. Call fit() first.")
self._validate_input(X)
logger.info(f"Predicting on {X.shape[0]} samples with SAP RPT-1...")
start_time = time.time()
try:
predictions = self.model.predict(X)
self.predict_time = time.time() - start_time
logger.info(f"Predictions complete in {self.predict_time:.2f} seconds")
return predictions
except Exception as e:
logger.error(f"Error during prediction: {e}")
raise
def _predict_proba_impl(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
"""
Implementation of predict_proba for SAP RPT-1.
Parameters
----------
X : pd.DataFrame or np.ndarray, shape (n_samples, n_features)
Test features
Returns
-------
probabilities : np.ndarray, shape (n_samples, n_classes)
Class probabilities
"""
if self.task_type != 'classification':
raise ValueError("predict_proba only available for classification")
try:
return self.model.predict_proba(X)
except AttributeError:
# Fallback if predict_proba not available
logger.warning("predict_proba not available, using one-hot encoding of predictions")
predictions = self.model.predict(X)
n_samples = len(predictions)
n_classes = len(np.unique(predictions))
proba = np.zeros((n_samples, n_classes))
proba[np.arange(n_samples), predictions] = 1.0
return proba
def get_params(self, deep: bool = True) -> dict:
"""Get parameters for this estimator."""
params = super().get_params(deep)
params.update({
'context_size': self.context_size,
'bagging_factor': self.bagging_factor,
'model_size': self.model_size,
'device': self.device
})
return params