File size: 560 Bytes
e4fe207
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
from transformers import AutoTokenizer, pipeline
import torch
import streamlit as st

@st.cache_data(show_spinner="Loading models.. please wait")
def load():
    model = "meta-llama/Llama-2-13b-chat-hf"
    tokenizer = AutoTokenizer.from_pretrained(model)
    pipeline_generator = pipeline(
        "text-generation",
        model=model,
        torch_dtype=torch.float16,
        device_map="auto",
    )

    st.session_state.model = model
    st.session_state.tokenizer = tokenizer
    st.session_state.pipeline_generator = pipeline_generator
    
# load()