File size: 505 Bytes
6c4b07c
 
74999df
6c4b07c
 
 
 
 
 
 
 
 
 
 
 
 
74999df
6c4b07c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import os

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def load_model_and_tok():
    AUTH_TOKEN = os.environ.get("HF_TOKEN", False)
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-3.2-3B-Instruct",
        token=AUTH_TOKEN,
    )
    tok = AutoTokenizer.from_pretrained(
        "meta-llama/Llama-3.2-3B-Instruct",
        token=AUTH_TOKEN,
    )
    model = model.to("cuda" if torch.cuda.is_available() else "cpu").eval()
    return model, tok