import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPooling from .configuration_pawan_embd import PawanEmbdConfig class PawanEmbdModel(PreTrainedModel): """ PawanEmbd Model - A lightweight embedding model for sentence embeddings. This model outputs normalized embeddings suitable for semantic similarity tasks. """ config_class = PawanEmbdConfig base_model_prefix = "pawan_embd" def __init__(self, config): super().__init__(config) self.config = config self.hidden_size = config.hidden_size self.output_size = config.output_size # Token + Position embeddings self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size) self.dropout = nn.Dropout(config.dropout) self.layer_norm = nn.LayerNorm(config.hidden_size) # Transformer encoder encoder_layer = nn.TransformerEncoderLayer( d_model=config.hidden_size, nhead=config.num_heads, dim_feedforward=config.intermediate_size, dropout=config.dropout, activation='gelu', batch_first=True, norm_first=True ) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers) # Projection to output size self.projection = nn.Sequential( nn.Linear(config.hidden_size, config.hidden_size * 2), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.hidden_size * 2, config.output_size) ) # Initialize weights self.post_init() def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=0.02) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=0.02) elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor = None, return_dict: bool = True, normalize: bool = True ): """ Args: input_ids: [batch_size, seq_len] attention_mask: [batch_size, seq_len] return_dict: Whether to return a ModelOutput object normalize: Whether to L2-normalize the embeddings Returns: If return_dict=True: BaseModelOutputWithPooling If return_dict=False: tuple of (last_hidden_state, pooler_output) """ batch_size, seq_len = input_ids.shape # Generate position IDs position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand(batch_size, -1) # Embeddings token_embeds = self.token_embedding(input_ids) position_embeds = self.position_embedding(position_ids) embeddings = self.dropout(self.layer_norm(token_embeds + position_embeds)) # Attention mask for transformer (convert 1/0 to True/False) if attention_mask is not None: attention_mask = attention_mask == 0 # True = masked position # Transformer encoding encoded = self.encoder(embeddings, src_key_padding_mask=attention_mask) # CLS pooling (take first token) cls_output = encoded[:, 0] # Project to output dimension pooler_output = self.projection(cls_output) # Normalize embeddings if normalize: pooler_output = F.normalize(pooler_output, p=2, dim=-1) if not return_dict: return (encoded, pooler_output) return BaseModelOutputWithPooling( last_hidden_state=encoded, pooler_output=pooler_output, hidden_states=None, attentions=None ) def count_parameters(self): """Count trainable parameters""" return sum(p.numel() for p in self.parameters() if p.requires_grad)