AseemD commited on
Commit
25f5525
·
verified ·
1 Parent(s): 4e9c666

Upload 6 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ chroma_langchain_db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
app_utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import re
3
+ import uuid
4
+ import base64
5
+ import chromadb
6
+ import gradio as gr
7
+ import numpy as np
8
+ from PIL import Image
9
+ from io import BytesIO
10
+ from operator import itemgetter
11
+ from IPython.display import HTML, display
12
+
13
+ from langchain.vectorstores import Chroma
14
+ from langchain.storage import InMemoryStore
15
+ from langchain.schema.document import Document
16
+ from langchain.embeddings import OpenAIEmbeddings
17
+ from langchain.retrievers.multi_vector import MultiVectorRetriever
18
+
19
+ from langchain_openai import ChatOpenAI
20
+ from langchain_core.messages import HumanMessage
21
+ from langchain_core.output_parsers import StrOutputParser
22
+ from langchain_core.runnables import RunnableLambda, RunnablePassthrough
23
+
24
+ # Load the vector store and retriever
25
+ vectorstore = Chroma(collection_name="multi_modal_rag",
26
+ embedding_function=OpenAIEmbeddings(),
27
+ persist_directory="chroma_langchain_db")
28
+
29
+ id_key = "doc_id"
30
+ store = InMemoryStore()
31
+ retriever = MultiVectorRetriever(
32
+ vectorstore=vectorstore,
33
+ docstore=store,
34
+ id_key=id_key,
35
+ )
36
+ retriever = vectorstore.as_retriever()
37
+
38
+ def plt_img_base64(img_base64):
39
+ """Disply base64 encoded string as image"""
40
+ # Create an HTML img tag with the base64 string as the source
41
+ image_html = f'<img src="data:image/jpeg;base64,{img_base64}" />'
42
+ # Display the image by rendering the HTML
43
+ display(HTML(image_html))
44
+
45
+
46
+ def looks_like_base64(sb):
47
+ """Check if the string looks like base64"""
48
+ return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None
49
+
50
+
51
+ def is_image_data(b64data):
52
+ """
53
+ Check if the base64 data is an image by looking at the start of the data
54
+ """
55
+ image_signatures = {
56
+ b"\xff\xd8\xff": "jpg",
57
+ b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a": "png",
58
+ b"\x47\x49\x46\x38": "gif",
59
+ b"\x52\x49\x46\x46": "webp",
60
+ }
61
+ try:
62
+ header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes
63
+ for sig, format in image_signatures.items():
64
+ if header.startswith(sig):
65
+ return True
66
+ return False
67
+ except Exception:
68
+ return False
69
+
70
+
71
+ def resize_base64_image(base64_string, size=(128, 128)):
72
+ """
73
+ Resize an image encoded as a Base64 string
74
+ """
75
+ # Decode the Base64 string
76
+ img_data = base64.b64decode(base64_string)
77
+ img = Image.open(io.BytesIO(img_data))
78
+
79
+ # Resize the image
80
+ resized_img = img.resize(size, Image.LANCZOS)
81
+
82
+ # Save the resized image to a bytes buffer
83
+ buffered = io.BytesIO()
84
+ resized_img.save(buffered, format=img.format)
85
+
86
+ # Encode the resized image to Base64
87
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
88
+
89
+
90
+ def split_image_text_types(docs):
91
+ """
92
+ Split base64-encoded images and texts
93
+ """
94
+ b64_images = []
95
+ texts = []
96
+ for doc in docs:
97
+ # Check if the document is of type Document and extract page_content if so
98
+ if isinstance(doc, Document):
99
+ doc = doc.page_content
100
+ if looks_like_base64(doc) and is_image_data(doc):
101
+ doc = resize_base64_image(doc, size=(1300, 600))
102
+ b64_images.append(doc)
103
+ else:
104
+ texts.append(doc)
105
+ return {"images": b64_images, "texts": texts}
106
+
107
+
108
+ def img_prompt_func(data_dict):
109
+ """
110
+ Join the context into a single string
111
+ """
112
+ formatted_texts = "\n".join(data_dict["context"]["texts"])
113
+ messages = []
114
+
115
+ # Adding image(s) to the messages if present
116
+ if data_dict["context"]["images"]:
117
+ for image in data_dict["context"]["images"]:
118
+ image_message = {
119
+ "type": "image_url",
120
+ "image_url": {"url": f"data:image/jpeg;base64,{image}"},
121
+ }
122
+ messages.append(image_message)
123
+
124
+ # Adding the text for analysis
125
+ text_message = {
126
+ "type": "text",
127
+ "text": (
128
+ "Answer the question based on the following context, which can include text, tables, and images."
129
+ f"User-provided question: {data_dict['question']}\n\n"
130
+ "Text and / or tables:\n"
131
+ f"{formatted_texts}"
132
+ ),
133
+ }
134
+ messages.append(text_message)
135
+ return [HumanMessage(content=messages)]
136
+
137
+
138
+ def multi_modal_rag_chain(retriever):
139
+ """
140
+ Multi-modal RAG chain
141
+ """
142
+
143
+ # Multi-modal LLM
144
+ model = ChatOpenAI(temperature=0,
145
+ model="gpt-4o-mini",
146
+ max_tokens=1024,
147
+ streaming=True)
148
+
149
+ # RAG pipeline
150
+ chain = (
151
+ {
152
+ "context": retriever | RunnableLambda(split_image_text_types),
153
+ "question": RunnablePassthrough(),
154
+ }
155
+ | RunnableLambda(img_prompt_func)
156
+ | model
157
+ | StrOutputParser()
158
+ )
159
+
160
+ return chain
chroma_langchain_db/0029ac48-b5c5-4755-9b42-a61f8e18b33c/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f18abd8c514282db82706e52b0a33ed659cd534e925a6f149deb7af9ce34bd8e
3
+ size 6284000
chroma_langchain_db/0029ac48-b5c5-4755-9b42-a61f8e18b33c/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:effaa959ce2b30070fdafc2fe82096fc46e4ee7561b75920dd3ce43d09679b21
3
+ size 100
chroma_langchain_db/0029ac48-b5c5-4755-9b42-a61f8e18b33c/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6177c4c9be35ad9060ac687da99280169339d5e56adc6a71f735308995b9f0bf
3
+ size 4000
chroma_langchain_db/0029ac48-b5c5-4755-9b42-a61f8e18b33c/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
3
+ size 0
chroma_langchain_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53475ae3bf140921f7a46a979f5411b5a2d537ebf2e564c5f61c015e3b45f92e
3
+ size 2895872