Spaces:
Running
Running
File size: 6,292 Bytes
e17f3ba | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | """
SAP RPT-1 Wrapper
=================
Sklearn-compatible wrapper for SAP RPT-1-OSS.
SAP RPT-1 uses in-context learning with pretrained transformers.
Requires Python 3.11 and Hugging Face model access.
Author: UW MSIM Team
Date: November 2025
"""
import time
import logging
from typing import Optional, Union
import numpy as np
import pandas as pd
from .base_wrapper import BaseModelWrapper
logger = logging.getLogger(__name__)
class SAPRPT1Wrapper(BaseModelWrapper):
"""
SAP RPT-1 (Retrieval Pretrained Transformer) wrapper.
Parameters
----------
task_type : str, default='classification'
Task type: 'classification' or 'regression'
context_size : int, default=4096
Maximum context window size in tokens
bagging_factor : int, default=4
Number of bagging iterations for prediction stability
model_size : str, default='small'
Model size: 'small' or 'large'
device : str, default='auto'
Device to use: 'cpu', 'cuda', or 'auto'
random_state : int, default=42
Random seed for reproducibility
"""
def __init__(
self,
task_type: str = 'classification',
context_size: int = 4096,
bagging_factor: int = 4,
model_size: str = 'small',
device: str = 'auto',
random_state: int = 42
):
super().__init__(task_type=task_type, random_state=random_state)
self.context_size = context_size
self.bagging_factor = bagging_factor
self.model_size = model_size
self.device = device
def fit(self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray]) -> 'SAPRPT1Wrapper':
"""
Train SAP RPT-1 model.
Note: SAP RPT-1 uses in-context learning, so "training" is primarily
about storing the training data for retrieval during inference.
Parameters
----------
X : pd.DataFrame or np.ndarray, shape (n_samples, n_features)
Training features
y : pd.Series or np.ndarray, shape (n_samples,)
Training target
Returns
-------
self : SAPRPT1Wrapper
Fitted model
"""
self._validate_input(X, y)
logger.info(f"Fitting SAP RPT-1 ({self.model_size}) on {X.shape[0]} samples...")
start_time = time.time()
try:
# Import here to avoid import errors in environments without SAP RPT-1
from sap_rpt_1_oss import SAP_RPT_OSS_Classifier, SAP_RPT_OSS_Regressor
# Initialize appropriate model
if self.task_type == 'classification':
self.model = SAP_RPT_OSS_Classifier(
context_size=self.context_size,
bagging_factor=self.bagging_factor,
model_size=self.model_size,
device=self.device
)
else:
self.model = SAP_RPT_OSS_Regressor(
context_size=self.context_size,
bagging_factor=self.bagging_factor,
model_size=self.model_size,
device=self.device
)
# Fit model (stores training data for in-context learning)
self.model.fit(X, y)
self.is_fitted = True
self.fit_time = time.time() - start_time
logger.info(f"SAP RPT-1 fitted in {self.fit_time:.2f} seconds")
except ImportError as e:
logger.error(f"SAP RPT-1 not installed: {e}")
raise ImportError(
"SAP RPT-1 not found. Install with: "
"pip install git+https://github.com/SAP-samples/sap-rpt-1-oss.git"
)
except Exception as e:
logger.error(f"Error fitting SAP RPT-1: {e}")
raise
return self
def predict(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
"""
Make predictions with SAP RPT-1.
Parameters
----------
X : pd.DataFrame or np.ndarray, shape (n_samples, n_features)
Test features
Returns
-------
predictions : np.ndarray, shape (n_samples,)
Predicted values or class labels
"""
if not self.is_fitted:
raise ValueError("Model not fitted. Call fit() first.")
self._validate_input(X)
logger.info(f"Predicting on {X.shape[0]} samples with SAP RPT-1...")
start_time = time.time()
try:
predictions = self.model.predict(X)
self.predict_time = time.time() - start_time
logger.info(f"Predictions complete in {self.predict_time:.2f} seconds")
return predictions
except Exception as e:
logger.error(f"Error during prediction: {e}")
raise
def _predict_proba_impl(self, X: Union[pd.DataFrame, np.ndarray]) -> np.ndarray:
"""
Implementation of predict_proba for SAP RPT-1.
Parameters
----------
X : pd.DataFrame or np.ndarray, shape (n_samples, n_features)
Test features
Returns
-------
probabilities : np.ndarray, shape (n_samples, n_classes)
Class probabilities
"""
if self.task_type != 'classification':
raise ValueError("predict_proba only available for classification")
try:
return self.model.predict_proba(X)
except AttributeError:
# Fallback if predict_proba not available
logger.warning("predict_proba not available, using one-hot encoding of predictions")
predictions = self.model.predict(X)
n_samples = len(predictions)
n_classes = len(np.unique(predictions))
proba = np.zeros((n_samples, n_classes))
proba[np.arange(n_samples), predictions] = 1.0
return proba
def get_params(self, deep: bool = True) -> dict:
"""Get parameters for this estimator."""
params = super().get_params(deep)
params.update({
'context_size': self.context_size,
'bagging_factor': self.bagging_factor,
'model_size': self.model_size,
'device': self.device
})
return params
|