codewithharsha's picture
Update main.py
84301bc verified
import os
import time
import json
from flask import Flask, request, jsonify, render_template
from flask_cors import CORS
from dotenv import load_dotenv
import logging
from langchain_groq import ChatGroq
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains import create_retrieval_chain
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFDirectoryLoader
from langchain_huggingface import HuggingFaceEmbeddings
logging.basicConfig(level=logging.DEBUG)
# ==========================================================
# Load environment variables
# ==========================================================
load_dotenv()
groq_api_key = os.getenv("GROQ_API_KEY")
if not groq_api_key:
raise ValueError("❌ GROQ_API_KEY not found. Please set it in your .env file or as an environment variable.")
# ==========================================================
# Initialize LLM
# ==========================================================
llm = ChatGroq(groq_api_key=groq_api_key, model_name="llama-3.1-8b-instant")
# ==========================================================
# Function: Load / Build Retrieval Chain
# ==========================================================
def load_retrieval_chain():
"""
Loads or builds the FAISS vector index and creates a retrieval chain.
This is now lazy-loaded to prevent Gunicorn worker boot crashes.
"""
print("πŸ”„ Initializing retrieval chain...")
prompt_template = """
You are a friendly and helpful hotel assistant.
Your role is to provide clear, welcoming, and professional responses to guest questions.
You MUST respond in a valid JSON format.
The JSON object must have two keys:
1. "intent": (string) This will always be "qa" for this version.
2. "response": (string) Your natural language, conversational response to the user.
RULES:
- Base your answers ONLY on the provided context. If the information is not in the context,
politely say "I'm sorry, I don't have that information, but I can connect you with our front desk for assistance."
- Do not make up information.
<context>
{context}
<context>
Question: {input}
Your JSON Response:
"""
prompt = ChatPromptTemplate.from_template(prompt_template)
# --- Load Embeddings ---
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# --- Create or Load FAISS Vectorstore ---
if not os.path.exists("data"):
os.makedirs("data")
print("⚠️ 'data' folder created. Please add your PDFs and restart.")
raise ValueError("No PDFs found in 'data' folder.")
if os.path.exists("faiss_index"):
print("βœ… Loading existing FAISS index...")
vectors = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True)
else:
print("πŸ“„ Loading PDFs and building FAISS index (first-time setup)...")
loader = PyPDFDirectoryLoader("data")
docs = loader.load()
if not docs:
raise ValueError("No PDF documents found in 'data' folder.")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
final_docs = text_splitter.split_documents(docs[:50])
vectors = FAISS.from_documents(final_docs, embeddings)
vectors.save_local("faiss_index")
print("πŸ’Ύ FAISS index saved to 'faiss_index' for future runs.")
# --- Create Chains ---
retriever = vectors.as_retriever()
document_chain = create_stuff_documents_chain(llm, prompt)
retrieval_chain = create_retrieval_chain(retriever, document_chain)
print("βœ… Retrieval chain initialized successfully.")
return retrieval_chain
# ==========================================================
# Flask App Setup
# ==========================================================
app = Flask(__name__)
CORS(app)
retrieval_chain = None # Lazy-load later
@app.before_request
def init_retrieval():
"""Initialize retrieval chain after Flask starts (prevents Gunicorn crash)."""
global retrieval_chain
if retrieval_chain is None:
try:
retrieval_chain = load_retrieval_chain()
except Exception as e:
print(f"❌ Failed to initialize retrieval chain: {e}")
retrieval_chain = None
# ==========================================================
# Routes
# ==========================================================
@app.route("/")
def index():
"""Serve main web page."""
return render_template("index.html")
@app.route("/chat", methods=["POST"])
def chat():
"""Main chat endpoint."""
global retrieval_chain
if retrieval_chain is None:
return jsonify({"error": "Vector database not initialized. Try again in a few seconds."}), 500
try:
user_input = request.json.get("message")
app.logger.info(f"Received user input: {user_input}")
data = request.json
user_query = data.get("query")
if not user_query:
return jsonify({"error": "No query provided"}), 400
print(f"πŸ’¬ Received query: {user_query}")
start = time.process_time()
# Run retrieval chain
response = retrieval_chain.invoke({'input': user_query})
elapsed = time.process_time() - start
print(f"⏱️ Response time: {elapsed:.3f} sec")
# Parse LLM JSON
try:
llm_output_str = response['answer']
parsed = json.loads(llm_output_str)
parsed["context"] = [doc.page_content for doc in response['context']]
return jsonify(parsed)
except json.JSONDecodeError:
print(f"⚠️ Invalid JSON from LLM: {response.get('answer', '')}")
return jsonify({"intent": "qa", "response": "I'm sorry, I had a small glitch. Could you rephrase that?"})
except Exception as e:
print(f"❌ Error during chat request: {e}")
return jsonify({"error": str(e)}), 500
# ==========================================================
# App Runner (for local debugging)
# ==========================================================
if __name__ == "__main__":
print("πŸš€ Starting Flask development server...")
app.run(host="0.0.0.0", port=7860, debug=True)