File size: 2,802 Bytes
c14624d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d05fec7
 
c14624d
 
 
 
 
3e65911
c14624d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import pickle
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.chat_message_histories import SQLChatMessageHistory
from langchain_groq import ChatGroq
from langchain_core.runnables import RunnablePassthrough
from langchain.schema.output_parser import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory
from operator import itemgetter

# Load the vector store
with open("qdrant_vectorstore.pkl", "rb") as f:
    qdrant_vectorstore = pickle.load(f)

# Updated function definition
def echo_user_input(*args):
    user_input = args[0]  # Extract user_input from args
    
    # Set up retriever
    qdrant_retriever = qdrant_vectorstore.as_retriever()
    found_docs = qdrant_vectorstore.similarity_search(user_input)

    context_str = ""
    for context_data in found_docs:
        context_str += context_data.page_content + '\n\n'
    
    # Define prompt template
    prompt = ChatPromptTemplate.from_messages([
        ("system", "Act as a helpful AI Assistant. Here is some {context}"),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{human_input}")
    ])

    # Set up session history
    def get_session_history(session_id):
        return SQLChatMessageHistory(session_id, "sqlite:///memory.db")

    # Initialize the LLM with Groq
    groq_api_key = "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"
    llm = ChatGroq(groq_api_key=groq_api_key, model_name="Gemma2-9b-It")

    # Chain context with retriever
    context = itemgetter("human_input") | qdrant_retriever
    first_step = RunnablePassthrough.assign(context=context)
    llm_chain = first_step | prompt | llm | StrOutputParser()

    conv_chain = RunnableWithMessageHistory(llm_chain, get_session_history, input_messages_key="human_input", history_messages_key="history")

    # Define a session ID for the conversation
    session_id = 'bond007'
    # return conv_chain.invoke(({"human_input": user_input}), {'configurable': {'session_id': session_id}})

    llm_response = conv_chain.invoke(({"human_input": user_input}), {'configurable': {'session_id': session_id}})

    # Combine context with the LLM response for Gradio output
    # combined_output = f"**Retrieved Context:**\n{context_str}\n\n**Response:**\n{llm_response}"
    combined_output = f"**Summarized Response:**\n{llm_response}**Retrieved Context:**\n{context_str}\n\n"
    return combined_output

# Define the Gradio chat interface
interface = gr.ChatInterface(
    fn=echo_user_input,
    title="MCG Demo",
    description="Type your question and press enter to see a conversational response. 🤖",
)

# Launch the interface with share=True for a public link
if __name__ == "__main__":
    interface.launch(share=True)