File size: 2,446 Bytes
aec9df8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch.nn as nn
import torch.nn.functional as F
import torch
from transformers import EsmModel, AutoModel, PreTrainedModel, AutoConfig
import evaluate
import numpy as np
from sklearn.metrics import accuracy_score, classification_report
import wandb

accuracy_metric = evaluate.load("accuracy")
precision_metric = evaluate.load("precision")
recall_metric = evaluate.load("recall")
f1_metric = evaluate.load("f1")


class CleavageSiteModel(nn.Module):
    def __init__(self, base_model, num_classes=75, class_weights=None):
        super().__init__()
        self.model = EsmModel.from_pretrained(base_model)
        self.classifier = nn.Linear(self.model.config.hidden_size, num_classes)

        if class_weights is not None:
            # Create full-length weights tensor with zeros
            weight_tensor = torch.zeros(num_classes)
            for class_idx, weight in class_weights.items():
                weight_tensor[class_idx] = weight
            self.loss_fn = nn.CrossEntropyLoss(weight=weight_tensor)
        else:
            self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0]
        logits = self.classifier(cls_output)

        if labels is not None:
            loss = self.loss_fn(logits, labels)
            return {"loss": loss, "logits": logits}
        else:
            return {"logits": logits}


def compute_metrics(eval_pred):
    # Computes classification metrics including overall accuracy and per-class accuracy.

    logits, labels = eval_pred  # Extract model outputs and labels
    predictions = np.argmax(logits, axis=1)  # Get predicted class

    # Compute overall accuracy
    accuracy = accuracy_score(labels, predictions)

    report = classification_report(labels, predictions, digits=4)
    wandb.log({"classification_report": wandb.Html(report.replace('\n', '<br>'))})

    # Compute per-class accuracy
    unique_classes = np.unique(labels)
    per_class_acc = {}
    for cls in unique_classes:
        class_mask = labels == cls  # Select samples belonging to this class
        per_class_acc[f"accuracy_class_{cls}"] = (predictions[class_mask] == labels[class_mask]).mean()

    # Log metrics
    wandb.log({"overall_accuracy": accuracy, **per_class_acc})

    return {"accuracy": accuracy, **per_class_acc}