omarmalik347 commited on
Commit
0bcab7d
·
verified ·
1 Parent(s): df61fc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -11
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import torch
3
  import streamlit as st
4
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
5
  from huggingface_hub import login
6
 
7
  # Access Hugging Face API token from environment variables
@@ -28,10 +28,8 @@ try:
28
  )
29
 
30
  # Define a chat template if it's missing
31
- if not tokenizer.chat_template:
32
  tokenizer.chat_template = """<s>[INST] {prompt} [/INST]"""
33
-
34
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
35
  except Exception as e:
36
  st.error(f"Error loading model: {e}")
37
  st.stop()
@@ -40,17 +38,54 @@ except Exception as e:
40
  st.title("Pirate Chatbot")
41
  st.write("Ask me anything, and I'll respond in pirate speak!")
42
 
43
- # User input
44
- user_input = st.text_input("Your question:", "")
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- if user_input:
47
- # Format input using chat template
 
 
 
 
48
  prompt = tokenizer.chat_template.format(prompt=f"You are a pirate chatbot! Answer in pirate speak!\nUser: {user_input}")
49
- st.write(f"Formatted Prompt: {prompt}") # Debugging: Print the formatted prompt
50
 
51
  try:
52
- outputs = pipe(prompt, max_new_tokens=256, do_sample=True)
53
- st.write(outputs[0]["generated_text"])
 
 
 
 
 
54
  except Exception as e:
55
  st.error(f"Error generating response: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
 
 
 
 
1
  import os
2
  import torch
3
  import streamlit as st
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from huggingface_hub import login
6
 
7
  # Access Hugging Face API token from environment variables
 
28
  )
29
 
30
  # Define a chat template if it's missing
31
+ if not hasattr(tokenizer, "chat_template"):
32
  tokenizer.chat_template = """<s>[INST] {prompt} [/INST]"""
 
 
33
  except Exception as e:
34
  st.error(f"Error loading model: {e}")
35
  st.stop()
 
38
  st.title("Pirate Chatbot")
39
  st.write("Ask me anything, and I'll respond in pirate speak!")
40
 
41
+ # Sidebar settings (removed API key input)
42
+ with st.sidebar:
43
+ st.title('Pirate Chatbot')
44
+ st.write('This chatbot uses the Llama 2 model for chat. You can interact with it directly.')
45
+
46
+ # Store conversation messages
47
+ if "messages" not in st.session_state:
48
+ st.session_state.messages = [{"role": "assistant", "content": "Ahoy, matey! How can I assist ye?"}]
49
+
50
+ # Display conversation
51
+ for message in st.session_state.messages:
52
+ with st.chat_message(message["role"]):
53
+ st.write(message["content"])
54
 
55
+ def clear_chat_history():
56
+ st.session_state.messages = [{"role": "assistant", "content": "Ahoy, matey! How can I assist ye?"}]
57
+ st.sidebar.button('Clear Chat History', on_click=clear_chat_history)
58
+
59
+ # Generate response function
60
+ def generate_pirate_response(user_input):
61
  prompt = tokenizer.chat_template.format(prompt=f"You are a pirate chatbot! Answer in pirate speak!\nUser: {user_input}")
62
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
63
 
64
  try:
65
+ # Generate response from the model
66
+ with torch.no_grad():
67
+ outputs = model.generate(inputs["input_ids"], max_new_tokens=256, do_sample=True)
68
+
69
+ # Decode the generated response
70
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
71
+ return response
72
  except Exception as e:
73
  st.error(f"Error generating response: {e}")
74
+ return "Oops! Something went wrong."
75
+
76
+ # Handle user input and response generation
77
+ if user_input := st.text_input("Your question:"):
78
+ st.session_state.messages.append({"role": "user", "content": user_input})
79
+ with st.chat_message("user"):
80
+ st.write(user_input)
81
+
82
+ # Generate and display assistant's response
83
+ if st.session_state.messages[-1]["role"] != "assistant":
84
+ with st.chat_message("assistant"):
85
+ with st.spinner("Thinking..."):
86
+ response = generate_pirate_response(user_input)
87
+ st.write(response)
88
 
89
+ # Store the assistant's response
90
+ message = {"role": "assistant", "content": response}
91
+ st.session_state.messages.append(message)