Spaces:
Sleeping
Sleeping
| """ | |
| SAP RPT-1 OSS Quick Test Script | |
| ================================= | |
| Validates HuggingFace token authentication and runs a quick | |
| classification test using the breast cancer dataset. | |
| Usage: | |
| # Set your token first | |
| set HUGGING_FACE_HUB_TOKEN=hf_xxxxxxxxxxxxxxxxxxxxxxxxxxxxx | |
| # Run test | |
| cd code | |
| python ../scripts/test_sap_rpt1.py | |
| Requirements: | |
| - Python >= 3.11 | |
| - pip install git+https://github.com/SAP-samples/sap-rpt-1-oss.git | |
| - Hugging Face token with access to SAP/sap-rpt-1-oss | |
| Author: UW MSIM Team | |
| Date: April 2026 | |
| """ | |
| import os | |
| import sys | |
| import time | |
| import logging | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| project_root = Path(__file__).parent.parent | |
| load_dotenv(project_root / ".env") | |
| # Add code directory to path | |
| sys.path.insert(0, str(project_root / "code")) | |
| # Fix Windows emoji printing issues | |
| if sys.stdout.encoding.lower() != 'utf-8' and hasattr(sys.stdout, 'reconfigure'): | |
| sys.stdout.reconfigure(encoding='utf-8') | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def check_prerequisites(): | |
| """Check all prerequisites before running the test.""" | |
| print("\n" + "=" * 60) | |
| print(" SAP RPT-1 OSS โ Quick Test") | |
| print("=" * 60) | |
| # 1. Check Python version | |
| py_version = sys.version_info | |
| print(f"\nโ Python version: {py_version.major}.{py_version.minor}.{py_version.micro}") | |
| if py_version < (3, 11): | |
| print("โ ๏ธ Warning: SAP RPT-1 OSS requires Python >= 3.11") | |
| print(f" Your version: {py_version.major}.{py_version.minor}") | |
| # 2. Check HF token | |
| token = os.getenv("HUGGING_FACE_HUB_TOKEN") or os.getenv("HF_TOKEN") | |
| if token: | |
| print(f"โ HF Token found: {token[:8]}...{token[-4:]}") | |
| else: | |
| print("โ No HF token found!") | |
| print(" Set it with: set HUGGING_FACE_HUB_TOKEN=hf_xxx") | |
| return False | |
| # 3. Check sap_rpt_oss package | |
| try: | |
| import sap_rpt_oss | |
| print("โ sap_rpt_oss package installed") | |
| except ImportError: | |
| print("โ sap_rpt_oss not installed!") | |
| print(" Install with: pip install git+https://github.com/SAP-samples/sap-rpt-1-oss.git") | |
| return False | |
| # 4. Check HF authentication | |
| try: | |
| from huggingface_hub import HfApi, login | |
| login(token=token, add_to_git_credential=False) | |
| api = HfApi() | |
| user_info = api.whoami() | |
| print(f"โ HF authenticated as: {user_info.get('name', 'unknown')}") | |
| except Exception as e: | |
| print(f"โ HF authentication failed: {e}") | |
| print(" Make sure you've accepted the license at:") | |
| print(" https://huggingface.co/SAP/sap-rpt-1-oss") | |
| return False | |
| return True | |
| def run_classification_test(): | |
| """Run a classification test on the breast cancer dataset.""" | |
| from sklearn.datasets import load_breast_cancer | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import accuracy_score, classification_report | |
| from sap_rpt_oss import SAP_RPT_OSS_Classifier | |
| print("\n" + "-" * 60) | |
| print(" Classification Test: Breast Cancer Dataset") | |
| print("-" * 60) | |
| # Load data | |
| X, y = load_breast_cancer(return_X_y=True, as_frame=True) | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=0.3, random_state=42 | |
| ) | |
| print(f"\n๐ Dataset: {X_train.shape[0]} train / {X_test.shape[0]} test samples") | |
| print(f"๐ Features: {X.shape[1]}") | |
| # Initialize model (use small context for quick test) | |
| print("\n๐ง Initializing SAP RPT-1 OSS Classifier...") | |
| print(" max_context_size=2048, bagging=1 (fast test mode)") | |
| start_init = time.time() | |
| clf = SAP_RPT_OSS_Classifier(max_context_size=2048, bagging=1) | |
| init_time = time.time() - start_init | |
| print(f" Model loaded in {init_time:.2f}s") | |
| # Fit | |
| print("\n๐๏ธ Fitting model (in-context learning)...") | |
| start_fit = time.time() | |
| clf.fit(X_train, y_train) | |
| fit_time = time.time() - start_fit | |
| print(f" Fit completed in {fit_time:.2f}s") | |
| # Predict | |
| print("\n๐ฎ Making predictions...") | |
| start_pred = time.time() | |
| predictions = clf.predict(X_test) | |
| pred_time = time.time() - start_pred | |
| print(f" Predictions completed in {pred_time:.2f}s") | |
| # Evaluate | |
| accuracy = accuracy_score(y_test, predictions) | |
| print("\n" + "=" * 60) | |
| print(" RESULTS") | |
| print("=" * 60) | |
| print(f"\n Accuracy: {accuracy:.4f} ({accuracy * 100:.1f}%)") | |
| print(f" Init time: {init_time:.2f}s") | |
| print(f" Fit time: {fit_time:.2f}s") | |
| print(f" Predict time: {pred_time:.2f}s") | |
| print(f" Total time: {init_time + fit_time + pred_time:.2f}s") | |
| print() | |
| print(classification_report(y_test, predictions, target_names=['malignant', 'benign'])) | |
| return accuracy | |
| def run_wrapper_test(): | |
| """Run a test using the SAPRPT1HFWrapper from the project.""" | |
| from models.sap_rpt1_hf_wrapper import SAPRPT1HFWrapper | |
| from sklearn.datasets import load_breast_cancer | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import accuracy_score | |
| print("\n" + "-" * 60) | |
| print(" Wrapper Integration Test: SAPRPT1HFWrapper") | |
| print("-" * 60) | |
| # Load data | |
| X, y = load_breast_cancer(return_X_y=True, as_frame=True) | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=0.3, random_state=42 | |
| ) | |
| # Use the project wrapper | |
| wrapper = SAPRPT1HFWrapper( | |
| task_type='classification', | |
| max_context_size=2048, | |
| bagging=1 | |
| ) | |
| wrapper.fit(X_train, y_train) | |
| predictions = wrapper.predict(X_test) | |
| accuracy = accuracy_score(y_test, predictions) | |
| print(f"\n โ Wrapper test passed! Accuracy: {accuracy:.4f}") | |
| print(f" โ Fit time: {wrapper.fit_time:.2f}s") | |
| # Test predict_proba | |
| try: | |
| proba = wrapper.predict_proba(X_test) | |
| print(f" โ predict_proba works! Shape: {proba.shape}") | |
| except Exception as e: | |
| print(f" โ ๏ธ predict_proba failed: {e}") | |
| return accuracy | |
| if __name__ == "__main__": | |
| # Check prerequisites | |
| if not check_prerequisites(): | |
| print("\nโ Prerequisites check failed. Fix the issues above and try again.") | |
| sys.exit(1) | |
| # Run tests | |
| try: | |
| accuracy = run_classification_test() | |
| wrapper_accuracy = run_wrapper_test() | |
| print("\n" + "=" * 60) | |
| print(" โ ALL TESTS PASSED!") | |
| print("=" * 60) | |
| print(f"\n You can now run experiments with:") | |
| print(f" python -m runners.run_experiment --dataset adult --model sap-rpt1-hf") | |
| print() | |
| except Exception as e: | |
| print(f"\nโ Test failed with error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |