| import os | |
| import urllib.request | |
| from pathlib import Path | |
| import tokenizers | |
| import tokenizers.processors | |
| import log | |
| MODELS_DIR = Path(p if (p := os.getenv("PROJECT")) else Path.cwd()) / "models" | |
| def fetch_file(url: str, filepath: str) -> None: | |
| with urllib.request.urlopen(url) as response: | |
| with open(filepath, "wb") as out_file: | |
| while chunk := response.read(8192): | |
| out_file.write(chunk) | |
| def t5() -> tuple[tokenizers.Tokenizer, int, int]: | |
| tokenizer_path = MODELS_DIR / "google-t5--t5-small" | |
| tokenizer_file = tokenizer_path / "tokenizer.json" | |
| if not tokenizer_file.exists(): | |
| tokenizer_path.mkdir(parents=True, exist_ok=True) | |
| log.info(f"Tokenizer: downloading to {tokenizer_file}...") | |
| fetch_file("https://huggingface.co/google-t5/t5-small/resolve/main/tokenizer.json", str(tokenizer_file)) | |
| log.info("Tokenizer: download complete.") | |
| tok = tokenizers.Tokenizer.from_file(str(tokenizer_file)) | |
| tok.add_special_tokens(["<s>", "</s>", "<unk>"]) | |
| bos_id = tok.token_to_id("<s>") | |
| eos_id = tok.token_to_id("</s>") | |
| tok.post_processor = tokenizers.processors.TemplateProcessing( | |
| single="<s> $A", pair=None, special_tokens=[("<s>", bos_id)] | |
| ) | |
| return tok, bos_id, eos_id | |