File size: 5,258 Bytes
090e11e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Tokenization utilities used for transformer models."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union

from transformers import AutoTokenizer


@dataclass
class RussianTextTokenizer:
    """Thin wrapper around a HuggingFace tokenizer with sane defaults."""

    model_name: str = "DeepPavlov/rubert-base-cased"
    max_length: int = 128
    padding: Union[bool, str] = "max_length"
    truncation: bool = True

    def __post_init__(self) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)

    def get_vocab_size(self) -> int:
        return int(getattr(self.tokenizer, "vocab_size", len(self.tokenizer.get_vocab())))

    def get_special_tokens(self) -> Dict[str, Optional[int]]:
        return {
            "pad_token_id": self.tokenizer.pad_token_id,
            "cls_token_id": self.tokenizer.cls_token_id,
            "sep_token_id": self.tokenizer.sep_token_id,
            "unk_token_id": self.tokenizer.unk_token_id,
        }

    def tokenize(self, text: str, add_special_tokens: bool = True) -> List[str]:
        return self.tokenizer.tokenize(text or "", add_special_tokens=add_special_tokens)

    def encode(
        self,
        text: str,
        *,
        max_length: Optional[int] = None,
        padding: Optional[Union[bool, str]] = None,
        truncation: Optional[bool] = None,
        return_tensors: Optional[str] = "pt",
    ) -> Dict[str, Any]:
        """Encode a single text.

        Returns a dict containing `input_ids` and `attention_mask`.
        """
        max_length_eff = max_length or self.max_length
        padding_eff = self.padding if padding is None else padding
        truncation_eff = self.truncation if truncation is None else truncation

        if return_tensors is None:
            enc = self.tokenizer(
                text or "",
                max_length=max_length_eff,
                padding=padding_eff,
                truncation=truncation_eff,
                return_attention_mask=True,
                return_tensors=None,
            )
            # HuggingFace returns lists for a single example; standardize to batch-like shape.
            return {
                "input_ids": [enc["input_ids"]],
                "attention_mask": [enc["attention_mask"]],
            }

        return self.tokenizer(
            text or "",
            max_length=max_length_eff,
            padding=padding_eff,
            truncation=truncation_eff,
            return_attention_mask=True,
            return_tensors=return_tensors,
        )

    def encode_batch(
        self,
        texts: List[str],
        *,
        max_length: Optional[int] = None,
        padding: Optional[Union[bool, str]] = None,
        truncation: Optional[bool] = None,
        return_tensors: str = "pt",
    ) -> Dict[str, Any]:
        max_length_eff = max_length or self.max_length
        padding_eff = self.padding if padding is None else padding
        truncation_eff = self.truncation if truncation is None else truncation
        return self.tokenizer(
            [t or "" for t in texts],
            max_length=max_length_eff,
            padding=padding_eff,
            truncation=truncation_eff,
            return_attention_mask=True,
            return_tensors=return_tensors,
        )

    def decode(self, token_ids: Union[List[int], Any], skip_special_tokens: bool = True) -> str:
        # Avoid importing torch at module import time; handle torch tensors via duck-typing.
        if hasattr(token_ids, "detach") and hasattr(token_ids, "cpu") and hasattr(token_ids, "tolist"):
            token_ids = token_ids.detach().cpu().tolist()
        return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)

    def get_token_info(self, token_id: int) -> Dict[str, Any]:
        tok = self.tokenizer.convert_ids_to_tokens(int(token_id))
        specials = set(self.tokenizer.all_special_ids)
        return {
            "token_id": int(token_id),
            "token": tok,
            "is_special": int(token_id) in specials,
        }


def create_tokenizer(model_name: str = "DeepPavlov/rubert-base-cased", max_length: int = 128) -> RussianTextTokenizer:
    return RussianTextTokenizer(model_name=model_name, max_length=max_length)


def tokenize_text_pair(
    *,
    title: str,
    snippet: Optional[str],
    tokenizer: RussianTextTokenizer,
    max_title_len: int = 128,
    max_snippet_len: int = 256,
) -> Dict[str, Any]:
    """Tokenize (title, snippet) as two independent sequences (not a single pair encoding)."""
    title_enc = tokenizer.encode(title or "", max_length=max_title_len, return_tensors="pt")
    out: Dict[str, Any] = {
        "title_input_ids": title_enc["input_ids"].squeeze(0),
        "title_attention_mask": title_enc["attention_mask"].squeeze(0),
    }

    if snippet is not None:
        snip_enc = tokenizer.encode(snippet or "", max_length=max_snippet_len, return_tensors="pt")
        out.update(
            {
                "snippet_input_ids": snip_enc["input_ids"].squeeze(0),
                "snippet_attention_mask": snip_enc["attention_mask"].squeeze(0),
            }
        )

    return out