Springboardmen commited on
Commit
f179e9f
·
verified ·
1 Parent(s): bdc39b1

Create model_loader.py

Browse files
Files changed (1) hide show
  1. model_loader.py +13 -0
model_loader.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
2
+ import torch
3
+
4
+ MODEL_NAME = "mistralai/Mistral-7B-Instruct-v0.2"
5
+
6
+ def load_model():
7
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ MODEL_NAME,
10
+ torch_dtype=torch.float16,
11
+ device_map="auto"
12
+ )
13
+ return tokenizer, model