studfaceval / app5_selectbox /load_llama2.py
MENG21's picture
Upload 68 files
e4fe207 verified
raw
history blame contribute delete
560 Bytes
from transformers import AutoTokenizer, pipeline
import torch
import streamlit as st
@st.cache_data(show_spinner="Loading models.. please wait")
def load():
model = "meta-llama/Llama-2-13b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model)
pipeline_generator = pipeline(
"text-generation",
model=model,
torch_dtype=torch.float16,
device_map="auto",
)
st.session_state.model = model
st.session_state.tokenizer = tokenizer
st.session_state.pipeline_generator = pipeline_generator
# load()