alverciito
upload safetensors and refactor research files
dbd79bd
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import tokenizers
import sys
import subprocess
import logging
import spacy
import numpy as np
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.normalizers import NFKC
from transformers import PreTrainedTokenizerFast
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
class SegmentationTokenizer:
def __init__(
self,
vocab_size=32_768,
min_frequency=2,
max_length=1024
):
self.max_length = max_length
# Raw tokenizer (training)
self.raw_tokenizer = tokenizers.Tokenizer(
BPE(unk_token="[UNK]")
)
self.raw_tokenizer.normalizer = NFKC()
self.raw_tokenizer.pre_tokenizer = Whitespace()
self.trainer = BpeTrainer(
vocab_size=vocab_size,
min_frequency=min_frequency,
special_tokens=["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]
)
self._hf_tokenizer = None # created after training
# ---------- TRAINING ----------
def build_iterator(self, dataset, batch_size=1024):
batch = []
for item in dataset:
batch.append("\n".join(item["text"]).replace("\n\n", "\n"))
if len(batch) == batch_size:
yield batch
batch = []
if batch:
yield batch
def train_from_iterator(self, iterator):
self.raw_tokenizer.train_from_iterator(
iterator, trainer=self.trainer
)
# ---------- IO ----------
def save(self, path):
self.raw_tokenizer.save(path)
def load(self, tokenizer_path):
self._hf_tokenizer = PreTrainedTokenizerFast(
tokenizer_file=tokenizer_path,
unk_token="[UNK]",
pad_token="[PAD]",
cls_token="[CLS]",
sep_token="[SEP]",
mask_token="[MASK]"
)
return self
# ---------- TOKENIZATION ----------
def compute_unk_rate(self, corpus):
unk_id = self._hf_tokenizer.convert_tokens_to_ids("[UNK]")
total_tokens = 0
unk_tokens = 0
for text in corpus:
enc = self._hf_tokenizer(
text,
add_special_tokens=False
)["input_ids"]
total_tokens += len(enc)
unk_tokens += sum(1 for t in enc if t == unk_id)
return unk_tokens / total_tokens if total_tokens > 0 else 0.0
def __call__(
self,
text,
return_tensors="pt",
padding=True,
truncation=True
):
"""
text: str or List[str]
returns: dict with input_ids and attention_mask (torch.long)
"""
if self._hf_tokenizer is None:
raise RuntimeError("Tokenizer not loaded. Call .load() first.")
enc = self._hf_tokenizer(
text,
padding="max_length" if padding else False,
truncation=truncation,
max_length=self.max_length,
return_tensors=return_tensors
)
return {
"input_ids": enc["input_ids"], # torch.LongTensor
"attention_mask": enc["attention_mask"] # torch.LongTensor
}
@property
def vocab_size(self):
if self._hf_tokenizer is None:
raise RuntimeError("Tokenizer not loaded.")
return self._hf_tokenizer.vocab_size
def __repr__(self):
return f"<SegmentationTokenizer vocab_size={self.trainer.vocab_size}>"
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# SENTENCE SEG #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
class SentenceSegmenter:
def __init__(
self,
max_sentences: int,
spacy_model: str = "es_core_news_sm",
logger: logging.Logger | None = None
):
self.max_sentences = max_sentences
self.logger = self._get_logger(logger)
self.nlp = self.__build_model__(spacy_model, logger=self.logger)
@staticmethod
def __build_model__(sentence_tokenizer_model: str, logger: logging.Logger) -> spacy.language.Language:
"""
Download the pre-trained sentence tokenizer model.
:param sentence_tokenizer_model: The sentence tokenizer model to download.
:return: The spacy language model.
"""
try:
spacy_model = spacy.load(sentence_tokenizer_model)
except OSError:
result = subprocess.run(
[sys.executable, "-m", "spacy", "download", sentence_tokenizer_model],
capture_output=True,
text=True
)
if result.returncode != 0:
logger.error(f'[BEAST-Tokenizer]: Loading {sentence_tokenizer_model} failed.')
raise RuntimeError(f"[BEAST-Tokenizer]: Error while downloading '{sentence_tokenizer_model}'")
spacy_model = spacy.load(sentence_tokenizer_model)
logger.info('[BEAST-Tokenizer]: Successfully downloaded the pre-trained sentence tokenizer model.')
if 'parser' not in spacy_model.pipe_names:
logger.error(f'[BEAST-Tokenizer]: The SpaCy model needs a parser installed.')
raise RuntimeError(f'[BEAST-Tokenizer]: The SpaCy model needs a parser installed.')
else:
spacy_model.add_pipe("newline_segmenter_keep_exact", before="parser")
return spacy_model
@staticmethod
def _get_logger(logger):
if logger is None:
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
return logger
def __call__(self, texts: list[str]) -> dict:
sentences = list()
sentence_candidates = list()
sentence_boundaries = list()
sentence_masking = list()
for article in texts:
doc = self.nlp(article)
for idx, sent in enumerate(doc.sents):
if idx == 0:
# Article opener
sentence_candidates.append(1)
sentence_boundaries.append(1)
elif sent.text.endswith("\n"):
# Paragraph break candidate
sentence_candidates.append(1)
sentence_boundaries.append(0)
else:
sentence_candidates.append(0)
sentence_boundaries.append(0)
sentences.append(sent.text.replace('\n', '').strip())
sentence_masking.append(1)
if len(sentences) >= self.max_sentences:
self.logger.warning(f"Maximum number of sentences reached: {self.max_sentences}")
break
if len(sentences) >= self.max_sentences:
break
# Pad with zeros:
while len(sentences) < self.max_sentences:
sentences.append("")
sentence_candidates.append(0)
sentence_boundaries.append(0)
sentence_masking.append(0)
return {
"sentences": sentences,
"sentence_candidates": np.array(sentence_candidates, dtype=np.int8),
"sentence_boundaries": np.array(sentence_boundaries, dtype=np.int8),
"sentence_mask": np.array(sentence_masking, dtype=np.int8)
}
@spacy.Language.component("newline_segmenter_keep_exact")
def newline_segmenter_keep_exact(doc):
for token in doc[:-1]:
if token.text == "\n":
doc[token.i + 1].is_sent_start = True
return doc
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #