codewithharsha commited on
Commit
e2535b2
·
verified ·
1 Parent(s): 7485d6e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +87 -87
main.py CHANGED
@@ -1,43 +1,39 @@
1
  import os
2
  import time
3
- import json
4
  from flask import Flask, request, jsonify, render_template
5
- from flask_cors import CORS
6
- from dotenv import load_dotenv
7
-
8
  from langchain_groq import ChatGroq
9
  from langchain_text_splitters import RecursiveCharacterTextSplitter
10
  from langchain.chains.combine_documents import create_stuff_documents_chain
 
11
  from langchain_core.prompts import ChatPromptTemplate
12
  from langchain.chains import create_retrieval_chain
13
  from langchain_community.vectorstores import FAISS
14
  from langchain_community.document_loaders import PyPDFDirectoryLoader
15
  from langchain_huggingface import HuggingFaceEmbeddings
 
16
 
17
- # ==========================================================
18
- # Load environment variables
19
- # ==========================================================
20
  load_dotenv()
 
 
21
  groq_api_key = os.getenv("GROQ_API_KEY")
22
 
23
  if not groq_api_key:
24
- raise ValueError("GROQ_API_KEY not found. Please set it in your .env file or as an environment variable.")
25
 
26
- # ==========================================================
27
- # Initialize LLM
28
- # ==========================================================
29
  llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.1-8b-instant")
30
 
31
- # ==========================================================
32
- # Function: Load / Build Retrieval Chain
33
- # ==========================================================
34
  def load_retrieval_chain():
35
  """
36
- Loads or builds the FAISS vector index and creates a retrieval chain.
37
- This is now lazy-loaded to prevent Gunicorn worker boot crashes.
38
  """
39
- print("🔄 Initializing retrieval chain...")
40
-
 
41
  prompt_template = """
42
  You are a friendly and helpful hotel assistant.
43
  Your role is to provide clear, welcoming, and professional responses to guest questions.
@@ -61,104 +57,108 @@ Question: {input}
61
  Your JSON Response:
62
  """
63
  prompt = ChatPromptTemplate.from_template(prompt_template)
64
-
65
- # --- Load Embeddings ---
66
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
67
-
68
- # --- Create or Load FAISS Vectorstore ---
69
- if not os.path.exists("data"):
70
- os.makedirs("data")
71
- print("⚠️ 'data' folder created. Please add your PDFs and restart.")
72
- raise ValueError("No PDFs found in 'data' folder.")
73
-
74
- if os.path.exists("faiss_index"):
75
- print(" Loading existing FAISS index...")
76
- vectors = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
77
- else:
78
- print("📄 Loading PDFs and building FAISS index (first-time setup)...")
79
- loader = PyPDFDirectoryLoader("data")
80
- docs = loader.load()
81
- if not docs:
82
- raise ValueError("No PDF documents found in 'data' folder.")
83
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
84
- final_docs = text_splitter.split_documents(docs[:50])
85
- vectors = FAISS.from_documents(final_docs, embeddings)
86
- vectors.save_local("faiss_index")
87
- print("💾 FAISS index saved to 'faiss_index' for future runs.")
88
-
89
- # --- Create Chains ---
90
- retriever = vectors.as_retriever()
91
  document_chain = create_stuff_documents_chain(llm, prompt)
 
 
 
 
92
  retrieval_chain = create_retrieval_chain(retriever, document_chain)
93
-
94
- print("✅ Retrieval chain initialized successfully.")
95
  return retrieval_chain
96
 
97
- # ==========================================================
98
- # Flask App Setup
99
- # ==========================================================
100
  app = Flask(__name__)
101
- CORS(app)
102
 
103
- retrieval_chain = None # Lazy-load later
 
 
 
 
 
104
 
105
- @app.before_request
106
- def init_retrieval():
107
- """Initialize retrieval chain after Flask starts (prevents Gunicorn crash)."""
108
- global retrieval_chain
109
- if retrieval_chain is None:
110
- try:
111
- retrieval_chain = load_retrieval_chain()
112
- except Exception as e:
113
- print(f"❌ Failed to initialize retrieval chain: {e}")
114
- retrieval_chain = None
115
-
116
- # ==========================================================
117
- # Routes
118
- # ==========================================================
119
  @app.route("/")
120
  def index():
121
- """Serve main web page."""
122
- return render_template("index.html")
 
 
123
 
124
  @app.route("/chat", methods=["POST"])
125
  def chat():
126
- """Main chat endpoint."""
127
- global retrieval_chain
128
-
 
129
  if retrieval_chain is None:
130
- return jsonify({"error": "Vector database not initialized. Try again in a few seconds."}), 500
131
 
132
  try:
133
  data = request.json
134
  user_query = data.get("query")
 
135
  if not user_query:
136
  return jsonify({"error": "No query provided"}), 400
137
 
138
- print(f"💬 Received query: {user_query}")
139
- start = time.process_time()
140
 
141
- # Run retrieval chain
 
 
142
  response = retrieval_chain.invoke({'input': user_query})
143
- elapsed = time.process_time() - start
144
- print(f"⏱️ Response time: {elapsed:.3f} sec")
 
145
 
146
- # Parse LLM JSON
147
  try:
 
148
  llm_output_str = response['answer']
149
- parsed = json.loads(llm_output_str)
150
- parsed["context"] = [doc.page_content for doc in response['context']]
151
- return jsonify(parsed)
 
 
 
 
 
 
 
152
  except json.JSONDecodeError:
153
- print(f"⚠️ Invalid JSON from LLM: {response.get('answer', '')}")
154
  return jsonify({"intent": "qa", "response": "I'm sorry, I had a small glitch. Could you rephrase that?"})
 
 
 
 
155
  except Exception as e:
156
- print(f"Error during chat request: {e}")
157
  return jsonify({"error": str(e)}), 500
158
 
159
- # ==========================================================
160
- # App Runner (for local debugging)
161
- # ==========================================================
162
  if __name__ == "__main__":
163
- print("🚀 Starting Flask development server...")
164
- app.run(host="0.0.0.0", port=5000, debug=True)
 
 
 
 
 
 
 
 
 
1
  import os
2
  import time
3
+ import json # Import JSON for parsing
4
  from flask import Flask, request, jsonify, render_template
5
+ from flask_cors import CORS
 
 
6
  from langchain_groq import ChatGroq
7
  from langchain_text_splitters import RecursiveCharacterTextSplitter
8
  from langchain.chains.combine_documents import create_stuff_documents_chain
9
+ # from langchain.chains import create_stuff_documents_chain
10
  from langchain_core.prompts import ChatPromptTemplate
11
  from langchain.chains import create_retrieval_chain
12
  from langchain_community.vectorstores import FAISS
13
  from langchain_community.document_loaders import PyPDFDirectoryLoader
14
  from langchain_huggingface import HuggingFaceEmbeddings
15
+ from dotenv import load_dotenv
16
 
17
+ # Load environment variables
 
 
18
  load_dotenv()
19
+
20
+ # --- LLM and API Key Setup ---
21
  groq_api_key = os.getenv("GROQ_API_KEY")
22
 
23
  if not groq_api_key:
24
+ raise ValueError("GROQ_API_KEY not found. Please set it in your .env file or as an environment variable.")
25
 
 
 
 
26
  llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.1-8b-instant")
27
 
28
+
 
 
29
  def load_retrieval_chain():
30
  """
31
+ Loads the vector database and creates the retrieval chain.
32
+ This function runs once when the server starts.
33
  """
34
+ print("Loading vector database... This may take a moment.")
35
+
36
+ # --- PROMPT TEMPLATE - Reverted to simple stateless version ---
37
  prompt_template = """
38
  You are a friendly and helpful hotel assistant.
39
  Your role is to provide clear, welcoming, and professional responses to guest questions.
 
57
  Your JSON Response:
58
  """
59
  prompt = ChatPromptTemplate.from_template(prompt_template)
60
+
 
61
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
62
+ loader = PyPDFDirectoryLoader("data")
63
+ docs = loader.load()
64
+
65
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
66
+ final_documents = text_splitter.split_documents(docs[:50])
67
+
68
+ vectors = FAISS.from_documents(final_documents, embeddings)
69
+
70
+ print("Vector database loaded successfully.")
71
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  document_chain = create_stuff_documents_chain(llm, prompt)
73
+
74
+ # 1. Create the retriever from the vector store
75
+ retriever = vectors.as_retriever()
76
+ # 2. Create the retrieval chain
77
  retrieval_chain = create_retrieval_chain(retriever, document_chain)
78
+
 
79
  return retrieval_chain
80
 
81
+ # --- Flask App Initialization ---
 
 
82
  app = Flask(__name__)
83
+ CORS(app) # Enable CORS for all routes in your app
84
 
85
+ # Load the retrieval chain ONCE when the app starts
86
+ try:
87
+ retrieval_chain = load_retrieval_chain()
88
+ except Exception as e:
89
+ print(f"Failed to load vector database on startup: {e}")
90
+ retrieval_chain = None
91
 
92
+ # --- NEW ROUTE TO SERVE YOUR WEBPAGE ---
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  @app.route("/")
94
  def index():
95
+ """
96
+ Serves the index.html file from the 'templates' folder.
97
+ """
98
+ return render_template('index.html')
99
 
100
  @app.route("/chat", methods=["POST"])
101
  def chat():
102
+ """
103
+ The main chat endpoint.
104
+ Receives a JSON with "query" and returns a JSON with "intent" and "response".
105
+ """
106
  if retrieval_chain is None:
107
+ return jsonify({"error": "Vector database is not initialized. Check server logs."}), 500
108
 
109
  try:
110
  data = request.json
111
  user_query = data.get("query")
112
+
113
  if not user_query:
114
  return jsonify({"error": "No query provided"}), 400
115
 
116
+ print(f"Received query: {user_query}")
 
117
 
118
+ start = time.process_time()
119
+
120
+ # Invoke the chain with the user's query
121
  response = retrieval_chain.invoke({'input': user_query})
122
+
123
+ response_time = time.process_time() - start
124
+ print(f"Response time: {response_time:.4f} seconds")
125
 
126
+ # --- Parse the JSON response from the LLM ---
127
  try:
128
+ # The LLM's answer is in the 'answer' field
129
  llm_output_str = response['answer']
130
+ # The LLM output itself is a JSON string, so we parse it.
131
+ parsed_response = json.loads(llm_output_str)
132
+
133
+ # We can also add the RAG context for debugging
134
+ parsed_response["context"] = [doc.page_content for doc in response['context']]
135
+
136
+ print(f"LLM Response: {parsed_response}")
137
+
138
+ return jsonify(parsed_response)
139
+
140
  except json.JSONDecodeError:
141
+ print(f"Error: LLM did not return valid JSON. Response was: {llm_output_str}")
142
  return jsonify({"intent": "qa", "response": "I'm sorry, I had a small glitch. Could you rephrase that?"})
143
+ except Exception as e:
144
+ print(f"Error parsing LLM response: {e}")
145
+ return jsonify({"intent": "qa", "response": "I'm sorry, I'm having trouble processing that request."})
146
+
147
  except Exception as e:
148
+ print(f"Error processing request: {e}")
149
  return jsonify({"error": str(e)}), 500
150
 
151
+ # --- /book ENDPOINT REMOVED ---
152
+
153
+ # --- Run the Flask Server ---
154
  if __name__ == "__main__":
155
+ # Ensure a 'data' directory exists
156
+ if not os.path.exists("data"):
157
+ os.makedirs("data")
158
+ print("Created 'data' directory. Please add your PDF files here and restart the server.")
159
+
160
+ # init_db() call removed
161
+
162
+ print("Starting Flask server...")
163
+ # Running on 0.0.0.0 makes it accessible on your network, ready for EC2
164
+ app.run(debug=True, host="0.0.0.0", port=7860)