|
|
"""Turkish Sentence Encoder Model.""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch import Tensor |
|
|
from typing import Optional |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class InputEmbeddings(nn.Module): |
|
|
def __init__(self, vocab_size: int, d_model: int, max_len: int, padding_idx: int = 0, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.token_embed = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) |
|
|
self.pos_embed = nn.Embedding(max_len, d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.d_model = d_model |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
seq_len = x.size(1) |
|
|
positions = torch.arange(seq_len, device=x.device).unsqueeze(0) |
|
|
x = self.token_embed(x) + self.pos_embed(positions) |
|
|
return self.dropout(x) |
|
|
|
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
|
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1, ffn_mult: int = 4, layer_idx: int = 0, n_layers: int = 1): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(d_model) |
|
|
self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True) |
|
|
self.ln2 = nn.LayerNorm(d_model) |
|
|
self.ffn_fc1 = nn.Linear(d_model, d_model * ffn_mult) |
|
|
self.ffn_fc2 = nn.Linear(d_model * ffn_mult, d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None) -> Tensor: |
|
|
x_norm = self.ln1(x) |
|
|
attn_out, _ = self.attn(x_norm, x_norm, x_norm, key_padding_mask=key_padding_mask) |
|
|
x = x + self.dropout(attn_out) |
|
|
x_norm = self.ln2(x) |
|
|
ffn_out = self.ffn_fc2(self.dropout(F.gelu(self.ffn_fc1(x_norm)))) |
|
|
x = x + self.dropout(ffn_out) |
|
|
return x |
|
|
|
|
|
|
|
|
class TransformerEncoder(nn.Module): |
|
|
def __init__(self, vocab_size: int, d_model: int, max_len: int, n_layers: int, n_heads: int, |
|
|
padding_idx: int = 0, dropout: float = 0.1, ffn_mult: int = 4): |
|
|
super().__init__() |
|
|
self.emb = InputEmbeddings(vocab_size, d_model, max_len, padding_idx, dropout) |
|
|
self.layers = nn.ModuleList([ |
|
|
TransformerEncoderLayer(d_model, n_heads, dropout, ffn_mult, i, n_layers) |
|
|
for i in range(n_layers) |
|
|
]) |
|
|
self.final_ln = nn.LayerNorm(d_model) |
|
|
|
|
|
def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None) -> Tensor: |
|
|
x = self.emb(input_ids) |
|
|
key_padding_mask = None |
|
|
if attention_mask is not None: |
|
|
key_padding_mask = (attention_mask == 0) |
|
|
for layer in self.layers: |
|
|
x = layer(x, key_padding_mask=key_padding_mask) |
|
|
return self.final_ln(x) |
|
|
|
|
|
|
|
|
class TurkishSentenceEncoder(nn.Module): |
|
|
"""Turkish Sentence Encoder for generating sentence embeddings.""" |
|
|
|
|
|
def __init__(self, config=None): |
|
|
super().__init__() |
|
|
if config is None: |
|
|
config = { |
|
|
"vocab_size": 32000, |
|
|
"d_model": 512, |
|
|
"max_len": 64, |
|
|
"n_layers": 12, |
|
|
"n_heads": 8, |
|
|
"padding_idx": 0, |
|
|
"dropout": 0.1, |
|
|
"ffn_mult": 4, |
|
|
} |
|
|
|
|
|
self.config = config |
|
|
self.encoder = TransformerEncoder( |
|
|
vocab_size=config.get("vocab_size", 32000), |
|
|
d_model=config.get("d_model", 512), |
|
|
max_len=config.get("max_len", 64), |
|
|
n_layers=config.get("n_layers", 12), |
|
|
n_heads=config.get("n_heads", 8), |
|
|
padding_idx=config.get("padding_idx", 0), |
|
|
dropout=config.get("dropout", 0.1), |
|
|
ffn_mult=config.get("ffn_mult", 4), |
|
|
) |
|
|
|
|
|
self.mlm_head = nn.Linear(config.get("d_model", 512), config.get("vocab_size", 32000), bias=True) |
|
|
|
|
|
def forward(self, input_ids: Tensor, attention_mask: Optional[Tensor] = None, **kwargs) -> Tensor: |
|
|
""" |
|
|
Forward pass that returns sentence embeddings (mean pooled). |
|
|
""" |
|
|
encoder_output = self.encoder(input_ids, attention_mask=attention_mask) |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
mask = attention_mask.unsqueeze(-1).expand(encoder_output.size()).float() |
|
|
summed = torch.sum(encoder_output * mask, dim=1) |
|
|
counted = torch.clamp(mask.sum(dim=1), min=1e-9) |
|
|
embeddings = summed / counted |
|
|
else: |
|
|
embeddings = torch.mean(encoder_output, dim=1) |
|
|
|
|
|
|
|
|
embeddings = F.normalize(embeddings, p=2, dim=1) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, model_path: str, **kwargs): |
|
|
"""Load model from pretrained weights.""" |
|
|
import json |
|
|
import os |
|
|
|
|
|
config_path = os.path.join(model_path, "config.json") |
|
|
if os.path.exists(config_path): |
|
|
with open(config_path) as f: |
|
|
config = json.load(f) |
|
|
else: |
|
|
config = None |
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
weights_path = os.path.join(model_path, "pytorch_model.bin") |
|
|
if os.path.exists(weights_path): |
|
|
state_dict = torch.load(weights_path, map_location="cpu") |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
return model |
|
|
|