Nahiyan14 commited on
Commit
ab12b63
·
verified ·
1 Parent(s): 4aae3e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +163 -120
app.py CHANGED
@@ -1,12 +1,7 @@
1
  import os
2
- import traceback
3
-
4
- # Try using a different directory path where you should have permissions
5
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/model_cache'
6
- os.environ['HF_HOME'] = '/tmp/model_cache'
7
- os.makedirs('/tmp/model_cache', exist_ok=True)
8
-
9
- from flask import Flask, render_template, jsonify, request
10
  from src.helper import download_hugging_face_embeddings
11
  from langchain_community.vectorstores import Pinecone
12
  from langchain_openai import OpenAI
@@ -14,149 +9,197 @@ from langchain.chains import create_retrieval_chain
14
  from langchain.chains.combine_documents import create_stuff_documents_chain
15
  from langchain_core.prompts import ChatPromptTemplate
16
  from dotenv import load_dotenv
17
- from src.prompt import *
 
 
 
 
 
 
 
 
18
 
19
- app = Flask(__name__)
 
 
20
 
21
- # Load environment variables - these will be set in Hugging Face Space secrets
22
- load_dotenv() # Still useful for local development
 
 
 
23
 
24
- print("Starting application initialization")
25
- print(f"Python version: {os.sys.version}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Add debugging endpoints
28
- @app.route("/test")
29
- def test():
30
- return "Flask app is working. This is a test endpoint."
 
31
 
32
- @app.route("/check-env")
33
- def check_env():
34
- has_pinecone = "Yes" if os.environ.get("PINECONE_API_KEY") else "No"
35
- has_openai = "Yes" if os.environ.get("OPENAI_API_KEY") else "No"
36
 
37
- # Check if keys appear valid (without revealing them)
38
- pinecone_valid = len(os.environ.get("PINECONE_API_KEY", "")) > 10 if has_pinecone == "Yes" else "N/A"
39
- openai_valid = os.environ.get("OPENAI_API_KEY", "").startswith("sk-") if has_openai == "Yes" else "N/A"
 
 
40
 
41
- return f"Pinecone key present: {has_pinecone} (appears valid: {pinecone_valid})<br>OpenAI key present: {has_openai} (appears valid: {openai_valid})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- print("Checking environment variables...")
44
  PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
45
  OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
46
 
47
- if not PINECONE_API_KEY:
48
- print("WARNING: Missing PINECONE_API_KEY")
49
- if not OPENAI_API_KEY:
50
- print("WARNING: Missing OPENAI_API_KEY")
51
-
52
  if not PINECONE_API_KEY or not OPENAI_API_KEY:
53
- print("CRITICAL ERROR: Missing API keys")
54
- # We'll continue anyway to allow debugging, but the app won't work properly
55
 
56
  os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
57
  os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
58
 
59
- # Initialize embeddings and chain at startup
60
- embeddings = None
61
- rag_chain = None
62
-
63
- def initialize_chain():
64
- global embeddings, rag_chain
65
  try:
66
- print("Step 1: Starting to download embeddings")
67
  embeddings = download_hugging_face_embeddings()
68
- print("Step 2: Successfully downloaded embeddings")
69
 
 
70
  index_name = "medprep"
71
- print(f"Step 3: Connecting to Pinecone index: {index_name}")
72
-
73
- try:
74
- from pinecone import Pinecone as PineconeClient
75
- pc = PineconeClient(api_key=PINECONE_API_KEY)
76
- # List available indexes to verify connection
77
- indexes = pc.list_indexes()
78
- print(f"Available Pinecone indexes: {indexes}")
79
-
80
- if index_name not in [idx.name for idx in indexes]:
81
- print(f"WARNING: Index '{index_name}' not found in your Pinecone account!")
82
- except Exception as e:
83
- print(f"Failed to connect to Pinecone API: {e}")
84
-
85
  docsearch = Pinecone.from_existing_index(
86
  index_name=index_name,
87
  embedding=embeddings
88
  )
89
- print("Step 4: Successfully connected to Pinecone")
90
 
91
- retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k":3})
92
- print("Step 5: Created retriever")
93
 
94
- print("Step 6: Initializing OpenAI")
95
  llm = OpenAI(temperature=0.4, max_tokens=500)
96
- print("Step 7: OpenAI initialized")
97
 
98
- print("Step 8: Creating prompt template")
99
- prompt = ChatPromptTemplate.from_messages(
100
- [
101
- ("system", system_prompt),
102
- ("human", "{input}"),
103
- ]
104
- )
105
 
106
- print("Step 9: Creating QA chain")
107
  question_answer_chain = create_stuff_documents_chain(llm, prompt)
108
-
109
- print("Step 10: Creating RAG chain")
110
  rag_chain = create_retrieval_chain(retriever, question_answer_chain)
111
- print("Step 11: RAG chain initialized successfully")
112
- return True
 
113
  except Exception as e:
114
- print(f"Failed to initialize RAG chain: {e}")
115
- print(f"Error type: {type(e)}")
116
- traceback.print_exc()
117
- return False
118
-
119
- # Initialize the chain when the application starts
120
- print("Starting chain initialization...")
121
- initialization_result = initialize_chain()
122
- print(f"Chain initialization result: {initialization_result}")
123
-
124
- @app.route("/")
125
- def index():
126
- return render_template('chat.html')
127
-
128
- @app.route("/get", methods=["GET", "POST"])
129
- def chat():
130
- global rag_chain
 
 
 
 
 
 
 
 
131
 
132
- # Make sure chain is initialized
133
- if rag_chain is None:
134
- print("RAG chain not initialized, attempting to initialize again...")
135
- if not initialize_chain():
136
- return "Error: System not initialized properly. Please check the logs."
137
 
138
- msg = request.form["msg"]
139
- try:
140
- print(f"Processing message: {msg[:30]}...") # Log only first 30 chars for privacy
141
- response = rag_chain.invoke({"input": msg})
142
- print("Successfully generated response")
143
- return str(response["answer"])
144
- except Exception as e:
145
- error_msg = f"Error processing request: {e}"
146
- print(error_msg)
147
- traceback.print_exc()
148
- return f"Error: {str(e)}"
149
-
150
- # Health check endpoint for monitoring
151
- @app.route("/health")
152
- def health_check():
153
- is_initialized = rag_chain is not None
154
- return jsonify({
155
- "status": "healthy",
156
- "rag_chain_initialized": is_initialized,
157
- "embeddings_loaded": embeddings is not None
158
- })
159
-
160
- if __name__ == '__main__':
161
- port = int(os.environ.get("PORT", 7860))
162
- app.run(host="0.0.0.0", port=port, debug=False)
 
 
1
  import os
2
+ import streamlit as st
3
+ import json
4
+ from datetime import datetime, timedelta
 
 
 
 
 
5
  from src.helper import download_hugging_face_embeddings
6
  from langchain_community.vectorstores import Pinecone
7
  from langchain_openai import OpenAI
 
9
  from langchain.chains.combine_documents import create_stuff_documents_chain
10
  from langchain_core.prompts import ChatPromptTemplate
11
  from dotenv import load_dotenv
12
+ from src.prompt import system_prompt
13
+
14
+ # Set up cache directories
15
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/model_cache'
16
+ os.environ['HF_HOME'] = '/tmp/model_cache'
17
+ os.makedirs('/tmp/model_cache', exist_ok=True)
18
+
19
+ # Load environment variables
20
+ load_dotenv()
21
 
22
+ # Rate limiting configuration
23
+ RATE_LIMIT_FILE = "/tmp/rate_limits.json"
24
+ MAX_REQUESTS_PER_DAY = 5
25
 
26
+ # Initialize rate limiting storage
27
+ def init_rate_limiting():
28
+ if not os.path.exists(RATE_LIMIT_FILE):
29
+ with open(RATE_LIMIT_FILE, 'w') as f:
30
+ json.dump({}, f)
31
 
32
+ # Check if a user has exceeded their daily limit
33
+ def check_rate_limit(user_id):
34
+ today = datetime.now().strftime('%Y-%m-%d')
35
+
36
+ try:
37
+ with open(RATE_LIMIT_FILE, 'r') as f:
38
+ rate_limits = json.load(f)
39
+ except (json.JSONDecodeError, FileNotFoundError):
40
+ rate_limits = {}
41
+
42
+ # Clean up old entries
43
+ yesterday = (datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d')
44
+ users_to_remove = []
45
+ for uid in rate_limits:
46
+ if yesterday in rate_limits[uid]:
47
+ del rate_limits[uid][yesterday]
48
+ if not rate_limits[uid]: # If user has no other days, remove them
49
+ users_to_remove.append(uid)
50
+
51
+ for uid in users_to_remove:
52
+ del rate_limits[uid]
53
+
54
+ # Check and update current user's limit
55
+ if user_id not in rate_limits:
56
+ rate_limits[user_id] = {}
57
+
58
+ if today not in rate_limits[user_id]:
59
+ rate_limits[user_id][today] = 0
60
+
61
+ # Check if limit exceeded
62
+ if rate_limits[user_id][today] >= MAX_REQUESTS_PER_DAY:
63
+ return False, rate_limits[user_id][today]
64
+
65
+ # Increment count and save
66
+ rate_limits[user_id][today] += 1
67
+ with open(RATE_LIMIT_FILE, 'w') as f:
68
+ json.dump(rate_limits, f)
69
+
70
+ return True, rate_limits[user_id][today]
71
 
72
+ def get_user_id():
73
+ # For Streamlit, we'll use session_id as user identifier
74
+ if not hasattr(st.session_state, 'user_id'):
75
+ st.session_state.user_id = str(hash(datetime.now().strftime("%Y%m%d%H%M%S")))
76
+ return st.session_state.user_id
77
 
78
+ def get_remaining_queries(user_id):
79
+ today = datetime.now().strftime('%Y-%m-%d')
 
 
80
 
81
+ try:
82
+ with open(RATE_LIMIT_FILE, 'r') as f:
83
+ rate_limits = json.load(f)
84
+ except (json.JSONDecodeError, FileNotFoundError):
85
+ return MAX_REQUESTS_PER_DAY
86
 
87
+ count = rate_limits.get(user_id, {}).get(today, 0)
88
+ return MAX_REQUESTS_PER_DAY - count
89
+
90
+ # Set up page configuration
91
+ st.set_page_config(
92
+ page_title="Medical Assistant RAG Chatbot",
93
+ page_icon="🩺",
94
+ layout="centered"
95
+ )
96
+
97
+ # Initialize session state for chat history
98
+ if 'messages' not in st.session_state:
99
+ st.session_state.messages = []
100
+
101
+ # Initialize rate limiting
102
+ init_rate_limiting()
103
+
104
+ # Display remaining queries
105
+ user_id = get_user_id()
106
+ remaining_queries = get_remaining_queries(user_id)
107
+ st.sidebar.write(f"Remaining queries today: {remaining_queries}/{MAX_REQUESTS_PER_DAY}")
108
 
109
+ # Check for API keys
110
  PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
111
  OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
112
 
 
 
 
 
 
113
  if not PINECONE_API_KEY or not OPENAI_API_KEY:
114
+ st.error("Missing API keys. Please set PINECONE_API_KEY and OPENAI_API_KEY environment variables.")
115
+ st.stop()
116
 
117
  os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
118
  os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
119
 
120
+ # Cache the RAG chain initialization
121
+ @st.cache_resource
122
+ def initialize_rag_chain():
 
 
 
123
  try:
124
+ st.sidebar.write("Loading embeddings model...")
125
  embeddings = download_hugging_face_embeddings()
 
126
 
127
+ st.sidebar.write("Connecting to Pinecone...")
128
  index_name = "medprep"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  docsearch = Pinecone.from_existing_index(
130
  index_name=index_name,
131
  embedding=embeddings
132
  )
 
133
 
134
+ retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3})
 
135
 
136
+ st.sidebar.write("Initializing OpenAI...")
137
  llm = OpenAI(temperature=0.4, max_tokens=500)
 
138
 
139
+ prompt = ChatPromptTemplate.from_messages([
140
+ ("system", system_prompt),
141
+ ("human", "{input}")
142
+ ])
 
 
 
143
 
 
144
  question_answer_chain = create_stuff_documents_chain(llm, prompt)
 
 
145
  rag_chain = create_retrieval_chain(retriever, question_answer_chain)
146
+
147
+ st.sidebar.success("✅ System initialized successfully!")
148
+ return rag_chain
149
  except Exception as e:
150
+ st.sidebar.error(f"Error initializing system: {str(e)}")
151
+ import traceback
152
+ st.sidebar.text(traceback.format_exc())
153
+ return None
154
+
155
+ # Main app title
156
+ st.title("Medical Assistant Chatbot")
157
+ st.write("Ask me any medical question, and I'll try to help!")
158
+
159
+ # Initialize the RAG chain
160
+ rag_chain = initialize_rag_chain()
161
+
162
+ if rag_chain is None:
163
+ st.error("Failed to initialize the system. Please check the sidebar for error details.")
164
+ st.stop()
165
+
166
+ # Display chat history
167
+ for message in st.session_state.messages:
168
+ with st.chat_message(message["role"]):
169
+ st.markdown(message["content"])
170
+
171
+ # Get user input
172
+ if prompt := st.chat_input("Ask a question..."):
173
+ # Add user message to chat history
174
+ st.session_state.messages.append({"role": "user", "content": prompt})
175
 
176
+ # Display user message
177
+ with st.chat_message("user"):
178
+ st.markdown(prompt)
 
 
179
 
180
+ # Check rate limit
181
+ user_id = get_user_id()
182
+ allowed, count = check_rate_limit(user_id)
183
+
184
+ if not allowed:
185
+ response = f"⚠️ Daily limit reached. You've used {count} queries today. Please try again tomorrow."
186
+ else:
187
+ # Process the query with the RAG chain
188
+ with st.chat_message("assistant"):
189
+ with st.spinner("Thinking..."):
190
+ try:
191
+ result = rag_chain.invoke({"input": prompt})
192
+ response = result.get("answer", "Sorry, I couldn't find an answer to that.")
193
+ remaining = MAX_REQUESTS_PER_DAY - count
194
+ response += f"\n\n\n_You have {remaining} queries remaining today._"
195
+ except Exception as e:
196
+ response = f"Error processing your request: {str(e)}"
197
+
198
+ st.markdown(response)
199
+
200
+ # Add assistant response to chat history
201
+ st.session_state.messages.append({"role": "assistant", "content": response})
202
+
203
+ # Footer
204
+ st.markdown("---")
205
+ st.markdown("*This is a RAG-based medical assistant chatbot. It retrieves information from a medical knowledge base to answer your questions.*")