dmedhi commited on
Commit
20091ac
·
verified ·
1 Parent(s): 5e4eab4

Upload modeling_pawan_embd.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_pawan_embd.py +124 -0
modeling_pawan_embd.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel
5
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
6
+
7
+ class PawanEmbdModel(PreTrainedModel):
8
+ """
9
+ PawanEmbd Model - A lightweight embedding model for sentence embeddings.
10
+
11
+ This model outputs normalized embeddings suitable for semantic similarity tasks.
12
+ """
13
+
14
+ config_class = PawanEmbdConfig
15
+ base_model_prefix = "pawan_embd"
16
+
17
+ def __init__(self, config):
18
+ super().__init__(config)
19
+
20
+ self.config = config
21
+ self.hidden_size = config.hidden_size
22
+ self.output_size = config.output_size
23
+
24
+ # Token + Position embeddings
25
+ self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
26
+ self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
27
+ self.dropout = nn.Dropout(config.dropout)
28
+ self.layer_norm = nn.LayerNorm(config.hidden_size)
29
+
30
+ # Transformer encoder
31
+ encoder_layer = nn.TransformerEncoderLayer(
32
+ d_model=config.hidden_size,
33
+ nhead=config.num_heads,
34
+ dim_feedforward=config.intermediate_size,
35
+ dropout=config.dropout,
36
+ activation='gelu',
37
+ batch_first=True,
38
+ norm_first=True
39
+ )
40
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
41
+
42
+ # Projection to output size
43
+ self.projection = nn.Sequential(
44
+ nn.Linear(config.hidden_size, config.hidden_size * 2),
45
+ nn.GELU(),
46
+ nn.Dropout(config.dropout),
47
+ nn.Linear(config.hidden_size * 2, config.output_size)
48
+ )
49
+
50
+ # Initialize weights
51
+ self.post_init()
52
+
53
+ def _init_weights(self, module):
54
+ """Initialize the weights"""
55
+ if isinstance(module, nn.Linear):
56
+ module.weight.data.normal_(mean=0.0, std=0.02)
57
+ if module.bias is not None:
58
+ module.bias.data.zero_()
59
+ elif isinstance(module, nn.Embedding):
60
+ module.weight.data.normal_(mean=0.0, std=0.02)
61
+ elif isinstance(module, nn.LayerNorm):
62
+ module.bias.data.zero_()
63
+ module.weight.data.fill_(1.0)
64
+
65
+ def forward(
66
+ self,
67
+ input_ids: torch.Tensor,
68
+ attention_mask: torch.Tensor = None,
69
+ return_dict: bool = True,
70
+ normalize: bool = True
71
+ ):
72
+ """
73
+ Args:
74
+ input_ids: [batch_size, seq_len]
75
+ attention_mask: [batch_size, seq_len]
76
+ return_dict: Whether to return a ModelOutput object
77
+ normalize: Whether to L2-normalize the embeddings
78
+
79
+ Returns:
80
+ If return_dict=True: BaseModelOutputWithPooling
81
+ If return_dict=False: tuple of (last_hidden_state, pooler_output)
82
+ """
83
+ batch_size, seq_len = input_ids.shape
84
+
85
+ # Generate position IDs
86
+ position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
87
+ position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
88
+
89
+ # Embeddings
90
+ token_embeds = self.token_embedding(input_ids)
91
+ position_embeds = self.position_embedding(position_ids)
92
+ embeddings = self.dropout(self.layer_norm(token_embeds + position_embeds))
93
+
94
+ # Attention mask for transformer (convert 1/0 to True/False)
95
+ if attention_mask is not None:
96
+ attention_mask = attention_mask == 0 # True = masked position
97
+
98
+ # Transformer encoding
99
+ encoded = self.encoder(embeddings, src_key_padding_mask=attention_mask)
100
+
101
+ # CLS pooling (take first token)
102
+ cls_output = encoded[:, 0]
103
+
104
+ # Project to output dimension
105
+ pooler_output = self.projection(cls_output)
106
+
107
+ # Normalize embeddings
108
+ if normalize:
109
+ pooler_output = F.normalize(pooler_output, p=2, dim=-1)
110
+
111
+ if not return_dict:
112
+ return (encoded, pooler_output)
113
+
114
+ return BaseModelOutputWithPooling(
115
+ last_hidden_state=encoded,
116
+ pooler_output=pooler_output,
117
+ hidden_states=None,
118
+ attentions=None
119
+ )
120
+
121
+ def count_parameters(self):
122
+ """Count trainable parameters"""
123
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
124
+