|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
class SummaryClassifier(nn.Module):
|
|
|
def __init__(self, embedder, num_classes, dropout=0.1):
|
|
|
"""
|
|
|
Initializes the classifier.
|
|
|
Args:
|
|
|
embedder: A pre-loaded SentenceTransformer model.
|
|
|
num_classes (int): The number of output classes.
|
|
|
dropout (float): Dropout probability.
|
|
|
"""
|
|
|
super().__init__()
|
|
|
self.embedder = embedder
|
|
|
embedding_dim = embedder.get_sentence_embedding_dimension()
|
|
|
|
|
|
self.head = nn.Sequential(
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(embedding_dim, 128),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(128, num_classes)
|
|
|
)
|
|
|
|
|
|
|
|
|
for p in self.embedder.parameters():
|
|
|
p.requires_grad = False
|
|
|
|
|
|
def forward(self, texts, return_embeddings=False):
|
|
|
"""
|
|
|
Forward pass.
|
|
|
Args:
|
|
|
texts (list[str]): A list of input strings.
|
|
|
return_embeddings (bool): Whether to return embeddings alongside logits.
|
|
|
Returns:
|
|
|
torch.Tensor: The output logits.
|
|
|
(Optional) torch.Tensor: The sentence embeddings.
|
|
|
"""
|
|
|
|
|
|
target_device = next(self.head.parameters()).device
|
|
|
|
|
|
embeddings = self.embedder.encode(
|
|
|
texts,
|
|
|
convert_to_tensor=True,
|
|
|
show_progress_bar=False,
|
|
|
device=str(target_device)
|
|
|
)
|
|
|
|
|
|
logits = self.head(embeddings)
|
|
|
|
|
|
if return_embeddings:
|
|
|
return logits, embeddings
|
|
|
return logits |