from typing import List, Dict, Any, Tuple import numpy as np from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score import logging from .features import RiskFeatureExtractor from .calibration_head import CalibrationHead logger = logging.getLogger(__name__) class CalibrationTrainer: def __init__(self, feature_extractor: RiskFeatureExtractor, calibration_head: CalibrationHead): self.feature_extractor = feature_extractor self.calibration_head = calibration_head def prepare_training_data(self, qa_data: List[Dict[str, Any]], retrieved_passages_list: List[List[Dict[str, Any]]], labels: List[int]) -> Tuple[np.ndarray, np.ndarray]: """Prepare training data from QA samples and retrieved passages""" # Extract features features_list = self.feature_extractor.extract_batch_features( [item['question'] for item in qa_data], retrieved_passages_list ) # Convert features to arrays X = np.array([self.feature_extractor._features_to_array(f) for f in features_list]) y = np.array(labels) logger.info(f"Prepared training data: {X.shape[0]} samples, {X.shape[1]} features") return X, y def train(self, X: np.ndarray, y: np.ndarray, test_size: float = 0.2, random_state: int = 42) -> Dict[str, Any]: """Train the calibration model""" # Split data X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=test_size, random_state=random_state, stratify=y ) # Train model train_metrics = self.calibration_head.train(X_train, y_train) # Evaluate on test set test_metrics = self.evaluate(X_test, y_test) # Combine metrics all_metrics = { 'train': train_metrics, 'test': test_metrics, 'train_size': len(X_train), 'test_size': len(X_test) } logger.info(f"Training completed. Test metrics: {test_metrics}") return all_metrics def evaluate(self, X: np.ndarray, y: np.ndarray) -> Dict[str, float]: """Evaluate the calibration model""" if not self.calibration_head.is_trained: raise ValueError("Model not trained yet") # Get predictions if hasattr(self.calibration_head.model, 'predict_proba'): y_proba = self.calibration_head.model.predict_proba(X)[:, 1] y_pred = (y_proba > 0.5).astype(int) else: y_pred = self.calibration_head.model.predict(X) y_proba = y_pred # Calculate metrics accuracy = accuracy_score(y, y_pred) precision, recall, f1, _ = precision_recall_fscore_support(y, y_pred, average='binary') try: auc = roc_auc_score(y, y_proba) except: auc = 0.0 return { 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1': f1, 'auc': auc } def create_synthetic_labels(self, qa_data: List[Dict[str, Any]], retrieved_passages_list: List[List[Dict[str, Any]]]) -> List[int]: """Create synthetic risk labels for training (placeholder implementation)""" labels = [] for qa_item, passages in zip(qa_data, retrieved_passages_list): # Simple heuristic for risk labeling # In practice, this would be based on human annotations or automated evaluation question = qa_item['question'] answer = qa_item['answer'] # Risk factors risk_score = 0.0 # Low similarity scores = high risk if passages: avg_similarity = np.mean([p.get('score', 0.0) for p in passages]) if avg_similarity < 0.3: risk_score += 0.3 # Few passages = high risk if len(passages) < 3: risk_score += 0.2 # Question complexity (length, question words) if len(question.split()) > 20: risk_score += 0.1 if any(word in question.lower() for word in ['why', 'how', 'explain', 'compare']): risk_score += 0.1 # Answer length (very short or very long answers might be risky) if len(answer.split()) < 5 or len(answer.split()) > 100: risk_score += 0.1 # Convert to binary label label = 1 if risk_score > 0.3 else 0 labels.append(label) logger.info(f"Created {sum(labels)} high-risk labels out of {len(labels)} total") return labels def cross_validate(self, X: np.ndarray, y: np.ndarray, cv_folds: int = 5) -> Dict[str, List[float]]: """Perform cross-validation""" from sklearn.model_selection import StratifiedKFold skf = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=42) fold_metrics = { 'accuracy': [], 'precision': [], 'recall': [], 'f1': [], 'auc': [] } for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)): logger.info(f"Training fold {fold + 1}/{cv_folds}") X_train, X_val = X[train_idx], X[val_idx] y_train, y_val = y[train_idx], y[val_idx] # Train on fold self.calibration_head.train(X_train, y_train) # Evaluate on validation set val_metrics = self.evaluate(X_val, y_val) for metric, value in val_metrics.items(): fold_metrics[metric].append(value) # Calculate mean and std cv_results = {} for metric, values in fold_metrics.items(): cv_results[f'{metric}_mean'] = np.mean(values) cv_results[f'{metric}_std'] = np.std(values) logger.info(f"Cross-validation results: {cv_results}") return cv_results