import io import re import uuid import base64 import chromadb import gradio as gr import numpy as np from PIL import Image from io import BytesIO from operator import itemgetter from IPython.display import HTML, display from langchain.vectorstores import Chroma from langchain.storage import InMemoryStore from langchain.schema.document import Document from langchain.embeddings import OpenAIEmbeddings from langchain.retrievers.multi_vector import MultiVectorRetriever from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnableLambda, RunnablePassthrough # Load the vector store and retriever vectorstore = Chroma(collection_name="multi_modal_rag", embedding_function=OpenAIEmbeddings(), persist_directory="chroma_langchain_db") id_key = "doc_id" store = InMemoryStore() retriever = MultiVectorRetriever( vectorstore=vectorstore, docstore=store, id_key=id_key, ) retriever = vectorstore.as_retriever() def plt_img_base64(img_base64): """Disply base64 encoded string as image""" # Create an HTML img tag with the base64 string as the source image_html = f'' # Display the image by rendering the HTML display(HTML(image_html)) def looks_like_base64(sb): """Check if the string looks like base64""" return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None def is_image_data(b64data): """ Check if the base64 data is an image by looking at the start of the data """ image_signatures = { b"\xff\xd8\xff": "jpg", b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a": "png", b"\x47\x49\x46\x38": "gif", b"\x52\x49\x46\x46": "webp", } try: header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes for sig, format in image_signatures.items(): if header.startswith(sig): return True return False except Exception: return False def resize_base64_image(base64_string, size=(128, 128)): """ Resize an image encoded as a Base64 string """ # Decode the Base64 string img_data = base64.b64decode(base64_string) img = Image.open(io.BytesIO(img_data)) # Resize the image resized_img = img.resize(size, Image.LANCZOS) # Save the resized image to a bytes buffer buffered = io.BytesIO() resized_img.save(buffered, format=img.format) # Encode the resized image to Base64 return base64.b64encode(buffered.getvalue()).decode("utf-8") def split_image_text_types(docs): """ Split base64-encoded images and texts """ b64_images = [] texts = [] for doc in docs: # Check if the document is of type Document and extract page_content if so if isinstance(doc, Document): doc = doc.page_content if looks_like_base64(doc) and is_image_data(doc): doc = resize_base64_image(doc, size=(1300, 600)) b64_images.append(doc) else: texts.append(doc) return {"images": b64_images, "texts": texts} def img_prompt_func(data_dict): """ Join the context into a single string """ formatted_texts = "\n".join(data_dict["context"]["texts"]) messages = [] # Adding image(s) to the messages if present if data_dict["context"]["images"]: for image in data_dict["context"]["images"]: image_message = { "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image}"}, } messages.append(image_message) # Adding the text for analysis text_message = { "type": "text", "text": ( "Answer the question based on the following context, which can include text, tables, and images." f"User-provided question: {data_dict['question']}\n\n" "Text and / or tables:\n" f"{formatted_texts}" ), } messages.append(text_message) return [HumanMessage(content=messages)] def multi_modal_rag_chain(retriever): """ Multi-modal RAG chain """ # Multi-modal LLM model = ChatOpenAI(temperature=0, model="gpt-4o-mini", max_tokens=1024, streaming=True) # RAG pipeline chain = ( { "context": retriever | RunnableLambda(split_image_text_types), "question": RunnablePassthrough(), } | RunnableLambda(img_prompt_func) | model | StrOutputParser() ) return chain