| """Standalone LASER2 encoder — no fairseq dependency. |
| |
| LASER2 architecture (from checkpoint params): |
| - 50,004 vocab, 320 embed dim |
| - 5-layer BiLSTM, hidden 512, bidirectional (= 1024 output dim) |
| - Left-padded input with padding_idx=1 |
| - Sentence embedding = max-pool over BiLSTM final layer outputs → 1024-dim |
| |
| We bypass fairseq by loading weights directly into nn.LSTM. |
| |
| Note: torch.nn.Module class method used for inference mode (not bare function name). |
| """ |
|
|
| import os |
| import torch |
| import torch.nn as nn |
| import sentencepiece as spm |
|
|
|
|
| class LaserEncoder(nn.Module): |
| """Pure PyTorch LASER2 encoder. Compatible with the original checkpoint.""" |
|
|
| def __init__(self, vocab_size=50004, embed_dim=320, hidden_size=512, |
| num_layers=5, padding_idx=1): |
| super().__init__() |
| self.padding_idx = padding_idx |
| self.hidden_size = hidden_size |
| self.num_layers = num_layers |
| self.output_dim = hidden_size * 2 |
|
|
| self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx) |
| self.lstm = nn.LSTM( |
| input_size=embed_dim, |
| hidden_size=hidden_size, |
| num_layers=num_layers, |
| bidirectional=True, |
| batch_first=False, |
| ) |
|
|
| @classmethod |
| def from_checkpoint(cls, checkpoint_path, device="cuda"): |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) |
| params = ckpt["params"] |
| model = cls( |
| vocab_size=params["num_embeddings"], |
| embed_dim=params["embed_dim"], |
| hidden_size=params["hidden_size"], |
| num_layers=params["num_layers"], |
| padding_idx=params["padding_idx"], |
| ) |
| model.load_state_dict(ckpt["model"]) |
| model = model.to(device) |
| torch.nn.Module.eval(model) |
| for p in model.parameters(): |
| p.requires_grad_(False) |
| return model |
|
|
| @torch.no_grad() |
| def forward(self, src_tokens, return_token_states=False): |
| """ |
| Args: |
| src_tokens: [B, T] token IDs, left-padded |
| return_token_states: if True, return [B, T, 1024], else max-pooled [B, 1024] |
| """ |
| embeds = self.embed_tokens(src_tokens) |
| embeds = embeds.transpose(0, 1) |
|
|
| pad_mask = (src_tokens == self.padding_idx) |
|
|
| output, _ = self.lstm(embeds) |
| output = output.transpose(0, 1) |
|
|
| if return_token_states: |
| return output, pad_mask |
|
|
| output_masked = output.masked_fill(pad_mask.unsqueeze(-1), float("-inf")) |
| sentence_emb = output_masked.max(dim=1)[0] |
| return sentence_emb |
|
|
|
|
| class LaserTokenizer: |
| """LASER2 SentencePiece tokenizer with left-padding.""" |
|
|
| |
| EOS_ID = 2 |
| VOCAB_OFFSET = 4 |
|
|
| def __init__(self, spm_path): |
| self.sp = spm.SentencePieceProcessor(model_file=spm_path) |
|
|
| def encode(self, text, add_eos=True): |
| spm_ids = self.sp.encode(text, out_type=int) |
| shifted = [x + self.VOCAB_OFFSET for x in spm_ids] |
| if add_eos: |
| shifted.append(self.EOS_ID) |
| return shifted |
|
|
| def encode_batch(self, texts, padding_idx=1, device="cuda", left_pad=True, max_len=None): |
| """Encode texts into padded batch. |
| |
| Args: |
| left_pad: LASER2 default. For decoder alignment use False (right-pad). |
| max_len: if set, truncate and pad to this length (for fixed context). |
| """ |
| encoded = [self.encode(t) for t in texts] |
| if max_len is None: |
| max_len = max(len(e) for e in encoded) |
| batch = [] |
| for ids in encoded: |
| ids = ids[:max_len] |
| pad_len = max_len - len(ids) |
| if left_pad: |
| padded = [padding_idx] * pad_len + ids |
| else: |
| padded = ids + [padding_idx] * pad_len |
| batch.append(padded) |
| return torch.tensor(batch, dtype=torch.long, device=device) |
|
|
|
|
| def encode_texts_laser(encoder, tokenizer, texts, device="cuda"): |
| """Encode a list of texts to L2-normalized embeddings.""" |
| tokens = tokenizer.encode_batch(texts, padding_idx=encoder.padding_idx, device=device) |
| embs = encoder(tokens) |
| embs = torch.nn.functional.normalize(embs, p=2, dim=-1) |
| return embs.cpu().numpy() |
|
|