Spaces:
Sleeping
Sleeping
| from typing import List, Dict, Any, Tuple | |
| import numpy as np | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score | |
| import logging | |
| from .features import RiskFeatureExtractor | |
| from .calibration_head import CalibrationHead | |
| logger = logging.getLogger(__name__) | |
| class CalibrationTrainer: | |
| def __init__(self, feature_extractor: RiskFeatureExtractor, | |
| calibration_head: CalibrationHead): | |
| self.feature_extractor = feature_extractor | |
| self.calibration_head = calibration_head | |
| def prepare_training_data(self, qa_data: List[Dict[str, Any]], | |
| retrieved_passages_list: List[List[Dict[str, Any]]], | |
| labels: List[int]) -> Tuple[np.ndarray, np.ndarray]: | |
| """Prepare training data from QA samples and retrieved passages""" | |
| # Extract features | |
| features_list = self.feature_extractor.extract_batch_features( | |
| [item['question'] for item in qa_data], | |
| retrieved_passages_list | |
| ) | |
| # Convert features to arrays | |
| X = np.array([self.feature_extractor._features_to_array(f) for f in features_list]) | |
| y = np.array(labels) | |
| logger.info(f"Prepared training data: {X.shape[0]} samples, {X.shape[1]} features") | |
| return X, y | |
| def train(self, X: np.ndarray, y: np.ndarray, | |
| test_size: float = 0.2, random_state: int = 42) -> Dict[str, Any]: | |
| """Train the calibration model""" | |
| # Split data | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=test_size, random_state=random_state, stratify=y | |
| ) | |
| # Train model | |
| train_metrics = self.calibration_head.train(X_train, y_train) | |
| # Evaluate on test set | |
| test_metrics = self.evaluate(X_test, y_test) | |
| # Combine metrics | |
| all_metrics = { | |
| 'train': train_metrics, | |
| 'test': test_metrics, | |
| 'train_size': len(X_train), | |
| 'test_size': len(X_test) | |
| } | |
| logger.info(f"Training completed. Test metrics: {test_metrics}") | |
| return all_metrics | |
| def evaluate(self, X: np.ndarray, y: np.ndarray) -> Dict[str, float]: | |
| """Evaluate the calibration model""" | |
| if not self.calibration_head.is_trained: | |
| raise ValueError("Model not trained yet") | |
| # Get predictions | |
| if hasattr(self.calibration_head.model, 'predict_proba'): | |
| y_proba = self.calibration_head.model.predict_proba(X)[:, 1] | |
| y_pred = (y_proba > 0.5).astype(int) | |
| else: | |
| y_pred = self.calibration_head.model.predict(X) | |
| y_proba = y_pred | |
| # Calculate metrics | |
| accuracy = accuracy_score(y, y_pred) | |
| precision, recall, f1, _ = precision_recall_fscore_support(y, y_pred, average='binary') | |
| try: | |
| auc = roc_auc_score(y, y_proba) | |
| except: | |
| auc = 0.0 | |
| return { | |
| 'accuracy': accuracy, | |
| 'precision': precision, | |
| 'recall': recall, | |
| 'f1': f1, | |
| 'auc': auc | |
| } | |
| def create_synthetic_labels(self, qa_data: List[Dict[str, Any]], | |
| retrieved_passages_list: List[List[Dict[str, Any]]]) -> List[int]: | |
| """Create synthetic risk labels for training (placeholder implementation)""" | |
| labels = [] | |
| for qa_item, passages in zip(qa_data, retrieved_passages_list): | |
| # Simple heuristic for risk labeling | |
| # In practice, this would be based on human annotations or automated evaluation | |
| question = qa_item['question'] | |
| answer = qa_item['answer'] | |
| # Risk factors | |
| risk_score = 0.0 | |
| # Low similarity scores = high risk | |
| if passages: | |
| avg_similarity = np.mean([p.get('score', 0.0) for p in passages]) | |
| if avg_similarity < 0.3: | |
| risk_score += 0.3 | |
| # Few passages = high risk | |
| if len(passages) < 3: | |
| risk_score += 0.2 | |
| # Question complexity (length, question words) | |
| if len(question.split()) > 20: | |
| risk_score += 0.1 | |
| if any(word in question.lower() for word in ['why', 'how', 'explain', 'compare']): | |
| risk_score += 0.1 | |
| # Answer length (very short or very long answers might be risky) | |
| if len(answer.split()) < 5 or len(answer.split()) > 100: | |
| risk_score += 0.1 | |
| # Convert to binary label | |
| label = 1 if risk_score > 0.3 else 0 | |
| labels.append(label) | |
| logger.info(f"Created {sum(labels)} high-risk labels out of {len(labels)} total") | |
| return labels | |
| def cross_validate(self, X: np.ndarray, y: np.ndarray, | |
| cv_folds: int = 5) -> Dict[str, List[float]]: | |
| """Perform cross-validation""" | |
| from sklearn.model_selection import StratifiedKFold | |
| skf = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=42) | |
| fold_metrics = { | |
| 'accuracy': [], | |
| 'precision': [], | |
| 'recall': [], | |
| 'f1': [], | |
| 'auc': [] | |
| } | |
| for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)): | |
| logger.info(f"Training fold {fold + 1}/{cv_folds}") | |
| X_train, X_val = X[train_idx], X[val_idx] | |
| y_train, y_val = y[train_idx], y[val_idx] | |
| # Train on fold | |
| self.calibration_head.train(X_train, y_train) | |
| # Evaluate on validation set | |
| val_metrics = self.evaluate(X_val, y_val) | |
| for metric, value in val_metrics.items(): | |
| fold_metrics[metric].append(value) | |
| # Calculate mean and std | |
| cv_results = {} | |
| for metric, values in fold_metrics.items(): | |
| cv_results[f'{metric}_mean'] = np.mean(values) | |
| cv_results[f'{metric}_std'] = np.std(values) | |
| logger.info(f"Cross-validation results: {cv_results}") | |
| return cv_results | |