turkish-sentence-encoder / modeling_turkish_encoder.py
Basar2004's picture
Upload folder using huggingface_hub
af8602a verified
"""Turkish Sentence Encoder Model."""
import torch
import torch.nn as nn
from torch import Tensor
from typing import Optional
import torch.nn.functional as F
class InputEmbeddings(nn.Module):
def __init__(self, vocab_size: int, d_model: int, max_len: int, padding_idx: int = 0, dropout: float = 0.1):
super().__init__()
self.token_embed = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
self.pos_embed = nn.Embedding(max_len, d_model)
self.dropout = nn.Dropout(dropout)
self.d_model = d_model
def forward(self, x: Tensor) -> Tensor:
seq_len = x.size(1)
positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
x = self.token_embed(x) + self.pos_embed(positions)
return self.dropout(x)
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1, ffn_mult: int = 4, layer_idx: int = 0, n_layers: int = 1):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
self.ln2 = nn.LayerNorm(d_model)
self.ffn_fc1 = nn.Linear(d_model, d_model * ffn_mult)
self.ffn_fc2 = nn.Linear(d_model * ffn_mult, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None) -> Tensor:
x_norm = self.ln1(x)
attn_out, _ = self.attn(x_norm, x_norm, x_norm, key_padding_mask=key_padding_mask)
x = x + self.dropout(attn_out)
x_norm = self.ln2(x)
ffn_out = self.ffn_fc2(self.dropout(F.gelu(self.ffn_fc1(x_norm))))
x = x + self.dropout(ffn_out)
return x
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size: int, d_model: int, max_len: int, n_layers: int, n_heads: int,
padding_idx: int = 0, dropout: float = 0.1, ffn_mult: int = 4):
super().__init__()
self.emb = InputEmbeddings(vocab_size, d_model, max_len, padding_idx, dropout)
self.layers = nn.ModuleList([
TransformerEncoderLayer(d_model, n_heads, dropout, ffn_mult, i, n_layers)
for i in range(n_layers)
])
self.final_ln = nn.LayerNorm(d_model)
def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None) -> Tensor:
x = self.emb(input_ids)
key_padding_mask = None
if attention_mask is not None:
key_padding_mask = (attention_mask == 0)
for layer in self.layers:
x = layer(x, key_padding_mask=key_padding_mask)
return self.final_ln(x)
class TurkishSentenceEncoder(nn.Module):
"""Turkish Sentence Encoder for generating sentence embeddings."""
def __init__(self, config=None):
super().__init__()
if config is None:
config = {
"vocab_size": 32000,
"d_model": 512,
"max_len": 64,
"n_layers": 12,
"n_heads": 8,
"padding_idx": 0,
"dropout": 0.1,
"ffn_mult": 4,
}
self.config = config
self.encoder = TransformerEncoder(
vocab_size=config.get("vocab_size", 32000),
d_model=config.get("d_model", 512),
max_len=config.get("max_len", 64),
n_layers=config.get("n_layers", 12),
n_heads=config.get("n_heads", 8),
padding_idx=config.get("padding_idx", 0),
dropout=config.get("dropout", 0.1),
ffn_mult=config.get("ffn_mult", 4),
)
# MLM head (for compatibility with pretrained weights)
self.mlm_head = nn.Linear(config.get("d_model", 512), config.get("vocab_size", 32000), bias=True)
def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None, **kwargs) -> Tensor:
"""
Forward pass that returns sentence embeddings (mean pooled).
"""
encoder_output = self.encoder(input_ids, attention_mask=attention_mask)
# Mean pooling
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).expand(encoder_output.size()).float()
summed = torch.sum(encoder_output * mask, dim=1)
counted = torch.clamp(mask.sum(dim=1), min=1e-9)
embeddings = summed / counted
else:
embeddings = torch.mean(encoder_output, dim=1)
# Normalize embeddings
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings
@classmethod
def from_pretrained(cls, model_path: str, **kwargs):
"""Load model from pretrained weights."""
import json
import os
config_path = os.path.join(model_path, "config.json")
if os.path.exists(config_path):
with open(config_path) as f:
config = json.load(f)
else:
config = None
model = cls(config)
weights_path = os.path.join(model_path, "pytorch_model.bin")
if os.path.exists(weights_path):
state_dict = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state_dict, strict=False)
return model