MultiModalRAG / app_utils.py
AseemD's picture
Upload 6 files
25f5525 verified
import io
import re
import uuid
import base64
import chromadb
import gradio as gr
import numpy as np
from PIL import Image
from io import BytesIO
from operator import itemgetter
from IPython.display import HTML, display
from langchain.vectorstores import Chroma
from langchain.storage import InMemoryStore
from langchain.schema.document import Document
from langchain.embeddings import OpenAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
# Load the vector store and retriever
vectorstore = Chroma(collection_name="multi_modal_rag",
embedding_function=OpenAIEmbeddings(),
persist_directory="chroma_langchain_db")
id_key = "doc_id"
store = InMemoryStore()
retriever = MultiVectorRetriever(
vectorstore=vectorstore,
docstore=store,
id_key=id_key,
)
retriever = vectorstore.as_retriever()
def plt_img_base64(img_base64):
"""Disply base64 encoded string as image"""
# Create an HTML img tag with the base64 string as the source
image_html = f'<img src="data:image/jpeg;base64,{img_base64}" />'
# Display the image by rendering the HTML
display(HTML(image_html))
def looks_like_base64(sb):
"""Check if the string looks like base64"""
return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None
def is_image_data(b64data):
"""
Check if the base64 data is an image by looking at the start of the data
"""
image_signatures = {
b"\xff\xd8\xff": "jpg",
b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a": "png",
b"\x47\x49\x46\x38": "gif",
b"\x52\x49\x46\x46": "webp",
}
try:
header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes
for sig, format in image_signatures.items():
if header.startswith(sig):
return True
return False
except Exception:
return False
def resize_base64_image(base64_string, size=(128, 128)):
"""
Resize an image encoded as a Base64 string
"""
# Decode the Base64 string
img_data = base64.b64decode(base64_string)
img = Image.open(io.BytesIO(img_data))
# Resize the image
resized_img = img.resize(size, Image.LANCZOS)
# Save the resized image to a bytes buffer
buffered = io.BytesIO()
resized_img.save(buffered, format=img.format)
# Encode the resized image to Base64
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def split_image_text_types(docs):
"""
Split base64-encoded images and texts
"""
b64_images = []
texts = []
for doc in docs:
# Check if the document is of type Document and extract page_content if so
if isinstance(doc, Document):
doc = doc.page_content
if looks_like_base64(doc) and is_image_data(doc):
doc = resize_base64_image(doc, size=(1300, 600))
b64_images.append(doc)
else:
texts.append(doc)
return {"images": b64_images, "texts": texts}
def img_prompt_func(data_dict):
"""
Join the context into a single string
"""
formatted_texts = "\n".join(data_dict["context"]["texts"])
messages = []
# Adding image(s) to the messages if present
if data_dict["context"]["images"]:
for image in data_dict["context"]["images"]:
image_message = {
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image}"},
}
messages.append(image_message)
# Adding the text for analysis
text_message = {
"type": "text",
"text": (
"Answer the question based on the following context, which can include text, tables, and images."
f"User-provided question: {data_dict['question']}\n\n"
"Text and / or tables:\n"
f"{formatted_texts}"
),
}
messages.append(text_message)
return [HumanMessage(content=messages)]
def multi_modal_rag_chain(retriever):
"""
Multi-modal RAG chain
"""
# Multi-modal LLM
model = ChatOpenAI(temperature=0,
model="gpt-4o-mini",
max_tokens=1024,
streaming=True)
# RAG pipeline
chain = (
{
"context": retriever | RunnableLambda(split_image_text_types),
"question": RunnablePassthrough(),
}
| RunnableLambda(img_prompt_func)
| model
| StrOutputParser()
)
return chain