| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import BaseModelOutputWithPooling |
| | from .configuration_pawan_embd import PawanEmbdConfig |
| |
|
| | class PawanEmbdModel(PreTrainedModel): |
| | """ |
| | PawanEmbd Model - A lightweight embedding model for sentence embeddings. |
| | |
| | This model outputs normalized embeddings suitable for semantic similarity tasks. |
| | """ |
| |
|
| | config_class = PawanEmbdConfig |
| | base_model_prefix = "pawan_embd" |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| |
|
| | self.config = config |
| | self.hidden_size = config.hidden_size |
| | self.output_size = config.output_size |
| |
|
| | |
| | self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) |
| | self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
| | self.dropout = nn.Dropout(config.dropout) |
| | self.layer_norm = nn.LayerNorm(config.hidden_size) |
| |
|
| | |
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=config.hidden_size, |
| | nhead=config.num_heads, |
| | dim_feedforward=config.intermediate_size, |
| | dropout=config.dropout, |
| | activation='gelu', |
| | batch_first=True, |
| | norm_first=True |
| | ) |
| | self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers) |
| |
|
| | |
| | self.projection = nn.Sequential( |
| | nn.Linear(config.hidden_size, config.hidden_size * 2), |
| | nn.GELU(), |
| | nn.Dropout(config.dropout), |
| | nn.Linear(config.hidden_size * 2, config.output_size) |
| | ) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def _init_weights(self, module): |
| | """Initialize the weights""" |
| | if isinstance(module, nn.Linear): |
| | module.weight.data.normal_(mean=0.0, std=0.02) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=0.02) |
| | elif isinstance(module, nn.LayerNorm): |
| | module.bias.data.zero_() |
| | module.weight.data.fill_(1.0) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | attention_mask: torch.Tensor = None, |
| | return_dict: bool = True, |
| | normalize: bool = True |
| | ): |
| | """ |
| | Args: |
| | input_ids: [batch_size, seq_len] |
| | attention_mask: [batch_size, seq_len] |
| | return_dict: Whether to return a ModelOutput object |
| | normalize: Whether to L2-normalize the embeddings |
| | |
| | Returns: |
| | If return_dict=True: BaseModelOutputWithPooling |
| | If return_dict=False: tuple of (last_hidden_state, pooler_output) |
| | """ |
| | batch_size, seq_len = input_ids.shape |
| |
|
| | |
| | position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) |
| | position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) |
| |
|
| | |
| | token_embeds = self.token_embedding(input_ids) |
| | position_embeds = self.position_embedding(position_ids) |
| | embeddings = self.dropout(self.layer_norm(token_embeds + position_embeds)) |
| |
|
| | |
| | if attention_mask is not None: |
| | attention_mask = attention_mask == 0 |
| |
|
| | |
| | encoded = self.encoder(embeddings, src_key_padding_mask=attention_mask) |
| |
|
| | |
| | cls_output = encoded[:, 0] |
| |
|
| | |
| | pooler_output = self.projection(cls_output) |
| |
|
| | |
| | if normalize: |
| | pooler_output = F.normalize(pooler_output, p=2, dim=-1) |
| |
|
| | if not return_dict: |
| | return (encoded, pooler_output) |
| |
|
| | return BaseModelOutputWithPooling( |
| | last_hidden_state=encoded, |
| | pooler_output=pooler_output, |
| | hidden_states=None, |
| | attentions=None |
| | ) |
| |
|
| | def count_parameters(self): |
| | """Count trainable parameters""" |
| | return sum(p.numel() for p in self.parameters() if p.requires_grad) |
| |
|
| |
|