File size: 5,918 Bytes
90948b7
2568443
90948b7
cbd5a41
 
 
 
90948b7
b7d7e4b
 
1ad8736
b7d7e4b
 
 
 
 
 
ee2d79b
b7d7e4b
 
 
 
 
2568443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7d7e4b
 
 
2568443
 
 
 
 
b7d7e4b
2568443
 
b7d7e4b
 
 
 
 
 
 
 
 
 
 
2568443
b7d7e4b
2568443
 
b7d7e4b
2568443
 
 
 
 
 
 
 
 
 
 
 
 
b7d7e4b
ad51c6c
2568443
 
b7d7e4b
2568443
b7d7e4b
 
2568443
b7d7e4b
2568443
b7d7e4b
2568443
 
 
b7d7e4b
 
 
 
 
 
 
2568443
b7d7e4b
2568443
 
b7d7e4b
2568443
b7d7e4b
 
 
2568443
 
b7d7e4b
 
 
2568443
 
 
b7d7e4b
 
 
 
 
 
 
 
 
 
 
2568443
b7d7e4b
 
 
 
 
2568443
b7d7e4b
2568443
b7d7e4b
 
2568443
 
 
b7d7e4b
 
 
 
 
2568443
 
 
 
 
 
b7d7e4b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
import traceback

# Try using a different directory path where you should have permissions
os.environ['TRANSFORMERS_CACHE'] = '/tmp/model_cache'
os.environ['HF_HOME'] = '/tmp/model_cache'
os.makedirs('/tmp/model_cache', exist_ok=True)

from flask import Flask, render_template, jsonify, request
from src.helper import download_hugging_face_embeddings
from langchain_community.vectorstores import Pinecone
from langchain_openai import OpenAI
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from dotenv import load_dotenv
from src.prompt import *

app = Flask(__name__)

# Load environment variables - these will be set in Hugging Face Space secrets
load_dotenv()  # Still useful for local development

print("Starting application initialization")
print(f"Python version: {os.sys.version}")

# Add debugging endpoints
@app.route("/test")
def test():
    return "Flask app is working. This is a test endpoint."

@app.route("/check-env")
def check_env():
    has_pinecone = "Yes" if os.environ.get("PINECONE_API_KEY") else "No"
    has_openai = "Yes" if os.environ.get("OPENAI_API_KEY") else "No"
    
    # Check if keys appear valid (without revealing them)
    pinecone_valid = len(os.environ.get("PINECONE_API_KEY", "")) > 10 if has_pinecone == "Yes" else "N/A"
    openai_valid = os.environ.get("OPENAI_API_KEY", "").startswith("sk-") if has_openai == "Yes" else "N/A"
    
    return f"Pinecone key present: {has_pinecone} (appears valid: {pinecone_valid})<br>OpenAI key present: {has_openai} (appears valid: {openai_valid})"

print("Checking environment variables...")
PINECONE_API_KEY = os.environ.get('PINECONE_API_KEY')
OPENAI_API_KEY = os.environ.get('OPENAI_API_KEY')

if not PINECONE_API_KEY:
    print("WARNING: Missing PINECONE_API_KEY")
if not OPENAI_API_KEY:
    print("WARNING: Missing OPENAI_API_KEY")

if not PINECONE_API_KEY or not OPENAI_API_KEY:
    print("CRITICAL ERROR: Missing API keys")
    # We'll continue anyway to allow debugging, but the app won't work properly

os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY

# Initialize embeddings and chain at startup
embeddings = None
rag_chain = None

def initialize_chain():
    global embeddings, rag_chain
    try:
        print("Step 1: Starting to download embeddings")
        embeddings = download_hugging_face_embeddings()
        print("Step 2: Successfully downloaded embeddings")
        
        index_name = "medprep"
        print(f"Step 3: Connecting to Pinecone index: {index_name}")
        
        try:
            from pinecone import Pinecone as PineconeClient
            pc = PineconeClient(api_key=PINECONE_API_KEY)
            # List available indexes to verify connection
            indexes = pc.list_indexes()
            print(f"Available Pinecone indexes: {indexes}")
            
            if index_name not in [idx.name for idx in indexes]:
                print(f"WARNING: Index '{index_name}' not found in your Pinecone account!")
        except Exception as e:
            print(f"Failed to connect to Pinecone API: {e}")
        
        docsearch = Pinecone.from_existing_index(
            index_name=index_name,
            embedding=embeddings
        )
        print("Step 4: Successfully connected to Pinecone")
        
        retriever = docsearch.as_retriever(search_type="similarity", search_kwargs={"k":3})
        print("Step 5: Created retriever")
        
        print("Step 6: Initializing OpenAI")
        llm = OpenAI(temperature=0.4, max_tokens=500)
        print("Step 7: OpenAI initialized")
        
        print("Step 8: Creating prompt template")
        prompt = ChatPromptTemplate.from_messages(
            [
                ("system", system_prompt),
                ("human", "{input}"),
            ]
        )
        
        print("Step 9: Creating QA chain")
        question_answer_chain = create_stuff_documents_chain(llm, prompt)
        
        print("Step 10: Creating RAG chain")
        rag_chain = create_retrieval_chain(retriever, question_answer_chain)
        print("Step 11: RAG chain initialized successfully")
        return True
    except Exception as e:
        print(f"Failed to initialize RAG chain: {e}")
        print(f"Error type: {type(e)}")
        traceback.print_exc()
        return False

# Initialize the chain when the application starts
print("Starting chain initialization...")
initialization_result = initialize_chain()
print(f"Chain initialization result: {initialization_result}")

@app.route("/")
def index():
    return render_template('chat.html')

@app.route("/get", methods=["GET", "POST"])
def chat():
    global rag_chain
    
    # Make sure chain is initialized
    if rag_chain is None:
        print("RAG chain not initialized, attempting to initialize again...")
        if not initialize_chain():
            return "Error: System not initialized properly. Please check the logs."
    
    msg = request.form["msg"]
    try:
        print(f"Processing message: {msg[:30]}...")  # Log only first 30 chars for privacy
        response = rag_chain.invoke({"input": msg})
        print("Successfully generated response")
        return str(response["answer"])
    except Exception as e:
        error_msg = f"Error processing request: {e}"
        print(error_msg)
        traceback.print_exc()
        return f"Error: {str(e)}"

# Health check endpoint for monitoring
@app.route("/health")
def health_check():
    is_initialized = rag_chain is not None
    return jsonify({
        "status": "healthy", 
        "rag_chain_initialized": is_initialized,
        "embeddings_loaded": embeddings is not None
    })

if __name__ == '__main__':
    port = int(os.environ.get("PORT", 7860))
    app.run(host="0.0.0.0", port=port, debug=False)