USMLEStep1Prep / app.py
Nahiyan14's picture
Update app.py
2568443 verified
import os
import traceback
# Try using a different directory path where you should have permissions
os.environ['TRANSFORMERS_CACHE'] = '/tmp/model_cache'
os.environ['HF_HOME'] = '/tmp/model_cache'
os.makedirs('/tmp/model_cache', exist_ok=True)
from flask import Flask, render_template, jsonify, request
from src.helper import download_hugging_face_embeddings
from langchain_community.vectorstores import Pinecone
from langchain_openai import OpenAI
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from dotenv import load_dotenv
from src.prompt import *
app = Flask(__name__)
# Load environment variables - these will be set in Hugging Face Space secrets
load_dotenv() # Still useful for local development
print("Starting application initialization")
print(f"Python version: {os.sys.version}")
# Add debugging endpoints
@app.route("/test")
def test():
return "Flask app is working. This is a test endpoint."
@app.route("/check-env")
def check_env():
has_pinecone = "Yes" if os.environ.get("PINECONE_API_KEY") else "No"
has_openai = "Yes" if os.environ.get("OPENAI_API_KEY") else "No"
# Check if keys appear valid (without revealing them)
pinecone_valid = len(os.environ.get("PINECONE_API_KEY", "")) > 10 if has_pinecone == "Yes" else "N/A"
openai_valid = os.environ.get("OPENAI_API_KEY", "").startswith("sk-") if has_openai == "Yes" else "N/A"
return f"Pinecone key present: {has_pinecone} (appears valid: {pinecone_valid})<br>OpenAI key present: {has_openai} (appears valid: {openai_valid})"
print("Checking environment variables...")
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')
if not PINECONE_API_KEY:
print("WARNING: Missing PINECONE_API_KEY")
if not OPENAI_API_KEY:
print("WARNING: Missing OPENAI_API_KEY")
if not PINECONE_API_KEY or not OPENAI_API_KEY:
print("CRITICAL ERROR: Missing API keys")
# We'll continue anyway to allow debugging, but the app won't work properly
os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
# Initialize embeddings and chain at startup
embeddings = None
rag_chain = None
def initialize_chain():
global embeddings, rag_chain
try:
print("Step 1: Starting to download embeddings")
embeddings = download_hugging_face_embeddings()
print("Step 2: Successfully downloaded embeddings")
index_name = "medprep"
print(f"Step 3: Connecting to Pinecone index: {index_name}")
try:
from pinecone import Pinecone as PineconeClient
pc = PineconeClient(api_key=PINECONE_API_KEY)
# List available indexes to verify connection
indexes = pc.list_indexes()
print(f"Available Pinecone indexes: {indexes}")
if index_name not in [idx.name for idx in indexes]:
print(f"WARNING: Index '{index_name}' not found in your Pinecone account!")
except Exception as e:
print(f"Failed to connect to Pinecone API: {e}")
docsearch = Pinecone.from_existing_index(
index_name=index_name,
embedding=embeddings
)
print("Step 4: Successfully connected to Pinecone")
retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k":3})
print("Step 5: Created retriever")
print("Step 6: Initializing OpenAI")
llm = OpenAI(temperature=0.4, max_tokens=500)
print("Step 7: OpenAI initialized")
print("Step 8: Creating prompt template")
prompt = ChatPromptTemplate.from_messages(
[
("system", system_prompt),
("human", "{input}"),
]
)
print("Step 9: Creating QA chain")
question_answer_chain = create_stuff_documents_chain(llm, prompt)
print("Step 10: Creating RAG chain")
rag_chain = create_retrieval_chain(retriever, question_answer_chain)
print("Step 11: RAG chain initialized successfully")
return True
except Exception as e:
print(f"Failed to initialize RAG chain: {e}")
print(f"Error type: {type(e)}")
traceback.print_exc()
return False
# Initialize the chain when the application starts
print("Starting chain initialization...")
initialization_result = initialize_chain()
print(f"Chain initialization result: {initialization_result}")
@app.route("/")
def index():
return render_template('chat.html')
@app.route("/get", methods=["GET", "POST"])
def chat():
global rag_chain
# Make sure chain is initialized
if rag_chain is None:
print("RAG chain not initialized, attempting to initialize again...")
if not initialize_chain():
return "Error: System not initialized properly. Please check the logs."
msg = request.form["msg"]
try:
print(f"Processing message: {msg[:30]}...") # Log only first 30 chars for privacy
response = rag_chain.invoke({"input": msg})
print("Successfully generated response")
return str(response["answer"])
except Exception as e:
error_msg = f"Error processing request: {e}"
print(error_msg)
traceback.print_exc()
return f"Error: {str(e)}"
# Health check endpoint for monitoring
@app.route("/health")
def health_check():
is_initialized = rag_chain is not None
return jsonify({
"status": "healthy",
"rag_chain_initialized": is_initialized,
"embeddings_loaded": embeddings is not None
})
if __name__ == '__main__':
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False)