USMLEPrepAI / app.py
Nahiyan14's picture
Update app.py
bda186e verified
import os
import streamlit as st
import json
from datetime import datetime, timedelta
from src.helper import download_hugging_face_embeddings
from langchain_community.vectorstores import Pinecone
from langchain_openai import OpenAI
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from dotenv import load_dotenv
from src.prompt import system_prompt
# Set up cache directories
os.environ['TRANSFORMERS_CACHE'] = '/tmp/model_cache'
os.environ['HF_HOME'] = '/tmp/model_cache'
os.makedirs('/tmp/model_cache', exist_ok=True)
# Load environment variables
load_dotenv()
# Rate limiting configuration
RATE_LIMIT_FILE = "/tmp/rate_limits.json"
MAX_REQUESTS_PER_DAY = 5
# Initialize rate limiting storage
def init_rate_limiting():
if not os.path.exists(RATE_LIMIT_FILE):
with open(RATE_LIMIT_FILE, 'w') as f:
json.dump({}, f)
# Check if a user has exceeded their daily limit
def check_rate_limit(user_id):
today = datetime.now().strftime('%Y-%m-%d')
try:
with open(RATE_LIMIT_FILE, 'r') as f:
rate_limits = json.load(f)
except (json.JSONDecodeError, FileNotFoundError):
rate_limits = {}
# Clean up old entries
yesterday = (datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d')
users_to_remove = []
for uid in rate_limits:
if yesterday in rate_limits[uid]:
del rate_limits[uid][yesterday]
if not rate_limits[uid]: # If user has no other days, remove them
users_to_remove.append(uid)
for uid in users_to_remove:
del rate_limits[uid]
# Check and update current user's limit
if user_id not in rate_limits:
rate_limits[user_id] = {}
if today not in rate_limits[user_id]:
rate_limits[user_id][today] = 0
# Check if limit exceeded
if rate_limits[user_id][today] >= MAX_REQUESTS_PER_DAY:
return False, rate_limits[user_id][today]
# Increment count and save
rate_limits[user_id][today] += 1
with open(RATE_LIMIT_FILE, 'w') as f:
json.dump(rate_limits, f)
return True, rate_limits[user_id][today]
def get_user_id():
# For Streamlit, we'll use session_id as user identifier
if not hasattr(st.session_state, 'user_id'):
st.session_state.user_id = str(hash(datetime.now().strftime("%Y%m%d%H%M%S")))
return st.session_state.user_id
def get_remaining_queries(user_id):
today = datetime.now().strftime('%Y-%m-%d')
try:
with open(RATE_LIMIT_FILE, 'r') as f:
rate_limits = json.load(f)
except (json.JSONDecodeError, FileNotFoundError):
return MAX_REQUESTS_PER_DAY
count = rate_limits.get(user_id, {}).get(today, 0)
return MAX_REQUESTS_PER_DAY - count
# Set up page configuration
st.set_page_config(
page_title="USMLE Step 1 AI",
page_icon="🩺",
layout="centered",
initial_sidebar_state="expanded"
)
# Apply custom CSS for better visual appearance
st.markdown("""
<style>
.main-header {
font-size: 2.5rem !important;
margin-bottom: 1rem !important;
color: #2c3e50;
}
.sub-header {
font-size: 1.2rem !important;
color: #34495e;
margin-bottom: 2rem !important;
}
.stAlert {
padding: 15px !important;
border-radius: 8px !important;
}
.footer-text {
font-size: 0.85rem !important;
color: #7f8c8d;
}
.stChatMessage div[data-testid="stChatMessageContent"] {
border-radius: 15px !important;
padding: 15px !important;
}
.user-message {
background-color: #f1f8ff !important;
}
.assistant-message {
background-color: #f9f9f9 !important;
}
</style>
""", unsafe_allow_html=True)
# Initialize session state for chat history
if 'messages' not in st.session_state:
st.session_state.messages = []
# Initialize rate limiting
init_rate_limiting()
# Sidebar content
with st.sidebar:
st.image("https://online.flipbuilder.com/clinical-library/vxes/files/shot.png", width=80)
st.markdown("### USMLE Step 1 Assistant")
st.markdown("---")
# Display remaining queries with visual indicator
user_id = get_user_id()
remaining_queries = get_remaining_queries(user_id)
# Determine styling based on remaining queries
status_color = "#4CAF50" # Default green for good status
if remaining_queries <= 2:
status_color = "#F44336" # Red for low queries
elif remaining_queries <= 3:
status_color = "#FFC107" # Yellow/amber for warning
# Create a universally visible usage indicator
st.markdown("""
<style>
.usage-container {
border-radius: 8px;
padding: 15px;
margin-bottom: 20px;
border-left: 5px solid var(--status-color);
background-color: rgba(240, 240, 240, 0.3);
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1);
}
.usage-title {
font-weight: 600;
margin-bottom: 8px;
color: #333333;
}
.usage-value {
font-size: 1.2rem;
font-weight: 700;
color: #333333;
}
/* Dark mode specific styles */
@media (prefers-color-scheme: dark) {
.usage-container {
background-color: rgba(70, 70, 70, 0.2);
}
.usage-title, .usage-value {
color: #FFFFFF;
}
}
</style>
""", unsafe_allow_html=True)
st.markdown(f"""
<div class="usage-container" style="--status-color: {status_color}">
<div class="usage-title">Daily Usage</div>
<div class="usage-value">{remaining_queries}/{MAX_REQUESTS_PER_DAY} queries remaining</div>
</div>
""", unsafe_allow_html=True)
# Help section in sidebar
with st.expander("ℹ️ How to use"):
st.markdown("""
1. Type your USMLE Step 1 question in the chat input
2. The AI will search First Aid content and respond
3. You have 5 queries per day
**Best for:**
- Fact checking First Aid content
- Understanding complex topics
- Quick reference during study
""")
with st.expander("🔍 Example Questions"):
st.markdown("""
- "Explain the Krebs cycle"
- "What are the symptoms of Parkinson's disease?"
- "Differentiate between type 1 and type 2 diabetes"
- "What antibiotics are used for MRSA?"
""")
# Check for API keys
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
if not PINECONE_API_KEY or not OPENAI_API_KEY:
st.error("⚠️ Missing API keys. Please set PINECONE_API_KEY and OPENAI_API_KEY environment variables.")
st.stop()
os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
# Cache the RAG chain initialization
@st.cache_resource
def initialize_rag_chain():
try:
progress_text = st.sidebar.empty()
progress_bar = st.sidebar.progress(0)
# Step 1: Load embeddings
progress_text.text("Loading embeddings model... (1/4)")
embeddings = download_hugging_face_embeddings()
progress_bar.progress(25)
# Step 2: Connect to Pinecone
progress_text.text("Connecting to Pinecone database... (2/4)")
index_name = "medprep"
docsearch = Pinecone.from_existing_index(
index_name=index_name,
embedding=embeddings
)
progress_bar.progress(50)
# Step 3: Set up retriever
progress_text.text("Setting up retrieval system... (3/4)")
retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3})
progress_bar.progress(75)
# Step 4: Initialize LLM and chain
progress_text.text("Initializing language model... (4/4)")
llm = OpenAI(temperature=0.4, max_tokens=500)
prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
("human", "{input}")
])
question_answer_chain = create_stuff_documents_chain(llm, prompt)
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
progress_bar.progress(100)
# Clean up progress indicators
progress_text.empty()
progress_bar.empty()
st.sidebar.success("✅ System initialized successfully!")
return rag_chain
except Exception as e:
st.sidebar.error(f"⚠️ Error initializing system: {str(e)}")
import traceback
st.sidebar.text(traceback.format_exc())
return None
# Main app content
st.markdown('<h1 class="main-header">First Aid USMLE Step 1 Assistant</h1>', unsafe_allow_html=True)
st.markdown('<p class="sub-header">Ask me any question from First Aid USMLE Step 1 book, and I\'ll try to help!</p>', unsafe_allow_html=True)
# Initialize the RAG chain
rag_chain = initialize_rag_chain()
if rag_chain is None:
st.error("⚠️ Failed to initialize the system. Please check the sidebar for error details.")
st.stop()
# Display chat history with improved styling
for i, message in enumerate(st.session_state.messages):
message_class = "user-message" if message["role"] == "user" else "assistant-message"
with st.chat_message(message["role"]):
st.markdown(f'<div class="{message_class}">{message["content"]}</div>', unsafe_allow_html=True)
# Get user input
if prompt := st.chat_input("Ask a USMLE Step 1 question..."):
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
# Display user message
with st.chat_message("user"):
st.markdown(f'<div class="user-message">{prompt}</div>', unsafe_allow_html=True)
# Check rate limit
user_id = get_user_id()
allowed, count = check_rate_limit(user_id)
if not allowed:
response = f"⚠️ **Daily limit reached**\n\nYou've used {count} queries today. Please try again tomorrow."
else:
# Process the query with the RAG chain
with st.chat_message("assistant"):
message_placeholder = st.empty()
with st.spinner("Searching First Aid content..."):
try:
result = rag_chain.invoke({"input": prompt})
response = result.get("answer", "Sorry, I couldn't find an answer to that.")
# Format the remaining queries notification
remaining = MAX_REQUESTS_PER_DAY - count
if remaining <= 1:
usage_note = f"⚠️ **{remaining} query remaining today**"
else:
usage_note = f"ℹ️ {remaining} queries remaining today"
# Add a separator and the usage note
response += f"\n\n---\n\n{usage_note}"
except Exception as e:
response = f"⚠️ **Error processing your request**\n\n{str(e)}"
message_placeholder.markdown(f'<div class="assistant-message">{response}</div>', unsafe_allow_html=True)
# Add assistant response to chat history
st.session_state.messages.append({"role": "assistant", "content": response})
# Footer with improved styling
st.markdown("---")
st.markdown("""
<div class="footer-text">
<p><strong>About this assistant</strong></p>
<p>This AI assistant uses retrieval augmented generation to provide information from First Aid USMLE Step 1 content.
It's designed to help with studying, but should not replace professional medical advice.</p>
<p><strong>Performance Data</strong></p>
<p>Our RAG-based system has been rigorously evaluated for accuracy and response quality.
<a href="https://github.com/Nahiyan140212/MedPrepAI-RAG" target="_blank">View detailed performance metrics on GitHub</a>
to learn about our testing methodology and results.</p>
<p>© 2025 USMLE Step 1 Assistant - Created by Nahiyan Noor</p>
</div>
""", unsafe_allow_html=True)
# Add a reset button at the bottom
if st.button("Clear Conversation"):
st.session_state.messages = []
st.rerun()