File size: 3,073 Bytes
3d5862f 8146abe 3d5862f 8146abe 3d5862f 8146abe 3d5862f 8146abe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import torch
# Model Configuration
model_name = "burman-ai/Meta-Llama-3.1-8B"
max_seq_length = 512
dtype = torch.float16
load_in_4bit = False
# Initialize model and tokenizer (run only once using st.cache_resource)
@st.cache_resource
def load_model_and_tokenizer(model_name, dtype, load_in_4bit):
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
load_in_4bit=load_in_4bit,
device_map="auto",
trust_remote_code=True,
)
model.eval()
return model, tokenizer
model, tokenizer = load_model_and_tokenizer(model_name, dtype, load_in_4bit)
# Alpaca Prompt Template
alpaca_prompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:
{output}"""
# Streamlit UI
st.title("Chatbot UI")
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you today?"}]
for message in st.session_state["messages"]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Ask me anything"):
st.session_state["messages"].append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
instruction = prompt
input_text = ""
formatted_prompt = alpaca_prompt.format(instruction=instruction, input=input_text, output="")
inputs = tokenizer(
[formatted_prompt],
return_tensors="pt",
max_length=max_seq_length,
truncation=True
).to(model.device)
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
with torch.no_grad():
output = model.generate(
**inputs,
streamer=text_streamer,
max_new_tokens=256, # Adjust as needed
do_sample=True,
top_p=0.8,
top_k=50
)
# The TextStreamer will print the output directly.
# We need to capture it manually if we want to store the full response.
# A simple way is to let the streamer print and then just use the last printed part.
# However, for a robust solution, you might need to subclass TextStreamer.
# For this basic example, we'll rely on the streaming output.
# If you need the full response as a single string reliably,
# consider not using TextStreamer and handling the generation differently.
# Update the message placeholder after generation (the streamer already printed)
message_placeholder.markdown(st.session_state["messages"][-1]["content"]) # Use the last assistant message |