taraky's picture
Upload folder using huggingface_hub
b7f3196 verified
from typing import Dict
from torch import nn
import torch
from huggingface_hub import PyTorchModelHubMixin
class ClassifierHead(
nn.Module,
PyTorchModelHubMixin,
repo_url="https://huggingface.co/davidgray/health-query-triage",
pipeline_tag="text-classification",
library_name="PyTorch",
tags=["medical", "classification"],
):
def __init__(self, num_classes: int, embedding_dim: int = 768): # Embedding-Gemma-300M has a 768-dimensional output
super().__init__()
self.linear_elu_stack = nn.Sequential(
nn.Linear(embedding_dim, 512),
nn.ELU(),
nn.Dropout(0.5),
nn.Linear(512, 512),
nn.ELU(),
nn.Dropout(0.5),
nn.Linear(512, num_classes),
)
self.softmax = nn.Softmax(dim=-1)
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Calculates logits from the sentence embedding.
Args:
features (Dict[str, torch.Tensor]): Output dictionary from the Sentence Transformer body,
containing 'sentence_embedding'.
Returns:
Dict[str, torch.Tensor]: Dictionary with the 'logits' key.
"""
embeddings = features['sentence_embedding']
logits = self.linear_elu_stack(embeddings)
return {"logits": logits}
def predict(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
Classifies embeddings into integer labels in the range [0, num_classes).
Args:
embeddings (torch.Tensor): Tensor with shape [num_inputs, embedding_size].
Returns:
torch.Tensor: Integer labels with shape [num_inputs].
"""
# Get probabilities and find the class with the highest probability
proba = self.predict_proba(embeddings)
return torch.argmax(proba, dim=-1)
def predict_proba(self, embeddings: torch.Tensor) -> torch.Tensor:
"""
Classifies embeddings into probabilities for each class (summing to 1).
Args:
embeddings (torch.Tensor): Tensor with shape [num_inputs, embedding_size].
Returns:
torch.Tensor: Float probabilities with shape [num_inputs, num_classes].
"""
# Apply the forward pass of the head to get logits
self.eval()
with torch.no_grad():
logits = self.linear_elu_stack(embeddings)
# Convert logits to probabilities using Softmax
probabilities = self.softmax(logits)
self.train() # Set back to training mode
return probabilities
def get_loss_fn(self) -> nn.Module:
"""
Returns an initialized loss function for training.
Returns:
nn.Module: An initialized loss function (e.g., CrossEntropyLoss).
"""
# CrossEntropyLoss expects logits (raw scores) as input
return nn.CrossEntropyLoss()