File size: 1,839 Bytes
26b9ebe
39e1edd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26b9ebe
39e1edd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26b9ebe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56

import gradio as gr
from langchain.vectorstores import Chroma
from langchain.storage import InMemoryStore
from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever

from app_utils import multi_modal_rag_chain

# Load the vector store and retriever
vectorstore = Chroma(collection_name="multi_modal_rag",
                     embedding_function=OpenAIEmbeddings(),
                     persist_directory="chroma_langchain_db")

id_key = "doc_id"
store = InMemoryStore()
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
)
retriever = vectorstore.as_retriever()
chain_multimodal_rag = multi_modal_rag_chain(retriever)

def generate_response(message, history):
    """
    This function will be called for each new user message.
    We run the chain for the *latest user message only*.
    Then return the chain response as a string.
    """
    # Run the chain using the user message
    response_chunks = chain_multimodal_rag.invoke(message)

    # If the chain is streaming, it might return chunks.
    # We'll collect them into one final string for simplicity.
    if hasattr(response_chunks, "__iter__"):
        # It's a generator or list
        response_text = "".join(response_chunks)
    else:
        response_text = response_chunks

    # Return the final text
    return response_text

with gr.ChatInterface(
    fn=generate_response,
    title="Multi-modal RAG Chatbot",
    description="Ask a question about the LongNet paper.",
     examples=[
        {"text": "What is Dilated attention?"},
        {"text": "How is Dilated attention better than vanilla attention?"},
        {"text": "What is the difference between the computational cost of Dilated and Vanilla Attention?"}
     ],
) as demo:
    demo.launch()