File size: 5,258 Bytes
090e11e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
"""Tokenization utilities used for transformer models."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from transformers import AutoTokenizer
@dataclass
class RussianTextTokenizer:
"""Thin wrapper around a HuggingFace tokenizer with sane defaults."""
model_name: str = "DeepPavlov/rubert-base-cased"
max_length: int = 128
padding: Union[bool, str] = "max_length"
truncation: bool = True
def __post_init__(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, use_fast=True)
def get_vocab_size(self) -> int:
return int(getattr(self.tokenizer, "vocab_size", len(self.tokenizer.get_vocab())))
def get_special_tokens(self) -> Dict[str, Optional[int]]:
return {
"pad_token_id": self.tokenizer.pad_token_id,
"cls_token_id": self.tokenizer.cls_token_id,
"sep_token_id": self.tokenizer.sep_token_id,
"unk_token_id": self.tokenizer.unk_token_id,
}
def tokenize(self, text: str, add_special_tokens: bool = True) -> List[str]:
return self.tokenizer.tokenize(text or "", add_special_tokens=add_special_tokens)
def encode(
self,
text: str,
*,
max_length: Optional[int] = None,
padding: Optional[Union[bool, str]] = None,
truncation: Optional[bool] = None,
return_tensors: Optional[str] = "pt",
) -> Dict[str, Any]:
"""Encode a single text.
Returns a dict containing `input_ids` and `attention_mask`.
"""
max_length_eff = max_length or self.max_length
padding_eff = self.padding if padding is None else padding
truncation_eff = self.truncation if truncation is None else truncation
if return_tensors is None:
enc = self.tokenizer(
text or "",
max_length=max_length_eff,
padding=padding_eff,
truncation=truncation_eff,
return_attention_mask=True,
return_tensors=None,
)
# HuggingFace returns lists for a single example; standardize to batch-like shape.
return {
"input_ids": [enc["input_ids"]],
"attention_mask": [enc["attention_mask"]],
}
return self.tokenizer(
text or "",
max_length=max_length_eff,
padding=padding_eff,
truncation=truncation_eff,
return_attention_mask=True,
return_tensors=return_tensors,
)
def encode_batch(
self,
texts: List[str],
*,
max_length: Optional[int] = None,
padding: Optional[Union[bool, str]] = None,
truncation: Optional[bool] = None,
return_tensors: str = "pt",
) -> Dict[str, Any]:
max_length_eff = max_length or self.max_length
padding_eff = self.padding if padding is None else padding
truncation_eff = self.truncation if truncation is None else truncation
return self.tokenizer(
[t or "" for t in texts],
max_length=max_length_eff,
padding=padding_eff,
truncation=truncation_eff,
return_attention_mask=True,
return_tensors=return_tensors,
)
def decode(self, token_ids: Union[List[int], Any], skip_special_tokens: bool = True) -> str:
# Avoid importing torch at module import time; handle torch tensors via duck-typing.
if hasattr(token_ids, "detach") and hasattr(token_ids, "cpu") and hasattr(token_ids, "tolist"):
token_ids = token_ids.detach().cpu().tolist()
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
def get_token_info(self, token_id: int) -> Dict[str, Any]:
tok = self.tokenizer.convert_ids_to_tokens(int(token_id))
specials = set(self.tokenizer.all_special_ids)
return {
"token_id": int(token_id),
"token": tok,
"is_special": int(token_id) in specials,
}
def create_tokenizer(model_name: str = "DeepPavlov/rubert-base-cased", max_length: int = 128) -> RussianTextTokenizer:
return RussianTextTokenizer(model_name=model_name, max_length=max_length)
def tokenize_text_pair(
*,
title: str,
snippet: Optional[str],
tokenizer: RussianTextTokenizer,
max_title_len: int = 128,
max_snippet_len: int = 256,
) -> Dict[str, Any]:
"""Tokenize (title, snippet) as two independent sequences (not a single pair encoding)."""
title_enc = tokenizer.encode(title or "", max_length=max_title_len, return_tensors="pt")
out: Dict[str, Any] = {
"title_input_ids": title_enc["input_ids"].squeeze(0),
"title_attention_mask": title_enc["attention_mask"].squeeze(0),
}
if snippet is not None:
snip_enc = tokenizer.encode(snippet or "", max_length=max_snippet_len, return_tensors="pt")
out.update(
{
"snippet_input_ids": snip_enc["input_ids"].squeeze(0),
"snippet_attention_mask": snip_enc["attention_mask"].squeeze(0),
}
)
return out
|