PLM4 / PhysicsLM4.2-8B /tokenization_llama_canon.py
quockhangdev's picture
PhysicsLM4.2-8B
f5c1628 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Zeyuan's edit note: this is nothing but a simple wrapper of either Llama2 or Llama3 tokenizer, depending on params.json
from transformers import PreTrainedTokenizerFast
class LlamaCanonTokenizer(PreTrainedTokenizerFast):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, variant="default", **kwargs):
from huggingface_hub import hf_hub_download
import os, json
if os.path.isfile(os.path.join(pretrained_model_name_or_path, variant, "params.json")):
config_path = os.path.join(pretrained_model_name_or_path, variant, "params.json")
else:
config_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
filename=f"{variant}/params.json",
)
print("Please ignore the tokenizer name mismatch warning; this LlamaCanonTokenizer is simply a wrapper of either Llama2 or Llama3 tokenizer, depending on params.json")
with open(config_path, "r") as f:
dd = json.load(f)
if dd['data']['tokenizer']['name']=='sp':
print("Using Llama2 tokenizer")
#return super().from_pretrained("meta-llama/Llama-2-7b-hf", *args, **kwargs)
return super().from_pretrained("NousResearch/Llama-2-7b-hf")
elif dd['data']['tokenizer']['name']=='tiktoken':
print("Using Llama3 tokenizer")
#return super().from_pretrained("meta-llama/Meta-Llama-3-8B", *args, **kwargs)
#return super().from_pretrained("Xenova/llama3-tokenizer")
return super().from_pretrained("nvidia/Llama-3.1-Nemotron-70B-Instruct-HF")
else:
raise ValueError(f"Unsupported tokenizer name: {dd['data']['tokenizer']['name']}")