Prahaladha commited on
Commit
50da3a7
·
verified ·
1 Parent(s): 19e03b8

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +54 -0
model.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Save this as model.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+ class SummaryClassifier(nn.Module):
7
+ def __init__(self, embedder, num_classes, dropout=0.1):
8
+ """
9
+ Initializes the classifier.
10
+ Args:
11
+ embedder: A pre-loaded SentenceTransformer model.
12
+ num_classes (int): The number of output classes.
13
+ dropout (float): Dropout probability.
14
+ """
15
+ super().__init__()
16
+ self.embedder = embedder
17
+ embedding_dim = embedder.get_sentence_embedding_dimension()
18
+
19
+ self.head = nn.Sequential(
20
+ nn.Dropout(dropout),
21
+ nn.Linear(embedding_dim, 128),
22
+ nn.ReLU(),
23
+ nn.Linear(128, num_classes)
24
+ )
25
+
26
+ # Freeze the embedder parameters
27
+ for p in self.embedder.parameters():
28
+ p.requires_grad = False
29
+
30
+ def forward(self, texts, return_embeddings=False):
31
+ """
32
+ Forward pass.
33
+ Args:
34
+ texts (list[str]): A list of input strings.
35
+ return_embeddings (bool): Whether to return embeddings alongside logits.
36
+ Returns:
37
+ torch.Tensor: The output logits.
38
+ (Optional) torch.Tensor: The sentence embeddings.
39
+ """
40
+ # Automatically use the same device as the model's 'head'
41
+ target_device = next(self.head.parameters()).device
42
+
43
+ embeddings = self.embedder.encode(
44
+ texts,
45
+ convert_to_tensor=True,
46
+ show_progress_bar=False,
47
+ device=str(target_device)
48
+ )
49
+
50
+ logits = self.head(embeddings)
51
+
52
+ if return_embeddings:
53
+ return logits, embeddings
54
+ return logits