PawanEmbd-68M / modeling_pawan_embd.py
dmedhi's picture
Add missing PawanEmbdConfig import
f1572cb verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPooling
from .configuration_pawan_embd import PawanEmbdConfig
class PawanEmbdModel(PreTrainedModel):
"""
PawanEmbd Model - A lightweight embedding model for sentence embeddings.
This model outputs normalized embeddings suitable for semantic similarity tasks.
"""
config_class = PawanEmbdConfig
base_model_prefix = "pawan_embd"
def __init__(self, config):
super().__init__(config)
self.config = config
self.hidden_size = config.hidden_size
self.output_size = config.output_size
# Token + Position embeddings
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embedding = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.dropout = nn.Dropout(config.dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size)
# Transformer encoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_heads,
dim_feedforward=config.intermediate_size,
dropout=config.dropout,
activation='gelu',
batch_first=True,
norm_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
# Projection to output size
self.projection = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size * 2),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.hidden_size * 2, config.output_size)
)
# Initialize weights
self.post_init()
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor = None,
return_dict: bool = True,
normalize: bool = True
):
"""
Args:
input_ids: [batch_size, seq_len]
attention_mask: [batch_size, seq_len]
return_dict: Whether to return a ModelOutput object
normalize: Whether to L2-normalize the embeddings
Returns:
If return_dict=True: BaseModelOutputWithPooling
If return_dict=False: tuple of (last_hidden_state, pooler_output)
"""
batch_size, seq_len = input_ids.shape
# Generate position IDs
position_ids = torch.arange(seq_len, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Embeddings
token_embeds = self.token_embedding(input_ids)
position_embeds = self.position_embedding(position_ids)
embeddings = self.dropout(self.layer_norm(token_embeds + position_embeds))
# Attention mask for transformer (convert 1/0 to True/False)
if attention_mask is not None:
attention_mask = attention_mask == 0 # True = masked position
# Transformer encoding
encoded = self.encoder(embeddings, src_key_padding_mask=attention_mask)
# CLS pooling (take first token)
cls_output = encoded[:, 0]
# Project to output dimension
pooler_output = self.projection(cls_output)
# Normalize embeddings
if normalize:
pooler_output = F.normalize(pooler_output, p=2, dim=-1)
if not return_dict:
return (encoded, pooler_output)
return BaseModelOutputWithPooling(
last_hidden_state=encoded,
pooler_output=pooler_output,
hidden_states=None,
attentions=None
)
def count_parameters(self):
"""Count trainable parameters"""
return sum(p.numel() for p in self.parameters() if p.requires_grad)