|
|
import streamlit as st |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
|
|
import torch |
|
|
|
|
|
|
|
|
model_name = "burman-ai/Meta-Llama-3.1-8B" |
|
|
max_seq_length = 512 |
|
|
dtype = torch.float16 |
|
|
load_in_4bit = False |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model_and_tokenizer(model_name, dtype, load_in_4bit): |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
torch_dtype=dtype, |
|
|
load_in_4bit=load_in_4bit, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
model.eval() |
|
|
return model, tokenizer |
|
|
|
|
|
model, tokenizer = load_model_and_tokenizer(model_name, dtype, load_in_4bit) |
|
|
|
|
|
|
|
|
alpaca_prompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request. |
|
|
|
|
|
### Instruction: |
|
|
{instruction} |
|
|
|
|
|
### Input: |
|
|
{input} |
|
|
|
|
|
### Response: |
|
|
{output}""" |
|
|
|
|
|
|
|
|
st.title("Chatbot UI") |
|
|
|
|
|
if "messages" not in st.session_state: |
|
|
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you today?"}] |
|
|
|
|
|
for message in st.session_state["messages"]: |
|
|
with st.chat_message(message["role"]): |
|
|
st.markdown(message["content"]) |
|
|
|
|
|
if prompt := st.chat_input("Ask me anything"): |
|
|
st.session_state["messages"].append({"role": "user", "content": prompt}) |
|
|
with st.chat_message("user"): |
|
|
st.markdown(prompt) |
|
|
|
|
|
with st.chat_message("assistant"): |
|
|
message_placeholder = st.empty() |
|
|
full_response = "" |
|
|
|
|
|
instruction = prompt |
|
|
input_text = "" |
|
|
formatted_prompt = alpaca_prompt.format(instruction=instruction, input=input_text, output="") |
|
|
|
|
|
inputs = tokenizer( |
|
|
[formatted_prompt], |
|
|
return_tensors="pt", |
|
|
max_length=max_seq_length, |
|
|
truncation=True |
|
|
).to(model.device) |
|
|
|
|
|
text_streamer = TextStreamer(tokenizer, skip_prompt=True) |
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model.generate( |
|
|
**inputs, |
|
|
streamer=text_streamer, |
|
|
max_new_tokens=256, |
|
|
do_sample=True, |
|
|
top_p=0.8, |
|
|
top_k=50 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
message_placeholder.markdown(st.session_state["messages"][-1]["content"]) |