Spaces:
Sleeping
Sleeping
File size: 1,839 Bytes
26b9ebe 39e1edd 26b9ebe 39e1edd 26b9ebe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import gradio as gr
from langchain.vectorstores import Chroma
from langchain.storage import InMemoryStore
from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from app_utils import multi_modal_rag_chain
# Load the vector store and retriever
vectorstore = Chroma(collection_name="multi_modal_rag",
embedding_function=OpenAIEmbeddings(),
persist_directory="chroma_langchain_db")
id_key = "doc_id"
store = InMemoryStore()
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=store,
id_key=id_key,
)
retriever = vectorstore.as_retriever()
chain_multimodal_rag = multi_modal_rag_chain(retriever)
def generate_response(message, history):
"""
This function will be called for each new user message.
We run the chain for the *latest user message only*.
Then return the chain response as a string.
"""
# Run the chain using the user message
response_chunks = chain_multimodal_rag.invoke(message)
# If the chain is streaming, it might return chunks.
# We'll collect them into one final string for simplicity.
if hasattr(response_chunks, "__iter__"):
# It's a generator or list
response_text = "".join(response_chunks)
else:
response_text = response_chunks
# Return the final text
return response_text
with gr.ChatInterface(
fn=generate_response,
title="Multi-modal RAG Chatbot",
description="Ask a question about the LongNet paper.",
examples=[
{"text": "What is Dilated attention?"},
{"text": "How is Dilated attention better than vanilla attention?"},
{"text": "What is the difference between the computational cost of Dilated and Vanilla Attention?"}
],
) as demo:
demo.launch()
|