# Save this as model.py import torch import torch.nn as nn from sentence_transformers import SentenceTransformer class SummaryClassifier(nn.Module): def __init__(self, embedder, num_classes, dropout=0.1): """ Initializes the classifier. Args: embedder: A pre-loaded SentenceTransformer model. num_classes (int): The number of output classes. dropout (float): Dropout probability. """ super().__init__() self.embedder = embedder embedding_dim = embedder.get_sentence_embedding_dimension() self.head = nn.Sequential( nn.Dropout(dropout), nn.Linear(embedding_dim, 128), nn.ReLU(), nn.Linear(128, num_classes) ) # Freeze the embedder parameters for p in self.embedder.parameters(): p.requires_grad = False def forward(self, texts, return_embeddings=False): """ Forward pass. Args: texts (list[str]): A list of input strings. return_embeddings (bool): Whether to return embeddings alongside logits. Returns: torch.Tensor: The output logits. (Optional) torch.Tensor: The sentence embeddings. """ # Automatically use the same device as the model's 'head' target_device = next(self.head.parameters()).device embeddings = self.embedder.encode( texts, convert_to_tensor=True, show_progress_bar=False, device=str(target_device) ) logits = self.head(embeddings) if return_embeddings: return logits, embeddings return logits