Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| import json | |
| from datetime import datetime, timedelta | |
| from src.helper import download_hugging_face_embeddings | |
| from langchain_community.vectorstores import Pinecone | |
| from langchain_openai import OpenAI | |
| from langchain.chains import create_retrieval_chain | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from dotenv import load_dotenv | |
| from src.prompt import system_prompt | |
| # Set up cache directories | |
| os.environ['TRANSFORMERS_CACHE'] = '/tmp/model_cache' | |
| os.environ['HF_HOME'] = '/tmp/model_cache' | |
| os.makedirs('/tmp/model_cache', exist_ok=True) | |
| # Load environment variables | |
| load_dotenv() | |
| # Rate limiting configuration | |
| RATE_LIMIT_FILE = "/tmp/rate_limits.json" | |
| MAX_REQUESTS_PER_DAY = 5 | |
| # Initialize rate limiting storage | |
| def init_rate_limiting(): | |
| if not os.path.exists(RATE_LIMIT_FILE): | |
| with open(RATE_LIMIT_FILE, 'w') as f: | |
| json.dump({}, f) | |
| # Check if a user has exceeded their daily limit | |
| def check_rate_limit(user_id): | |
| today = datetime.now().strftime('%Y-%m-%d') | |
| try: | |
| with open(RATE_LIMIT_FILE, 'r') as f: | |
| rate_limits = json.load(f) | |
| except (json.JSONDecodeError, FileNotFoundError): | |
| rate_limits = {} | |
| # Clean up old entries | |
| yesterday = (datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d') | |
| users_to_remove = [] | |
| for uid in rate_limits: | |
| if yesterday in rate_limits[uid]: | |
| del rate_limits[uid][yesterday] | |
| if not rate_limits[uid]: # If user has no other days, remove them | |
| users_to_remove.append(uid) | |
| for uid in users_to_remove: | |
| del rate_limits[uid] | |
| # Check and update current user's limit | |
| if user_id not in rate_limits: | |
| rate_limits[user_id] = {} | |
| if today not in rate_limits[user_id]: | |
| rate_limits[user_id][today] = 0 | |
| # Check if limit exceeded | |
| if rate_limits[user_id][today] >= MAX_REQUESTS_PER_DAY: | |
| return False, rate_limits[user_id][today] | |
| # Increment count and save | |
| rate_limits[user_id][today] += 1 | |
| with open(RATE_LIMIT_FILE, 'w') as f: | |
| json.dump(rate_limits, f) | |
| return True, rate_limits[user_id][today] | |
| def get_user_id(): | |
| # For Streamlit, we'll use session_id as user identifier | |
| if not hasattr(st.session_state, 'user_id'): | |
| st.session_state.user_id = str(hash(datetime.now().strftime("%Y%m%d%H%M%S"))) | |
| return st.session_state.user_id | |
| def get_remaining_queries(user_id): | |
| today = datetime.now().strftime('%Y-%m-%d') | |
| try: | |
| with open(RATE_LIMIT_FILE, 'r') as f: | |
| rate_limits = json.load(f) | |
| except (json.JSONDecodeError, FileNotFoundError): | |
| return MAX_REQUESTS_PER_DAY | |
| count = rate_limits.get(user_id, {}).get(today, 0) | |
| return MAX_REQUESTS_PER_DAY - count | |
| # Set up page configuration | |
| st.set_page_config( | |
| page_title="USMLE Step 1 AI", | |
| page_icon="🩺", | |
| layout="centered", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Apply custom CSS for better visual appearance | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| font-size: 2.5rem !important; | |
| margin-bottom: 1rem !important; | |
| color: #2c3e50; | |
| } | |
| .sub-header { | |
| font-size: 1.2rem !important; | |
| color: #34495e; | |
| margin-bottom: 2rem !important; | |
| } | |
| .stAlert { | |
| padding: 15px !important; | |
| border-radius: 8px !important; | |
| } | |
| .footer-text { | |
| font-size: 0.85rem !important; | |
| color: #7f8c8d; | |
| } | |
| .stChatMessage div[data-testid="stChatMessageContent"] { | |
| border-radius: 15px !important; | |
| padding: 15px !important; | |
| } | |
| .user-message { | |
| background-color: #f1f8ff !important; | |
| } | |
| .assistant-message { | |
| background-color: #f9f9f9 !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Initialize session state for chat history | |
| if 'messages' not in st.session_state: | |
| st.session_state.messages = [] | |
| # Initialize rate limiting | |
| init_rate_limiting() | |
| # Sidebar content | |
| with st.sidebar: | |
| st.image("https://online.flipbuilder.com/clinical-library/vxes/files/shot.png", width=80) | |
| st.markdown("### USMLE Step 1 Assistant") | |
| st.markdown("---") | |
| # Display remaining queries with visual indicator | |
| user_id = get_user_id() | |
| remaining_queries = get_remaining_queries(user_id) | |
| # Determine styling based on remaining queries | |
| status_color = "#4CAF50" # Default green for good status | |
| if remaining_queries <= 2: | |
| status_color = "#F44336" # Red for low queries | |
| elif remaining_queries <= 3: | |
| status_color = "#FFC107" # Yellow/amber for warning | |
| # Create a universally visible usage indicator | |
| st.markdown(""" | |
| <style> | |
| .usage-container { | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin-bottom: 20px; | |
| border-left: 5px solid var(--status-color); | |
| background-color: rgba(240, 240, 240, 0.3); | |
| box-shadow: 0 2px 5px rgba(0, 0, 0, 0.1); | |
| } | |
| .usage-title { | |
| font-weight: 600; | |
| margin-bottom: 8px; | |
| color: #333333; | |
| } | |
| .usage-value { | |
| font-size: 1.2rem; | |
| font-weight: 700; | |
| color: #333333; | |
| } | |
| /* Dark mode specific styles */ | |
| @media (prefers-color-scheme: dark) { | |
| .usage-container { | |
| background-color: rgba(70, 70, 70, 0.2); | |
| } | |
| .usage-title, .usage-value { | |
| color: #FFFFFF; | |
| } | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.markdown(f""" | |
| <div class="usage-container" style="--status-color: {status_color}"> | |
| <div class="usage-title">Daily Usage</div> | |
| <div class="usage-value">{remaining_queries}/{MAX_REQUESTS_PER_DAY} queries remaining</div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Help section in sidebar | |
| with st.expander("ℹ️ How to use"): | |
| st.markdown(""" | |
| 1. Type your USMLE Step 1 question in the chat input | |
| 2. The AI will search First Aid content and respond | |
| 3. You have 5 queries per day | |
| **Best for:** | |
| - Fact checking First Aid content | |
| - Understanding complex topics | |
| - Quick reference during study | |
| """) | |
| with st.expander("🔍 Example Questions"): | |
| st.markdown(""" | |
| - "Explain the Krebs cycle" | |
| - "What are the symptoms of Parkinson's disease?" | |
| - "Differentiate between type 1 and type 2 diabetes" | |
| - "What antibiotics are used for MRSA?" | |
| """) | |
| # Check for API keys | |
| PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY') | |
| OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY') | |
| if not PINECONE_API_KEY or not OPENAI_API_KEY: | |
| st.error("⚠️ Missing API keys. Please set PINECONE_API_KEY and OPENAI_API_KEY environment variables.") | |
| st.stop() | |
| os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY | |
| os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY | |
| # Cache the RAG chain initialization | |
| def initialize_rag_chain(): | |
| try: | |
| progress_text = st.sidebar.empty() | |
| progress_bar = st.sidebar.progress(0) | |
| # Step 1: Load embeddings | |
| progress_text.text("Loading embeddings model... (1/4)") | |
| embeddings = download_hugging_face_embeddings() | |
| progress_bar.progress(25) | |
| # Step 2: Connect to Pinecone | |
| progress_text.text("Connecting to Pinecone database... (2/4)") | |
| index_name = "medprep" | |
| docsearch = Pinecone.from_existing_index( | |
| index_name=index_name, | |
| embedding=embeddings | |
| ) | |
| progress_bar.progress(50) | |
| # Step 3: Set up retriever | |
| progress_text.text("Setting up retrieval system... (3/4)") | |
| retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
| progress_bar.progress(75) | |
| # Step 4: Initialize LLM and chain | |
| progress_text.text("Initializing language model... (4/4)") | |
| llm = OpenAI(temperature=0.4, max_tokens=500) | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", system_prompt), | |
| ("human", "{input}") | |
| ]) | |
| question_answer_chain = create_stuff_documents_chain(llm, prompt) | |
| rag_chain = create_retrieval_chain(retriever, question_answer_chain) | |
| progress_bar.progress(100) | |
| # Clean up progress indicators | |
| progress_text.empty() | |
| progress_bar.empty() | |
| st.sidebar.success("✅ System initialized successfully!") | |
| return rag_chain | |
| except Exception as e: | |
| st.sidebar.error(f"⚠️ Error initializing system: {str(e)}") | |
| import traceback | |
| st.sidebar.text(traceback.format_exc()) | |
| return None | |
| # Main app content | |
| st.markdown('<h1 class="main-header">First Aid USMLE Step 1 Assistant</h1>', unsafe_allow_html=True) | |
| st.markdown('<p class="sub-header">Ask me any question from First Aid USMLE Step 1 book, and I\'ll try to help!</p>', unsafe_allow_html=True) | |
| # Initialize the RAG chain | |
| rag_chain = initialize_rag_chain() | |
| if rag_chain is None: | |
| st.error("⚠️ Failed to initialize the system. Please check the sidebar for error details.") | |
| st.stop() | |
| # Display chat history with improved styling | |
| for i, message in enumerate(st.session_state.messages): | |
| message_class = "user-message" if message["role"] == "user" else "assistant-message" | |
| with st.chat_message(message["role"]): | |
| st.markdown(f'<div class="{message_class}">{message["content"]}</div>', unsafe_allow_html=True) | |
| # Get user input | |
| if prompt := st.chat_input("Ask a USMLE Step 1 question..."): | |
| # Add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # Display user message | |
| with st.chat_message("user"): | |
| st.markdown(f'<div class="user-message">{prompt}</div>', unsafe_allow_html=True) | |
| # Check rate limit | |
| user_id = get_user_id() | |
| allowed, count = check_rate_limit(user_id) | |
| if not allowed: | |
| response = f"⚠️ **Daily limit reached**\n\nYou've used {count} queries today. Please try again tomorrow." | |
| else: | |
| # Process the query with the RAG chain | |
| with st.chat_message("assistant"): | |
| message_placeholder = st.empty() | |
| with st.spinner("Searching First Aid content..."): | |
| try: | |
| result = rag_chain.invoke({"input": prompt}) | |
| response = result.get("answer", "Sorry, I couldn't find an answer to that.") | |
| # Format the remaining queries notification | |
| remaining = MAX_REQUESTS_PER_DAY - count | |
| if remaining <= 1: | |
| usage_note = f"⚠️ **{remaining} query remaining today**" | |
| else: | |
| usage_note = f"ℹ️ {remaining} queries remaining today" | |
| # Add a separator and the usage note | |
| response += f"\n\n---\n\n{usage_note}" | |
| except Exception as e: | |
| response = f"⚠️ **Error processing your request**\n\n{str(e)}" | |
| message_placeholder.markdown(f'<div class="assistant-message">{response}</div>', unsafe_allow_html=True) | |
| # Add assistant response to chat history | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| # Footer with improved styling | |
| st.markdown("---") | |
| st.markdown(""" | |
| <div class="footer-text"> | |
| <p><strong>About this assistant</strong></p> | |
| <p>This AI assistant uses retrieval augmented generation to provide information from First Aid USMLE Step 1 content. | |
| It's designed to help with studying, but should not replace professional medical advice.</p> | |
| <p><strong>Performance Data</strong></p> | |
| <p>Our RAG-based system has been rigorously evaluated for accuracy and response quality. | |
| <a href="https://github.com/Nahiyan140212/MedPrepAI-RAG" target="_blank">View detailed performance metrics on GitHub</a> | |
| to learn about our testing methodology and results.</p> | |
| <p>© 2025 USMLE Step 1 Assistant - Created by Nahiyan Noor</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Add a reset button at the bottom | |
| if st.button("Clear Conversation"): | |
| st.session_state.messages = [] | |
| st.rerun() |