makekali commited on
Commit
adea141
Β·
verified Β·
1 Parent(s): 096a420

Create model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +18 -0
model_loader.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+
4
+ MODELS = {
5
+ "Tiny-GPT2": "sshleifer/tiny-gpt2"
6
+ }
7
+
8
+ def load_models(model_map):
9
+ all_models = {}
10
+ for name, hf_id in model_map.items():
11
+ tokenizer = AutoTokenizer.from_pretrained(hf_id)
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ hf_id,
14
+ torch_dtype=torch.float32,
15
+ )
16
+ model.eval()
17
+ all_models[name] = {"tokenizer": tokenizer, "model": model}
18
+ return all_models