arkadiko-v4-ablation / code /laser_encoder.py
Ahmed
Upload code/laser_encoder.py with huggingface_hub
eeec1cb verified
"""Standalone LASER2 encoder — no fairseq dependency.
LASER2 architecture (from checkpoint params):
- 50,004 vocab, 320 embed dim
- 5-layer BiLSTM, hidden 512, bidirectional (= 1024 output dim)
- Left-padded input with padding_idx=1
- Sentence embedding = max-pool over BiLSTM final layer outputs → 1024-dim
We bypass fairseq by loading weights directly into nn.LSTM.
Note: torch.nn.Module class method used for inference mode (not bare function name).
"""
import os
import torch
import torch.nn as nn
import sentencepiece as spm
class LaserEncoder(nn.Module):
"""Pure PyTorch LASER2 encoder. Compatible with the original checkpoint."""
def __init__(self, vocab_size=50004, embed_dim=320, hidden_size=512,
num_layers=5, padding_idx=1):
super().__init__()
self.padding_idx = padding_idx
self.hidden_size = hidden_size
self.num_layers = num_layers
self.output_dim = hidden_size * 2 # bidirectional
self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_size,
num_layers=num_layers,
bidirectional=True,
batch_first=False,
)
@classmethod
def from_checkpoint(cls, checkpoint_path, device="cuda"):
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
params = ckpt["params"]
model = cls(
vocab_size=params["num_embeddings"],
embed_dim=params["embed_dim"],
hidden_size=params["hidden_size"],
num_layers=params["num_layers"],
padding_idx=params["padding_idx"],
)
model.load_state_dict(ckpt["model"])
model = model.to(device)
torch.nn.Module.eval(model) # inference mode
for p in model.parameters():
p.requires_grad_(False)
return model
@torch.no_grad()
def forward(self, src_tokens, return_token_states=False):
"""
Args:
src_tokens: [B, T] token IDs, left-padded
return_token_states: if True, return [B, T, 1024], else max-pooled [B, 1024]
"""
embeds = self.embed_tokens(src_tokens)
embeds = embeds.transpose(0, 1) # [T, B, 320]
pad_mask = (src_tokens == self.padding_idx) # [B, T]
output, _ = self.lstm(embeds) # [T, B, 1024]
output = output.transpose(0, 1) # [B, T, 1024]
if return_token_states:
return output, pad_mask
output_masked = output.masked_fill(pad_mask.unsqueeze(-1), float("-inf"))
sentence_emb = output_masked.max(dim=1)[0] # [B, 1024]
return sentence_emb
class LaserTokenizer:
"""LASER2 SentencePiece tokenizer with left-padding."""
# fairseq dictionary order: bos=0, pad=1, eos=2, unk=3, then SPM tokens from id=4
EOS_ID = 2
VOCAB_OFFSET = 4 # SPM ids shifted by 4 to match fairseq dict
def __init__(self, spm_path):
self.sp = spm.SentencePieceProcessor(model_file=spm_path)
def encode(self, text, add_eos=True):
spm_ids = self.sp.encode(text, out_type=int)
shifted = [x + self.VOCAB_OFFSET for x in spm_ids]
if add_eos:
shifted.append(self.EOS_ID)
return shifted
def encode_batch(self, texts, padding_idx=1, device="cuda", left_pad=True, max_len=None):
"""Encode texts into padded batch.
Args:
left_pad: LASER2 default. For decoder alignment use False (right-pad).
max_len: if set, truncate and pad to this length (for fixed context).
"""
encoded = [self.encode(t) for t in texts]
if max_len is None:
max_len = max(len(e) for e in encoded)
batch = []
for ids in encoded:
ids = ids[:max_len]
pad_len = max_len - len(ids)
if left_pad:
padded = [padding_idx] * pad_len + ids
else:
padded = ids + [padding_idx] * pad_len
batch.append(padded)
return torch.tensor(batch, dtype=torch.long, device=device)
def encode_texts_laser(encoder, tokenizer, texts, device="cuda"):
"""Encode a list of texts to L2-normalized embeddings."""
tokens = tokenizer.encode_batch(texts, padding_idx=encoder.padding_idx, device=device)
embs = encoder(tokens)
embs = torch.nn.functional.normalize(embs, p=2, dim=-1)
return embs.cpu().numpy()