Nahiyan14 commited on
Commit
c081489
·
verified ·
1 Parent(s): dacda30

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -0
app.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import json
4
+ from datetime import datetime, timedelta
5
+ from src.helper import download_hugging_face_embeddings
6
+ from langchain_community.vectorstores import Pinecone
7
+ from langchain_openai import OpenAI
8
+ from langchain.chains import create_retrieval_chain
9
+ from langchain.chains.combine_documents import create_stuff_documents_chain
10
+ from langchain_core.prompts import ChatPromptTemplate
11
+ from dotenv import load_dotenv
12
+ from src.prompt import system_prompt
13
+
14
+ # Set up cache directories
15
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/model_cache'
16
+ os.environ['HF_HOME'] = '/tmp/model_cache'
17
+ os.makedirs('/tmp/model_cache', exist_ok=True)
18
+
19
+ # Load environment variables
20
+ load_dotenv()
21
+
22
+ # Rate limiting configuration
23
+ RATE_LIMIT_FILE = "/tmp/rate_limits.json"
24
+ MAX_REQUESTS_PER_DAY = 5
25
+
26
+ # Initialize rate limiting storage
27
+ def init_rate_limiting():
28
+ if not os.path.exists(RATE_LIMIT_FILE):
29
+ with open(RATE_LIMIT_FILE, 'w') as f:
30
+ json.dump({}, f)
31
+
32
+ # Check if a user has exceeded their daily limit
33
+ def check_rate_limit(user_id):
34
+ today = datetime.now().strftime('%Y-%m-%d')
35
+
36
+ try:
37
+ with open(RATE_LIMIT_FILE, 'r') as f:
38
+ rate_limits = json.load(f)
39
+ except (json.JSONDecodeError, FileNotFoundError):
40
+ rate_limits = {}
41
+
42
+ # Clean up old entries
43
+ yesterday = (datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d')
44
+ users_to_remove = []
45
+ for uid in rate_limits:
46
+ if yesterday in rate_limits[uid]:
47
+ del rate_limits[uid][yesterday]
48
+ if not rate_limits[uid]: # If user has no other days, remove them
49
+ users_to_remove.append(uid)
50
+
51
+ for uid in users_to_remove:
52
+ del rate_limits[uid]
53
+
54
+ # Check and update current user's limit
55
+ if user_id not in rate_limits:
56
+ rate_limits[user_id] = {}
57
+
58
+ if today not in rate_limits[user_id]:
59
+ rate_limits[user_id][today] = 0
60
+
61
+ # Check if limit exceeded
62
+ if rate_limits[user_id][today] >= MAX_REQUESTS_PER_DAY:
63
+ return False, rate_limits[user_id][today]
64
+
65
+ # Increment count and save
66
+ rate_limits[user_id][today] += 1
67
+ with open(RATE_LIMIT_FILE, 'w') as f:
68
+ json.dump(rate_limits, f)
69
+
70
+ return True, rate_limits[user_id][today]
71
+
72
+ def get_user_id():
73
+ # For Streamlit, we'll use session_id as user identifier
74
+ if not hasattr(st.session_state, 'user_id'):
75
+ st.session_state.user_id = str(hash(datetime.now().strftime("%Y%m%d%H%M%S")))
76
+ return st.session_state.user_id
77
+
78
+ def get_remaining_queries(user_id):
79
+ today = datetime.now().strftime('%Y-%m-%d')
80
+
81
+ try:
82
+ with open(RATE_LIMIT_FILE, 'r') as f:
83
+ rate_limits = json.load(f)
84
+ except (json.JSONDecodeError, FileNotFoundError):
85
+ return MAX_REQUESTS_PER_DAY
86
+
87
+ count = rate_limits.get(user_id, {}).get(today, 0)
88
+ return MAX_REQUESTS_PER_DAY - count
89
+
90
+ # Set up page configuration
91
+ st.set_page_config(
92
+ page_title="Medical Assistant RAG Chatbot",
93
+ page_icon="🩺",
94
+ layout="centered"
95
+ )
96
+
97
+ # Initialize session state for chat history
98
+ if 'messages' not in st.session_state:
99
+ st.session_state.messages = []
100
+
101
+ # Initialize rate limiting
102
+ init_rate_limiting()
103
+
104
+ # Display remaining queries
105
+ user_id = get_user_id()
106
+ remaining_queries = get_remaining_queries(user_id)
107
+ st.sidebar.write(f"Remaining queries today: {remaining_queries}/{MAX_REQUESTS_PER_DAY}")
108
+
109
+ # Check for API keys
110
+ PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
111
+ OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
112
+
113
+ if not PINECONE_API_KEY or not OPENAI_API_KEY:
114
+ st.error("Missing API keys. Please set PINECONE_API_KEY and OPENAI_API_KEY environment variables.")
115
+ st.stop()
116
+
117
+ os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
118
+ os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
119
+
120
+ # Cache the RAG chain initialization
121
+ @st.cache_resource
122
+ def initialize_rag_chain():
123
+ try:
124
+ st.sidebar.write("Loading embeddings model...")
125
+ embeddings = download_hugging_face_embeddings()
126
+
127
+ st.sidebar.write("Connecting to Pinecone...")
128
+ index_name = "medprep"
129
+ docsearch = Pinecone.from_existing_index(
130
+ index_name=index_name,
131
+ embedding=embeddings
132
+ )
133
+
134
+ retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3})
135
+
136
+ st.sidebar.write("Initializing OpenAI...")
137
+ llm = OpenAI(temperature=0.4, max_tokens=500)
138
+
139
+ prompt = ChatPromptTemplate.from_messages([
140
+ ("system", system_prompt),
141
+ ("human", "{input}")
142
+ ])
143
+
144
+ question_answer_chain = create_stuff_documents_chain(llm, prompt)
145
+ rag_chain = create_retrieval_chain(retriever, question_answer_chain)
146
+
147
+ st.sidebar.success("✅ System initialized successfully!")
148
+ return rag_chain
149
+ except Exception as e:
150
+ st.sidebar.error(f"Error initializing system: {str(e)}")
151
+ import traceback
152
+ st.sidebar.text(traceback.format_exc())
153
+ return None
154
+
155
+ # Main app title
156
+ st.title("Medical Assistant Chatbot")
157
+ st.write("Ask me any medical question, and I'll try to help!")
158
+
159
+ # Initialize the RAG chain
160
+ rag_chain = initialize_rag_chain()
161
+
162
+ if rag_chain is None:
163
+ st.error("Failed to initialize the system. Please check the sidebar for error details.")
164
+ st.stop()
165
+
166
+ # Display chat history
167
+ for message in st.session_state.messages:
168
+ with st.chat_message(message["role"]):
169
+ st.markdown(message["content"])
170
+
171
+ # Get user input
172
+ if prompt := st.chat_input("Ask a question..."):
173
+ # Add user message to chat history
174
+ st.session_state.messages.append({"role": "user", "content": prompt})
175
+
176
+ # Display user message
177
+ with st.chat_message("user"):
178
+ st.markdown(prompt)
179
+
180
+ # Check rate limit
181
+ user_id = get_user_id()
182
+ allowed, count = check_rate_limit(user_id)
183
+
184
+ if not allowed:
185
+ response = f"⚠️ Daily limit reached. You've used {count} queries today. Please try again tomorrow."
186
+ else:
187
+ # Process the query with the RAG chain
188
+ with st.chat_message("assistant"):
189
+ with st.spinner("Thinking..."):
190
+ try:
191
+ result = rag_chain.invoke({"input": prompt})
192
+ response = result.get("answer", "Sorry, I couldn't find an answer to that.")
193
+ remaining = MAX_REQUESTS_PER_DAY - count
194
+ response += f"\n\n\n_You have {remaining} queries remaining today._"
195
+ except Exception as e:
196
+ response = f"Error processing your request: {str(e)}"
197
+
198
+ st.markdown(response)
199
+
200
+ # Add assistant response to chat history
201
+ st.session_state.messages.append({"role": "assistant", "content": response})
202
+
203
+ # Footer
204
+ st.markdown("---")
205
+ st.markdown("*This is a RAG-based medical assistant chatbot. It retrieves information from a medical knowledge base to answer your questions.*")