"""Tokenization utilities used for transformer models.""" from __future__ import annotations from dataclasses import dataclass from typing import Any, Dict, List, Optional, Union from transformers import AutoTokenizer @dataclass class RussianTextTokenizer: """Thin wrapper around a HuggingFace tokenizer with sane defaults.""" model_name: str = "DeepPavlov/rubert-base-cased" max_length: int = 128 padding: Union[bool, str] = "max_length" truncation: bool = True def __post_init__(self) -> None: self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True) def get_vocab_size(self) -> int: return int(getattr(self.tokenizer, "vocab_size", len(self.tokenizer.get_vocab()))) def get_special_tokens(self) -> Dict[str, Optional[int]]: return { "pad_token_id": self.tokenizer.pad_token_id, "cls_token_id": self.tokenizer.cls_token_id, "sep_token_id": self.tokenizer.sep_token_id, "unk_token_id": self.tokenizer.unk_token_id, } def tokenize(self, text: str, add_special_tokens: bool = True) -> List[str]: return self.tokenizer.tokenize(text or "", add_special_tokens=add_special_tokens) def encode( self, text: str, *, max_length: Optional[int] = None, padding: Optional[Union[bool, str]] = None, truncation: Optional[bool] = None, return_tensors: Optional[str] = "pt", ) -> Dict[str, Any]: """Encode a single text. Returns a dict containing `input_ids` and `attention_mask`. """ max_length_eff = max_length or self.max_length padding_eff = self.padding if padding is None else padding truncation_eff = self.truncation if truncation is None else truncation if return_tensors is None: enc = self.tokenizer( text or "", max_length=max_length_eff, padding=padding_eff, truncation=truncation_eff, return_attention_mask=True, return_tensors=None, ) # HuggingFace returns lists for a single example; standardize to batch-like shape. return { "input_ids": [enc["input_ids"]], "attention_mask": [enc["attention_mask"]], } return self.tokenizer( text or "", max_length=max_length_eff, padding=padding_eff, truncation=truncation_eff, return_attention_mask=True, return_tensors=return_tensors, ) def encode_batch( self, texts: List[str], *, max_length: Optional[int] = None, padding: Optional[Union[bool, str]] = None, truncation: Optional[bool] = None, return_tensors: str = "pt", ) -> Dict[str, Any]: max_length_eff = max_length or self.max_length padding_eff = self.padding if padding is None else padding truncation_eff = self.truncation if truncation is None else truncation return self.tokenizer( [t or "" for t in texts], max_length=max_length_eff, padding=padding_eff, truncation=truncation_eff, return_attention_mask=True, return_tensors=return_tensors, ) def decode(self, token_ids: Union[List[int], Any], skip_special_tokens: bool = True) -> str: # Avoid importing torch at module import time; handle torch tensors via duck-typing. if hasattr(token_ids, "detach") and hasattr(token_ids, "cpu") and hasattr(token_ids, "tolist"): token_ids = token_ids.detach().cpu().tolist() return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) def get_token_info(self, token_id: int) -> Dict[str, Any]: tok = self.tokenizer.convert_ids_to_tokens(int(token_id)) specials = set(self.tokenizer.all_special_ids) return { "token_id": int(token_id), "token": tok, "is_special": int(token_id) in specials, } def create_tokenizer(model_name: str = "DeepPavlov/rubert-base-cased", max_length: int = 128) -> RussianTextTokenizer: return RussianTextTokenizer(model_name=model_name, max_length=max_length) def tokenize_text_pair( *, title: str, snippet: Optional[str], tokenizer: RussianTextTokenizer, max_title_len: int = 128, max_snippet_len: int = 256, ) -> Dict[str, Any]: """Tokenize (title, snippet) as two independent sequences (not a single pair encoding).""" title_enc = tokenizer.encode(title or "", max_length=max_title_len, return_tensors="pt") out: Dict[str, Any] = { "title_input_ids": title_enc["input_ids"].squeeze(0), "title_attention_mask": title_enc["attention_mask"].squeeze(0), } if snippet is not None: snip_enc = tokenizer.encode(snippet or "", max_length=max_snippet_len, return_tensors="pt") out.update( { "snippet_input_ids": snip_enc["input_ids"].squeeze(0), "snippet_attention_mask": snip_enc["attention_mask"].squeeze(0), } ) return out