File size: 508 Bytes
adea141 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
MODELS = {
"Tiny-GPT2": "sshleifer/tiny-gpt2"
}
def load_models(model_map):
all_models = {}
for name, hf_id in model_map.items():
tokenizer = AutoTokenizer.from_pretrained(hf_id)
model = AutoModelForCausalLM.from_pretrained(
hf_id,
torch_dtype=torch.float32,
)
model.eval()
all_models[name] = {"tokenizer": tokenizer, "model": model}
return all_models
|