import streamlit as st from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer import torch # Model Configuration model_name = "burman-ai/Meta-Llama-3.1-8B" max_seq_length = 512 dtype = torch.float16 load_in_4bit = False # Initialize model and tokenizer (run only once using st.cache_resource) @st.cache_resource def load_model_and_tokenizer(model_name, dtype, load_in_4bit): tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=dtype, load_in_4bit=load_in_4bit, device_map="auto", trust_remote_code=True, ) model.eval() return model, tokenizer model, tokenizer = load_model_and_tokenizer(model_name, dtype, load_in_4bit) # Alpaca Prompt Template alpaca_prompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: {instruction} ### Input: {input} ### Response: {output}""" # Streamlit UI st.title("Chatbot UI") if "messages" not in st.session_state: st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you today?"}] for message in st.session_state["messages"]: with st.chat_message(message["role"]): st.markdown(message["content"]) if prompt := st.chat_input("Ask me anything"): st.session_state["messages"].append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): message_placeholder = st.empty() full_response = "" instruction = prompt input_text = "" formatted_prompt = alpaca_prompt.format(instruction=instruction, input=input_text, output="") inputs = tokenizer( [formatted_prompt], return_tensors="pt", max_length=max_seq_length, truncation=True ).to(model.device) text_streamer = TextStreamer(tokenizer, skip_prompt=True) with torch.no_grad(): output = model.generate( **inputs, streamer=text_streamer, max_new_tokens=256, # Adjust as needed do_sample=True, top_p=0.8, top_k=50 ) # The TextStreamer will print the output directly. # We need to capture it manually if we want to store the full response. # A simple way is to let the streamer print and then just use the last printed part. # However, for a robust solution, you might need to subclass TextStreamer. # For this basic example, we'll rely on the streaming output. # If you need the full response as a single string reliably, # consider not using TextStreamer and handling the generation differently. # Update the message placeholder after generation (the streamer already printed) message_placeholder.markdown(st.session_state["messages"][-1]["content"]) # Use the last assistant message