File size: 3,952 Bytes
503a7f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from flask import Flask, render_template, request
from dotenv import load_dotenv
import os
from src.helper import download_embeddings
from src.prompt import *
from langchain_pinecone import PineconeVectorStore
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_classic.chains import create_retrieval_chain, create_history_aware_retriever
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory
# -----------------------------------------------------------------------------
# Flask + Environment Setup
# -----------------------------------------------------------------------------
app = Flask(__name__)
load_dotenv() # Load from .env if present
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY", "")
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")
os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
# -----------------------------------------------------------------------------
# RAG Pipeline Setup
# -----------------------------------------------------------------------------
embeddings = download_embeddings()
index_name = "virtual-doc"
docsearch = PineconeVectorStore.from_existing_index(
index_name=index_name,
embedding=embeddings,
)
retriever = docsearch.as_retriever(
search_type="mmr",
search_kwargs={"k": 4, "fetch_k": 20}
)
chatModel = ChatGroq(model="llama-3.3-70b-versatile")
# Prompt to rewrite follow-up questions into standalone queries
contextualize_q_prompt = ChatPromptTemplate.from_messages([
("system", "Rewrite the user's follow-up question using chat history. "
"Only rewrite — do not answer."),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
])
history_aware_retriever = create_history_aware_retriever(
llm=chatModel,
retriever=retriever,
prompt=contextualize_q_prompt
)
answer_prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
MessagesPlaceholder("chat_history"),
("human", "{input}"),
])
qa_chain = create_stuff_documents_chain(chatModel, answer_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, qa_chain)
# -----------------------------------------------------------------------------
# Session History
# -----------------------------------------------------------------------------
SESSION_STORE = {}
def get_session_history(session_id: str) -> ChatMessageHistory:
if session_id not in SESSION_STORE:
SESSION_STORE[session_id] = ChatMessageHistory()
return SESSION_STORE[session_id]
chain_with_history = RunnableWithMessageHistory(
rag_chain,
get_session_history,
input_messages_key="input",
history_messages_key="chat_history",
output_messages_key="answer",
)
# -----------------------------------------------------------------------------
# Routes
# -----------------------------------------------------------------------------
@app.route("/")
def index():
return render_template("chat.html")
@app.route("/get", methods=["GET", "POST"])
def chat():
msg = request.form.get("msg", "")
session_id = request.remote_addr or "anon"
out = chain_with_history.invoke(
{"input": msg},
config={"configurable": {"session_id": session_id}}
)
return str(out["answer"])
# -----------------------------------------------------------------------------
# Run the Flask app (PORT 7860 for Hugging Face)
# -----------------------------------------------------------------------------
if __name__ == "__main__":
port = int(os.getenv("PORT", 7860)) # Hugging Face requires port 7860
app.run(host="0.0.0.0", port=port)
|