omarmalik347 commited on
Commit
48a1a10
·
verified ·
1 Parent(s): 0bcab7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -59
app.py CHANGED
@@ -1,70 +1,74 @@
 
1
  import os
2
  import torch
3
- import streamlit as st
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from huggingface_hub import login
6
 
7
- # Access Hugging Face API token from environment variables
8
- api_key = os.getenv("llama3")
9
-
10
- if not api_key:
11
- st.error("Hugging Face API token is missing!")
12
- st.stop() # Stop execution if no API key is found
13
-
14
- # Authenticate with Hugging Face Hub
15
- try:
16
- login(api_key)
17
- except Exception as e:
18
- st.error(f"Authentication failed: {e}")
19
- st.stop()
20
 
21
- # Load the model and tokenizer
22
- model_id = "meta-llama/Llama-3.2-1B"
 
 
 
 
 
23
 
24
- try:
25
- tokenizer = AutoTokenizer.from_pretrained(model_id)
26
- model = AutoModelForCausalLM.from_pretrained(
27
- model_id, torch_dtype=torch.bfloat16, device_map="auto"
28
- )
29
 
30
- # Define a chat template if it's missing
31
- if not hasattr(tokenizer, "chat_template"):
32
- tokenizer.chat_template = """<s>[INST] {prompt} [/INST]"""
33
- except Exception as e:
34
- st.error(f"Error loading model: {e}")
35
- st.stop()
 
36
 
37
- # Streamlit interface
38
- st.title("Pirate Chatbot")
39
- st.write("Ask me anything, and I'll respond in pirate speak!")
 
 
 
 
 
40
 
41
- # Sidebar settings (removed API key input)
42
- with st.sidebar:
43
- st.title('Pirate Chatbot')
44
- st.write('This chatbot uses the Llama 2 model for chat. You can interact with it directly.')
45
 
46
- # Store conversation messages
47
- if "messages" not in st.session_state:
48
- st.session_state.messages = [{"role": "assistant", "content": "Ahoy, matey! How can I assist ye?"}]
49
 
50
- # Display conversation
51
  for message in st.session_state.messages:
52
  with st.chat_message(message["role"]):
53
  st.write(message["content"])
54
 
55
  def clear_chat_history():
56
- st.session_state.messages = [{"role": "assistant", "content": "Ahoy, matey! How can I assist ye?"}]
57
  st.sidebar.button('Clear Chat History', on_click=clear_chat_history)
58
 
59
- # Generate response function
60
- def generate_pirate_response(user_input):
61
- prompt = tokenizer.chat_template.format(prompt=f"You are a pirate chatbot! Answer in pirate speak!\nUser: {user_input}")
62
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
63
-
 
 
 
 
 
 
64
  try:
65
  # Generate response from the model
66
  with torch.no_grad():
67
- outputs = model.generate(inputs["input_ids"], max_new_tokens=256, do_sample=True)
68
 
69
  # Decode the generated response
70
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -73,19 +77,22 @@ def generate_pirate_response(user_input):
73
  st.error(f"Error generating response: {e}")
74
  return "Oops! Something went wrong."
75
 
76
- # Handle user input and response generation
77
- if user_input := st.text_input("Your question:"):
78
- st.session_state.messages.append({"role": "user", "content": user_input})
79
  with st.chat_message("user"):
80
- st.write(user_input)
81
-
82
- # Generate and display assistant's response
83
- if st.session_state.messages[-1]["role"] != "assistant":
84
- with st.chat_message("assistant"):
85
- with st.spinner("Thinking..."):
86
- response = generate_pirate_response(user_input)
87
- st.write(response)
88
 
89
- # Store the assistant's response
90
- message = {"role": "assistant", "content": response}
91
- st.session_state.messages.append(message)
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  import os
3
  import torch
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from huggingface_hub import login
6
 
7
+ # App title
8
+ st.set_page_config(page_title="🦙💬 Llama 2 Chatbot")
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Hugging Face Credentials
11
+ with st.sidebar:
12
+ st.title('🦙💬 Llama 2 Chatbot')
13
+ st.write('This chatbot is created using the open-source Llama model from Meta.')
14
+
15
+ # Use Hugging Face API Key from secrets or environment
16
+ api_key = os.getenv("HUGGINGFACE_API_KEY")
17
 
18
+ if not api_key:
19
+ st.error("Hugging Face API key is missing!")
20
+ st.stop()
 
 
21
 
22
+ # Authenticate with Hugging Face Hub
23
+ try:
24
+ login(api_key)
25
+ st.success('API key successfully authenticated!', icon='✅')
26
+ except Exception as e:
27
+ st.error(f"Authentication failed: {e}")
28
+ st.stop()
29
 
30
+ st.subheader('Models and parameters')
31
+ selected_model = st.sidebar.selectbox('Choose a Llama model', ['Llama-3.2-1B', 'Llama-7B'], key='selected_model')
32
+
33
+ # Model configurations based on selection
34
+ if selected_model == 'Llama-3.2-1B':
35
+ model_id = "meta-llama/Llama-3.2-1B"
36
+ elif selected_model == 'Llama-7B':
37
+ model_id = "meta-llama/Llama-7B"
38
 
39
+ temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=1.0, value=0.1, step=0.01)
40
+ top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
41
+ max_length = st.sidebar.slider('max_length', min_value=20, max_value=80, value=50, step=5)
42
+ st.markdown('📖 Learn how to build this app in this [blog](https://blog.streamlit.io/how-to-build-a-llama-2-chatbot/)!')
43
 
44
+ # Store LLM generated responses
45
+ if "messages" not in st.session_state.keys():
46
+ st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
47
 
48
+ # Display or clear chat messages
49
  for message in st.session_state.messages:
50
  with st.chat_message(message["role"]):
51
  st.write(message["content"])
52
 
53
  def clear_chat_history():
54
+ st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
55
  st.sidebar.button('Clear Chat History', on_click=clear_chat_history)
56
 
57
+ # Load the tokenizer and model
58
+ try:
59
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
60
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
61
+ except Exception as e:
62
+ st.error(f"Error loading model: {e}")
63
+ st.stop()
64
+
65
+ # Function for generating response using Hugging Face model
66
+ def generate_huggingface_response(prompt_input):
67
+ inputs = tokenizer(prompt_input, return_tensors="pt").to(model.device)
68
  try:
69
  # Generate response from the model
70
  with torch.no_grad():
71
+ outputs = model.generate(inputs["input_ids"], max_new_tokens=max_length, temperature=temperature, top_p=top_p, do_sample=True)
72
 
73
  # Decode the generated response
74
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
77
  st.error(f"Error generating response: {e}")
78
  return "Oops! Something went wrong."
79
 
80
+ # User-provided prompt
81
+ if prompt := st.chat_input(disabled=not api_key):
82
+ st.session_state.messages.append({"role": "user", "content": prompt})
83
  with st.chat_message("user"):
84
+ st.write(prompt)
 
 
 
 
 
 
 
85
 
86
+ # Generate a new response if last message is not from assistant
87
+ if st.session_state.messages[-1]["role"] != "assistant":
88
+ with st.chat_message("assistant"):
89
+ with st.spinner("Thinking..."):
90
+ response = generate_huggingface_response(prompt)
91
+ placeholder = st.empty()
92
+ full_response = ''
93
+ for item in response:
94
+ full_response += item
95
+ placeholder.markdown(full_response)
96
+ placeholder.markdown(full_response)
97
+ message = {"role": "assistant", "content": full_response}
98
+ st.session_state.messages.append(message)