|
|
""" |
|
|
Tokenizer training script - trains BPE tokenizer on SMILES data and uploads to HF Hub. |
|
|
""" |
|
|
import json |
|
|
from tokenizers import Tokenizer, models, trainers, pre_tokenizers |
|
|
from datasets import DatasetDict |
|
|
from config import CACHE_DIR, TOKENIZER_NAME, SPECIAL_TOKENS, VOCAB_SIZE, MIN_FREQUENCY |
|
|
from huggingface_hub import HfApi, create_repo |
|
|
import os |
|
|
|
|
|
|
|
|
def iter_text(ds: DatasetDict): |
|
|
"""Iterator over source and target text from dataset.""" |
|
|
for split in ds: |
|
|
for row in ds[split]: |
|
|
yield row["source"] |
|
|
yield row["target"] |
|
|
|
|
|
|
|
|
def train_and_upload_tokenizer(): |
|
|
"""Train BPE tokenizer and upload to Hugging Face Hub.""" |
|
|
print("=" * 60) |
|
|
print("Training Tokenizer") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print(f"Loading forward dataset from {CACHE_DIR / 'forward'}...") |
|
|
forward = DatasetDict.load_from_disk(str(CACHE_DIR / "forward")) |
|
|
|
|
|
|
|
|
print("Creating BPE tokenizer...") |
|
|
tokenizer = Tokenizer(models.BPE(unk_token=SPECIAL_TOKENS["unk_token"])) |
|
|
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel() |
|
|
|
|
|
trainer = trainers.BpeTrainer( |
|
|
vocab_size=VOCAB_SIZE, |
|
|
min_frequency=MIN_FREQUENCY, |
|
|
special_tokens=list(SPECIAL_TOKENS.values()), |
|
|
) |
|
|
|
|
|
|
|
|
print("Training tokenizer on dataset...") |
|
|
tokenizer.train_from_iterator(iter_text(forward), trainer=trainer, length=len(forward["train"]) + len(forward.get("validation", []))) |
|
|
|
|
|
|
|
|
local_path = "tokenizer.json" |
|
|
tokenizer.save(local_path) |
|
|
print(f"Saved tokenizer to {local_path}") |
|
|
|
|
|
|
|
|
print(f"Creating/accessing tokenizer repo: {TOKENIZER_NAME}") |
|
|
try: |
|
|
create_repo( |
|
|
TOKENIZER_NAME, |
|
|
repo_type="model", |
|
|
exist_ok=True, |
|
|
private=False, |
|
|
token=os.environ.get("HF_TOKEN") |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Note: {e}") |
|
|
|
|
|
|
|
|
print(f"Uploading tokenizer to {TOKENIZER_NAME}...") |
|
|
api = HfApi(token=os.environ.get("HF_TOKEN")) |
|
|
|
|
|
try: |
|
|
api.upload_file( |
|
|
path_or_fileobj=local_path, |
|
|
path_in_repo="tokenizer.json", |
|
|
repo_id=TOKENIZER_NAME, |
|
|
repo_type="model", |
|
|
) |
|
|
print("Tokenizer JSON uploaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"Error uploading tokenizer.json: {e}") |
|
|
|
|
|
|
|
|
config = { |
|
|
"tokenizer_class": "ByteLevelBPETokenizer", |
|
|
"unk_token": SPECIAL_TOKENS["unk_token"], |
|
|
"bos_token": SPECIAL_TOKENS["bos_token"], |
|
|
"eos_token": SPECIAL_TOKENS["eos_token"], |
|
|
"pad_token": SPECIAL_TOKENS["pad_token"], |
|
|
} |
|
|
|
|
|
config_path = "tokenizer_config.json" |
|
|
with open(config_path, "w") as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
try: |
|
|
api.upload_file( |
|
|
path_or_fileobj=config_path, |
|
|
path_in_repo="tokenizer_config.json", |
|
|
repo_id=TOKENIZER_NAME, |
|
|
repo_type="model", |
|
|
) |
|
|
print("Tokenizer config uploaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"Error uploading tokenizer_config.json: {e}") |
|
|
|
|
|
|
|
|
special_tokens_map = { |
|
|
"unk_token": {"content": SPECIAL_TOKENS["unk_token"], "lstrip": False, "normalized": True, "rstrip": False}, |
|
|
"bos_token": {"content": SPECIAL_TOKENS["bos_token"], "lstrip": False, "normalized": True, "rstrip": False}, |
|
|
"eos_token": {"content": SPECIAL_TOKENS["eos_token"], "lstrip": False, "normalized": True, "rstrip": False}, |
|
|
"pad_token": {"content": SPECIAL_TOKENS["pad_token"], "lstrip": False, "normalized": True, "rstrip": False}, |
|
|
} |
|
|
|
|
|
special_tokens_path = "special_tokens_map.json" |
|
|
with open(special_tokens_path, "w") as f: |
|
|
json.dump(special_tokens_map, f, indent=2) |
|
|
|
|
|
try: |
|
|
api.upload_file( |
|
|
path_or_fileobj=special_tokens_path, |
|
|
path_in_repo="special_tokens_map.json", |
|
|
repo_id=TOKENIZER_NAME, |
|
|
repo_type="model", |
|
|
) |
|
|
print("Special tokens map uploaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"Error uploading special_tokens_map.json: {e}") |
|
|
|
|
|
print(f"\nTokenizer training complete!") |
|
|
print(f"Access your tokenizer at: https://huggingface.co/{TOKENIZER_NAME}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
train_and_upload_tokenizer() |
|
|
|