Spaces:
Runtime error
Runtime error
| import os | |
| import streamlit as st | |
| import json | |
| from datetime import datetime, timedelta | |
| from src.helper import download_hugging_face_embeddings | |
| from langchain_community.vectorstores import Pinecone | |
| from langchain_openai import OpenAI | |
| from langchain.chains import create_retrieval_chain | |
| from langchain.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from dotenv import load_dotenv | |
| from src.prompt import system_prompt | |
| # Set up cache directories | |
| os.environ['TRANSFORMERS_CACHE'] = '/tmp/model_cache' | |
| os.environ['HF_HOME'] = '/tmp/model_cache' | |
| os.makedirs('/tmp/model_cache', exist_ok=True) | |
| # Load environment variables | |
| load_dotenv() | |
| # Rate limiting configuration | |
| RATE_LIMIT_FILE = "/tmp/rate_limits.json" | |
| MAX_REQUESTS_PER_DAY = 5 | |
| # Initialize rate limiting storage | |
| def init_rate_limiting(): | |
| if not os.path.exists(RATE_LIMIT_FILE): | |
| with open(RATE_LIMIT_FILE, 'w') as f: | |
| json.dump({}, f) | |
| # Check if a user has exceeded their daily limit | |
| def check_rate_limit(user_id): | |
| today = datetime.now().strftime('%Y-%m-%d') | |
| try: | |
| with open(RATE_LIMIT_FILE, 'r') as f: | |
| rate_limits = json.load(f) | |
| except (json.JSONDecodeError, FileNotFoundError): | |
| rate_limits = {} | |
| # Clean up old entries | |
| yesterday = (datetime.now() - timedelta(days=1)).strftime('%Y-%m-%d') | |
| users_to_remove = [] | |
| for uid in rate_limits: | |
| if yesterday in rate_limits[uid]: | |
| del rate_limits[uid][yesterday] | |
| if not rate_limits[uid]: # If user has no other days, remove them | |
| users_to_remove.append(uid) | |
| for uid in users_to_remove: | |
| del rate_limits[uid] | |
| # Check and update current user's limit | |
| if user_id not in rate_limits: | |
| rate_limits[user_id] = {} | |
| if today not in rate_limits[user_id]: | |
| rate_limits[user_id][today] = 0 | |
| # Check if limit exceeded | |
| if rate_limits[user_id][today] >= MAX_REQUESTS_PER_DAY: | |
| return False, rate_limits[user_id][today] | |
| # Increment count and save | |
| rate_limits[user_id][today] += 1 | |
| with open(RATE_LIMIT_FILE, 'w') as f: | |
| json.dump(rate_limits, f) | |
| return True, rate_limits[user_id][today] | |
| def get_user_id(): | |
| # For Streamlit, we'll use session_id as user identifier | |
| if not hasattr(st.session_state, 'user_id'): | |
| st.session_state.user_id = str(hash(datetime.now().strftime("%Y%m%d%H%M%S"))) | |
| return st.session_state.user_id | |
| def get_remaining_queries(user_id): | |
| today = datetime.now().strftime('%Y-%m-%d') | |
| try: | |
| with open(RATE_LIMIT_FILE, 'r') as f: | |
| rate_limits = json.load(f) | |
| except (json.JSONDecodeError, FileNotFoundError): | |
| return MAX_REQUESTS_PER_DAY | |
| count = rate_limits.get(user_id, {}).get(today, 0) | |
| return MAX_REQUESTS_PER_DAY - count | |
| # Set up page configuration | |
| st.set_page_config( | |
| page_title="Medical Assistant RAG Chatbot", | |
| page_icon="🩺", | |
| layout="centered" | |
| ) | |
| # Initialize session state for chat history | |
| if 'messages' not in st.session_state: | |
| st.session_state.messages = [] | |
| # Initialize rate limiting | |
| init_rate_limiting() | |
| # Display remaining queries | |
| user_id = get_user_id() | |
| remaining_queries = get_remaining_queries(user_id) | |
| st.sidebar.write(f"Remaining queries today: {remaining_queries}/{MAX_REQUESTS_PER_DAY}") | |
| # Check for API keys | |
| PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY') | |
| OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY') | |
| if not PINECONE_API_KEY or not OPENAI_API_KEY: | |
| st.error("Missing API keys. Please set PINECONE_API_KEY and OPENAI_API_KEY environment variables.") | |
| st.stop() | |
| os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY | |
| os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY | |
| # Cache the RAG chain initialization | |
| def initialize_rag_chain(): | |
| try: | |
| st.sidebar.write("Loading embeddings model...") | |
| embeddings = download_hugging_face_embeddings() | |
| st.sidebar.write("Connecting to Pinecone...") | |
| index_name = "medprep" | |
| docsearch = Pinecone.from_existing_index( | |
| index_name=index_name, | |
| embedding=embeddings | |
| ) | |
| retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
| st.sidebar.write("Initializing OpenAI...") | |
| llm = OpenAI(temperature=0.4, max_tokens=500) | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", system_prompt), | |
| ("human", "{input}") | |
| ]) | |
| question_answer_chain = create_stuff_documents_chain(llm, prompt) | |
| rag_chain = create_retrieval_chain(retriever, question_answer_chain) | |
| st.sidebar.success("✅ System initialized successfully!") | |
| return rag_chain | |
| except Exception as e: | |
| st.sidebar.error(f"Error initializing system: {str(e)}") | |
| import traceback | |
| st.sidebar.text(traceback.format_exc()) | |
| return None | |
| # Main app title | |
| st.title("Medical Assistant Chatbot") | |
| st.write("Ask me any medical question, and I'll try to help!") | |
| # Initialize the RAG chain | |
| rag_chain = initialize_rag_chain() | |
| if rag_chain is None: | |
| st.error("Failed to initialize the system. Please check the sidebar for error details.") | |
| st.stop() | |
| # Display chat history | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| # Get user input | |
| if prompt := st.chat_input("Ask a question..."): | |
| # Add user message to chat history | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| # Display user message | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| # Check rate limit | |
| user_id = get_user_id() | |
| allowed, count = check_rate_limit(user_id) | |
| if not allowed: | |
| response = f"⚠️ Daily limit reached. You've used {count} queries today. Please try again tomorrow." | |
| else: | |
| # Process the query with the RAG chain | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| try: | |
| result = rag_chain.invoke({"input": prompt}) | |
| response = result.get("answer", "Sorry, I couldn't find an answer to that.") | |
| remaining = MAX_REQUESTS_PER_DAY - count | |
| response += f"\n\n\n_You have {remaining} queries remaining today._" | |
| except Exception as e: | |
| response = f"Error processing your request: {str(e)}" | |
| st.markdown(response) | |
| # Add assistant response to chat history | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| # Footer | |
| st.markdown("---") | |
| st.markdown("*This is a RAG-based medical assistant chatbot. It retrieves information from a medical knowledge base to answer your questions.*") |