stlenc-arch / modeling_stlenc.py
saracandu's picture
Training in progress, step 23750
2ed168e verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_stlenc import STLEncoderConfig
class STLEncoderModel(PreTrainedModel):
config_class = STLEncoderConfig
def __init__(self, config):
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.hidden_size,
nhead=config.num_attention_heads,
dim_feedforward=config.intermediate_size,
activation="gelu", # GELU è standard per i Transformer moderni
batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers)
# --- POTENZIAMENTO ARCHITETTURALE ---
# Creiamo una testa di proiezione profonda (MLP)
self.projector = nn.Sequential(
nn.Linear(config.hidden_size, config.hidden_size),
nn.GELU(),
nn.LayerNorm(config.hidden_size),
nn.Dropout(0.1), # Aiuta a non overfittare sulle costanti numeriche
nn.Linear(config.hidden_size, config.hidden_size // 2),
nn.GELU(),
nn.Linear(config.hidden_size // 2, config.embedding_dim_target)
)
# ------------------------------------
self.post_init()
def forward(self, input_ids, attention_mask=None, **kwargs):
batch_size, seq_length = input_ids.size()
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_length)
x = self.embeddings(input_ids) + self.position_embeddings(position_ids)
# Maschera per il padding (TransformerEncoder si aspetta True dove NON deve guardare)
padding_mask = (attention_mask == 0) if attention_mask is not None else None
# Encoding delle sequenze
sequence_output = self.encoder(x, src_key_padding_mask=padding_mask)
# Prendiamo il CLS (indice 0)
cls_token = sequence_output[:, 0, :]
# Passiamo per la testa di proiezione non-lineare
# Rimuoviamo la Tanh finale per lasciare che il kernel scalare respiri
pooled_output = self.projector(cls_token)
return pooled_output