Huy commited on
Commit
d8bb2be
·
1 Parent(s): c20735c

First commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.json filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ __pycache__/
2
+ .ipynb_checkpoints/
3
+ env/
4
+ .DS_Store
README.md CHANGED
@@ -1,13 +0,0 @@
1
- ---
2
- title: RAG ColPali
3
- emoji: 📊
4
- colorFrom: yellow
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.4.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import base64
4
+ import asyncio
5
+ from io import BytesIO
6
+ import gradio as gr
7
+ import qdrant_client
8
+ from PIL import Image
9
+ from typing import List, Dict, Tuple
10
+
11
+ import llamaindex_utils
12
+ from rag_pipeline import async_indexDocument
13
+ from models import get_lora_model, enable_lora, ColPali, ColPaliProcessor
14
+ from utils import load_tokenizer
15
+
16
+ from llama_index.llms.gemini import Gemini
17
+ from llama_index.core.tools import RetrieverTool
18
+
19
+
20
+ GEMINI_API_KEY = os.getenv(key="GEMINI_API_KEY")
21
+ QDRANT_API_KEY = os.getenv(key="QDRANT_API_KEY")
22
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
23
+
24
+ async def initialize_model() -> Dict:
25
+ """Initialize models
26
+
27
+ Returns:
28
+ model_dict: Dict: Dictionary stores neccessary models
29
+ """
30
+
31
+ model = ColPali.from_pretrained(model_dir='./pretrained/colpaligemma-3b-mix-448-base', torch_dtype=torch.bfloat16)
32
+ tokenizer = load_tokenizer(tokenizer_dir='./pretrained/colpaligemma-3b-mix-448-base')
33
+ processor = ColPaliProcessor(tokenizer=tokenizer).from_pretrained(pretrained_dir='./pretrained/colpaligemma-3b-mix-448-base')
34
+
35
+ model.model.language_model.model = get_lora_model(model.model.language_model.model,
36
+ rank=32,
37
+ alphas=32,
38
+ lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'],
39
+ training=False,
40
+ dropout_p=0.1,
41
+ torch_dtype=torch.bfloat16)
42
+ model.model.language_model.model = enable_lora(model.model.language_model.model, lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'], enabled=True)
43
+
44
+ model = get_lora_model(model,
45
+ rank=32,
46
+ alphas=32,
47
+ lora_modules=['custom_text_proj'],
48
+ training=False,
49
+ dropout_p=0.1,
50
+ torch_dtype=torch.bfloat16)
51
+
52
+ model = enable_lora(model, lora_modules=['custom_text_proj'], enabled=True)
53
+
54
+ model.load_lora('./pretrained/colpaligemma-3b-mix-448-base')
55
+
56
+ # Initialize LLM
57
+ generation_config = {
58
+ "temperature": 0.0,
59
+ "top_p": 0.95,
60
+ "top_k": 64,
61
+ "max_output_tokens": 1024,
62
+ "response_mime_type": "text/plain",
63
+ }
64
+
65
+ llm = Gemini(api_key=GEMINI_API_KEY, generation_config=generation_config)
66
+
67
+ # Setup Qdrant
68
+ # Creating Qdrant Client
69
+ vector_store_client = qdrant_client.AsyncQdrantClient(location="https://b3878645-ec71-426c-8afa-b8b3b7589e40.us-east4-0.gcp.cloud.qdrant.io",
70
+ api_key=QDRANT_API_KEY,
71
+ timeout=100)
72
+
73
+ embed_model = llamaindex_utils.ColPaliGemmaEmbedding(model=model,
74
+ processor=processor,
75
+ device=device)
76
+
77
+ collections = await get_collection_names(vector_store_client)
78
+ retrievers_dict = {}
79
+ for name in collections:
80
+ if name not in retrievers_dict:
81
+ retrievers_dict[name] = llamaindex_utils.ColPaliRetriever(vector_store_client=vector_store_client,
82
+ target_collection=name,
83
+ embed_model=embed_model,
84
+ similarity_top_k=3)
85
+ return {"llm": llm,
86
+ "vector_store_client": vector_store_client,
87
+ "model": model,
88
+ "processor": processor,
89
+ "embed_model": embed_model,
90
+ "collections": collections,
91
+ "retrievers_dict": retrievers_dict}
92
+
93
+ async def get_collection_names(vector_store_client):
94
+ collections = await vector_store_client.get_collections()
95
+ return [collection.name for collection in collections.collections]
96
+
97
+ async def index(files: List[str],
98
+ target_collection: str
99
+ ) -> Tuple[str, gr.Dropdown, List[str], Dict[str, llamaindex_utils.ColPaliRetriever]]:
100
+ """
101
+ Insert all image pages from files to speicified target collection to the vector store
102
+ and return the mapping from retriever's name to its object instance.
103
+
104
+ Args:
105
+ files (List[str]): List of file path
106
+ target_collection (str): Target collection to insert into the vector store
107
+
108
+ Returns:
109
+ Tuple[str, gr.Dropdown, List[str], Dict[str, llamaindex_utils.ColPaliRetriever]]: Return message, dropdown component, collections' names, dictionary mapping retriever to its object instance
110
+ """
111
+
112
+ for file in files:
113
+ await async_indexDocument(file_path=file,
114
+ vector_store_client=model_dict["vector_store_client"],
115
+ target_collection=target_collection,
116
+ model=model_dict["model"],
117
+ processor=model_dict["processor"],
118
+ device=device)
119
+
120
+ if target_collection not in retrievers:
121
+ retrievers[target_collection] = llamaindex_utils.ColPaliRetriever(vector_store_client=model_dict["vector_store_client"],
122
+ target_collection=target_collection,
123
+ embed_model=model_dict["embed_model"],
124
+ similarity_top_k=3)
125
+ collection_names = await get_collection_names(model_dict["vector_store_client"])
126
+ return (f"Uploaded and index {len(files)} files.",
127
+ gr.Dropdown(choices=collection_names),
128
+ collection_names)
129
+
130
+ async def search_with_llm(query: str,
131
+ similarity_top_k: int,
132
+ num_children: int) -> Tuple[str, List[Image.Image]]:
133
+ """Search the result given query and list of retrievers.
134
+ Returns the search's response and list of images support for that response.
135
+
136
+ Args:
137
+ query (str): Query question
138
+ retrievers (Dict[str, llamaindex_utils.ColPaliRetriever]): Dictionary mapping between retrievers' names and their object instances
139
+ similarity_top_k (int): top K similarity results retrieved from the retriever
140
+ num_children (int): number of children for tree summarization
141
+
142
+ Returns:
143
+ Tuple[str, List[Image.Image]]: Returns the search's response and list of images support for that response.
144
+ """
145
+ retriever_tools = [RetrieverTool.from_defaults(
146
+ name=key,
147
+ retriever=value,
148
+ description=f"Useful for retrieving information about {key} financials") for key, value in retrievers.items()]
149
+
150
+ retriever_mappings = {retriever_tool.metadata.name: retriever_tool.retriever for retriever_tool in retriever_tools}
151
+
152
+ fusion_retriever = llamaindex_utils.CustomFusionRetriever(llm=model_dict["llm"],
153
+ retriever_mappings=retriever_mappings,
154
+ similarity_top_k=similarity_top_k)
155
+
156
+ query_engine = llamaindex_utils.CustomQueryEngine(retriever_tools=[retriever_tool.metadata for retriever_tool in retriever_tools],
157
+ fusion_retriever=fusion_retriever,
158
+ llm=model_dict["llm"],
159
+ num_children=num_children)
160
+ response = await query_engine.aquery(query_str=query)
161
+
162
+ return response.response, [Image.open(BytesIO(base64.b64decode(image))) for image in response.source_images]
163
+
164
+
165
+ def build_gui():
166
+ with gr.Blocks() as demo:
167
+ gr.Markdown("# Image Based RAG System using ColPali 📚🔍")
168
+ with gr.Row(equal_height=True):
169
+ with gr.Column():
170
+ gr.Markdown("## 1️. Upload PDFs")
171
+ files = gr.File(file_types=["pdf"],
172
+ file_count="multiple",
173
+ interactive=True)
174
+
175
+ choices = gr.State(value=model_dict["collections"])
176
+ gr.Markdown("## 2️. Index the PDFs and upload")
177
+ target_collection = gr.Dropdown(choices=choices.value,
178
+ allow_custom_value=True,
179
+ label="Collection name",
180
+ show_label=True,
181
+ interactive=True)
182
+
183
+ message_box = gr.Textbox(value="File not yet uploaded",
184
+ show_label=False,
185
+ interactive=False)
186
+ convert_button = gr.Button("🔄 Convert and upload")
187
+
188
+ # Define the actions for conversion
189
+ convert_button.click(index, inputs=[files, target_collection], outputs=[message_box, target_collection, choices])
190
+
191
+ with gr.Column():
192
+ gr.Markdown("## 3️. Enter your question")
193
+ query = gr.Textbox(placeholder="Enter your query to match",
194
+ lines=15,
195
+ max_lines=20,
196
+ autoscroll=True)
197
+ with gr.Accordion(label="Additional Settings", open=False):
198
+ similarity_top_k = gr.Slider(minimum=1,
199
+ maximum=10,
200
+ value=3,
201
+ step=1.0,
202
+ label="Top K similarity retrieved from the retriever")
203
+
204
+ num_children = gr.Slider(minimum=1,
205
+ maximum=10,
206
+ value=3,
207
+ step=1.0,
208
+ label="Set number of children for Tree Summarization")
209
+ search_button = gr.Button("🔍 Search")
210
+
211
+ gr.Markdown("## 4️. ColPali Retrieval")
212
+ with gr.Row(equal_height=True):
213
+ output_text = gr.Textbox(label="Query result",
214
+ show_label=True,
215
+ placeholder="Response from query",
216
+ lines=8,
217
+ max_lines=20,
218
+ interactive=False)
219
+ output_imgs = gr.Gallery(label="Most relevant images is...",
220
+ show_fullscreen_button=True,
221
+ show_label=True,
222
+ show_download_button=True,
223
+ interactive=False)
224
+
225
+
226
+ # Action for search button
227
+ search_button.click(
228
+ search_with_llm,
229
+ inputs=[query, similarity_top_k, num_children],
230
+ outputs=[output_text, output_imgs])
231
+ return demo
232
+
233
+ async def amain():
234
+ global model_dict, retrievers
235
+ model_dict = await initialize_model()
236
+ retrievers = model_dict["retrievers_dict"]
237
+
238
+ demo = build_gui()
239
+ demo.queue().launch(debug=True, share=False)
240
+
241
+
242
+ if __name__ == "__main__":
243
+ asyncio.run(amain())
244
+
245
+
env.yaml ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - bzip2=1.0.8=h80987f9_6
6
+ - ca-certificates=2024.7.2=hca03da5_0
7
+ - libffi=3.4.4=hca03da5_1
8
+ - ncurses=6.4=h313beb8_0
9
+ - openssl=3.0.15=h80987f9_0
10
+ - pip=24.2=py311hca03da5_0
11
+ - python=3.11.9=hb885b13_0
12
+ - readline=8.2=h1a28f6b_0
13
+ - setuptools=75.1.0=py311hca03da5_0
14
+ - sqlite=3.45.3=h80987f9_0
15
+ - tk=8.6.14=h6ba3021_0
16
+ - wheel=0.44.0=py311hca03da5_0
17
+ - xz=5.4.6=h80987f9_1
18
+ - zlib=1.2.13=h18a0788_1
19
+ - pip:
20
+ - accelerate==1.1.0
21
+ - aiofiles==23.2.1
22
+ - aiohappyeyeballs==2.4.3
23
+ - aiohttp==3.10.10
24
+ - aiosignal==1.3.1
25
+ - annotated-types==0.7.0
26
+ - anyio==4.6.2.post1
27
+ - appnope==0.1.4
28
+ - argon2-cffi==23.1.0
29
+ - argon2-cffi-bindings==21.2.0
30
+ - arrow==1.3.0
31
+ - asttokens==2.4.1
32
+ - async-lru==2.0.4
33
+ - attrs==24.2.0
34
+ - babel==2.16.0
35
+ - beautifulsoup4==4.12.3
36
+ - bleach==6.2.0
37
+ - cachetools==5.5.0
38
+ - certifi==2024.8.30
39
+ - cffi==1.17.1
40
+ - charset-normalizer==3.4.0
41
+ - click==8.1.7
42
+ - comm==0.2.2
43
+ - contourpy==1.3.0
44
+ - cycler==0.12.1
45
+ - dataclasses-json==0.6.7
46
+ - datasets==3.0.1
47
+ - debugpy==1.8.7
48
+ - decorator==5.1.1
49
+ - defusedxml==0.7.1
50
+ - deprecated==1.2.14
51
+ - dill==0.3.8
52
+ - dirtyjson==1.0.8
53
+ - distro==1.9.0
54
+ - executing==2.1.0
55
+ - fastapi==0.115.4
56
+ - fastjsonschema==2.20.0
57
+ - ffmpy==0.4.0
58
+ - filelock==3.16.1
59
+ - fonttools==4.54.1
60
+ - fqdn==1.5.1
61
+ - frozenlist==1.5.0
62
+ - fsspec==2024.6.1
63
+ - google-ai-generativelanguage==0.6.4
64
+ - google-api-core==2.20.0
65
+ - google-api-python-client==2.147.0
66
+ - google-auth==2.35.0
67
+ - google-auth-httplib2==0.2.0
68
+ - google-generativeai==0.5.4
69
+ - googleapis-common-protos==1.65.0
70
+ - gradio==4.44.1
71
+ - gradio-client==1.3.0
72
+ - greenlet==3.1.1
73
+ - grpcio==1.67.1
74
+ - grpcio-status==1.62.3
75
+ - grpcio-tools==1.62.3
76
+ - h11==0.14.0
77
+ - h2==4.1.0
78
+ - hpack==4.0.0
79
+ - httpcore==1.0.6
80
+ - httplib2==0.22.0
81
+ - httpx==0.27.2
82
+ - huggingface-hub==0.26.2
83
+ - hyperframe==6.0.1
84
+ - idna==3.10
85
+ - importlib-resources==6.4.5
86
+ - instructorembedding==1.0.1
87
+ - ipykernel==6.29.5
88
+ - ipython==8.29.0
89
+ - isoduration==20.11.0
90
+ - jedi==0.19.1
91
+ - jinja2==3.1.4
92
+ - jiter==0.7.0
93
+ - joblib==1.4.2
94
+ - json5==0.9.25
95
+ - jsonpointer==3.0.0
96
+ - jsonschema==4.23.0
97
+ - jsonschema-specifications==2024.10.1
98
+ - jupyter-client==8.6.3
99
+ - jupyter-core==5.7.2
100
+ - jupyter-events==0.10.0
101
+ - jupyter-lsp==2.2.5
102
+ - jupyter-server==2.14.2
103
+ - jupyter-server-terminals==0.5.3
104
+ - jupyterlab==4.2.5
105
+ - jupyterlab-pygments==0.3.0
106
+ - jupyterlab-server==2.27.3
107
+ - kiwisolver==1.4.7
108
+ - llama-cloud==0.1.2
109
+ - llama-index==0.11.17
110
+ - llama-index-agent-openai==0.3.4
111
+ - llama-index-cli==0.3.1
112
+ - llama-index-core==0.11.17
113
+ - llama-index-embeddings-huggingface==0.3.1
114
+ - llama-index-embeddings-instructor==0.2.1
115
+ - llama-index-embeddings-openai==0.2.5
116
+ - llama-index-indices-managed-llama-cloud==0.4.0
117
+ - llama-index-legacy==0.9.48.post3
118
+ - llama-index-llms-gemini==0.3.7
119
+ - llama-index-llms-openai==0.2.13
120
+ - llama-index-multi-modal-llms-gemini==0.3.1
121
+ - llama-index-multi-modal-llms-openai==0.2.2
122
+ - llama-index-postprocessor-colbert-rerank==0.2.1
123
+ - llama-index-program-openai==0.2.0
124
+ - llama-index-question-gen-openai==0.2.0
125
+ - llama-index-readers-file==0.2.2
126
+ - llama-index-readers-llama-parse==0.3.0
127
+ - llama-index-vector-stores-qdrant==0.3.1
128
+ - llama-parse==0.5.7
129
+ - markdown-it-py==3.0.0
130
+ - markupsafe==2.1.5
131
+ - marshmallow==3.23.1
132
+ - matplotlib==3.9.2
133
+ - matplotlib-inline==0.1.7
134
+ - mdurl==0.1.2
135
+ - mistune==3.0.2
136
+ - mpmath==1.3.0
137
+ - multidict==6.1.0
138
+ - multiprocess==0.70.16
139
+ - mypy-extensions==1.0.0
140
+ - nbclient==0.10.0
141
+ - nbconvert==7.16.4
142
+ - nbformat==5.10.4
143
+ - nest-asyncio==1.6.0
144
+ - networkx==3.4.2
145
+ - nltk==3.9.1
146
+ - notebook==7.2.2
147
+ - notebook-shim==0.2.4
148
+ - numpy==1.26.4
149
+ - openai==1.53.0
150
+ - orjson==3.10.11
151
+ - overrides==7.7.0
152
+ - packaging==24.1
153
+ - pandas==2.2.3
154
+ - pandocfilters==1.5.1
155
+ - parso==0.8.4
156
+ - pdf2image==1.17.0
157
+ - peft==0.11.1
158
+ - pexpect==4.9.0
159
+ - pillow==10.4.0
160
+ - platformdirs==4.3.6
161
+ - portalocker==2.10.1
162
+ - prometheus-client==0.21.0
163
+ - prompt-toolkit==3.0.48
164
+ - propcache==0.2.0
165
+ - proto-plus==1.24.0
166
+ - protobuf==4.25.5
167
+ - psutil==6.0.0
168
+ - ptyprocess==0.7.0
169
+ - pure-eval==0.2.3
170
+ - pyarrow==17.0.0
171
+ - pyasn1==0.6.1
172
+ - pyasn1-modules==0.4.1
173
+ - pycparser==2.22
174
+ - pydantic==2.9.2
175
+ - pydantic-core==2.23.4
176
+ - pydub==0.25.1
177
+ - pygments==2.18.0
178
+ - pyparsing==3.1.4
179
+ - pypdf==4.3.1
180
+ - python-dateutil==2.9.0.post0
181
+ - python-json-logger==2.0.7
182
+ - python-multipart==0.0.12
183
+ - pytz==2024.2
184
+ - pyyaml==6.0.2
185
+ - pyzmq==26.2.0
186
+ - qdrant-client==1.12.0
187
+ - referencing==0.35.1
188
+ - regex==2024.9.11
189
+ - requests==2.32.3
190
+ - rfc3339-validator==0.1.4
191
+ - rfc3986-validator==0.1.1
192
+ - rich==13.9.4
193
+ - rpds-py==0.20.1
194
+ - rsa==4.9
195
+ - ruff==0.7.2
196
+ - safetensors==0.4.5
197
+ - scikit-learn==1.5.2
198
+ - scipy==1.14.1
199
+ - semantic-version==2.10.0
200
+ - send2trash==1.8.3
201
+ - sentence-transformers==2.7.0
202
+ - shellingham==1.5.4
203
+ - six==1.16.0
204
+ - sniffio==1.3.1
205
+ - soupsieve==2.6
206
+ - sqlalchemy==2.0.36
207
+ - stack-data==0.6.3
208
+ - starlette==0.41.2
209
+ - striprtf==0.0.26
210
+ - sympy==1.13.3
211
+ - tenacity==8.5.0
212
+ - terminado==0.18.1
213
+ - threadpoolctl==3.5.0
214
+ - tiktoken==0.8.0
215
+ - tinycss2==1.4.0
216
+ - tokenizers==0.20.1
217
+ - tomlkit==0.12.0
218
+ - torch==2.4.1
219
+ - torchinfo==1.8.0
220
+ - torchvision==0.19.1
221
+ - tornado==6.4.1
222
+ - tqdm==4.66.5
223
+ - traitlets==5.14.3
224
+ - transformers==4.45.1
225
+ - typer==0.12.5
226
+ - types-python-dateutil==2.9.0.20241003
227
+ - typing-extensions==4.12.2
228
+ - typing-inspect==0.9.0
229
+ - tzdata==2024.2
230
+ - uri-template==1.3.0
231
+ - uritemplate==4.1.1
232
+ - urllib3==2.2.3
233
+ - uvicorn==0.32.0
234
+ - wcwidth==0.2.13
235
+ - webcolors==24.8.0
236
+ - webencodings==0.5.1
237
+ - websocket-client==1.8.0
238
+ - websockets==12.0
239
+ - wrapt==1.16.0
240
+ - xxhash==3.5.0
241
+ - yarl==1.17.1
llamaindex_utils.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import asyncio
4
+ import qdrant_client
5
+ from PIL import Image
6
+ from pydantic import PrivateAttr, Field
7
+ from typing import Union, Optional, List, Any, Dict, Set
8
+ from dataclasses import dataclass
9
+
10
+ from llama_index.core.vector_stores.types import VectorStoreQueryResult
11
+ from llama_index.core.vector_stores.utils import (
12
+ legacy_metadata_dict_to_node,
13
+ metadata_dict_to_node,
14
+ )
15
+ from llama_index.core.embeddings import BaseEmbedding
16
+ from llama_index.core.retrievers import BaseRetriever
17
+ from llama_index.core import QueryBundle, PromptTemplate
18
+ from llama_index.core.schema import NodeWithScore, TextNode
19
+ from llama_index.core.llms import LLM
20
+ from llama_index.core.question_gen import LLMQuestionGenerator
21
+ from llama_index.core.tools import ToolMetadata
22
+ from llama_index.core.output_parsers.utils import parse_json_markdown
23
+ from llama_index.core.question_gen.types import SubQuestion
24
+
25
+ from models import ColPali, ColPaliProcessor
26
+ from prompt_templates import (DEFAULT_GEN_PROMPT_TMPL,
27
+ DEFAULT_FINAL_ANSWER_PROMPT_TMPL,
28
+ DEFAULT_SUB_QUESTION_PROMPT_TMPL,
29
+ DEFAULT_SYNTHESIZE_PROMPT_TMPL)
30
+ from typing import Any, List, Optional, Tuple, cast
31
+ from qdrant_client.http.models import Payload
32
+
33
+ from collections import defaultdict
34
+
35
+ def parse_to_query_result(response: List[Any]) -> VectorStoreQueryResult:
36
+ """
37
+ Convert vector store response to VectorStoreQueryResult.
38
+
39
+ Args:
40
+ response: List[Any]: List of results returned from the vector store.
41
+ """
42
+ nodes = []
43
+ similarities = []
44
+ ids = []
45
+
46
+ for point in response:
47
+ payload = cast(Payload, point.payload)
48
+ try:
49
+ node = metadata_dict_to_node(payload)
50
+ except Exception:
51
+ metadata, node_info, relationships = legacy_metadata_dict_to_node(
52
+ payload
53
+ )
54
+
55
+ node = TextNode(
56
+ id_=str(point.id),
57
+ text=payload.get("text"),
58
+ metadata=metadata,
59
+ start_char_idx=node_info.get("start", None),
60
+ end_char_idx=node_info.get("end", None),
61
+ relationships=relationships,
62
+ )
63
+ nodes.append(node)
64
+ ids.append(str(point.id))
65
+ try:
66
+ similarities.append(point.score)
67
+ except AttributeError:
68
+ # certain requests do not return a score
69
+ similarities.append(1.0)
70
+
71
+ return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
72
+
73
+
74
+ class ColPaliGemmaEmbedding(BaseEmbedding):
75
+ _model: ColPali = PrivateAttr()
76
+ _processor: ColPaliProcessor = PrivateAttr()
77
+
78
+ device: Union[torch.device | str] = Field(default="cpu",
79
+ description="Device to use")
80
+ def __init__(self,
81
+ model: ColPali,
82
+ processor: ColPaliProcessor,
83
+ device: Optional[str] = 'cpu',
84
+ **kwargs):
85
+ super().__init__(device=device,
86
+ **kwargs)
87
+ self._model = model.to(device).eval()
88
+ self._processor = processor
89
+
90
+ @classmethod
91
+ def class_name(cls) -> str:
92
+ return "ColPaliGemmaEmbedding"
93
+
94
+ def _get_query_embedding(self, query: str) -> List[float]:
95
+ """Get query embedding.
96
+
97
+ Args:
98
+ query (str): Query String
99
+ """
100
+ with torch.no_grad():
101
+ processed_query = self._processor.process_queries([query])
102
+ processed_query = {k: v.to(self.device) for k, v in processed_query.items()}
103
+ query_embeddings = self._model(**processed_query)
104
+ return query_embeddings.to('cpu')[0]
105
+
106
+ def _get_text_embedding(self, text: str) -> List[float]:
107
+ """Get text embedding.
108
+
109
+ Args:
110
+ text (str): Text String
111
+ """
112
+ with torch.no_grad():
113
+ processed_query = self._processor.process_queries([text])
114
+ processed_query = {k: v.to(self.device) for k, v in processed_query.items()}
115
+ query_embeddings = self._model(**processed_query)
116
+ return query_embeddings.to('cpu')[0]
117
+
118
+ def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
119
+ """Get text embeddings.
120
+
121
+ Args:
122
+ texts (List[str]): List of text string
123
+ """
124
+ with torch.no_grad():
125
+ processed_queries = self._processor.process_queries(texts)
126
+ processed_query = {k: v.to(self.device) for k, v in processed_query.items()}
127
+ query_embeddings = self._model(**processed_queries)
128
+ return query_embeddings.to('cpu')
129
+
130
+ async def _aget_query_embedding(self, query: str) -> List[float]:
131
+ return self._get_query_embedding(query)
132
+
133
+ async def _aget_text_embedding(self, text: str) -> List[float]:
134
+ return self._get_text_embedding(text)
135
+
136
+ class ColPaliRetriever(BaseRetriever):
137
+ def __init__(self,
138
+ vector_store_client: Union[qdrant_client.QdrantClient | qdrant_client.AsyncQdrantClient],
139
+ target_collection: str,
140
+ embed_model: ColPaliGemmaEmbedding,
141
+ query_mode: str = 'default',
142
+ similarity_top_k: int = 3,
143
+ ) -> None:
144
+ self._vector_store_client = vector_store_client
145
+ self._target_collection = target_collection
146
+ self._embed_model = embed_model
147
+ self._query_mode = query_mode
148
+ self._similarity_top_k = similarity_top_k
149
+ super().__init__()
150
+
151
+ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
152
+ """Get retrived nodes from the vector store by retriever given query string.
153
+
154
+ Args:
155
+ query_bundle (QueryBundle): QueryBundle class includes query string
156
+
157
+ Returns:
158
+ List[NodeWithScore]: List of retrieved nodes.
159
+ """
160
+ if query_bundle.embedding is None:
161
+ query_embedding = self._embed_model._get_query_embedding(query_bundle.query_str)
162
+ else:
163
+ query_embedding = query_bundle.embedding
164
+
165
+
166
+ query_embedding = query_embedding.cpu().float().numpy().tolist()
167
+
168
+ # Get nodes from vector store
169
+ response = self._vector_store_client.query_points(collection_name=self._target_collection,
170
+ query=query_embedding,
171
+ limit=self._similarity_top_k).points
172
+ # Parse to structured output nodes
173
+ query_result = parse_to_query_result(response)
174
+ nodes_with_scores = []
175
+ for idx, node in enumerate(query_result.nodes):
176
+ score = None
177
+ if query_result.similarities is not None:
178
+ score = query_result.similarities[idx]
179
+ nodes_with_scores.append(NodeWithScore(node=node, score=score))
180
+ return nodes_with_scores
181
+
182
+ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
183
+ """Asynchronously get retrived nodes from the vector store by retriever given query string.
184
+
185
+ Args:
186
+ query_bundle (QueryBundle): QueryBundle class includes query string
187
+
188
+ Returns:
189
+ List[NodeWithScore]: List of retrieved nodes.
190
+ """
191
+ if query_bundle.embedding is None:
192
+ query_embedding = await self._embed_model._aget_query_embedding(query_bundle.query_str)
193
+ else:
194
+ query_embedding = query_bundle.embedding
195
+
196
+ query_embedding = query_embedding.cpu().float().numpy().tolist()
197
+
198
+ # Get nodes from vector store
199
+ responses = await self._vector_store_client.query_points(collection_name=self._target_collection,
200
+ query=query_embedding,
201
+ limit=self._similarity_top_k)
202
+
203
+ responses = responses.points
204
+ # Parse to structured output nodes
205
+ query_result = parse_to_query_result(responses)
206
+ nodes_with_scores = []
207
+ for idx, node in enumerate(query_result.nodes):
208
+ score = None
209
+ if query_result.similarities is not None:
210
+ score = query_result.similarities[idx]
211
+ nodes_with_scores.append(NodeWithScore(node=node, score=score))
212
+ return nodes_with_scores
213
+
214
+
215
+ def fuse_results(retrieved_nodes: List[NodeWithScore], similarity_top_k: int) -> List[NodeWithScore]:
216
+ """Fuse retrieved nodes using Reciprocal Rank
217
+
218
+ Args:
219
+ retrieved_nodes (List[NodeWithScore]): List of nodes.
220
+ similarity_top_k (int): get top K nodes.
221
+
222
+ Returns:
223
+ List[NodeWithScore]: List of nodes after fused
224
+ """
225
+ k = 60.0
226
+ fused_scores = {}
227
+ text_to_node = {}
228
+ for rank, node_with_score in enumerate(sorted(retrieved_nodes, key=lambda x: x.score or 0.0, reverse=True)):
229
+ text = node_with_score.node.get_content(metadata_mode='all')
230
+ text_to_node[text] = node_with_score
231
+ fused_scores[text] = fused_scores.get(text, 0.0) + 1.0 / (rank + k)
232
+
233
+ # Sort results by calculated score
234
+ reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True))
235
+ reranked_nodes: List[NodeWithScore] = []
236
+ for text, score in reranked_results.items():
237
+ reranked_nodes.append(text_to_node[text])
238
+ reranked_nodes[-1].score = score
239
+ return reranked_nodes[:similarity_top_k]
240
+
241
+
242
+ def generate_queries(llm: LLM, query: str, num_queries: int) -> List[str]:
243
+ """Generate num_queries queries
244
+
245
+ Args:
246
+ llm (LLM): LLM model
247
+ query (str): query string
248
+ num_queries (int): Number of queries to generate
249
+
250
+ Returns:
251
+ generate_queries List[str]: List of generated queries
252
+ """
253
+ query_prompt = PromptTemplate(DEFAULT_GEN_PROMPT_TMPL)
254
+ generate_queries = llm.predict(query_prompt,
255
+ num_queries=num_queries,
256
+ query=query)
257
+ generate_queries = generate_queries.split('\n')
258
+ return generate_queries
259
+
260
+ async def agenerate_queries(llm: LLM, query: str, num_queries: int):
261
+ """Asynchronously generate num_queries queries
262
+
263
+ Args:
264
+ llm (LLM): LLM model
265
+ query (str): query string
266
+ num_queries (int): Number of queries to generate
267
+
268
+ Returns:
269
+ generate_queries List[str]: List of generated queries
270
+ """
271
+ query_prompt = PromptTemplate(DEFAULT_GEN_PROMPT_TMPL)
272
+ generate_queries = await llm.apredict(query_prompt,
273
+ num_queries=num_queries,
274
+ query=query)
275
+ generate_queries = generate_queries.split('\n')
276
+ return generate_queries
277
+
278
+
279
+ # Tree Summarization
280
+ def synthesize_results(queries: List[SubQuestion], contexts: Dict[str, Set[str]], llm: LLM, num_children: int) -> Tuple[str, List[str]]:
281
+ """Summarize the results generated from LLM.
282
+
283
+ Args:
284
+ queries (List[SubQuestion]): Generated results
285
+ contexts (Dict[str, Set[str]]): Dictionary maps context information string to its set of source images
286
+ llm (LLM): LLM Model
287
+ num_children (int): Number of children for Tree Summarization
288
+
289
+ Returns:
290
+ Tuple[str, List[str]]: Synthesized text, set of source images.
291
+ """
292
+ qa_prompt = PromptTemplate(DEFAULT_SYNTHESIZE_PROMPT_TMPL)
293
+
294
+ new_contexts = defaultdict(set)
295
+ keys = list(contexts.keys())
296
+ for idx in range(0, len(keys), num_children):
297
+ contexts_batch = keys[idx: idx + num_children]
298
+ context_str = '\n\n'.join([f"{i + 1}. {text}" for i, text in enumerate(contexts_batch)])
299
+
300
+ fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str="\n".join([query.sub_question for query in queries]))
301
+ combined_result = llm.complete(fmt_qa_prompt)
302
+
303
+ # Parse json string to dictionary
304
+ json_dict = parse_json_markdown(str(combined_result))
305
+ if len(json_dict['choices']) > 0:
306
+ for choice in json_dict['choices']:
307
+ new_contexts[json_dict['summarized_text']] = new_contexts[json_dict['summarized_text']].union(contexts[contexts_batch[choice - 1]])
308
+ else:
309
+ new_contexts[json_dict['summarized_text']] = set()
310
+
311
+ if len(new_contexts) == 1:
312
+ synthesized_text = list(new_contexts.keys())[0]
313
+ return synthesized_text, list(new_contexts[synthesized_text])
314
+ else:
315
+ return synthesize_results(queries, new_contexts, llm, num_children=num_children)
316
+
317
+
318
+ async def asynthesize_results(queries: List[SubQuestion], contexts: Dict[str, Set[str]], llm: LLM, num_children: int) -> Union[str, List[str]]:
319
+ """Asynchronously sumamarize the results generated from LLM.
320
+
321
+ Args:
322
+ queries (List[SubQuestion]): Generated results
323
+ contexts (Dict[str, Set[str]]): Dictionary maps context information string to its set of source images
324
+ llm (LLM): LLM Model
325
+ num_children (int): Number of children for Tree Summarization
326
+
327
+ Returns:
328
+ Tuple[str, List[str]]: Synthesized text, set of source images.
329
+ """
330
+ qa_prompt = PromptTemplate(DEFAULT_SYNTHESIZE_PROMPT_TMPL)
331
+ fmt_qa_prompts = []
332
+ keys = list(contexts.keys())
333
+ contexts_batches = []
334
+ for idx in range(0, len(keys), num_children):
335
+ contexts_batch = keys[idx: idx + num_children]
336
+
337
+ context_str = '\n\n'.join([f"{idx + 1}. {text}" for idx, text in enumerate(contexts_batch)])
338
+
339
+ fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str="\n".join([query.sub_question for query in queries]))
340
+ fmt_qa_prompts.append(fmt_qa_prompt)
341
+ contexts_batches.append(contexts_batch)
342
+
343
+ tasks = []
344
+ async with asyncio.TaskGroup() as tg:
345
+ for fmt_qa_prompt in fmt_qa_prompts:
346
+ task = tg.create_task(llm.acomplete(fmt_qa_prompt))
347
+ tasks.append(task)
348
+
349
+ responses = [str(task.result()) for task in tasks]
350
+ new_contexts = defaultdict(set)
351
+ for idx, response in enumerate(responses):
352
+ # Parse json string to dictionary
353
+ json_dict = parse_json_markdown(response)
354
+
355
+ if len(json_dict["choices"]) > 1:
356
+ for choice in json_dict["choices"]:
357
+ new_contexts[json_dict["summarized_text"]] = new_contexts[json_dict["summarized_text"]].union(contexts[contexts_batches[idx][choice - 1]])
358
+ else:
359
+ new_contexts[json_dict["summarized_text"]] = set()
360
+
361
+ if len(new_contexts) == 1:
362
+ synthesized_text = list(new_contexts.keys())[0]
363
+ return synthesized_text, list(new_contexts[synthesized_text])
364
+ else:
365
+ return await asynthesize_results(queries, new_contexts, llm, num_children=num_children)
366
+
367
+ class CustomFusionRetriever(BaseRetriever):
368
+ def __init__(self,
369
+ llm,
370
+ retriever_mappings: Dict[str, BaseRetriever],
371
+ similarity_top_k: int = 3,
372
+ num_generated_queries = 3,
373
+ ) -> None:
374
+ self._retriever_mappings = retriever_mappings
375
+ self._similarity_top_k = similarity_top_k
376
+ self._num_generated_queries = num_generated_queries
377
+ self._llm = llm
378
+ super().__init__()
379
+
380
+ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
381
+ """Retrieve self._similarity_top_k content nodes given query
382
+
383
+ Args:
384
+ query_bundle (QueryBundle): query bundle include query string
385
+ """
386
+ # Get data from query bundle
387
+ query_dict = json.loads(query_bundle.query_str)
388
+ original_query = query_dict['sub_question']
389
+ tool_name = query_dict['tool_name']
390
+
391
+ # Rewrite original query to n queries
392
+ generated_queries = generate_queries(self._llm, original_query, num_queries=self._num_generated_queries)
393
+
394
+ # For each generated query, retrieve relevant nodes
395
+ retrieved_nodes = []
396
+ for query in generated_queries:
397
+ if len(query) == 0:
398
+ continue
399
+ retrieved_nodes.extend(self._retriever_mappings[tool_name].retrieve(query))
400
+
401
+ # Fuse retrieved nodes using reciprocal rank
402
+ fused_results = fuse_results(retrieved_nodes,
403
+ similarity_top_k=self._similarity_top_k)
404
+ return fused_results
405
+
406
+ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
407
+ """Asynchronously retrieve self._similarity_top_k content nodes given query
408
+
409
+ Args:
410
+ query_bundle (QueryBundle): query bundle include query string
411
+ """
412
+ # Get data from query bundle
413
+ query_dict = json.loads(query_bundle.query_str)
414
+ original_query = query_dict['sub_question']
415
+ tool_name = query_dict['tool_name']
416
+
417
+ # Rewrite original query to n queries
418
+ generated_queries = await agenerate_queries(llm=self._llm, query=original_query, num_queries=self._num_generated_queries)
419
+
420
+ # For each generated query, retrieve relevant nodes
421
+ tasks = []
422
+ async with asyncio.TaskGroup() as tg:
423
+ for query in generated_queries:
424
+ if len(query) == 0:
425
+ continue
426
+ task = tg.create_task(self._retriever_mappings[tool_name].aretrieve(query))
427
+ tasks.append(task)
428
+
429
+ retrieved_nodes = [node for task in tasks for node in task.result()]
430
+
431
+ # Fuse retrieved nodes using reciprocal rank
432
+ fused_results = fuse_results(retrieved_nodes,
433
+ similarity_top_k=self._similarity_top_k)
434
+ return fused_results
435
+
436
+
437
+ @dataclass
438
+ class Response:
439
+ response: str
440
+ source_images: Optional[List] = None
441
+
442
+ def __str__(self):
443
+ return self.response
444
+
445
+ class CustomQueryEngine:
446
+ def __init__(self,
447
+ retriever_tools: List[ToolMetadata],
448
+ fusion_retriever: BaseRetriever,
449
+ qa_prompt: PromptTemplate = None,
450
+ llm: LLM = None,
451
+ num_children: int = 3):
452
+ self._qa_prompt = qa_prompt if qa_prompt else PromptTemplate(DEFAULT_FINAL_ANSWER_PROMPT_TMPL)
453
+ self._llm = llm
454
+ self._num_children = num_children
455
+ self._sub_question_generator = LLMQuestionGenerator.from_defaults(llm=self._llm,
456
+ prompt_template_str=DEFAULT_SUB_QUESTION_PROMPT_TMPL)
457
+ self._fusion_retriever = fusion_retriever
458
+ self._retriever_tools = retriever_tools
459
+
460
+
461
+ def query(self, query_str: str) -> Response:
462
+ # Generate sub queries
463
+ sub_queries = self._sub_question_generator.generate(tools=self._retriever_tools,
464
+ query=QueryBundle(query_str=query_str))
465
+
466
+ if len(sub_queries) == 0:
467
+ response_template = PromptTemplate("Cannot answer the query: {query_str}")
468
+ return Response(response=response_template.format(query_str=query_str), source_images=[])
469
+ else:
470
+ # Dictionary to map response -> source_images
471
+ response2images_mapping = defaultdict(set)
472
+
473
+ # For each sub queries retrieve relevant image nodes
474
+ # With fusion retriever, each sub query is rewritten to n queries -> retrieve relevant nodes for each generated query
475
+ # -> fuse all nodes retrieved from multiple generated queries using reciprocal rank -> get top k results
476
+ for sub_query in sub_queries:
477
+ retrieved_nodes = self._fusion_retriever.retrieve(QueryBundle(query_str=sub_query.model_dump_json()))
478
+ # Using LLM to get the answer for sub query from retrieved nodes
479
+ for retrieved_node in retrieved_nodes:
480
+ response2images_mapping[str(self._llm.complete([sub_query.sub_question, Image.open(retrieved_node.node.resolve_image())]))].add(retrieved_node.node.image)
481
+
482
+ # Synthesize results
483
+ synthesized_text, source_images = synthesize_results(queries=sub_queries,
484
+ contexts=response2images_mapping,
485
+ llm=self._llm,
486
+ num_children=self._num_children)
487
+
488
+ final_answer = self._llm.predict(self._qa_prompt,
489
+ context_str=synthesized_text,
490
+ query_str=query_str)
491
+
492
+ response_template = PromptTemplate("Retrieved Information:\n"
493
+ "------------------------\n"
494
+ "{retrieved_information}\n"
495
+ "-------------------------\n\n"
496
+ "Answer:\n"
497
+ "{final_answer}")
498
+
499
+ return Response(response=response_template.format(retrieved_information=synthesized_text, final_answer=final_answer), source_images=source_images)
500
+
501
+ async def aquery(self, query_str: str):
502
+ sub_queries = await self._sub_question_generator.agenerate(tools=self._retriever_tools,
503
+ query=QueryBundle(query_str=query_str))
504
+ if len(sub_queries) == 0:
505
+ response_template = PromptTemplate("Cannot answer the query: {query_str}")
506
+ return Response(response=response_template.format(query_str=query_str), source_images=[])
507
+ else:
508
+ retrieved_subquestion_nodes = []
509
+ async with asyncio.TaskGroup() as tg:
510
+ for sub_query in sub_queries:
511
+ task = tg.create_task(self._fusion_retriever.aretrieve(QueryBundle(query_str=sub_query.model_dump_json())))
512
+ retrieved_subquestion_nodes.append([sub_query.sub_question, task])
513
+
514
+ retrieved_subquestion_nodes = [[sub_question, task.result()] for sub_question, task in retrieved_subquestion_nodes]
515
+
516
+ answers = []
517
+ # For each sub queries retrieve relevant image nodes
518
+ # With fusion retriever, each sub query is rewritten to n queries -> retrieve relevant nodes for each generated query
519
+ # -> fuse all nodes retrieved from multiple generated queries using reciprocal rank -> get top k results
520
+ async with asyncio.TaskGroup() as tg:
521
+ for sub_question, retrieved_nodes in retrieved_subquestion_nodes:
522
+ for retrieved_node in retrieved_nodes:
523
+ task = tg.create_task(self._llm.acomplete([sub_question, Image.open(retrieved_node.node.resolve_image())]))
524
+ answers.append([task, retrieved_node.node.image])
525
+
526
+ # Dictionary to map response -> source_images
527
+ response2images_mapping = defaultdict(set)
528
+
529
+ for task, image in answers:
530
+ response2images_mapping[str(task.result())].add(image)
531
+
532
+ # Synthesize results
533
+ synthesized_text, source_images = await asynthesize_results(queries=sub_queries,
534
+ contexts=response2images_mapping,
535
+ llm=self._llm,
536
+ num_children=self._num_children)
537
+
538
+
539
+ final_answer = await self._llm.apredict(self._qa_prompt,
540
+ context_str=synthesized_text,
541
+ query_str=query_str)
542
+
543
+ response_template = PromptTemplate("Retrieved Information:\n"
544
+ "------------------------\n"
545
+ "{retrieved_information}\n"
546
+ "-------------------------\n\n"
547
+ "Answer:\n"
548
+ "{final_answer}")
549
+
550
+ return Response(response=response_template.format(retrieved_information=synthesized_text, final_answer=final_answer), source_images=source_images)
551
+
552
+
553
+
554
+
555
+
556
+
557
+
558
+
models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .colpali import ColPali, KVCache
2
+ from .paligemma_processor import PaliGemmaProcessor
3
+ from .colpali_processor import ColPaliProcessor
4
+ from .paligemma import PaliGemma
5
+ from .lora import *
models/colpali.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from dataclasses import dataclass
7
+ from .gemma import KVCache
8
+ from .paligemma import PaliGemma, PaliGemmaConfig
9
+ from typing import Optional
10
+ from utils import *
11
+ from pathlib import Path
12
+ from safetensors import safe_open
13
+
14
+ def convert_weights_dict(original_weights):
15
+ converted_weights = {}
16
+ converted_weights['custom_text_proj.lora_A.weight'] = original_weights['base_model.model.custom_text_proj.lora_A.weight']
17
+ converted_weights['custom_text_proj.lora_B.weight'] = original_weights['base_model.model.custom_text_proj.lora_B.weight']
18
+ for i in range(18):
19
+ converted_weights[f'model.language_model.model.layers.{i}.mlp.down_proj.lora_A.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.mlp.down_proj.lora_A.weight']
20
+ converted_weights[f'model.language_model.model.layers.{i}.mlp.down_proj.lora_B.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.mlp.down_proj.lora_B.weight']
21
+ converted_weights[f'model.language_model.model.layers.{i}.mlp.gate_proj.lora_A.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.mlp.gate_proj.lora_A.weight']
22
+ converted_weights[f'model.language_model.model.layers.{i}.mlp.gate_proj.lora_B.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.mlp.gate_proj.lora_B.weight']
23
+ converted_weights[f'model.language_model.model.layers.{i}.mlp.up_proj.lora_A.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.mlp.up_proj.lora_A.weight']
24
+ converted_weights[f'model.language_model.model.layers.{i}.mlp.up_proj.lora_B.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.mlp.up_proj.lora_B.weight']
25
+ converted_weights[f'model.language_model.model.layers.{i}.self_attn.q_proj.lora_A.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.q_proj.lora_A.weight']
26
+ converted_weights[f'model.language_model.model.layers.{i}.self_attn.q_proj.lora_B.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.q_proj.lora_B.weight']
27
+ converted_weights[f'model.language_model.model.layers.{i}.self_attn.k_proj.lora_A.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.k_proj.lora_A.weight']
28
+ converted_weights[f'model.language_model.model.layers.{i}.self_attn.k_proj.lora_B.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.k_proj.lora_B.weight']
29
+ converted_weights[f'model.language_model.model.layers.{i}.self_attn.v_proj.lora_A.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.v_proj.lora_A.weight']
30
+ converted_weights[f'model.language_model.model.layers.{i}.self_attn.v_proj.lora_B.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.v_proj.lora_B.weight']
31
+ converted_weights[f'model.language_model.model.layers.{i}.self_attn.o_proj.lora_A.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.o_proj.lora_A.weight']
32
+ converted_weights[f'model.language_model.model.layers.{i}.self_attn.o_proj.lora_B.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.o_proj.lora_B.weight']
33
+
34
+ return converted_weights
35
+
36
+
37
+ class ColPali(nn.Module):
38
+ def __init__(self, cfg: PaliGemmaConfig):
39
+ super().__init__()
40
+ self.model = PaliGemma(cfg=cfg)
41
+ self.dim = 128
42
+ self.custom_text_proj = nn.Linear(self.model.cfg.text_config.hidden_size, self.dim, bias=False)
43
+
44
+ @staticmethod
45
+ def from_pretrained(model_dir, torch_dtype: torch.dtype = torch.float32):
46
+ torch.set_default_dtype(torch_dtype)
47
+ with open(os.path.join(model_dir, 'config.json'), "r") as f:
48
+ model_config = json.loads(f.read())
49
+ config = PaliGemmaConfig.from_dict(model_config)
50
+
51
+ safetensor_files = Path(model_dir).glob("*.safetensors")
52
+
53
+ weights = {}
54
+ for file in safetensor_files:
55
+ with safe_open(file, framework='pt', device="cpu") as f:
56
+ for key in f.keys():
57
+ weights[key] = f.get_tensor(key)
58
+ model = ColPali(config)
59
+ model.load_state_dict(weights, strict=False)
60
+ model.tie_weights()
61
+ return model
62
+
63
+ def load_lora(self, model_dir):
64
+ weights = {}
65
+ with safe_open(os.path.join(model_dir, "adapter_model.safetensors"), framework="pt", device="cpu") as f:
66
+ for key in f.keys():
67
+ weights[key] = f.get_tensor(key)
68
+
69
+ converted_weights = convert_weights_dict(weights)
70
+ self.load_state_dict(converted_weights, strict=False)
71
+
72
+ def tie_weights(self):
73
+ self.model.language_model.tie_weights()
74
+
75
+ def forward(self, *args, **kwargs) -> torch.Tensor:
76
+ outputs = self.model(*args, **kwargs)
77
+ last_hidden_states = outputs[0]
78
+ proj = self.custom_text_proj(last_hidden_states)
79
+ # L2 normalization
80
+ proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
81
+
82
+ proj = proj * kwargs['attention_mask'].unsqueeze(-1) # (batch_size, sequence_length, dim)
83
+
84
+ return proj
85
+
86
+
87
+
88
+
89
+
models/colpali_processor.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from typing import Tuple, List
4
+ import numpy as np
5
+ from transformers import GemmaTokenizerFast
6
+ from .paligemma_processor import PaliGemmaProcessor
7
+ from typing import Optional
8
+
9
+ def process_imgs(imgs: List[Image.Image],
10
+ img_size: Tuple[int, int],
11
+ rescale: float,
12
+ mean: Tuple[float, float, float],
13
+ std: Tuple[float, float, float]):
14
+
15
+ def normalize(img, mean, std):
16
+ img = (img - np.array(mean, dtype=img.dtype)) / np.array(std, dtype=img.dtype)
17
+ return img
18
+
19
+ resized_imgs = [img.resize((img_size[0], img_size[1]), resample=Image.Resampling.BICUBIC) for img in imgs]
20
+
21
+ rescaled_imgs = [np.array(img, dtype=np.float32) * rescale for img in resized_imgs]
22
+
23
+ normalized_imgs = [normalize(img, mean, std) for img in rescaled_imgs]
24
+
25
+ transposed_imgs = [img.transpose(2, 0, 1) for img in normalized_imgs]
26
+
27
+ tensor_imgs = torch.tensor(np.stack(transposed_imgs, axis=0), dtype=torch.float32)
28
+ return tensor_imgs
29
+
30
+
31
+ def process_prompts(prompt, image_token, max_num_image_token, bos_token):
32
+ return f"{image_token * max_num_image_token}{bos_token}{prompt}\n"
33
+
34
+
35
+ class ColPaliProcessor(PaliGemmaProcessor):
36
+ def __init__(self,
37
+ tokenizer: GemmaTokenizerFast) -> None:
38
+ super().__init__(tokenizer=tokenizer)
39
+ self.mock_image = Image.new(mode='RGB', size=(16, 16), color='black')
40
+
41
+ def process_images(self, images: List[Image.Image]):
42
+ input_prompts = ["Describe the image."] * len(images)
43
+
44
+ images = [image.convert("RGB") for image in images]
45
+
46
+ return_data = self(images,
47
+ input_prompts,
48
+ padding="longest",
49
+ truncation=False)
50
+
51
+ return return_data
52
+
53
+ def process_queries(self,
54
+ queries: List[str],
55
+ max_length: int = 50,
56
+ suffix: Optional[str] = None):
57
+
58
+ if suffix is None:
59
+ suffix = "<pad>" * 10
60
+
61
+ texts_query: List[str] = []
62
+
63
+ for query in queries:
64
+ query = f"Question: {query}"
65
+ query += suffix
66
+ texts_query.append(query)
67
+
68
+
69
+ batch_query = self(imgs=[self.mock_image] * len(texts_query),
70
+ prompts=texts_query,
71
+ padding="longest",
72
+ max_length=max_length + self.image_seq_length,
73
+ truncation=True)
74
+
75
+ del batch_query["pixel_values"]
76
+
77
+ batch_query["input_ids"] = batch_query["input_ids"][..., self.image_seq_length:]
78
+ batch_query["attention_mask"] = batch_query["attention_mask"][..., self.image_seq_length:]
79
+
80
+ return batch_query
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
models/gemma.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.nn.utils.parametrize as parametrize
5
+ from dataclasses import dataclass
6
+ from typing import Optional, List
7
+ import math
8
+ import torch.utils.checkpoint as checkpoint
9
+
10
+ @dataclass
11
+ class GemmaConfig:
12
+ hidden_size: int = 2048
13
+ intermediate_size: int = 16384
14
+ num_attention_heads: int = 8
15
+ num_hidden_layers: int = 18
16
+ num_image_tokens: int = 256
17
+ num_key_value_heads: int = 1
18
+ vocab_size: int = 257216
19
+ norm_eps: float = 1e-6
20
+ max_seq_len: int = 8192
21
+ attention_dropout: float = 0.0
22
+ use_lora: bool = False
23
+ training: bool = False
24
+
25
+ @classmethod
26
+ def from_dict(cls, data):
27
+ return cls(
28
+ hidden_size = data['hidden_size'],
29
+ intermediate_size = data['intermediate_size'],
30
+ num_attention_heads = data['num_attention_heads'],
31
+ num_hidden_layers = data['num_hidden_layers'],
32
+ num_image_tokens = data['num_image_tokens'],
33
+ num_key_value_heads = data['num_key_value_heads'],
34
+ vocab_size = data['vocab_size'],
35
+ training = data['training'])
36
+
37
+ class RMSNorm(nn.Module):
38
+ def __init__(self, dim: int, norm_eps: float = 1e-6):
39
+ super().__init__()
40
+ self.weight = nn.Parameter(torch.zeros(dim))
41
+ self.norm_eps = norm_eps
42
+
43
+ def _norm(self, x):
44
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.norm_eps)
45
+
46
+ def forward(self, x: torch.Tensor):
47
+ output = self._norm(x.float())
48
+ output = output * (1.0 + self.weight.float())
49
+ return output.type_as(x)
50
+
51
+
52
+ def precompute_freqs(head_dim: int, max_seq_len: int, theta: int = 10000):
53
+ thetas = 1 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float() / head_dim))
54
+ m = torch.arange(max_seq_len, dtype=torch.long)
55
+
56
+ # (max_seq_len, head_dim // 2)
57
+ freqs = torch.outer(m, thetas)
58
+
59
+ # (max_seq_len, head_dim // 2) -> (max_seq_len, head_dim)
60
+ freqs = torch.cat((freqs, freqs), dim=-1)
61
+ return freqs
62
+
63
+ def roate_half(x: torch.Tensor):
64
+ x1 = x[..., :x.shape[-1] // 2]
65
+ x2 = x[..., x.shape[-1] // 2:]
66
+
67
+ return torch.cat((-x2, x1), dim=-1)
68
+
69
+ def apply_rotary_embed(x: torch.Tensor,
70
+ freqs: torch.Tensor):
71
+ # x: (n, n_heads, seq_len, head_dim)
72
+ # freqs: (n, seq_len, head_dim)
73
+ device_type = x.device.type
74
+ device_type = device_type if device_type != 'mps' else 'cpu'
75
+ with torch.autocast(device_type=device_type, enabled=False):
76
+ cos = freqs.cos()
77
+ sin = freqs.sin()
78
+ while len(cos.shape) < len(x.shape):
79
+ cos = cos.unsqueeze(1)
80
+ sin = sin.unsqueeze(1)
81
+ cos = cos.to(x.dtype)
82
+ sin = sin.to(x.dtype)
83
+ x = (x * cos) + (roate_half(x) * sin)
84
+ return x
85
+
86
+ class KVCache:
87
+ def __init__(self):
88
+ self.cache_k: List[torch.Tensor] = []
89
+ self.cache_v: List[torch.Tensor] = []
90
+
91
+ def num_items(self):
92
+ if len(self.cache_k) == 0:
93
+ return 0
94
+ else:
95
+ # (n, num_heads, seq_len, head_dim)
96
+ return self.cache_k[0].shape[-2]
97
+
98
+ def update(self, xk, xv, layer_idx):
99
+ if layer_idx < len(self.cache_k):
100
+ self.cache_k[layer_idx] = torch.cat((self.cache_k[layer_idx], xk), dim=-2)
101
+ self.cache_v[layer_idx] = torch.cat((self.cache_v[layer_idx], xv), dim=-2)
102
+ else:
103
+ self.cache_k.append(xk)
104
+ self.cache_v.append(xv)
105
+
106
+ return self.cache_k[layer_idx], self.cache_v[layer_idx]
107
+
108
+
109
+ class GemmaTransformerAttention(nn.Module):
110
+ def __init__(self, cfg: GemmaConfig, layer_idx: int):
111
+ super().__init__()
112
+ self.cfg = cfg
113
+ self.layer_idx = layer_idx
114
+ self.vocab_size = cfg.vocab_size
115
+ self.hidden_size = cfg.hidden_size
116
+ self.num_attention_heads = cfg.num_attention_heads
117
+ self.num_key_value_heads = cfg.num_key_value_heads
118
+ self.max_seq_len = cfg.max_seq_len
119
+
120
+ assert self.hidden_size % self.num_attention_heads == 0
121
+
122
+ self.n_rep =self.num_attention_heads // self.num_key_value_heads
123
+ self.head_dim = self.hidden_size // self.num_attention_heads
124
+
125
+ self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
126
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
127
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
128
+
129
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
130
+
131
+ self.attn_dropout = cfg.attention_dropout
132
+ self.training = cfg.training
133
+
134
+ self.register_buffer('freqs',
135
+ precompute_freqs(self.head_dim, cfg.max_seq_len),
136
+ persistent=False)
137
+
138
+ def forward(self, x: torch.Tensor,
139
+ position_ids: Optional[torch.Tensor] = None,
140
+ attention_mask: Optional[torch.Tensor] = None,
141
+ kv_cache: Optional[KVCache] = None):
142
+ batch_size, seq_len, embed_dim = x.shape
143
+
144
+ xq = self.q_proj(x)
145
+ xk = self.k_proj(x)
146
+ xv = self.v_proj(x)
147
+
148
+ # (n, seq_len, hidden_size) -> (n, seq_len, num_heads, head_dim) -> (n, num_heads, seq_len, head_dim)
149
+ xq = xq.view(batch_size, seq_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
150
+ # (n, seq_len, hidden_size) -> (n, seq_len, num_kv_heads, head_dim) -> (n, num_kv_heads, seq_len, head_dim)
151
+ xk = xk.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
152
+ xv = xv.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
153
+
154
+ xq = apply_rotary_embed(xq, self.freqs[position_ids, :])
155
+ xk = apply_rotary_embed(xk, self.freqs[position_ids, :])
156
+
157
+ if kv_cache is not None:
158
+ keys, values = kv_cache.update(xk, xv, self.layer_idx)
159
+ else:
160
+ keys, values = xk, xv
161
+
162
+ # (n, num_kv_heads, seq_len, head_dim) -> (n, num_kv_heads * n_rep, seq_len, head_dim) -> (n, num_heads, seq_len, head_dim)
163
+ keys = keys[:, :, None, :, :].expand(-1, -1, self.n_rep, -1, -1).view(batch_size, -1, keys.shape[-2], self.head_dim)
164
+ values = values[:, :, None, :, :].expand(-1, -1, self.n_rep, -1, -1).view(batch_size, -1, keys.shape[-2], self.head_dim)
165
+
166
+ assert attention_mask is not None
167
+ # (n, num_heads, seq_len, head_dim) @ (n, num_heads, head_dim, seq_len) -> (n, num_heads, seq_len, seq_len)
168
+ attn_weights = torch.softmax(xq @ keys.transpose(2, 3) / math.sqrt(self.head_dim) + attention_mask, dim=-1)
169
+
170
+ # dropout when training
171
+ attn_weights = F.dropout(attn_weights, p=self.attn_dropout, training=self.training)
172
+ # (n, num_heads, seq_len, seq_len) @ (n, num_heads, seq_len, head_dim) -> (n, num_heads, seq_len, head_dim)
173
+ attn_output = attn_weights @ values
174
+ attn_output = attn_output.transpose(1, 2).contiguous()
175
+ attn_output = attn_output.view(*x.shape)
176
+
177
+ attn_output = self.o_proj(attn_output)
178
+ return attn_output, attn_weights
179
+
180
+
181
+ class GemmaTransformerMLP(nn.Module):
182
+ def __init__(self, cfg: GemmaConfig):
183
+ super().__init__()
184
+ self.cfg = cfg
185
+
186
+ self.down_proj = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False)
187
+ self.gate_proj = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
188
+ self.up_proj = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
189
+
190
+ def forward(self, x: torch.Tensor):
191
+ return self.down_proj(F.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x))
192
+
193
+
194
+
195
+ class GemmaTransformerDecoder(nn.Module):
196
+ def __init__(self, cfg: GemmaConfig, layer_idx: int) -> None:
197
+ super().__init__()
198
+ self.cfg = cfg
199
+
200
+ self.input_layernorm = RMSNorm(cfg.hidden_size, cfg.norm_eps)
201
+ self.self_attn = GemmaTransformerAttention(cfg, layer_idx)
202
+ self.mlp = GemmaTransformerMLP(cfg)
203
+ self.post_attention_layernorm = RMSNorm(cfg.hidden_size, cfg.norm_eps)
204
+ self.gradient_checking = False
205
+
206
+
207
+ def forward(self, x: torch.Tensor,
208
+ position_ids: Optional[torch.Tensor] = None,
209
+ attention_mask: Optional[torch.Tensor] = None,
210
+ kv_cache: Optional[KVCache] = None):
211
+
212
+ residual = x
213
+ x = self.input_layernorm(x)
214
+
215
+ if self.gradient_checking:
216
+ x = checkpoint.checkpoint(self.self_attn, x, position_ids, attention_mask, kv_cache)
217
+ else:
218
+ x = self.self_attn(x,
219
+ position_ids,
220
+ attention_mask,
221
+ kv_cache)[0]
222
+ x += residual
223
+
224
+
225
+ residual = x
226
+ x = self.post_attention_layernorm(x)
227
+ x = residual + self.mlp(x)
228
+ return x
229
+
230
+
231
+ class GemmaModel(nn.Module):
232
+ def __init__(self, cfg: GemmaConfig) -> None:
233
+ super().__init__()
234
+ self.cfg = cfg
235
+ self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
236
+
237
+ self.layers = nn.ModuleList(
238
+ [GemmaTransformerDecoder(cfg, layer_idx) for layer_idx in range(cfg.num_hidden_layers)]
239
+ )
240
+
241
+ self.norm = RMSNorm(cfg.hidden_size, cfg.norm_eps)
242
+
243
+ def forward(self, x: torch.Tensor,
244
+ position_ids: Optional[torch.Tensor],
245
+ attention_mask: Optional[torch.Tensor],
246
+ kv_cache: Optional[KVCache]) -> torch.Tensor:
247
+
248
+ output = x * torch.tensor(self.cfg.hidden_size ** 0.5, dtype=x.dtype)
249
+ for layer in self.layers:
250
+ output = layer(output,
251
+ position_ids,
252
+ attention_mask,
253
+ kv_cache)
254
+ output = self.norm(output)
255
+ return output
256
+
257
+
258
+ class Gemma(nn.Module):
259
+ def __init__(self, cfg: GemmaConfig) -> None:
260
+ super().__init__()
261
+ self.cfg = cfg
262
+ self.model = GemmaModel(cfg)
263
+ self.vocab_size = cfg.vocab_size
264
+ self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
265
+
266
+
267
+ def gradient_checkpointing_enabled(self, enabled=False):
268
+ for name, module in self.model.named_modules():
269
+ if isinstance(module, GemmaTransformerDecoder):
270
+ module.gradient_checking = enabled
271
+
272
+ def tie_weights(self):
273
+ self.lm_head.weight = self.model.embed_tokens.weight
274
+
275
+ def forward(self,
276
+ input_embeds: torch.Tensor,
277
+ position_ids: Optional[torch.Tensor],
278
+ attention_mask: Optional[torch.Tensor],
279
+ kv_cache: Optional[KVCache]):
280
+
281
+ output = self.model(input_embeds,
282
+ position_ids,
283
+ attention_mask,
284
+ kv_cache)
285
+ return output, kv_cache
models/lora.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.nn.utils.parametrize as parametrize
5
+ from typing import List
6
+
7
+ class LoRALayer:
8
+ def __init__(self, features_in: int, features_out: int, rank: int=1, alphas: int=1):
9
+ super().__init__()
10
+ self.lora_A = nn.Linear(features_in, rank, bias=False)
11
+ self.lora_B = nn.Linear(rank, features_out, bias=False)
12
+ nn.init.normal_(self.lora_A.weight, mean=0, std=1/rank)
13
+
14
+ self.scale = alphas / rank
15
+
16
+ class LoRALinear(nn.Module, LoRALayer):
17
+ def __init__(self, base_layer: nn.Module, rank: int=1, alphas: int=1, dropout_p: float=0.0):
18
+ features_out, features_in = base_layer.weight.shape
19
+ super().__init__()
20
+ LoRALayer.__init__(self, features_in=features_in, features_out=features_out, rank=rank, alphas=alphas)
21
+
22
+ self.base_layer = nn.Linear(features_in, features_out, bias=False)
23
+ self.base_layer.weight = base_layer.weight
24
+
25
+ if dropout_p > 0.0:
26
+ self.lora_dropout = nn.Dropout(p=dropout_p, inplace=False)
27
+ else:
28
+ self.lora_dropout = nn.Identity()
29
+
30
+ self.enabled = False
31
+
32
+ def forward(self, x: torch.Tensor):
33
+ result = self.base_layer(x)
34
+ if self.enabled:
35
+ result = result + self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scale
36
+ return result
37
+
38
+ def enable_lora(model: nn.Module, lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'], enabled=True):
39
+ for name, module in model.named_modules():
40
+ if name.split('.')[-1] in lora_modules:
41
+ module.enabled = enabled
42
+ return model
43
+
44
+ def replace_module(module: nn.Module, target_modules: List[str], torch_dtype: torch.dtype, **kwargs):
45
+ for child_name, child_module in module.named_children():
46
+ if child_name in target_modules:
47
+ new_module = LoRALinear(child_module, **kwargs).to(torch_dtype)
48
+ setattr(module, child_name, new_module)
49
+ else:
50
+ replace_module(child_module, target_modules, torch_dtype, **kwargs)
51
+
52
+ def get_lora_model(model: nn.Module, rank: float, alphas: float, lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'], dropout_p: float = 0.0, training: bool = False, torch_dtype: torch.dtype = torch.bfloat16):
53
+ lora_config = {'rank': rank,
54
+ 'alphas': alphas,
55
+ 'dropout_p': dropout_p}
56
+ replace_module(model, lora_modules, torch_dtype, **lora_config)
57
+
58
+ for name, param in model.named_parameters():
59
+ if 'lora' not in name:
60
+ param.requires_grad = False
61
+ else:
62
+ if training:
63
+ param.requires_grad = True
64
+ else:
65
+ param.requires_grad = False
66
+
67
+ return model
68
+
models/paligemma.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ from .gemma import GemmaConfig, Gemma, KVCache
6
+ from .siglip import SigLIPConfig, SigLIPVisionTower
7
+ from typing import Optional
8
+ import os
9
+ import json
10
+ from pathlib import Path
11
+ from safetensors import safe_open
12
+
13
+ @dataclass
14
+ class PaliGemmaConfig:
15
+ bos_token_id: int = 2
16
+ eos_token_id: int = 1
17
+ hidden_size: int = 2048
18
+ ignore_index: int = -100
19
+ image_token_index: int = 257152
20
+ pad_token_id: int = 0
21
+ projection_dim: int = 2048
22
+ text_config: GemmaConfig = None
23
+ vision_config: SigLIPConfig = None
24
+ vocab_size: int = 257216
25
+ @classmethod
26
+ def from_dict(cls, data):
27
+ return cls(
28
+ bos_token_id = data['bos_token_id'],
29
+ eos_token_id = data['eos_token_id'],
30
+ hidden_size = data['hidden_size'],
31
+ ignore_index = data['ignore_index'],
32
+ image_token_index = data['image_token_index'],
33
+ pad_token_id = data['pad_token_id'],
34
+ projection_dim = data['projection_dim'],
35
+ text_config = GemmaConfig.from_dict(data['text_config']),
36
+ vision_config = SigLIPConfig.from_dict(data['vision_config'])
37
+ )
38
+
39
+ class PaliGemmaMultimodalProjector(nn.Module):
40
+ def __init__(self, cfg: PaliGemmaConfig):
41
+ super().__init__()
42
+ self.linear = nn.Linear(cfg.vision_config.hidden_size, cfg.vision_config.projection_dim)
43
+
44
+ def forward(self, x: torch.Tensor):
45
+ x = self.linear(x)
46
+ return x
47
+
48
+ class PaliGemma(nn.Module):
49
+ def __init__(self, cfg: PaliGemmaConfig):
50
+ super().__init__()
51
+ self.cfg = cfg
52
+ self.language_model = Gemma(cfg.text_config)
53
+
54
+ self.vision_tower = SigLIPVisionTower(cfg.vision_config)
55
+
56
+ self.multi_modal_projector = PaliGemmaMultimodalProjector(cfg)
57
+
58
+ def tie_weights(self):
59
+ self.language_model.tie_weights()
60
+
61
+ def _merge_img_embeds_and_input_embeds(self, img_embeds: torch.Tensor,
62
+ input_embeds: torch.Tensor,
63
+ input_tokens: torch.Tensor):
64
+ batch_size, seq_len, embed_dim = input_embeds.shape
65
+ scaled_img = img_embeds / (self.cfg.hidden_size ** 0.5)
66
+
67
+ final_embeddings = torch.zeros((batch_size, seq_len, embed_dim), dtype=img_embeds.dtype, device=img_embeds.device)
68
+
69
+
70
+ # (n, seq_len)
71
+ text_mask = (input_tokens != self.cfg.pad_token_id) & (input_tokens != self.cfg.image_token_index)
72
+ img_mask = input_tokens == self.cfg.image_token_index
73
+ pad_mask = input_tokens == self.cfg.pad_token_id
74
+
75
+ text_mask = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
76
+ img_mask = img_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
77
+ pad_mask = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
78
+
79
+ # (n, seq_len, embed_dim)
80
+ final_embeddings = torch.where(text_mask, input_embeds, final_embeddings)
81
+ final_embeddings = final_embeddings.masked_scatter(img_mask, scaled_img)
82
+ final_embeddings = torch.where(pad_mask, torch.zeros_like(final_embeddings), final_embeddings)
83
+
84
+ return final_embeddings
85
+
86
+ def _create_position_ids_and_attention_mask(self,
87
+ device: str = '',
88
+ dtype: torch.dtype = torch.float32,
89
+ batch_size: int = 32,
90
+ seq_len: int = 1,
91
+ attention_mask: Optional[torch.Tensor] = None,
92
+ kv_cache: Optional[KVCache] = None):
93
+ # Create Attention Mask
94
+ if kv_cache is None or kv_cache.num_items() == 0:
95
+ causal_mask = torch.full((batch_size, seq_len, seq_len), 0, dtype=dtype, device=device)
96
+ position_ids = attention_mask.cumsum(dim=-1).masked_fill_((attention_mask == 0), 1).to(device)
97
+
98
+ else:
99
+ assert seq_len == 1
100
+ kv_len = kv_cache.num_items() + 1
101
+ causal_mask = torch.full((batch_size, 1, kv_len), 0, dtype=dtype, device=device)
102
+ position_ids = attention_mask.cumsum(dim=-1)[:, -1].to(device)
103
+
104
+ # (n, seq_len, kv_len) -> (n, 1, seq_len, kv_len)
105
+ causal_mask = causal_mask.unsqueeze(1)
106
+
107
+ return position_ids, causal_mask
108
+
109
+ @staticmethod
110
+ def from_pretrained(model_dir):
111
+ with open(os.path.join(model_dir, 'config.json'), "r") as f:
112
+ model_config = json.loads(f.read())
113
+ config = PaliGemmaConfig.from_dict(model_config)
114
+
115
+ safetensor_files = Path(model_dir).glob("*.safetensors")
116
+
117
+ weights = {}
118
+ for file in safetensor_files:
119
+ with safe_open(file, framework='pt', device="cpu") as f:
120
+ for key in f.keys():
121
+ weights[key] = f.get_tensor(key)
122
+
123
+ model = PaliGemma(config)
124
+ model.load_state_dict(weights, strict=False)
125
+ model.tie_weights()
126
+ return model
127
+
128
+
129
+ def forward(self, *args, **kwargs):
130
+
131
+ # input_tokens: (n, seq_len)
132
+
133
+ # -> (n, seq_len, embed_dim)
134
+ kv_cache = kwargs['kv_cache'] if 'kv_cache' in kwargs else None
135
+ input_tokens = kwargs['input_ids']
136
+ pixel_values = kwargs['pixel_values'] if 'pixel_values' in kwargs else None
137
+ attention_mask = kwargs['attention_mask']
138
+ input_embeds = self.language_model.model.embed_tokens(input_tokens)
139
+ if pixel_values is not None:
140
+ img_embeds = self.vision_tower(pixel_values.to(input_embeds.dtype))
141
+ img_embeds = self.multi_modal_projector(img_embeds)
142
+ final_embeddings = self._merge_img_embeds_and_input_embeds(img_embeds=img_embeds,
143
+ input_embeds=input_embeds,
144
+ input_tokens=input_tokens)
145
+ else:
146
+ final_embeddings = input_embeds
147
+
148
+ position_ids, causal_mask = self._create_position_ids_and_attention_mask(device=final_embeddings.device.type,
149
+ dtype=final_embeddings.dtype,
150
+ batch_size=final_embeddings.shape[0],
151
+ seq_len=final_embeddings.shape[1],
152
+ attention_mask=attention_mask,
153
+ kv_cache=kv_cache)
154
+
155
+ outputs, kv_cache = self.language_model(
156
+ input_embeds=final_embeddings,
157
+ position_ids=position_ids,
158
+ attention_mask=causal_mask,
159
+ kv_cache=kv_cache
160
+ )
161
+ return outputs, kv_cache
162
+
models/paligemma_processor.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from typing import Tuple, List
4
+ import numpy as np
5
+ from transformers import GemmaTokenizerFast, BatchFeature
6
+ import json
7
+ import os
8
+
9
+ def preprocess_imgs(imgs: List[Image.Image],
10
+ img_size: Tuple[int, int],
11
+ rescale: float,
12
+ mean: Tuple[float, float, float],
13
+ std: Tuple[float, float, float]):
14
+
15
+ def normalize(img, mean, std):
16
+ img = (img - np.array(mean, dtype=img.dtype)) / np.array(std, dtype=img.dtype)
17
+ return img
18
+
19
+ resized_imgs = [np.array(img.resize((img_size[0], img_size[1]), resample=3)) for img in imgs]
20
+
21
+ rescaled_imgs = [(img * rescale).astype(np.float32) for img in resized_imgs]
22
+
23
+
24
+ normalized_imgs = [normalize(img, mean, std) for img in rescaled_imgs]
25
+ transposed_imgs = [img.transpose(2, 0, 1) for img in normalized_imgs]
26
+
27
+ tensor_imgs = torch.tensor(np.stack(transposed_imgs, axis=0), dtype=torch.float32)
28
+ return tensor_imgs
29
+
30
+
31
+ def preprocess_prompts(prompt, image_token, max_num_image_token, bos_token):
32
+ return f"{image_token * max_num_image_token}{bos_token}{prompt}\n"
33
+
34
+
35
+ class PaliGemmaProcessor:
36
+ IMAGE_TOKEN = "<image>"
37
+ def __init__(self,
38
+ tokenizer: GemmaTokenizerFast) -> None:
39
+
40
+ additional_special_tokens = {"additional_special_tokens": [self.IMAGE_TOKEN]}
41
+ tokenizer.add_special_tokens(additional_special_tokens)
42
+
43
+ EXTRA_TOKENS = [
44
+ f"<loc{i:04d}>" for i in range(1024)
45
+ ] # These tokens are used for object detection (bounding boxes)
46
+ EXTRA_TOKENS += [
47
+ f"<seg{i:03d}>" for i in range(128)
48
+ ]
49
+
50
+ tokenizer.add_tokens(EXTRA_TOKENS)
51
+
52
+ tokenizer.add_bos_token = False
53
+ tokenizer.add_eos_token = False
54
+
55
+ self.tokenizer = tokenizer
56
+
57
+ def from_pretrained(self, pretrained_dir):
58
+
59
+ with open(os.path.join(pretrained_dir, "preprocessor_config.json"), "r") as f:
60
+ config = json.loads(f.read())
61
+
62
+ self.image_seq_length = config['image_seq_length']
63
+ self.image_mean = config['image_mean']
64
+ self.image_std = config['image_std']
65
+ self.resample = config['resample']
66
+ self.rescale_factor = config['rescale_factor']
67
+ self.size = (config['size']['height'], config['size']['width'])
68
+ return self
69
+
70
+
71
+ def __call__(self,
72
+ imgs: List[Image.Image],
73
+ prompts: List[str],
74
+ padding: str = "longest",
75
+ truncation: bool = True,
76
+ max_length: int = None):
77
+
78
+ processed_imgs = preprocess_imgs(imgs,
79
+ img_size=self.size,
80
+ rescale=self.rescale_factor,
81
+ mean=self.image_mean,
82
+ std=self.image_mean)
83
+
84
+ processed_prompts = [preprocess_prompts(prompt,
85
+ image_token=self.IMAGE_TOKEN,
86
+ max_num_image_token=self.image_seq_length,
87
+ bos_token=self.tokenizer.bos_token) for prompt in prompts]
88
+
89
+ model_inputs = self.tokenizer(processed_prompts,
90
+ return_tensors='pt',
91
+ padding=padding,
92
+ truncation=truncation,
93
+ max_length=max_length)
94
+
95
+ return {**model_inputs, "pixel_values": processed_imgs}
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
models/siglip.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from dataclasses import dataclass
5
+ from typing import Optional
6
+
7
+ @dataclass
8
+ class SigLIPConfig:
9
+ hidden_size: int = 1152
10
+ intermediate_size: int = 4304
11
+ num_attention_heads: int = 16
12
+ num_hidden_layers: int = 27
13
+ num_image_tokens: int = 256
14
+ patch_size: int = 14
15
+ projection_dim: int = 2048
16
+ n_channels: int = 3
17
+ img_size: int = 224
18
+ norm_eps: float = 1e-6
19
+ attention_dropout: float = 0.0
20
+
21
+ @classmethod
22
+ def from_dict(cls, data):
23
+ return cls(
24
+ hidden_size = data['hidden_size'],
25
+ intermediate_size = data['intermediate_size'],
26
+ num_attention_heads = data['num_attention_heads'],
27
+ num_hidden_layers = data['num_hidden_layers'],
28
+ num_image_tokens = data['num_image_tokens'],
29
+ patch_size = data['patch_size'],
30
+ projection_dim = data['projection_dim']
31
+ )
32
+
33
+ class SigLIPEmbedding(nn.Module):
34
+ def __init__(self, cfg: SigLIPConfig):
35
+ super().__init__()
36
+ self.patch_embedding = nn.Conv2d(cfg.n_channels, cfg.hidden_size, kernel_size=cfg.patch_size, stride=cfg.patch_size, padding='valid')
37
+
38
+ self.num_patches = (cfg.img_size // cfg.patch_size) ** 2
39
+ self.position_embedding = nn.Embedding(cfg.num_image_tokens, cfg.hidden_size)
40
+
41
+ self.register_buffer('position_ids',
42
+ torch.arange(cfg.num_image_tokens).expand(1, -1),
43
+ persistent=False)
44
+
45
+ def forward(self, x: torch.FloatTensor):
46
+ # x: (n, c, h, w) -> (n, c, num_patch_h, num_patch_w)
47
+ img_embeds = self.patch_embedding(x)
48
+ # (n, c, num_patch_h, num_patch_w) -> (n, c, num_patches) -> (n, num_patches, c)
49
+ img_embeds = img_embeds.reshape(*img_embeds.shape[:2], -1).transpose(1, 2)
50
+ return img_embeds + self.position_embedding(self.position_ids.to(torch.int64))
51
+
52
+ class SigLIPTransformerAttention(nn.Module):
53
+ def __init__(self, cfg: SigLIPConfig):
54
+ super().__init__()
55
+ self.cfg = cfg
56
+ self.num_attention_heads = cfg.num_attention_heads
57
+ self.head_dim = cfg.hidden_size // self.num_attention_heads
58
+
59
+ self.q_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size)
60
+ self.k_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size)
61
+ self.v_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size)
62
+
63
+ self.out_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size)
64
+ self.dropout_p = self.cfg.attention_dropout
65
+
66
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor):
67
+ batch_size, num_patches, _ = x.shape
68
+
69
+ xq = self.q_proj(x)
70
+ xk = self.k_proj(x)
71
+ xv = self.v_proj(x)
72
+
73
+ xq = xq.view(batch_size, num_patches, self.num_attention_heads, self.head_dim).transpose(1, 2)
74
+ xk = xk.view(batch_size, num_patches, self.num_attention_heads, self.head_dim).transpose(1, 2)
75
+ xv = xv.view(batch_size, num_patches, self.num_attention_heads, self.head_dim).transpose(1, 2)
76
+
77
+ # attn_weights = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
78
+
79
+ # attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype)
80
+
81
+ # attn_output = torch.matmul(attn_weights, xv)
82
+ # attn_output = attn_output.transpose(1, 2).contiguous()
83
+ # attn_output = attn_output.view(batch_size, num_patches, -1)
84
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
85
+ query=xq,
86
+ key=xk,
87
+ value=xv,
88
+ attn_mask=attention_mask,
89
+ dropout_p=self.dropout_p,
90
+ is_causal=False
91
+ )
92
+ attn_output = attn_output.transpose(1, 2).contiguous()
93
+ attn_output = attn_output.view(batch_size, num_patches, -1)
94
+ attn_output = self.out_proj(attn_output)
95
+ return attn_output, None
96
+
97
+ class SigLIPTransformerMLP(nn.Module):
98
+ def __init__(self, cfg: SigLIPConfig):
99
+ super().__init__()
100
+ self.cfg = cfg
101
+
102
+ self.fc1 = nn.Linear(cfg.hidden_size, cfg.intermediate_size)
103
+ self.fc2 = nn.Linear(cfg.intermediate_size, cfg.hidden_size)
104
+
105
+ def forward(self, x: torch.Tensor):
106
+
107
+ x = self.fc1(x)
108
+ x = F.gelu(x, approximate='tanh')
109
+ x = self.fc2(x)
110
+ return x
111
+
112
+ class SigLIPTransformerBlock(nn.Module):
113
+ def __init__(self, cfg: SigLIPConfig):
114
+ super().__init__()
115
+ self.layer_norm1 = nn.LayerNorm(cfg.hidden_size, eps=cfg.norm_eps)
116
+ self.layer_norm2 = nn.LayerNorm(cfg.hidden_size, eps=cfg.norm_eps)
117
+
118
+ self.self_attn = SigLIPTransformerAttention(cfg)
119
+ self.mlp = SigLIPTransformerMLP(cfg)
120
+
121
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
122
+ residual = x
123
+ x = self.layer_norm1(x)
124
+ x = residual + self.self_attn(x, attention_mask)[0]
125
+ residual = x
126
+ x = self.layer_norm2(x)
127
+ x = residual + self.mlp(x)
128
+ return x
129
+
130
+ class SigLIPTransformerEncoder(nn.Module):
131
+ def __init__(self, cfg: SigLIPConfig):
132
+ super().__init__()
133
+
134
+ self.cfg = cfg
135
+ self.layers = nn.ModuleList(
136
+ [SigLIPTransformerBlock(cfg) for _ in range(cfg.num_hidden_layers)]
137
+ )
138
+
139
+ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
140
+ for layer in self.layers:
141
+ x = layer(x, attention_mask)
142
+ return x
143
+ class SigLIPModel(nn.Module):
144
+ def __init__(self, cfg: SigLIPConfig):
145
+ super().__init__()
146
+ self.embeddings = SigLIPEmbedding(cfg)
147
+ self.encoder = SigLIPTransformerEncoder(cfg)
148
+ self.post_layernorm = nn.LayerNorm(cfg.hidden_size, eps=cfg.norm_eps)
149
+
150
+ def forward(self, x: torch.Tensor):
151
+ img_embed = self.embeddings(x)
152
+ output = self.encoder(img_embed)
153
+ output = self.post_layernorm(output)
154
+ return output
155
+
156
+
157
+
158
+ class SigLIPVisionTower(nn.Module):
159
+ def __init__(self, cfg: SigLIPConfig):
160
+ super().__init__()
161
+ self.cfg = cfg
162
+ self.vision_model = SigLIPModel(cfg)
163
+
164
+ def forward(self, x: torch.Tensor):
165
+ return self.vision_model(x)
166
+
167
+
168
+
pretrained/colpaligemma-3b-mix-448-base/adapter_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:caed65068cae6d50e572d984914324a7d8a9360cdd7f4263ea82f1792614391f
3
+ size 78625112
pretrained/colpaligemma-3b-mix-448-base/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:012239f7d70c76d7f85bfca5e23f6afcde455f9ed23fa3f2ec9057b6028f6a5b
3
+ size 1047
pretrained/colpaligemma-3b-mix-448-base/model-00001-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c128f5670d7a66942a194be6e2d324dc329c0de19e99c6f047513878e14f988e
3
+ size 4986817288
pretrained/colpaligemma-3b-mix-448-base/model-00002-of-00002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8352c38e4d1785c4a35547d13f4d8d5562faab6fe8e9a30b1f5d8039d355a409
3
+ size 862495528
pretrained/colpaligemma-3b-mix-448-base/preprocessor_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5fc342baea95529a5eb9746a0232fb88941d759812d7b616c382f2f87ba6123f
3
+ size 700
pretrained/colpaligemma-3b-mix-448-base/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffd310e50986db7a039948ab83441d612689e7f989198e31b5c8984ca458adf6
3
+ size 17763459
pretrained/colpaligemma-3b-mix-448-base/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8986bb4f423f07f8c7f70d0dbe3526fb2316056c17bae71b1ea975e77a168fc6
3
+ size 4264023
pretrained/colpaligemma-3b-mix-448-base/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d5e95b5ab863693113e65e4899e1db28c09d892fa84243c7dfe6ce7f727f1888
3
+ size 242696
prompt_templates.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from llama_index.core.question_gen.types import SubQuestion
3
+ from llama_index.core.tools.types import ToolMetadata
4
+ from llama_index.core.question_gen.prompts import build_tools_text
5
+
6
+ PREFIX = """\
7
+ Given a user question, and a list of tools, output a list of relevant sub-questions \
8
+ in json markdown that when composed can help answer the full user question:
9
+
10
+ """
11
+
12
+ example_query_str = (
13
+ "Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021"
14
+ )
15
+ example_tools = [
16
+ ToolMetadata(
17
+ name="uber_10k",
18
+ description="Provides information about Uber financials for year 2021",
19
+ ),
20
+ ToolMetadata(
21
+ name="lyft_10k",
22
+ description="Provides information about Lyft financials for year 2021",
23
+ ),
24
+ ]
25
+ example_tools_str = build_tools_text(example_tools)
26
+ example_output = [
27
+ SubQuestion(
28
+ sub_question="What is the revenue growth of Uber", tool_name="uber_10k"
29
+ ),
30
+ SubQuestion(sub_question="What is the EBITDA of Uber", tool_name="uber_10k"),
31
+ SubQuestion(
32
+ sub_question="What is the revenue growth of Lyft", tool_name="lyft_10k"
33
+ ),
34
+ SubQuestion(sub_question="What is the EBITDA of Lyft", tool_name="lyft_10k"),
35
+ ]
36
+ example_output_str = json.dumps(
37
+ {"items": [x.model_dump() for x in example_output]}, indent=4
38
+ )
39
+
40
+ EXAMPLES = f"""\
41
+ # Example 1
42
+ <Tools>
43
+ ```json
44
+ {example_tools_str}
45
+ ```
46
+
47
+ <User Question>
48
+ {example_query_str}
49
+
50
+
51
+ <Output>
52
+ ```json
53
+ {example_output_str}
54
+ ```
55
+
56
+ """
57
+
58
+ SUFFIX = """\
59
+ # Example 2
60
+ <Tools>
61
+ ```json
62
+ {tools_str}
63
+ ```
64
+
65
+ <User Question>
66
+ {query_str}
67
+
68
+ <Output>
69
+ """
70
+
71
+ DEFAULT_SUB_QUESTION_PROMPT_TMPL = PREFIX + EXAMPLES + SUFFIX
72
+
73
+ DEFAULT_GEN_PROMPT_TMPL = """\
74
+ You are a helpful assistant that generates multiple search queries based on a \
75
+ single input query. Generate {num_queries} search queries, one on each line, \
76
+ related to the following input query:
77
+ Query: {query}
78
+ Queries:
79
+ """
80
+
81
+ DEFAULT_FINAL_ANSWER_PROMPT_TMPL = """\
82
+ Context information is below.
83
+ ---------------------
84
+ {context_str}
85
+ ---------------------
86
+ Given the context information and not prior knowledge, answer the query.
87
+ Query: {query_str}
88
+ Answer: \
89
+ """
90
+
91
+
92
+ SYNTHESIZE_PROMPT = """\
93
+ Context information is below.
94
+ ---------------------
95
+ {context_str}
96
+ ---------------------
97
+ Given the information from multiple sources and not prior knowledge,
98
+ Summarize the information that are most relevant to the queries and return index of choices chosen to summarize.
99
+
100
+ Query: {query_str}\n
101
+ """
102
+
103
+
104
+ SYNTHESIZE_OUTPUT_FORMAT = """Return the output that conforms to the JSON schema below.
105
+ Here is the output schema.
106
+
107
+ {
108
+ "properties": {
109
+ "summarized_text": {
110
+ "title": "Summarized Text",
111
+ "type": "string"
112
+ },
113
+ "choices": {
114
+ "items": {
115
+ "type": "integer"
116
+ },
117
+ "title": "Choices",
118
+ "type": "array"
119
+ }
120
+ },
121
+ "required": [
122
+ "summarized_text",
123
+ "choices"
124
+ ],
125
+ "title": "SummarizeAnswer",
126
+ "type": "object"
127
+ }
128
+
129
+ Answer: \
130
+ """.replace("{", "{{").replace("}", "}}")
131
+
132
+ DEFAULT_SYNTHESIZE_PROMPT_TMPL = SYNTHESIZE_PROMPT + SYNTHESIZE_OUTPUT_FORMAT
rag_pipeline.py ADDED
@@ -0,0 +1,531 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import asyncio
3
+ from torch.utils.data import DataLoader
4
+ import os
5
+ import uuid
6
+ import base64
7
+ from io import BytesIO
8
+ from PIL import Image
9
+ from pdf2image import pdf2image
10
+ from typing import List, Union
11
+ from tqdm.auto import tqdm
12
+
13
+ from utils import *
14
+ from models import ColPali, ColPaliProcessor, get_lora_model, enable_lora
15
+
16
+ import qdrant_client
17
+ from qdrant_client.http import models as rest
18
+ from llamaindex_utils import ColPaliGemmaEmbedding, ColPaliRetriever, CustomFusionRetriever, CustomQueryEngine
19
+ from llama_index.llms.gemini import Gemini
20
+ from llama_index.core.tools import RetrieverTool
21
+
22
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
23
+
24
+ def embed_imgs(model: ColPali,
25
+ processor: ColPaliProcessor,
26
+ input_imgs: List[Image.Image],
27
+ device: str = 'cpu') -> List[torch.Tensor]:
28
+ """Generates embeddings given images.
29
+
30
+ Args:
31
+ model (ColPali): Main model
32
+ processor (ColPaliProcessor): Data Processor
33
+ input_imgs (List[Image.Image]): List of input images
34
+ device (str, optional): device to run model. Defaults to 'cpu'.
35
+
36
+ Returns:
37
+ List[torch.Tensor]: List of output embedings.
38
+ """
39
+
40
+ colpali_model = model.to(device=device).eval()
41
+
42
+ dataloader = DataLoader(input_imgs,
43
+ batch_size=8,
44
+ shuffle=False,
45
+ num_workers=0,
46
+ collate_fn=lambda x: processor.process_images(x))
47
+
48
+ document_embeddings = []
49
+ with torch.no_grad():
50
+ for batch, model_inputs in tqdm(enumerate(dataloader)):
51
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
52
+ # Encode images
53
+ img_embeds = colpali_model(**model_inputs, kv_cache=None)
54
+ document_embeddings.extend(list(torch.unbind(img_embeds.to('cpu').to(torch.float32))))
55
+ return document_embeddings
56
+
57
+ def embed_queries(model: ColPali,
58
+ processor: ColPaliProcessor,
59
+ queries: List[str],
60
+ device: str = 'cpu') -> List[torch.Tensor]:
61
+ """Generate embeddings given queries.
62
+
63
+ Args:
64
+ model (ColPali): Embedding model
65
+ processor (ColPaliProcessor): Data Processor
66
+ queries (List[str]): List of query strings
67
+ device (str, optional): Device to run model. Defaults to 'cpu'.
68
+
69
+ Returns:
70
+ List[torch.Tensor]: List of embeddings
71
+ """
72
+ colpali_model = model.to(device=device).eval()
73
+
74
+ dataloader = DataLoader(queries,
75
+ batch_size=8,
76
+ shuffle=False,
77
+ num_workers=0,
78
+ collate_fn=lambda x: processor.process_queries(x))
79
+
80
+ queries_embeddings = []
81
+ with torch.no_grad():
82
+ for batch, model_inputs in tqdm(enumerate(dataloader)):
83
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
84
+ # Encode Queries
85
+ query_embeds = colpali_model(**model_inputs, kv_cache=None)
86
+ queries_embeddings.extend(torch.unbind(query_embeds.to('cpu').type(torch.float32)))
87
+
88
+ return queries_embeddings
89
+
90
+
91
+ def score_single_vectors(qs: List[torch.Tensor],
92
+ ps: List[torch.Tensor]) -> torch.FloatTensor:
93
+ """Calculate similarity between 2 single vectors
94
+
95
+ Args:
96
+ qs (List[torch.Tensor]): First Embeddings
97
+ ps (List[torch.Tensor]): Second Embeddings
98
+
99
+ Returns:
100
+ torch.FloatTensor: Score Tensor
101
+ """
102
+ assert len(qs) != 0 and len(ps) != 0
103
+
104
+ qs_stacked = torch.stack(qs)
105
+ ps_stacked = torch.stack(ps)
106
+
107
+ scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked)
108
+ assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
109
+ scores = scores.to(torch.float32)
110
+ return scores
111
+
112
+ def score_multi_vectors(qs: List[torch.Tensor],
113
+ ps: List[torch.Tensor],
114
+ batch_size: int = 8,
115
+ device: Union[torch.device|str] = "cpu") -> torch.FloatTensor:
116
+ """Calculate MaxSim between 2 list of vectors.
117
+
118
+ Args:
119
+ qs (List[torch.Tensor]): List of query embeddings
120
+ ps (List[torch.Tensor]): List of document embeddings
121
+ batch_size (int, optional): Batch Size. Defaults to 8.
122
+ device (Union[torch.device | str], optional): Device to cast tensor to. Defaults to "cpu".
123
+
124
+ Returns:
125
+ torch.FloatTensor: Score tensors.
126
+ """
127
+
128
+ assert len(qs) != 0 and len(ps) != 0
129
+ scores_list = []
130
+ for i in range(0, len(qs), batch_size):
131
+ scores_batch = []
132
+ qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i:i+batch_size], batch_first=True, padding_value=0).to(device)
133
+ for j in range(0, len(ps), batch_size):
134
+ ps_batch = torch.nn.utils.rnn.pad_sequence(ps[j:j+batch_size], batch_first=True, padding_value=0).to(device)
135
+ tmp = torch.einsum("abd,ced->acbe", qs_batch, ps_batch).max(dim=-1)[0].sum(dim=2)
136
+ scores_batch.append(tmp)
137
+
138
+ scores_batch = torch.cat(scores_batch, dim=1).cpu()
139
+ scores_list.append(scores_batch)
140
+
141
+ scores = torch.cat(scores_list, dim=0)
142
+ return scores.to(torch.float32)
143
+
144
+ def indexDocument(file_path: str,
145
+ vector_store_client,
146
+ target_collection: str,
147
+ model: nn.Module,
148
+ processor: ColPaliProcessor,
149
+ device: Union[str|torch.device]) -> None:
150
+ """Index document given file_path.
151
+ Each page in document is embedded by ColPaliGemma Model, then insert into Qdrant vector store given target collection.
152
+ Creates taret collection if it is not created in the vector store yet.
153
+
154
+ Args:
155
+ file_path (str): _description_
156
+ vector_store_client (_type_): _description_
157
+ target_collection (str): _description_
158
+ model (nn.Module): _description_
159
+ processor (ColPaliProcessor): _description_
160
+ device (Union[str | torch.device]): _description_
161
+ """
162
+ document_images = []
163
+ document_embeddings = []
164
+ document_images.extend(pdf2image.convert_from_path(file_path))
165
+
166
+ document_embeddings = embed_imgs(model=model,
167
+ processor=processor,
168
+ input_imgs=document_images,
169
+ device=device)
170
+
171
+ # Create Qdrant Collectioon
172
+ if not vector_store_client.collection_exists(collection_name=target_collection):
173
+ # Specify vectors_config
174
+ scalar_quant = rest.ScalarQuantizationConfig(
175
+ type=rest.ScalarType.INT8,
176
+ quantile=0.99,
177
+ always_ram=False
178
+ )
179
+ vector_params = rest.VectorParams(
180
+ size=128,
181
+ distance=rest.Distance.COSINE,
182
+ multivector_config=rest.MultiVectorConfig(
183
+ comparator=rest.MultiVectorComparator.MAX_SIM
184
+ ),
185
+ quantization_config=rest.ScalarQuantization(
186
+ scalar=scalar_quant
187
+ ),
188
+ )
189
+ vector_store_client.create_collection(
190
+ collection_name=target_collection,
191
+ on_disk_payload=True,
192
+ optimizers_config=rest.OptimizersConfigDiff(
193
+ indexing_threshold=100
194
+ ),
195
+ vectors_config=vector_params
196
+ )
197
+
198
+ # Add embedding to Qdrant Collection
199
+ points = []
200
+ for i, embedding in enumerate(document_embeddings):
201
+ multivector = embedding.cpu().float().numpy().tolist()
202
+
203
+ buffer = BytesIO()
204
+ document_images[i].save(buffer, format='JPEG')
205
+ image_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
206
+ # Define payload
207
+ payload = {}
208
+ node_metadata = {"file_name": file_path,
209
+ "page_id": i + 1}
210
+ node_content = {'id_': abs(hash(file_path + str(i + 1))),
211
+ 'image': image_str,
212
+ "metadata": node_metadata}
213
+
214
+ payload["_node_content"] = json.dumps(node_content)
215
+ payload["_node_type"] = "ImageNode"
216
+
217
+ # store ref doc id at top level to allow metadata filtering
218
+ # kept for backwards compatibility, will consolidate in future
219
+ payload["document_id"] = "None" # for Chroma
220
+ payload["doc_id"] = "None" # for Pinecone, Qdrant, Redis
221
+ payload["ref_doc_id"] = "None" # for Weaviate
222
+
223
+ points.append(rest.PointStruct(
224
+ id=node_content['id_'],
225
+ vector=multivector,
226
+ payload=payload,
227
+ ))
228
+
229
+ step = 8
230
+ for i in range(0, len(points), step):
231
+ points_batch = points[i: i + step]
232
+ vector_store_client.upsert(collection_name=target_collection,
233
+ points=points_batch,
234
+ wait=False)
235
+
236
+
237
+ async def async_indexDocument(file_path: str,
238
+ vector_store_client: qdrant_client.AsyncQdrantClient,
239
+ target_collection: str,
240
+ model: nn.Module,
241
+ processor: ColPaliProcessor,
242
+ device: Union[str|torch.device]) -> None:
243
+ """Asynchrously index document given file_path.
244
+ Each page in document is embedded by ColPaliGemma Model, then insert into Qdrant vector store given target collection.
245
+ Creates taret collection if it is not created in the vector store yet.
246
+
247
+ Args:
248
+ file_path (str): _description_
249
+ vector_store_client (_type_): _description_
250
+ target_collection (str): _description_
251
+ model (nn.Module): _description_
252
+ processor (ColPaliProcessor): _description_
253
+ device (Union[str | torch.device]): _description_
254
+ """
255
+ document_images = []
256
+ document_embeddings = []
257
+ document_images.extend(pdf2image.convert_from_path(file_path))
258
+
259
+ document_embeddings = embed_imgs(model=model,
260
+ processor=processor,
261
+ input_imgs=document_images,
262
+ device=device)
263
+
264
+ # Create Qdrant Collectioon
265
+ if not await vector_store_client.collection_exists(collection_name=target_collection):
266
+ # Specify vectors_config
267
+ scalar_quant = rest.ScalarQuantizationConfig(
268
+ type=rest.ScalarType.INT8,
269
+ quantile=0.99,
270
+ always_ram=False
271
+ )
272
+ vector_params = rest.VectorParams(
273
+ size=128,
274
+ distance=rest.Distance.COSINE,
275
+ multivector_config=rest.MultiVectorConfig(
276
+ comparator=rest.MultiVectorComparator.MAX_SIM
277
+ ),
278
+ quantization_config=rest.ScalarQuantization(
279
+ scalar=scalar_quant
280
+ ),
281
+ )
282
+ await vector_store_client.create_collection(
283
+ collection_name=target_collection,
284
+ on_disk_payload=True,
285
+ optimizers_config=rest.OptimizersConfigDiff(
286
+ indexing_threshold=100
287
+ ),
288
+ vectors_config=vector_params
289
+ )
290
+
291
+ # Add embedding to Qdrant Collection
292
+ points = []
293
+ for i, embedding in enumerate(document_embeddings):
294
+ multivector = embedding.cpu().float().numpy().tolist()
295
+
296
+ buffer = BytesIO()
297
+ document_images[i].save(buffer, format='JPEG')
298
+ image_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
299
+ # Define payload
300
+ payload = {}
301
+ node_metadata = {"file_name": file_path,
302
+ "page_id": i + 1}
303
+ node_content = {'id_': abs(hash(file_path + str(i + 1))),
304
+ 'image': image_str,
305
+ "metadata": node_metadata}
306
+
307
+ payload["_node_content"] = json.dumps(node_content)
308
+ payload["_node_type"] = "ImageNode"
309
+
310
+ # store ref doc id at top level to allow metadata filtering
311
+ # kept for backwards compatibility, will consolidate in future
312
+ payload["document_id"] = "None" # for Chroma
313
+ payload["doc_id"] = "None" # for Pinecone, Qdrant, Redis
314
+ payload["ref_doc_id"] = "None" # for Weaviate
315
+
316
+ points.append(rest.PointStruct(
317
+ id=node_content['id_'],
318
+ vector=multivector,
319
+ payload=payload,
320
+ ))
321
+
322
+ step = 8
323
+ for i in range(0, len(points), step):
324
+ points_batch = points[i: i + step]
325
+ await vector_store_client.upsert(collection_name=target_collection,
326
+ points=points_batch,
327
+ wait=False)
328
+
329
+
330
+ GEMINI_API_KEY = os.getenv(key="GEMINI_API_KEY")
331
+
332
+ def main():
333
+ model = ColPali.from_pretrained(model_dir='./pretrained/colpaligemma-3b-mix-448-base', torch_dtype=torch.bfloat16)
334
+ tokenizer = load_tokenizer(tokenizer_dir='./pretrained/colpaligemma-3b-mix-448-base')
335
+ processor = ColPaliProcessor(tokenizer=tokenizer).from_pretrained(pretrained_dir='./pretrained/colpaligemma-3b-mix-448-base')
336
+
337
+ model.model.language_model.model = get_lora_model(model.model.language_model.model,
338
+ rank=32,
339
+ alphas=32,
340
+ lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'],
341
+ training=False,
342
+ dropout_p=0.1,
343
+ torch_dtype=torch.bfloat16)
344
+ model.model.language_model.model = enable_lora(model.model.language_model.model, lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'], enabled=True)
345
+
346
+ model = get_lora_model(model,
347
+ rank=32,
348
+ alphas=32,
349
+ lora_modules=['custom_text_proj'],
350
+ training=False,
351
+ dropout_p=0.1,
352
+ torch_dtype=torch.bfloat16)
353
+ model = enable_lora(model, lora_modules=['custom_text_proj'], enabled=True)
354
+
355
+ model.load_lora('./pretrained/colpaligemma-3b-mix-448-base')
356
+
357
+ # Initialize LLM
358
+ generation_config = {
359
+ "temperature": 0.0,
360
+ "top_p": 0.95,
361
+ "top_k": 64,
362
+ "max_output_tokens": 1024,
363
+ "response_mime_type": "text/plain",
364
+ }
365
+
366
+ llm = Gemini(api_key=GEMINI_API_KEY, generation_config=generation_config)
367
+
368
+ # Setup Qdrant
369
+ # Creating Qdrant Client
370
+ vector_store_client = qdrant_client.QdrantClient(location="http://localhost:6333", timeout=100)
371
+
372
+ indexDocument('./data/pdfs-financial/Alphabet_Inc_goog-10-q-q1-2024.pdf',
373
+ vector_store_client=vector_store_client,
374
+ target_collection="Alphabet",
375
+ model=model,
376
+ processor=processor,
377
+ device='mps')
378
+
379
+ indexDocument('./data/pdfs-financial/Nvidia_ecefb2b2-efcb-45f3-b72b-212d90fcd873.pdf',
380
+ vector_store_client=vector_store_client,
381
+ target_collection="Nvidia",
382
+ model=model,
383
+ processor=processor,
384
+ device='mps')
385
+
386
+ # RAG using LLamaIndex
387
+
388
+ embed_model = ColPaliGemmaEmbedding(model=model, processor=processor, device="mps")
389
+
390
+ alphabet_retriever = ColPaliRetriever(vector_store_client=vector_store_client,
391
+ target_collection="Alphabet",
392
+ embed_model=embed_model,
393
+ query_mode='default',
394
+ similarity_top_k=3)
395
+
396
+ nvidia_retriever = ColPaliRetriever(vector_store_client=vector_store_client,
397
+ target_collection="Nvidia",
398
+ embed_model=embed_model,
399
+ query_mode='default',
400
+ similarity_top_k=3)
401
+
402
+ # Query Router Among Multiple Retrievers
403
+ retriever_tools = [
404
+ RetrieverTool.from_defaults(
405
+ name="alphabet",
406
+ retriever=alphabet_retriever,
407
+ description="Useful for retrieving information about Alphabet Inc financials"
408
+ ),
409
+ RetrieverTool.from_defaults(
410
+ name="nvidia",
411
+ retriever=nvidia_retriever,
412
+ description="Useful for retrieving information about Nvidia financials"
413
+ )
414
+ ]
415
+
416
+ retriever_mappings = {retriever_tool.metadata.name: retriever_tool.retriever for retriever_tool in retriever_tools}
417
+
418
+ fusion_retriever = CustomFusionRetriever(llm=llm,
419
+ retriever_mappings=retriever_mappings,
420
+ num_generated_queries=3,
421
+ similarity_top_k=3)
422
+
423
+ query_engine = CustomQueryEngine(retriever_tools=[retriever_tool.metadata for retriever_tool in retriever_tools],
424
+ fusion_retriever=fusion_retriever,
425
+ llm=llm,
426
+ num_children=3)
427
+
428
+ query_str = "Compare the net income between Nvidia and Alphabet"
429
+ response = query_engine.query(query_str=query_str)
430
+ print(response.response)
431
+
432
+ async def amain():
433
+ model = ColPali.from_pretrained(model_dir='./pretrained/colpaligemma-3b-mix-448-base', torch_dtype=torch.bfloat16)
434
+ tokenizer = load_tokenizer(tokenizer_dir='./pretrained/colpaligemma-3b-mix-448-base')
435
+ processor = ColPaliProcessor(tokenizer=tokenizer).from_pretrained(pretrained_dir='./pretrained/colpaligemma-3b-mix-448-base')
436
+
437
+ model.model.language_model.model = get_lora_model(model.model.language_model.model,
438
+ rank=32,
439
+ alphas=32,
440
+ lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'],
441
+ training=False,
442
+ dropout_p=0.1,
443
+ torch_dtype=torch.bfloat16)
444
+ model.model.language_model.model = enable_lora(model.model.language_model.model, lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'], enabled=True)
445
+
446
+ model = get_lora_model(model,
447
+ rank=32,
448
+ alphas=32,
449
+ lora_modules=['custom_text_proj'],
450
+ training=False,
451
+ dropout_p=0.1,
452
+ torch_dtype=torch.bfloat16)
453
+ model = enable_lora(model, lora_modules=['custom_text_proj'], enabled=True)
454
+
455
+ model.load_lora('./pretrained/colpaligemma-3b-mix-448-base')
456
+
457
+ # Initialize LLM
458
+ generation_config = {
459
+ "temperature": 0.0,
460
+ "top_p": 0.95,
461
+ "top_k": 64,
462
+ "max_output_tokens": 1024,
463
+ "response_mime_type": "text/plain",
464
+ }
465
+
466
+ llm = Gemini(api_key=GEMINI_API_KEY, generation_config=generation_config)
467
+
468
+ # Setup Qdrant
469
+ # Creating Qdrant Client
470
+ vector_store_client = qdrant_client.AsyncQdrantClient(location="http://localhost:6333", timeout=100)
471
+
472
+ await async_indexDocument('./data/pdfs-financial/Alphabet_Inc_goog-10-q-q1-2024.pdf',
473
+ vector_store_client=vector_store_client,
474
+ target_collection="Alphabet",
475
+ model=model,
476
+ processor=processor,
477
+ device='mps')
478
+
479
+ await async_indexDocument('./data/pdfs-financial/Nvidia_ecefb2b2-efcb-45f3-b72b-212d90fcd873.pdf',
480
+ vector_store_client=vector_store_client,
481
+ target_collection="Nvidia",
482
+ model=model,
483
+ processor=processor,
484
+ device='mps')
485
+
486
+ embed_model = ColPaliGemmaEmbedding(model=model, processor=processor, device="mps")
487
+
488
+ alphabet_retriever = ColPaliRetriever(vector_store_client=vector_store_client,
489
+ target_collection="Alphabet",
490
+ embed_model=embed_model,
491
+ query_mode='default',
492
+ similarity_top_k=3)
493
+
494
+ nvidia_retriever = ColPaliRetriever(vector_store_client=vector_store_client,
495
+ target_collection="Nvidia",
496
+ embed_model=embed_model,
497
+ query_mode='default',
498
+ similarity_top_k=3)
499
+
500
+
501
+ # Query Router Among Multiple Retrievers
502
+ retriever_tools = [
503
+ RetrieverTool.from_defaults(
504
+ name="alphabet",
505
+ retriever=alphabet_retriever,
506
+ description="Useful for retrieving information about Alphabet Inc financials"
507
+ ),
508
+ RetrieverTool.from_defaults(
509
+ name="nvidia",
510
+ retriever=nvidia_retriever,
511
+ description="Useful for retrieving information about Nvidia financials"
512
+ )
513
+ ]
514
+
515
+ retriever_mappings = {retriever_tool.metadata.name: retriever_tool.retriever for retriever_tool in retriever_tools}
516
+
517
+ fusion_retriever = CustomFusionRetriever(llm=llm,
518
+ retriever_mappings=retriever_mappings,
519
+ similarity_top_k=3)
520
+
521
+ query_engine = CustomQueryEngine(retriever_tools=[retriever_tool.metadata for retriever_tool in retriever_tools],
522
+ fusion_retriever=fusion_retriever,
523
+ llm=llm,
524
+ num_children=3)
525
+
526
+ query_str = "Compare the net income between Nvidia and Alphabet"
527
+ response = await query_engine.aquery(query_str=query_str)
528
+ print(str(response))
529
+
530
+ if __name__ == "__main__":
531
+ main()
requirements.txt ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.1.0
2
+ aiofiles==23.2.1
3
+ aiohappyeyeballs==2.4.3
4
+ aiohttp==3.10.10
5
+ aiosignal==1.3.1
6
+ annotated-types==0.7.0
7
+ anyio==4.6.2.post1
8
+ appnope==0.1.4
9
+ argon2-cffi==23.1.0
10
+ argon2-cffi-bindings==21.2.0
11
+ arrow==1.3.0
12
+ asttokens==2.4.1
13
+ async-lru==2.0.4
14
+ attrs==24.2.0
15
+ babel==2.16.0
16
+ beautifulsoup4==4.12.3
17
+ bleach==6.2.0
18
+ cachetools==5.5.0
19
+ certifi==2024.8.30
20
+ cffi==1.17.1
21
+ charset-normalizer==3.4.0
22
+ click==8.1.7
23
+ comm==0.2.2
24
+ contourpy==1.3.0
25
+ cycler==0.12.1
26
+ dataclasses-json==0.6.7
27
+ datasets==3.0.1
28
+ debugpy==1.8.7
29
+ decorator==5.1.1
30
+ defusedxml==0.7.1
31
+ Deprecated==1.2.14
32
+ dill==0.3.8
33
+ dirtyjson==1.0.8
34
+ distro==1.9.0
35
+ executing==2.1.0
36
+ fastapi==0.115.4
37
+ fastjsonschema==2.20.0
38
+ ffmpy==0.4.0
39
+ filelock==3.16.1
40
+ fonttools==4.54.1
41
+ fqdn==1.5.1
42
+ frozenlist==1.5.0
43
+ fsspec==2024.6.1
44
+ google-ai-generativelanguage==0.6.4
45
+ google-api-core==2.20.0
46
+ google-api-python-client==2.147.0
47
+ google-auth==2.35.0
48
+ google-auth-httplib2==0.2.0
49
+ google-generativeai==0.5.4
50
+ googleapis-common-protos==1.65.0
51
+ gradio==4.44.1
52
+ gradio_client==1.3.0
53
+ greenlet==3.1.1
54
+ grpcio==1.67.1
55
+ grpcio-status==1.62.3
56
+ grpcio-tools==1.62.3
57
+ h11==0.14.0
58
+ h2==4.1.0
59
+ hpack==4.0.0
60
+ httpcore==1.0.6
61
+ httplib2==0.22.0
62
+ httpx==0.27.2
63
+ huggingface-hub==0.26.2
64
+ hyperframe==6.0.1
65
+ idna==3.10
66
+ importlib_resources==6.4.5
67
+ InstructorEmbedding==1.0.1
68
+ ipykernel==6.29.5
69
+ ipython==8.29.0
70
+ isoduration==20.11.0
71
+ jedi==0.19.1
72
+ Jinja2==3.1.4
73
+ jiter==0.7.0
74
+ joblib==1.4.2
75
+ json5==0.9.25
76
+ jsonpointer==3.0.0
77
+ jsonschema==4.23.0
78
+ jsonschema-specifications==2024.10.1
79
+ jupyter_client==8.6.3
80
+ jupyter_core==5.7.2
81
+ jupyter-events==0.10.0
82
+ jupyter-lsp==2.2.5
83
+ jupyter_server==2.14.2
84
+ jupyter_server_terminals==0.5.3
85
+ jupyterlab==4.2.5
86
+ jupyterlab_pygments==0.3.0
87
+ jupyterlab_server==2.27.3
88
+ kiwisolver==1.4.7
89
+ llama-cloud==0.1.2
90
+ llama-index==0.11.17
91
+ llama-index-agent-openai==0.3.4
92
+ llama-index-cli==0.3.1
93
+ llama-index-core==0.11.17
94
+ llama-index-embeddings-huggingface==0.3.1
95
+ llama-index-embeddings-instructor==0.2.1
96
+ llama-index-embeddings-openai==0.2.5
97
+ llama-index-indices-managed-llama-cloud==0.4.0
98
+ llama-index-legacy==0.9.48.post3
99
+ llama-index-llms-gemini==0.3.7
100
+ llama-index-llms-openai==0.2.13
101
+ llama-index-multi-modal-llms-gemini==0.3.1
102
+ llama-index-multi-modal-llms-openai==0.2.2
103
+ llama-index-postprocessor-colbert-rerank==0.2.1
104
+ llama-index-program-openai==0.2.0
105
+ llama-index-question-gen-openai==0.2.0
106
+ llama-index-readers-file==0.2.2
107
+ llama-index-readers-llama-parse==0.3.0
108
+ llama-index-vector-stores-qdrant==0.3.1
109
+ llama-parse==0.5.7
110
+ markdown-it-py==3.0.0
111
+ MarkupSafe==2.1.5
112
+ marshmallow==3.23.1
113
+ matplotlib==3.9.2
114
+ matplotlib-inline==0.1.7
115
+ mdurl==0.1.2
116
+ mistune==3.0.2
117
+ mpmath==1.3.0
118
+ multidict==6.1.0
119
+ multiprocess==0.70.16
120
+ mypy-extensions==1.0.0
121
+ nbclient==0.10.0
122
+ nbconvert==7.16.4
123
+ nbformat==5.10.4
124
+ nest-asyncio==1.6.0
125
+ networkx==3.4.2
126
+ nltk==3.9.1
127
+ notebook==7.2.2
128
+ notebook_shim==0.2.4
129
+ numpy==1.26.4
130
+ openai==1.53.0
131
+ orjson==3.10.11
132
+ overrides==7.7.0
133
+ packaging==24.1
134
+ pandas==2.2.3
135
+ pandocfilters==1.5.1
136
+ parso==0.8.4
137
+ pdf2image==1.17.0
138
+ peft==0.11.1
139
+ pexpect==4.9.0
140
+ pillow==10.4.0
141
+ pip==24.2
142
+ platformdirs==4.3.6
143
+ portalocker==2.10.1
144
+ prometheus_client==0.21.0
145
+ prompt_toolkit==3.0.48
146
+ propcache==0.2.0
147
+ proto-plus==1.24.0
148
+ protobuf==4.25.5
149
+ psutil==6.0.0
150
+ ptyprocess==0.7.0
151
+ pure_eval==0.2.3
152
+ pyarrow==17.0.0
153
+ pyasn1==0.6.1
154
+ pyasn1_modules==0.4.1
155
+ pycparser==2.22
156
+ pydantic==2.9.2
157
+ pydantic_core==2.23.4
158
+ pydub==0.25.1
159
+ Pygments==2.18.0
160
+ pyparsing==3.1.4
161
+ pypdf==4.3.1
162
+ python-dateutil==2.9.0.post0
163
+ python-json-logger==2.0.7
164
+ python-multipart==0.0.12
165
+ pytz==2024.2
166
+ PyYAML==6.0.2
167
+ pyzmq==26.2.0
168
+ qdrant-client==1.12.0
169
+ referencing==0.35.1
170
+ regex==2024.9.11
171
+ requests==2.32.3
172
+ rfc3339-validator==0.1.4
173
+ rfc3986-validator==0.1.1
174
+ rich==13.9.4
175
+ rpds-py==0.20.1
176
+ rsa==4.9
177
+ ruff==0.7.2
178
+ safetensors==0.4.5
179
+ scikit-learn==1.5.2
180
+ scipy==1.14.1
181
+ semantic-version==2.10.0
182
+ Send2Trash==1.8.3
183
+ sentence-transformers==2.7.0
184
+ setuptools==75.1.0
185
+ shellingham==1.5.4
186
+ six==1.16.0
187
+ sniffio==1.3.1
188
+ soupsieve==2.6
189
+ SQLAlchemy==2.0.36
190
+ stack-data==0.6.3
191
+ starlette==0.41.2
192
+ striprtf==0.0.26
193
+ sympy==1.13.3
194
+ tenacity==8.5.0
195
+ terminado==0.18.1
196
+ threadpoolctl==3.5.0
197
+ tiktoken==0.8.0
198
+ tinycss2==1.4.0
199
+ tokenizers==0.20.1
200
+ tomlkit==0.12.0
201
+ torch==2.4.1
202
+ torchinfo==1.8.0
203
+ torchvision==0.19.1
204
+ tornado==6.4.1
205
+ tqdm==4.66.5
206
+ traitlets==5.14.3
207
+ transformers==4.45.1
208
+ typer==0.12.5
209
+ types-python-dateutil==2.9.0.20241003
210
+ typing_extensions==4.12.2
211
+ typing-inspect==0.9.0
212
+ tzdata==2024.2
213
+ uri-template==1.3.0
214
+ uritemplate==4.1.1
215
+ urllib3==2.2.3
216
+ uvicorn==0.32.0
217
+ wcwidth==0.2.13
218
+ webcolors==24.8.0
219
+ webencodings==0.5.1
220
+ websocket-client==1.8.0
221
+ websockets==12.0
222
+ wheel==0.44.0
223
+ wrapt==1.16.0
224
+ xxhash==3.5.0
225
+ yarl==1.17.1
utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .utils import *
2
+ IMAGE_TOKEN = "<image>"
utils/utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from typing import Tuple, List
4
+ import numpy as np
5
+ import torch.nn as nn
6
+ import os
7
+ from transformers import AutoTokenizer, GemmaTokenizerFast
8
+ from safetensors import safe_open
9
+ import json
10
+ from pathlib import Path
11
+ from models.paligemma import PaliGemmaConfig, PaliGemma
12
+
13
+
14
+ def load_model(model_dir: str):
15
+
16
+ with open(os.path.join(model_dir, 'config.json'), "r") as f:
17
+ model_config = json.loads(f.read())
18
+ config = PaliGemmaConfig.from_dict(model_config)
19
+
20
+ safetensor_files = Path(model_dir).glob("*.safetensors")
21
+
22
+ weights = {}
23
+ for file in safetensor_files:
24
+ with safe_open(file, framework='pt', device="cpu") as f:
25
+ for key in f.keys():
26
+ weights[key] = f.get_tensor(key)
27
+
28
+ model = PaliGemma(config)
29
+ model.load_state_dict(weights, strict=False)
30
+ model.tie_weights()
31
+
32
+ return model
33
+
34
+
35
+ def load_tokenizer(tokenizer_dir: str):
36
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, padding_side='right')
37
+ return tokenizer
38
+
39
+
40
+ def freeze_model(model: nn.Module):
41
+ for param in model.parameters():
42
+ param.requires_grad = False
43
+
44
+ return model