Spaces:
Sleeping
Sleeping
File size: 3,136 Bytes
b7f3196 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from typing import Dict
from torch import nn
import torch
from huggingface_hub import PyTorchModelHubMixin
class ClassifierHead(
nn.Module,
PyTorchModelHubMixin,
repo_url="https://huggingface.co/davidgray/health-query-triage",
pipeline_tag="text-classification",
library_name="PyTorch",
tags=["medical", "classification"],
):
def __init__(self, num_classes: int, embedding_dim: int = 768): # Embedding-Gemma-300M has a 768-dimensional output
super().__init__()
self.linear_elu_stack = nn.Sequential(
nn.Linear(embedding_dim, 512),
nn.ELU(),
nn.Dropout(0.5),
nn.Linear(512, 512),
nn.ELU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes),
)
self.softmax = nn.Softmax(dim=-1)
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Calculates logits from the sentence embedding.
Args:
features (Dict[str, torch.Tensor]): Output dictionary from the Sentence Transformer body,
containing 'sentence_embedding'.
Returns:
Dict[str, torch.Tensor]: Dictionary with the 'logits' key.
"""
embeddings = features['sentence_embedding']
logits = self.linear_elu_stack(embeddings)
return {"logits": logits}
def predict(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
Classifies embeddings into integer labels in the range [0, num_classes).
Args:
embeddings (torch.Tensor): Tensor with shape [num_inputs, embedding_size].
Returns:
torch.Tensor: Integer labels with shape [num_inputs].
"""
# Get probabilities and find the class with the highest probability
proba = self.predict_proba(embeddings)
return torch.argmax(proba, dim=-1)
def predict_proba(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
Classifies embeddings into probabilities for each class (summing to 1).
Args:
embeddings (torch.Tensor): Tensor with shape [num_inputs, embedding_size].
Returns:
torch.Tensor: Float probabilities with shape [num_inputs, num_classes].
"""
# Apply the forward pass of the head to get logits
self.eval()
with torch.no_grad():
logits = self.linear_elu_stack(embeddings)
# Convert logits to probabilities using Softmax
probabilities = self.softmax(logits)
self.train() # Set back to training mode
return probabilities
def get_loss_fn(self) -> nn.Module:
"""
Returns an initialized loss function for training.
Returns:
nn.Module: An initialized loss function (e.g., CrossEntropyLoss).
"""
# CrossEntropyLoss expects logits (raw scores) as input
return nn.CrossEntropyLoss()
|