Prahaladha's picture
Upload model.py with huggingface_hub
50da3a7 verified
# Save this as model.py
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer
class SummaryClassifier(nn.Module):
def __init__(self, embedder, num_classes, dropout=0.1):
"""
Initializes the classifier.
Args:
embedder: A pre-loaded SentenceTransformer model.
num_classes (int): The number of output classes.
dropout (float): Dropout probability.
"""
super().__init__()
self.embedder = embedder
embedding_dim = embedder.get_sentence_embedding_dimension()
self.head = nn.Sequential(
nn.Dropout(dropout),
nn.Linear(embedding_dim, 128),
nn.ReLU(),
nn.Linear(128, num_classes)
)
# Freeze the embedder parameters
for p in self.embedder.parameters():
p.requires_grad = False
def forward(self, texts, return_embeddings=False):
"""
Forward pass.
Args:
texts (list[str]): A list of input strings.
return_embeddings (bool): Whether to return embeddings alongside logits.
Returns:
torch.Tensor: The output logits.
(Optional) torch.Tensor: The sentence embeddings.
"""
# Automatically use the same device as the model's 'head'
target_device = next(self.head.parameters()).device
embeddings = self.embedder.encode(
texts,
convert_to_tensor=True,
show_progress_bar=False,
device=str(target_device)
)
logits = self.head(embeddings)
if return_embeddings:
return logits, embeddings
return logits