"""Standalone LASER2 encoder — no fairseq dependency. LASER2 architecture (from checkpoint params): - 50,004 vocab, 320 embed dim - 5-layer BiLSTM, hidden 512, bidirectional (= 1024 output dim) - Left-padded input with padding_idx=1 - Sentence embedding = max-pool over BiLSTM final layer outputs → 1024-dim We bypass fairseq by loading weights directly into nn.LSTM. Note: torch.nn.Module class method used for inference mode (not bare function name). """ import os import torch import torch.nn as nn import sentencepiece as spm class LaserEncoder(nn.Module): """Pure PyTorch LASER2 encoder. Compatible with the original checkpoint.""" def __init__(self, vocab_size=50004, embed_dim=320, hidden_size=512, num_layers=5, padding_idx=1): super().__init__() self.padding_idx = padding_idx self.hidden_size = hidden_size self.num_layers = num_layers self.output_dim = hidden_size * 2 # bidirectional self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx) self.lstm = nn.LSTM( input_size=embed_dim, hidden_size=hidden_size, num_layers=num_layers, bidirectional=True, batch_first=False, ) @classmethod def from_checkpoint(cls, checkpoint_path, device="cuda"): ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) params = ckpt["params"] model = cls( vocab_size=params["num_embeddings"], embed_dim=params["embed_dim"], hidden_size=params["hidden_size"], num_layers=params["num_layers"], padding_idx=params["padding_idx"], ) model.load_state_dict(ckpt["model"]) model = model.to(device) torch.nn.Module.eval(model) # inference mode for p in model.parameters(): p.requires_grad_(False) return model @torch.no_grad() def forward(self, src_tokens, return_token_states=False): """ Args: src_tokens: [B, T] token IDs, left-padded return_token_states: if True, return [B, T, 1024], else max-pooled [B, 1024] """ embeds = self.embed_tokens(src_tokens) embeds = embeds.transpose(0, 1) # [T, B, 320] pad_mask = (src_tokens == self.padding_idx) # [B, T] output, _ = self.lstm(embeds) # [T, B, 1024] output = output.transpose(0, 1) # [B, T, 1024] if return_token_states: return output, pad_mask output_masked = output.masked_fill(pad_mask.unsqueeze(-1), float("-inf")) sentence_emb = output_masked.max(dim=1)[0] # [B, 1024] return sentence_emb class LaserTokenizer: """LASER2 SentencePiece tokenizer with left-padding.""" # fairseq dictionary order: bos=0, pad=1, eos=2, unk=3, then SPM tokens from id=4 EOS_ID = 2 VOCAB_OFFSET = 4 # SPM ids shifted by 4 to match fairseq dict def __init__(self, spm_path): self.sp = spm.SentencePieceProcessor(model_file=spm_path) def encode(self, text, add_eos=True): spm_ids = self.sp.encode(text, out_type=int) shifted = [x + self.VOCAB_OFFSET for x in spm_ids] if add_eos: shifted.append(self.EOS_ID) return shifted def encode_batch(self, texts, padding_idx=1, device="cuda", left_pad=True, max_len=None): """Encode texts into padded batch. Args: left_pad: LASER2 default. For decoder alignment use False (right-pad). max_len: if set, truncate and pad to this length (for fixed context). """ encoded = [self.encode(t) for t in texts] if max_len is None: max_len = max(len(e) for e in encoded) batch = [] for ids in encoded: ids = ids[:max_len] pad_len = max_len - len(ids) if left_pad: padded = [padding_idx] * pad_len + ids else: padded = ids + [padding_idx] * pad_len batch.append(padded) return torch.tensor(batch, dtype=torch.long, device=device) def encode_texts_laser(encoder, tokenizer, texts, device="cuda"): """Encode a list of texts to L2-normalized embeddings.""" tokens = tokenizer.encode_batch(texts, padding_idx=encoder.padding_idx, device=device) embs = encoder(tokens) embs = torch.nn.functional.normalize(embs, p=2, dim=-1) return embs.cpu().numpy()