Spaces:
Runtime error
Runtime error
File size: 6,936 Bytes
90948b7 ab12b63 b7d7e4b 1ad8736 b7d7e4b ab12b63 ee2d79b ab12b63 b7d7e4b ab12b63 b7d7e4b ab12b63 2568443 ab12b63 2568443 ab12b63 2568443 ab12b63 2568443 ab12b63 2568443 ab12b63 b7d7e4b ab12b63 b7d7e4b ab12b63 b7d7e4b ab12b63 b7d7e4b 2568443 ab12b63 b7d7e4b ad51c6c 2568443 b7d7e4b ab12b63 b7d7e4b ab12b63 b7d7e4b 2568443 ab12b63 b7d7e4b ab12b63 b7d7e4b ab12b63 b7d7e4b ab12b63 b7d7e4b ab12b63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import os
import streamlit as st
import json
from datetime import datetime, timedelta
from src.helper import download_hugging_face_embeddings
from langchain_community.vectorstores import Pinecone
from langchain_openai import OpenAI
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from dotenv import load_dotenv
from src.prompt import system_prompt
# Set up cache directories
os.environ['TRANSFORMERS_CACHE'] = '/tmp/model_cache'
os.environ['HF_HOME'] = '/tmp/model_cache'
os.makedirs('/tmp/model_cache', exist_ok=True)
# Load environment variables
load_dotenv()
# Rate limiting configuration
RATE_LIMIT_FILE = "/tmp/rate_limits.json"
MAX_REQUESTS_PER_DAY = 5
# Initialize rate limiting storage
def init_rate_limiting():
if not os.path.exists(RATE_LIMIT_FILE):
with open(RATE_LIMIT_FILE, 'w') as f:
json.dump({}, f)
# Check if a user has exceeded their daily limit
def check_rate_limit(user_id):
today = datetime.now().strftime('%Y-%m-%d')
try:
with open(RATE_LIMIT_FILE, 'r') as f:
rate_limits = json.load(f)
except (json.JSONDecodeError, FileNotFoundError):
rate_limits = {}
# Clean up old entries
yesterday = (datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d')
users_to_remove = []
for uid in rate_limits:
if yesterday in rate_limits[uid]:
del rate_limits[uid][yesterday]
if not rate_limits[uid]: # If user has no other days, remove them
users_to_remove.append(uid)
for uid in users_to_remove:
del rate_limits[uid]
# Check and update current user's limit
if user_id not in rate_limits:
rate_limits[user_id] = {}
if today not in rate_limits[user_id]:
rate_limits[user_id][today] = 0
# Check if limit exceeded
if rate_limits[user_id][today] >= MAX_REQUESTS_PER_DAY:
return False, rate_limits[user_id][today]
# Increment count and save
rate_limits[user_id][today] += 1
with open(RATE_LIMIT_FILE, 'w') as f:
json.dump(rate_limits, f)
return True, rate_limits[user_id][today]
def get_user_id():
# For Streamlit, we'll use session_id as user identifier
if not hasattr(st.session_state, 'user_id'):
st.session_state.user_id = str(hash(datetime.now().strftime("%Y%m%d%H%M%S")))
return st.session_state.user_id
def get_remaining_queries(user_id):
today = datetime.now().strftime('%Y-%m-%d')
try:
with open(RATE_LIMIT_FILE, 'r') as f:
rate_limits = json.load(f)
except (json.JSONDecodeError, FileNotFoundError):
return MAX_REQUESTS_PER_DAY
count = rate_limits.get(user_id, {}).get(today, 0)
return MAX_REQUESTS_PER_DAY - count
# Set up page configuration
st.set_page_config(
page_title="Medical Assistant RAG Chatbot",
page_icon="🩺",
layout="centered"
)
# Initialize session state for chat history
if 'messages' not in st.session_state:
st.session_state.messages = []
# Initialize rate limiting
init_rate_limiting()
# Display remaining queries
user_id = get_user_id()
remaining_queries = get_remaining_queries(user_id)
st.sidebar.write(f"Remaining queries today: {remaining_queries}/{MAX_REQUESTS_PER_DAY}")
# Check for API keys
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
if not PINECONE_API_KEY or not OPENAI_API_KEY:
st.error("Missing API keys. Please set PINECONE_API_KEY and OPENAI_API_KEY environment variables.")
st.stop()
os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
# Cache the RAG chain initialization
@st.cache_resource
def initialize_rag_chain():
try:
st.sidebar.write("Loading embeddings model...")
embeddings = download_hugging_face_embeddings()
st.sidebar.write("Connecting to Pinecone...")
index_name = "medprep"
docsearch = Pinecone.from_existing_index(
index_name=index_name,
embedding=embeddings
)
retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3})
st.sidebar.write("Initializing OpenAI...")
llm = OpenAI(temperature=0.4, max_tokens=500)
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
("human", "{input}")
])
question_answer_chain = create_stuff_documents_chain(llm, prompt)
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
st.sidebar.success("✅ System initialized successfully!")
return rag_chain
except Exception as e:
st.sidebar.error(f"Error initializing system: {str(e)}")
import traceback
st.sidebar.text(traceback.format_exc())
return None
# Main app title
st.title("Medical Assistant Chatbot")
st.write("Ask me any medical question, and I'll try to help!")
# Initialize the RAG chain
rag_chain = initialize_rag_chain()
if rag_chain is None:
st.error("Failed to initialize the system. Please check the sidebar for error details.")
st.stop()
# Display chat history
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# Get user input
if prompt := st.chat_input("Ask a question..."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Check rate limit
user_id = get_user_id()
allowed, count = check_rate_limit(user_id)
if not allowed:
response = f"⚠️ Daily limit reached. You've used {count} queries today. Please try again tomorrow."
else:
# Process the query with the RAG chain
with st.chat_message("assistant"):
with st.spinner("Thinking..."):
try:
result = rag_chain.invoke({"input": prompt})
response = result.get("answer", "Sorry, I couldn't find an answer to that.")
remaining = MAX_REQUESTS_PER_DAY - count
response += f"\n\n\n_You have {remaining} queries remaining today._"
except Exception as e:
response = f"Error processing your request: {str(e)}"
st.markdown(response)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
# Footer
st.markdown("---")
st.markdown("*This is a RAG-based medical assistant chatbot. It retrieves information from a medical knowledge base to answer your questions.*") |