safe_rag / calibration /trainer.py
Tairun Meng
Initial commit: SafeRAG project ready for HF Spaces
db06013
from typing import List, Dict, Any, Tuple
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
import logging
from .features import RiskFeatureExtractor
from .calibration_head import CalibrationHead
logger = logging.getLogger(__name__)
class CalibrationTrainer:
def __init__(self, feature_extractor: RiskFeatureExtractor,
calibration_head: CalibrationHead):
self.feature_extractor = feature_extractor
self.calibration_head = calibration_head
def prepare_training_data(self, qa_data: List[Dict[str, Any]],
retrieved_passages_list: List[List[Dict[str, Any]]],
labels: List[int]) -> Tuple[np.ndarray, np.ndarray]:
"""Prepare training data from QA samples and retrieved passages"""
# Extract features
features_list = self.feature_extractor.extract_batch_features(
[item['question'] for item in qa_data],
retrieved_passages_list
)
# Convert features to arrays
X = np.array([self.feature_extractor._features_to_array(f) for f in features_list])
y = np.array(labels)
logger.info(f"Prepared training data: {X.shape[0]} samples, {X.shape[1]} features")
return X, y
def train(self, X: np.ndarray, y: np.ndarray,
test_size: float = 0.2, random_state: int = 42) -> Dict[str, Any]:
"""Train the calibration model"""
# Split data
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, random_state=random_state, stratify=y
)
# Train model
train_metrics = self.calibration_head.train(X_train, y_train)
# Evaluate on test set
test_metrics = self.evaluate(X_test, y_test)
# Combine metrics
all_metrics = {
'train': train_metrics,
'test': test_metrics,
'train_size': len(X_train),
'test_size': len(X_test)
}
logger.info(f"Training completed. Test metrics: {test_metrics}")
return all_metrics
def evaluate(self, X: np.ndarray, y: np.ndarray) -> Dict[str, float]:
"""Evaluate the calibration model"""
if not self.calibration_head.is_trained:
raise ValueError("Model not trained yet")
# Get predictions
if hasattr(self.calibration_head.model, 'predict_proba'):
y_proba = self.calibration_head.model.predict_proba(X)[:, 1]
y_pred = (y_proba > 0.5).astype(int)
else:
y_pred = self.calibration_head.model.predict(X)
y_proba = y_pred
# Calculate metrics
accuracy = accuracy_score(y, y_pred)
precision, recall, f1, _ = precision_recall_fscore_support(y, y_pred, average='binary')
try:
auc = roc_auc_score(y, y_proba)
except:
auc = 0.0
return {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1': f1,
'auc': auc
}
def create_synthetic_labels(self, qa_data: List[Dict[str, Any]],
retrieved_passages_list: List[List[Dict[str, Any]]]) -> List[int]:
"""Create synthetic risk labels for training (placeholder implementation)"""
labels = []
for qa_item, passages in zip(qa_data, retrieved_passages_list):
# Simple heuristic for risk labeling
# In practice, this would be based on human annotations or automated evaluation
question = qa_item['question']
answer = qa_item['answer']
# Risk factors
risk_score = 0.0
# Low similarity scores = high risk
if passages:
avg_similarity = np.mean([p.get('score', 0.0) for p in passages])
if avg_similarity < 0.3:
risk_score += 0.3
# Few passages = high risk
if len(passages) < 3:
risk_score += 0.2
# Question complexity (length, question words)
if len(question.split()) > 20:
risk_score += 0.1
if any(word in question.lower() for word in ['why', 'how', 'explain', 'compare']):
risk_score += 0.1
# Answer length (very short or very long answers might be risky)
if len(answer.split()) < 5 or len(answer.split()) > 100:
risk_score += 0.1
# Convert to binary label
label = 1 if risk_score > 0.3 else 0
labels.append(label)
logger.info(f"Created {sum(labels)} high-risk labels out of {len(labels)} total")
return labels
def cross_validate(self, X: np.ndarray, y: np.ndarray,
cv_folds: int = 5) -> Dict[str, List[float]]:
"""Perform cross-validation"""
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=42)
fold_metrics = {
'accuracy': [],
'precision': [],
'recall': [],
'f1': [],
'auc': []
}
for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
logger.info(f"Training fold {fold + 1}/{cv_folds}")
X_train, X_val = X[train_idx], X[val_idx]
y_train, y_val = y[train_idx], y[val_idx]
# Train on fold
self.calibration_head.train(X_train, y_train)
# Evaluate on validation set
val_metrics = self.evaluate(X_val, y_val)
for metric, value in val_metrics.items():
fold_metrics[metric].append(value)
# Calculate mean and std
cv_results = {}
for metric, values in fold_metrics.items():
cv_results[f'{metric}_mean'] = np.mean(values)
cv_results[f'{metric}_std'] = np.std(values)
logger.info(f"Cross-validation results: {cv_results}")
return cv_results