import os import streamlit as st import json from datetime import datetime, timedelta from src.helper import download_hugging_face_embeddings from langchain_community.vectorstores import Pinecone from langchain_openai import OpenAI from langchain.chains import create_retrieval_chain from langchain.chains.combine_documents import create_stuff_documents_chain from langchain_core.prompts import ChatPromptTemplate from dotenv import load_dotenv from src.prompt import system_prompt # Set up cache directories os.environ['TRANSFORMERS_CACHE'] = '/tmp/model_cache' os.environ['HF_HOME'] = '/tmp/model_cache' os.makedirs('/tmp/model_cache', exist_ok=True) # Load environment variables load_dotenv() # Rate limiting configuration RATE_LIMIT_FILE = "/tmp/rate_limits.json" MAX_REQUESTS_PER_DAY = 5 # Initialize rate limiting storage def init_rate_limiting(): if not os.path.exists(RATE_LIMIT_FILE): with open(RATE_LIMIT_FILE, 'w') as f: json.dump({}, f) # Check if a user has exceeded their daily limit def check_rate_limit(user_id): today = datetime.now().strftime('%Y-%m-%d') try: with open(RATE_LIMIT_FILE, 'r') as f: rate_limits = json.load(f) except (json.JSONDecodeError, FileNotFoundError): rate_limits = {} # Clean up old entries yesterday = (datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d') users_to_remove = [] for uid in rate_limits: if yesterday in rate_limits[uid]: del rate_limits[uid][yesterday] if not rate_limits[uid]: # If user has no other days, remove them users_to_remove.append(uid) for uid in users_to_remove: del rate_limits[uid] # Check and update current user's limit if user_id not in rate_limits: rate_limits[user_id] = {} if today not in rate_limits[user_id]: rate_limits[user_id][today] = 0 # Check if limit exceeded if rate_limits[user_id][today] >= MAX_REQUESTS_PER_DAY: return False, rate_limits[user_id][today] # Increment count and save rate_limits[user_id][today] += 1 with open(RATE_LIMIT_FILE, 'w') as f: json.dump(rate_limits, f) return True, rate_limits[user_id][today] def get_user_id(): # For Streamlit, we'll use session_id as user identifier if not hasattr(st.session_state, 'user_id'): st.session_state.user_id = str(hash(datetime.now().strftime("%Y%m%d%H%M%S"))) return st.session_state.user_id def get_remaining_queries(user_id): today = datetime.now().strftime('%Y-%m-%d') try: with open(RATE_LIMIT_FILE, 'r') as f: rate_limits = json.load(f) except (json.JSONDecodeError, FileNotFoundError): return MAX_REQUESTS_PER_DAY count = rate_limits.get(user_id, {}).get(today, 0) return MAX_REQUESTS_PER_DAY - count # Set up page configuration st.set_page_config( page_title="USMLE Step 1 AI", page_icon="🩺", layout="centered", initial_sidebar_state="expanded" ) # Apply custom CSS for better visual appearance st.markdown(""" """, unsafe_allow_html=True) # Initialize session state for chat history if 'messages' not in st.session_state: st.session_state.messages = [] # Initialize rate limiting init_rate_limiting() # Sidebar content with st.sidebar: st.image("https://online.flipbuilder.com/clinical-library/vxes/files/shot.png", width=80) st.markdown("### USMLE Step 1 Assistant") st.markdown("---") # Display remaining queries with visual indicator user_id = get_user_id() remaining_queries = get_remaining_queries(user_id) # Determine styling based on remaining queries status_color = "#4CAF50" # Default green for good status if remaining_queries <= 2: status_color = "#F44336" # Red for low queries elif remaining_queries <= 3: status_color = "#FFC107" # Yellow/amber for warning # Create a universally visible usage indicator st.markdown(""" """, unsafe_allow_html=True) st.markdown(f"""
Daily Usage
{remaining_queries}/{MAX_REQUESTS_PER_DAY} queries remaining
""", unsafe_allow_html=True) # Help section in sidebar with st.expander("ℹ️ How to use"): st.markdown(""" 1. Type your USMLE Step 1 question in the chat input 2. The AI will search First Aid content and respond 3. You have 5 queries per day **Best for:** - Fact checking First Aid content - Understanding complex topics - Quick reference during study """) with st.expander("🔍 Example Questions"): st.markdown(""" - "Explain the Krebs cycle" - "What are the symptoms of Parkinson's disease?" - "Differentiate between type 1 and type 2 diabetes" - "What antibiotics are used for MRSA?" """) # Check for API keys PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY') OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY') if not PINECONE_API_KEY or not OPENAI_API_KEY: st.error("⚠️ Missing API keys. Please set PINECONE_API_KEY and OPENAI_API_KEY environment variables.") st.stop() os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY # Cache the RAG chain initialization @st.cache_resource def initialize_rag_chain(): try: progress_text = st.sidebar.empty() progress_bar = st.sidebar.progress(0) # Step 1: Load embeddings progress_text.text("Loading embeddings model... (1/4)") embeddings = download_hugging_face_embeddings() progress_bar.progress(25) # Step 2: Connect to Pinecone progress_text.text("Connecting to Pinecone database... (2/4)") index_name = "medprep" docsearch = Pinecone.from_existing_index( index_name=index_name, embedding=embeddings ) progress_bar.progress(50) # Step 3: Set up retriever progress_text.text("Setting up retrieval system... (3/4)") retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3}) progress_bar.progress(75) # Step 4: Initialize LLM and chain progress_text.text("Initializing language model... (4/4)") llm = OpenAI(temperature=0.4, max_tokens=500) prompt = ChatPromptTemplate.from_messages([ ("system", system_prompt), ("human", "{input}") ]) question_answer_chain = create_stuff_documents_chain(llm, prompt) rag_chain = create_retrieval_chain(retriever, question_answer_chain) progress_bar.progress(100) # Clean up progress indicators progress_text.empty() progress_bar.empty() st.sidebar.success("✅ System initialized successfully!") return rag_chain except Exception as e: st.sidebar.error(f"⚠️ Error initializing system: {str(e)}") import traceback st.sidebar.text(traceback.format_exc()) return None # Main app content st.markdown('

First Aid USMLE Step 1 Assistant

', unsafe_allow_html=True) st.markdown('

Ask me any question from First Aid USMLE Step 1 book, and I\'ll try to help!

', unsafe_allow_html=True) # Initialize the RAG chain rag_chain = initialize_rag_chain() if rag_chain is None: st.error("⚠️ Failed to initialize the system. Please check the sidebar for error details.") st.stop() # Display chat history with improved styling for i, message in enumerate(st.session_state.messages): message_class = "user-message" if message["role"] == "user" else "assistant-message" with st.chat_message(message["role"]): st.markdown(f'
{message["content"]}
', unsafe_allow_html=True) # Get user input if prompt := st.chat_input("Ask a USMLE Step 1 question..."): # Add user message to chat history st.session_state.messages.append({"role": "user", "content": prompt}) # Display user message with st.chat_message("user"): st.markdown(f'
{prompt}
', unsafe_allow_html=True) # Check rate limit user_id = get_user_id() allowed, count = check_rate_limit(user_id) if not allowed: response = f"⚠️ **Daily limit reached**\n\nYou've used {count} queries today. Please try again tomorrow." else: # Process the query with the RAG chain with st.chat_message("assistant"): message_placeholder = st.empty() with st.spinner("Searching First Aid content..."): try: result = rag_chain.invoke({"input": prompt}) response = result.get("answer", "Sorry, I couldn't find an answer to that.") # Format the remaining queries notification remaining = MAX_REQUESTS_PER_DAY - count if remaining <= 1: usage_note = f"⚠️ **{remaining} query remaining today**" else: usage_note = f"ℹ️ {remaining} queries remaining today" # Add a separator and the usage note response += f"\n\n---\n\n{usage_note}" except Exception as e: response = f"⚠️ **Error processing your request**\n\n{str(e)}" message_placeholder.markdown(f'
{response}
', unsafe_allow_html=True) # Add assistant response to chat history st.session_state.messages.append({"role": "assistant", "content": response}) # Footer with improved styling st.markdown("---") st.markdown(""" """, unsafe_allow_html=True) # Add a reset button at the bottom if st.button("Clear Conversation"): st.session_state.messages = [] st.rerun()