tanusrich commited on
Commit
cfef0de
·
verified ·
1 Parent(s): 315a1b3

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +124 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import streamlit as st
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+
5
+ # Set up the Streamlit app
6
+ st.set_page_config(page_title="Therapy Chatbot", layout="wide")
7
+
8
+ # Custom CSS to style the chat interface
9
+ st.markdown("""
10
+ <style>
11
+ .stTextInput > div > div > input {
12
+ border-radius: 20px;
13
+ }
14
+ .stButton > button {
15
+ border-radius: 20px;
16
+ float: right;
17
+ }
18
+ .chat-message {
19
+ padding: 1.5rem; border-radius: 0.5rem; margin-bottom: 1rem; display: flex
20
+ }
21
+ .chat-message.user {
22
+ background-color: #2b313e
23
+ }
24
+ .chat-message.bot {
25
+ background-color: #475063
26
+ }
27
+ .chat-message .avatar {
28
+ width: 20%;
29
+ }
30
+ .chat-message .avatar img {
31
+ max-width: 78px;
32
+ max-height: 78px;
33
+ border-radius: 50%;
34
+ object-fit: cover;
35
+ }
36
+ .chat-message .message {
37
+ width: 80%;
38
+ padding: 0 1.5rem;
39
+ color: #fff;
40
+ }
41
+ </style>
42
+ """, unsafe_allow_html=True)
43
+
44
+ # Load the model (unchanged)
45
+ @st.cache_resource
46
+ def load_model():
47
+ model = AutoModelForCausalLM.from_pretrained("tanusrich/Mental_Health_Chatbot", torch_dtype=torch.float16)
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ model.to(device)
50
+ tokenizer = AutoTokenizer.from_pretrained("tanusrich/Mental_Health_Chatbot")
51
+ return model, tokenizer, device
52
+
53
+ model, tokenizer, device = load_model()
54
+
55
+ # Functions for prompt formatting and output cleaning (unchanged)
56
+ def format_prompt(prompt, chat_history):
57
+ history = "".join([f"User: {entry['user']}\nAI: {entry['ai']}\n" for entry in chat_history])
58
+ return f"[INST] <<SYS>> You are a virtual AI therapy assistant. Your role is to provide thoughtful and supportive responses. Always ensure that you complete your last sentence with a period.<</SYS>> {history}User: {prompt.strip()} [/INST]"
59
+
60
+ def clean_output(output_text, input_text):
61
+ # Ensure special tokens are removed, but not meaningful text
62
+ output_text = output_text.replace(input_text, "")
63
+ output_text = output_text.replace("[INST]", "").replace("[/INST]", "").replace("(period)","").replace("(Period)","")
64
+ output_text = output_text.replace("1)", "\n\n1)").replace("2)", "\n\n2)").replace("3)", "\n\n3)")\
65
+ .replace("4)", "\n\n4)").replace("5)", "\n\n5)").replace("6)", "\n\n6)").replace("7)", "\n\n7)").replace("8)", "\n\n8)").replace("9)", "\n\n9)")
66
+ return output_text.strip()
67
+
68
+ # Initialize chat history
69
+ if "chat_history" not in st.session_state:
70
+ st.session_state.chat_history = []
71
+
72
+ # New Chat Button: Clears the chat history to start a new session
73
+ if st.button("New Chat"):
74
+ st.session_state.chat_history = []
75
+
76
+ # Chat interface
77
+ st.markdown("<h1 style='text-align: center;'>Therapy Chatbot 🤗</h1>", unsafe_allow_html=True)
78
+
79
+ # Display chat messages
80
+ for message in st.session_state.chat_history:
81
+ with st.chat_message("user"):
82
+ st.write(message["user"])
83
+ with st.chat_message("assistant"):
84
+ st.write(message["ai"])
85
+
86
+ # User input
87
+ user_input = st.chat_input("Type your message here...")
88
+
89
+ if user_input:
90
+ # Add user message to chat history
91
+ st.session_state.chat_history.append({"user": user_input, "ai": ""})
92
+
93
+ # Display user message
94
+ with st.chat_message("user"):
95
+ st.write(user_input)
96
+
97
+ # Generate bot response
98
+ formatted_prompt = format_prompt(user_input, st.session_state.chat_history[:-1])
99
+ inputs = tokenizer(formatted_prompt, return_tensors="pt").to(device)
100
+
101
+ with torch.no_grad():
102
+ output = model.generate(
103
+ **inputs,
104
+ temperature=0.6,
105
+ max_new_tokens=500,
106
+ top_k=50,
107
+ top_p=0.9,
108
+ repetition_penalty=1.2,
109
+ no_repeat_ngram_size=3,
110
+ pad_token_id=tokenizer.eos_token_id
111
+ )
112
+
113
+ response = tokenizer.decode(output[0], skip_special_tokens=True)
114
+ clean_response = clean_output(response, formatted_prompt)
115
+
116
+ # Update the last message in chat history with bot response
117
+ st.session_state.chat_history[-1]["ai"] = clean_response
118
+
119
+ # Display bot response
120
+ with st.chat_message("assistant"):
121
+ st.write(clean_response)
122
+
123
+ # Clean up memory
124
+ torch.cuda.empty_cache()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ torch
3
+ transformers