import torch.nn as nn import torch.nn.functional as F import torch from transformers import EsmModel, AutoModel, PreTrainedModel, AutoConfig import evaluate import numpy as np from sklearn.metrics import accuracy_score, classification_report import wandb accuracy_metric = evaluate.load("accuracy") precision_metric = evaluate.load("precision") recall_metric = evaluate.load("recall") f1_metric = evaluate.load("f1") class CleavageSiteModel(nn.Module): def __init__(self, base_model, num_classes=75, class_weights=None): super().__init__() self.model = EsmModel.from_pretrained(base_model) self.classifier = nn.Linear(self.model.config.hidden_size, num_classes) if class_weights is not None: # Create full-length weights tensor with zeros weight_tensor = torch.zeros(num_classes) for class_idx, weight in class_weights.items(): weight_tensor[class_idx] = weight self.loss_fn = nn.CrossEntropyLoss(weight=weight_tensor) else: self.loss_fn = nn.CrossEntropyLoss() def forward(self, input_ids, attention_mask, labels=None): outputs = self.model(input_ids=input_ids, attention_mask=attention_mask) cls_output = outputs.last_hidden_state[:, 0] logits = self.classifier(cls_output) if labels is not None: loss = self.loss_fn(logits, labels) return {"loss": loss, "logits": logits} else: return {"logits": logits} def compute_metrics(eval_pred): # Computes classification metrics including overall accuracy and per-class accuracy. logits, labels = eval_pred # Extract model outputs and labels predictions = np.argmax(logits, axis=1) # Get predicted class # Compute overall accuracy accuracy = accuracy_score(labels, predictions) report = classification_report(labels, predictions, digits=4) wandb.log({"classification_report": wandb.Html(report.replace('\n', '
'))}) # Compute per-class accuracy unique_classes = np.unique(labels) per_class_acc = {} for cls in unique_classes: class_mask = labels == cls # Select samples belonging to this class per_class_acc[f"accuracy_class_{cls}"] = (predictions[class_mask] == labels[class_mask]).mean() # Log metrics wandb.log({"overall_accuracy": accuracy, **per_class_acc}) return {"accuracy": accuracy, **per_class_acc}