File size: 3,952 Bytes
503a7f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from flask import Flask, render_template, request
from dotenv import load_dotenv
import os

from src.helper import download_embeddings
from src.prompt import *  
from langchain_pinecone import PineconeVectorStore
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_classic.chains import create_retrieval_chain, create_history_aware_retriever
from langchain_classic.chains.combine_documents import create_stuff_documents_chain
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.runnables.history import RunnableWithMessageHistory


# -----------------------------------------------------------------------------
# Flask + Environment Setup
# -----------------------------------------------------------------------------
app = Flask(__name__)
load_dotenv()  # Load from .env if present

PINECONE_API_KEY = os.getenv("PINECONE_API_KEY", "")
GROQ_API_KEY = os.getenv("GROQ_API_KEY", "")

os.environ["PINECONE_API_KEY"] = PINECONE_API_KEY
os.environ["GROQ_API_KEY"] = GROQ_API_KEY


# -----------------------------------------------------------------------------
# RAG Pipeline Setup
# -----------------------------------------------------------------------------
embeddings = download_embeddings()

index_name = "virtual-doc"
docsearch = PineconeVectorStore.from_existing_index(
    index_name=index_name,
    embedding=embeddings,
)

retriever = docsearch.as_retriever(
    search_type="mmr",
    search_kwargs={"k": 4, "fetch_k": 20}
)

chatModel = ChatGroq(model="llama-3.3-70b-versatile")

# Prompt to rewrite follow-up questions into standalone queries
contextualize_q_prompt = ChatPromptTemplate.from_messages([
    ("system", "Rewrite the user's follow-up question using chat history. "
               "Only rewrite — do not answer."),
    MessagesPlaceholder("chat_history"),
    ("human", "{input}"),
])

history_aware_retriever = create_history_aware_retriever(
    llm=chatModel,
    retriever=retriever,
    prompt=contextualize_q_prompt
)

answer_prompt = ChatPromptTemplate.from_messages([
    ("system", system_prompt),
    MessagesPlaceholder("chat_history"),
    ("human", "{input}"),
])

qa_chain = create_stuff_documents_chain(chatModel, answer_prompt)
rag_chain = create_retrieval_chain(history_aware_retriever, qa_chain)


# -----------------------------------------------------------------------------
# Session History
# -----------------------------------------------------------------------------
SESSION_STORE = {}

def get_session_history(session_id: str) -> ChatMessageHistory:
    if session_id not in SESSION_STORE:
        SESSION_STORE[session_id] = ChatMessageHistory()
    return SESSION_STORE[session_id]

chain_with_history = RunnableWithMessageHistory(
    rag_chain,
    get_session_history,
    input_messages_key="input",
    history_messages_key="chat_history",
    output_messages_key="answer",
)


# -----------------------------------------------------------------------------
# Routes
# -----------------------------------------------------------------------------
@app.route("/")
def index():
    return render_template("chat.html")

@app.route("/get", methods=["GET", "POST"])
def chat():
    msg = request.form.get("msg", "")
    session_id = request.remote_addr or "anon"

    out = chain_with_history.invoke(
    {"input": msg},
    config={"configurable": {"session_id": session_id}}
    )
    return str(out["answer"])


# -----------------------------------------------------------------------------
# Run the Flask app (PORT 7860 for Hugging Face)
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    port = int(os.getenv("PORT", 7860))  # Hugging Face requires port 7860
    app.run(host="0.0.0.0", port=port)