File size: 6,468 Bytes
db06013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from typing import List, Dict, Any, Tuple
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
import logging
from .features import RiskFeatureExtractor
from .calibration_head import CalibrationHead

logger = logging.getLogger(__name__)

class CalibrationTrainer:
    def __init__(self, feature_extractor: RiskFeatureExtractor, 
                 calibration_head: CalibrationHead):
        self.feature_extractor = feature_extractor
        self.calibration_head = calibration_head
    
    def prepare_training_data(self, qa_data: List[Dict[str, Any]], 
                            retrieved_passages_list: List[List[Dict[str, Any]]],
                            labels: List[int]) -> Tuple[np.ndarray, np.ndarray]:
        """Prepare training data from QA samples and retrieved passages"""
        
        # Extract features
        features_list = self.feature_extractor.extract_batch_features(
            [item['question'] for item in qa_data],
            retrieved_passages_list
        )
        
        # Convert features to arrays
        X = np.array([self.feature_extractor._features_to_array(f) for f in features_list])
        y = np.array(labels)
        
        logger.info(f"Prepared training data: {X.shape[0]} samples, {X.shape[1]} features")
        return X, y
    
    def train(self, X: np.ndarray, y: np.ndarray, 
              test_size: float = 0.2, random_state: int = 42) -> Dict[str, Any]:
        """Train the calibration model"""
        
        # Split data
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, random_state=random_state, stratify=y
        )
        
        # Train model
        train_metrics = self.calibration_head.train(X_train, y_train)
        
        # Evaluate on test set
        test_metrics = self.evaluate(X_test, y_test)
        
        # Combine metrics
        all_metrics = {
            'train': train_metrics,
            'test': test_metrics,
            'train_size': len(X_train),
            'test_size': len(X_test)
        }
        
        logger.info(f"Training completed. Test metrics: {test_metrics}")
        return all_metrics
    
    def evaluate(self, X: np.ndarray, y: np.ndarray) -> Dict[str, float]:
        """Evaluate the calibration model"""
        if not self.calibration_head.is_trained:
            raise ValueError("Model not trained yet")
        
        # Get predictions
        if hasattr(self.calibration_head.model, 'predict_proba'):
            y_proba = self.calibration_head.model.predict_proba(X)[:, 1]
            y_pred = (y_proba > 0.5).astype(int)
        else:
            y_pred = self.calibration_head.model.predict(X)
            y_proba = y_pred
        
        # Calculate metrics
        accuracy = accuracy_score(y, y_pred)
        precision, recall, f1, _ = precision_recall_fscore_support(y, y_pred, average='binary')
        
        try:
            auc = roc_auc_score(y, y_proba)
        except:
            auc = 0.0
        
        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'auc': auc
        }
    
    def create_synthetic_labels(self, qa_data: List[Dict[str, Any]], 
                              retrieved_passages_list: List[List[Dict[str, Any]]]) -> List[int]:
        """Create synthetic risk labels for training (placeholder implementation)"""
        labels = []
        
        for qa_item, passages in zip(qa_data, retrieved_passages_list):
            # Simple heuristic for risk labeling
            # In practice, this would be based on human annotations or automated evaluation
            
            question = qa_item['question']
            answer = qa_item['answer']
            
            # Risk factors
            risk_score = 0.0
            
            # Low similarity scores = high risk
            if passages:
                avg_similarity = np.mean([p.get('score', 0.0) for p in passages])
                if avg_similarity < 0.3:
                    risk_score += 0.3
            
            # Few passages = high risk
            if len(passages) < 3:
                risk_score += 0.2
            
            # Question complexity (length, question words)
            if len(question.split()) > 20:
                risk_score += 0.1
            
            if any(word in question.lower() for word in ['why', 'how', 'explain', 'compare']):
                risk_score += 0.1
            
            # Answer length (very short or very long answers might be risky)
            if len(answer.split()) < 5 or len(answer.split()) > 100:
                risk_score += 0.1
            
            # Convert to binary label
            label = 1 if risk_score > 0.3 else 0
            labels.append(label)
        
        logger.info(f"Created {sum(labels)} high-risk labels out of {len(labels)} total")
        return labels
    
    def cross_validate(self, X: np.ndarray, y: np.ndarray, 
                      cv_folds: int = 5) -> Dict[str, List[float]]:
        """Perform cross-validation"""
        from sklearn.model_selection import StratifiedKFold
        
        skf = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=42)
        
        fold_metrics = {
            'accuracy': [],
            'precision': [],
            'recall': [],
            'f1': [],
            'auc': []
        }
        
        for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
            logger.info(f"Training fold {fold + 1}/{cv_folds}")
            
            X_train, X_val = X[train_idx], X[val_idx]
            y_train, y_val = y[train_idx], y[val_idx]
            
            # Train on fold
            self.calibration_head.train(X_train, y_train)
            
            # Evaluate on validation set
            val_metrics = self.evaluate(X_val, y_val)
            
            for metric, value in val_metrics.items():
                fold_metrics[metric].append(value)
        
        # Calculate mean and std
        cv_results = {}
        for metric, values in fold_metrics.items():
            cv_results[f'{metric}_mean'] = np.mean(values)
            cv_results[f'{metric}_std'] = np.std(values)
        
        logger.info(f"Cross-validation results: {cv_results}")
        return cv_results