File size: 4,219 Bytes
ef814bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import confusion_matrix
from models import SingleTransformer
from utils.helpers import create_multimodal_model


def compute_confusion_matrices(id, model_config, fold_results, dataset, device):
    """
    Get confusion matrices for each fold and aggregate them.
    Args:
        id (str): Model ID.
        model_config (dict): Model configuration.
        fold_results (list): List of dictionaries containing fold results.
        cls_valid_loader (torch.utils.data.DataLoader): Validation data loader.
        device (str): Device to use.
    Returns:
        list: List of confusion matrices for each fold and the aggregated confusion
            matrix.
    """
    if id not in ['RNA', 'ATAC', 'Flux', 'Multi']:
            raise ValueError("id must be one of 'RNA', 'ATAC', 'Flux', 'Multi'")
    # Initialize an empty confusion matrix for aggregation
    agg_cm = np.zeros((2, 2), dtype=int)
    cms = []

    for i, fold in enumerate(fold_results, 1):
        model_path = fold['best_model_path']
        state_dict = torch.load(model_path)
        val_subset = Subset(dataset, fold['val_idx'])
        cls_valid_loader = DataLoader(val_subset, batch_size=32, shuffle=False)
        
        if id=='Multi':
            model = create_multimodal_model(model_config, device, use_mlm=False)
        else:
            model = SingleTransformer(id, **model_config).to(device)
        
        model.load_state_dict(state_dict, strict=True)
        model.eval()
        
        val_preds, val_labels = [], []
        with torch.no_grad():
            for inputs, bi, y in cls_valid_loader:
                if isinstance(inputs, list):
                    rna= inputs[0].to(device)
                    atac = inputs[1].to(device)
                    flux = inputs[2].to(device)
                    inputs = (rna, atac, flux)
                else:
                    inputs = inputs.to(device)
                bi, y = bi.to(device), y.to(device)

                preds, _ = model(inputs, bi)
                preds = preds.cpu().numpy()
                val_preds.append(preds)
                val_labels.append(y.cpu().numpy())

        val_preds = np.concatenate(val_preds).ravel()
        val_labels = np.concatenate(val_labels).ravel()
        
        binary_preds = (val_preds >= 0.5).astype(int)
        # print(f"Fold {i} Confusion Matrix:", val_preds)
        cm = confusion_matrix(val_labels, binary_preds)
        agg_cm += cm
        cms.append(cm)

    cms.append(agg_cm)
    return cms


def compute_metrics_from_confusion_matrix(cm):
    """
    Compute classification metrics from a confusion matrix.
    Args:
        cm (np.array): Confusion matrix.
    Returns:
        dict: Dictionary containing classification metrics.
    """
    # in cm results of 5 folds are saved in a list. compute this metrics for each fold
    # then return the average of them and the std
    metrics_list = []
    for fold_cm in cm[:-1]:  # Exclude the aggregated confusion matrix
        tn, fp, fn, tp = fold_cm.ravel()
        precision = tp / (tp + fp) if tp + fp > 0 else 0
        recall = tp / (tp + fn) if tp + fn > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
        accuracy = (tp + tn) / (tp + tn + fp + fn) if tp + tn + fp + fn > 0 else 0
        metrics_list.append({
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'accuracy': accuracy,
        })

    avg_metrics = {
        'precision': np.mean([m['precision'] for m in metrics_list]),
        'recall': np.mean([m['recall'] for m in metrics_list]),
        'f1': np.mean([m['f1'] for m in metrics_list]),
        'accuracy': np.mean([m['accuracy'] for m in metrics_list]),
    }

    std_metrics = {
        'precision': np.std([m['precision'] for m in metrics_list]),
        'recall': np.std([m['recall'] for m in metrics_list]),
        'f1': np.std([m['f1'] for m in metrics_list]),
        'accuracy': np.std([m['accuracy'] for m in metrics_list]),
    }

    return {
        'average': avg_metrics,
        'std': std_metrics,
    }