File size: 4,503 Bytes
eeec1cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Standalone LASER2 encoder — no fairseq dependency.

LASER2 architecture (from checkpoint params):
- 50,004 vocab, 320 embed dim
- 5-layer BiLSTM, hidden 512, bidirectional (= 1024 output dim)
- Left-padded input with padding_idx=1
- Sentence embedding = max-pool over BiLSTM final layer outputs → 1024-dim

We bypass fairseq by loading weights directly into nn.LSTM.

Note: torch.nn.Module class method used for inference mode (not bare function name).
"""

import os
import torch
import torch.nn as nn
import sentencepiece as spm


class LaserEncoder(nn.Module):
    """Pure PyTorch LASER2 encoder. Compatible with the original checkpoint."""

    def __init__(self, vocab_size=50004, embed_dim=320, hidden_size=512,
                 num_layers=5, padding_idx=1):
        super().__init__()
        self.padding_idx = padding_idx
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_dim = hidden_size * 2  # bidirectional

        self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            bidirectional=True,
            batch_first=False,
        )

    @classmethod
    def from_checkpoint(cls, checkpoint_path, device="cuda"):
        ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
        params = ckpt["params"]
        model = cls(
            vocab_size=params["num_embeddings"],
            embed_dim=params["embed_dim"],
            hidden_size=params["hidden_size"],
            num_layers=params["num_layers"],
            padding_idx=params["padding_idx"],
        )
        model.load_state_dict(ckpt["model"])
        model = model.to(device)
        torch.nn.Module.eval(model)  # inference mode
        for p in model.parameters():
            p.requires_grad_(False)
        return model

    @torch.no_grad()
    def forward(self, src_tokens, return_token_states=False):
        """
        Args:
            src_tokens: [B, T] token IDs, left-padded
            return_token_states: if True, return [B, T, 1024], else max-pooled [B, 1024]
        """
        embeds = self.embed_tokens(src_tokens)
        embeds = embeds.transpose(0, 1)  # [T, B, 320]

        pad_mask = (src_tokens == self.padding_idx)  # [B, T]

        output, _ = self.lstm(embeds)  # [T, B, 1024]
        output = output.transpose(0, 1)  # [B, T, 1024]

        if return_token_states:
            return output, pad_mask

        output_masked = output.masked_fill(pad_mask.unsqueeze(-1), float("-inf"))
        sentence_emb = output_masked.max(dim=1)[0]  # [B, 1024]
        return sentence_emb


class LaserTokenizer:
    """LASER2 SentencePiece tokenizer with left-padding."""

    # fairseq dictionary order: bos=0, pad=1, eos=2, unk=3, then SPM tokens from id=4
    EOS_ID = 2
    VOCAB_OFFSET = 4  # SPM ids shifted by 4 to match fairseq dict

    def __init__(self, spm_path):
        self.sp = spm.SentencePieceProcessor(model_file=spm_path)

    def encode(self, text, add_eos=True):
        spm_ids = self.sp.encode(text, out_type=int)
        shifted = [x + self.VOCAB_OFFSET for x in spm_ids]
        if add_eos:
            shifted.append(self.EOS_ID)
        return shifted

    def encode_batch(self, texts, padding_idx=1, device="cuda", left_pad=True, max_len=None):
        """Encode texts into padded batch.

        Args:
            left_pad: LASER2 default. For decoder alignment use False (right-pad).
            max_len: if set, truncate and pad to this length (for fixed context).
        """
        encoded = [self.encode(t) for t in texts]
        if max_len is None:
            max_len = max(len(e) for e in encoded)
        batch = []
        for ids in encoded:
            ids = ids[:max_len]
            pad_len = max_len - len(ids)
            if left_pad:
                padded = [padding_idx] * pad_len + ids
            else:
                padded = ids + [padding_idx] * pad_len
            batch.append(padded)
        return torch.tensor(batch, dtype=torch.long, device=device)


def encode_texts_laser(encoder, tokenizer, texts, device="cuda"):
    """Encode a list of texts to L2-normalized embeddings."""
    tokens = tokenizer.encode_batch(texts, padding_idx=encoder.padding_idx, device=device)
    embs = encoder(tokens)
    embs = torch.nn.functional.normalize(embs, p=2, dim=-1)
    return embs.cpu().numpy()