MyGPT2 / app.py
Burman-AI's picture
Update app.py
8146abe verified
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import torch
# Model Configuration
model_name = "burman-ai/Meta-Llama-3.1-8B"
max_seq_length = 512
dtype = torch.float16
load_in_4bit = False
# Initialize model and tokenizer (run only once using st.cache_resource)
@st.cache_resource
def load_model_and_tokenizer(model_name, dtype, load_in_4bit):
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
load_in_4bit=load_in_4bit,
device_map="auto",
trust_remote_code=True,
)
model.eval()
return model, tokenizer
model, tokenizer = load_model_and_tokenizer(model_name, dtype, load_in_4bit)
# Alpaca Prompt Template
alpaca_prompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Input:
{input}
### Response:
{output}"""
# Streamlit UI
st.title("Chatbot UI")
if "messages" not in st.session_state:
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you today?"}]
for message in st.session_state["messages"]:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Ask me anything"):
st.session_state["messages"].append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
instruction = prompt
input_text = ""
formatted_prompt = alpaca_prompt.format(instruction=instruction, input=input_text, output="")
inputs = tokenizer(
[formatted_prompt],
return_tensors="pt",
max_length=max_seq_length,
truncation=True
).to(model.device)
text_streamer = TextStreamer(tokenizer, skip_prompt=True)
with torch.no_grad():
output = model.generate(
**inputs,
streamer=text_streamer,
max_new_tokens=256, # Adjust as needed
do_sample=True,
top_p=0.8,
top_k=50
)
# The TextStreamer will print the output directly.
# We need to capture it manually if we want to store the full response.
# A simple way is to let the streamer print and then just use the last printed part.
# However, for a robust solution, you might need to subclass TextStreamer.
# For this basic example, we'll rely on the streaming output.
# If you need the full response as a single string reliably,
# consider not using TextStreamer and handling the generation differently.
# Update the message placeholder after generation (the streamer already printed)
message_placeholder.markdown(st.session_state["messages"][-1]["content"]) # Use the last assistant message