File size: 6,418 Bytes
bea2229
 
84301bc
bea2229
84301bc
 
 
bea2229
 
bbb0818
bea2229
 
 
 
 
 
84301bc
e2535b2
84301bc
 
 
 
bea2229
 
 
84301bc
bea2229
84301bc
 
 
bea2229
 
84301bc
 
 
bea2229
 
84301bc
 
bea2229
84301bc
 
bea2229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84301bc
 
bea2229
84301bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2535b2
84301bc
bea2229
84301bc
 
bea2229
 
84301bc
 
 
bea2229
84301bc
dd95a79
84301bc
bea2229
84301bc
 
 
 
 
 
 
 
 
 
 
 
 
 
bea2229
 
84301bc
 
bea2229
 
 
84301bc
 
 
bea2229
84301bc
bea2229
 
84301bc
 
bea2229
 
 
 
 
84301bc
e2535b2
84301bc
 
bea2229
84301bc
 
bea2229
84301bc
bea2229
 
84301bc
 
 
bea2229
84301bc
bea2229
 
84301bc
bea2229
 
84301bc
 
 
bea2229
84301bc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import os
import time
import json
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
from dotenv import load_dotenv
import logging
from langchain_groq import ChatGroq
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain_huggingface import HuggingFaceEmbeddings

logging.basicConfig(level=logging.DEBUG)

# ==========================================================
#  Load environment variables
# ==========================================================
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")

if not groq_api_key:
    raise ValueError("❌ GROQ_API_KEY not found. Please set it in your .env file or as an environment variable.")

# ==========================================================
#  Initialize LLM
# ==========================================================
llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.1-8b-instant")

# ==========================================================
#  Function: Load / Build Retrieval Chain
# ==========================================================
def load_retrieval_chain():
    """
    Loads or builds the FAISS vector index and creates a retrieval chain.
    This is now lazy-loaded to prevent Gunicorn worker boot crashes.
    """
    print("πŸ”„ Initializing retrieval chain...")

    prompt_template = """
You are a friendly and helpful hotel assistant.
Your role is to provide clear, welcoming, and professional responses to guest questions.
You MUST respond in a valid JSON format.

The JSON object must have two keys:
1. "intent": (string) This will always be "qa" for this version.
2. "response": (string) Your natural language, conversational response to the user.

RULES:
- Base your answers ONLY on the provided context. If the information is not in the context,
  politely say "I'm sorry, I don't have that information, but I can connect you with our front desk for assistance."
- Do not make up information.

<context>
{context}
<context>

Question: {input}

Your JSON Response:
"""
    prompt = ChatPromptTemplate.from_template(prompt_template)

    # --- Load Embeddings ---
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

    # --- Create or Load FAISS Vectorstore ---
    if not os.path.exists("data"):
        os.makedirs("data")
        print("⚠️ 'data' folder created. Please add your PDFs and restart.")
        raise ValueError("No PDFs found in 'data' folder.")

    if os.path.exists("faiss_index"):
        print("βœ… Loading existing FAISS index...")
        vectors = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
    else:
        print("πŸ“„ Loading PDFs and building FAISS index (first-time setup)...")
        loader = PyPDFDirectoryLoader("data")
        docs = loader.load()
        if not docs:
            raise ValueError("No PDF documents found in 'data' folder.")
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        final_docs = text_splitter.split_documents(docs[:50])
        vectors = FAISS.from_documents(final_docs, embeddings)
        vectors.save_local("faiss_index")
        print("πŸ’Ύ FAISS index saved to 'faiss_index' for future runs.")

    # --- Create Chains ---
    retriever = vectors.as_retriever()
    document_chain = create_stuff_documents_chain(llm, prompt)
    retrieval_chain = create_retrieval_chain(retriever, document_chain)

    print("βœ… Retrieval chain initialized successfully.")
    return retrieval_chain

# ==========================================================
#  Flask App Setup
# ==========================================================
app = Flask(__name__)
CORS(app)

retrieval_chain = None  # Lazy-load later

@app.before_request
def init_retrieval():
    """Initialize retrieval chain after Flask starts (prevents Gunicorn crash)."""
    global retrieval_chain
    if retrieval_chain is None:
        try:
            retrieval_chain = load_retrieval_chain()
        except Exception as e:
            print(f"❌ Failed to initialize retrieval chain: {e}")
            retrieval_chain = None

# ==========================================================
#  Routes
# ==========================================================
@app.route("/")
def index():
    """Serve main web page."""
    return render_template("index.html")

@app.route("/chat", methods=["POST"])
def chat():
    """Main chat endpoint."""
    global retrieval_chain

    if retrieval_chain is None:
        return jsonify({"error": "Vector database not initialized. Try again in a few seconds."}), 500

    try:
        user_input = request.json.get("message")
        app.logger.info(f"Received user input: {user_input}")
        data = request.json
        user_query = data.get("query")
        if not user_query:
            return jsonify({"error": "No query provided"}), 400

        print(f"πŸ’¬ Received query: {user_query}")
        start = time.process_time()

        # Run retrieval chain
        response = retrieval_chain.invoke({'input': user_query})
        elapsed = time.process_time() - start
        print(f"⏱️ Response time: {elapsed:.3f} sec")

        # Parse LLM JSON
        try:
            llm_output_str = response['answer']
            parsed = json.loads(llm_output_str)
            parsed["context"] = [doc.page_content for doc in response['context']]
            return jsonify(parsed)
        except json.JSONDecodeError:
            print(f"⚠️ Invalid JSON from LLM: {response.get('answer', '')}")
            return jsonify({"intent": "qa", "response": "I'm sorry, I had a small glitch. Could you rephrase that?"})
    except Exception as e:
        print(f"❌ Error during chat request: {e}")
        return jsonify({"error": str(e)}), 500

# ==========================================================
#  App Runner (for local debugging)
# ==========================================================
if __name__ == "__main__":
    print("πŸš€ Starting Flask development server...")
    app.run(host="0.0.0.0", port=7860, debug=True)