Spaces:
Runtime error
Runtime error
Huy
commited on
Commit
·
d8bb2be
1
Parent(s):
c20735c
First commit
Browse files- .gitattributes +1 -0
- .gitignore +4 -0
- README.md +0 -13
- app.py +245 -0
- env.yaml +241 -0
- llamaindex_utils.py +558 -0
- models/__init__.py +5 -0
- models/colpali.py +89 -0
- models/colpali_processor.py +89 -0
- models/gemma.py +285 -0
- models/lora.py +68 -0
- models/paligemma.py +162 -0
- models/paligemma_processor.py +103 -0
- models/siglip.py +168 -0
- pretrained/colpaligemma-3b-mix-448-base/adapter_model.safetensors +3 -0
- pretrained/colpaligemma-3b-mix-448-base/config.json +3 -0
- pretrained/colpaligemma-3b-mix-448-base/model-00001-of-00002.safetensors +3 -0
- pretrained/colpaligemma-3b-mix-448-base/model-00002-of-00002.safetensors +3 -0
- pretrained/colpaligemma-3b-mix-448-base/preprocessor_config.json +3 -0
- pretrained/colpaligemma-3b-mix-448-base/tokenizer.json +3 -0
- pretrained/colpaligemma-3b-mix-448-base/tokenizer.model +3 -0
- pretrained/colpaligemma-3b-mix-448-base/tokenizer_config.json +3 -0
- prompt_templates.py +132 -0
- rag_pipeline.py +531 -0
- requirements.txt +225 -0
- utils/__init__.py +2 -0
- utils/utils.py +44 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.json filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
.ipynb_checkpoints/
|
| 3 |
+
env/
|
| 4 |
+
.DS_Store
|
README.md
CHANGED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: RAG ColPali
|
| 3 |
-
emoji: 📊
|
| 4 |
-
colorFrom: yellow
|
| 5 |
-
colorTo: gray
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.4.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: mit
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import base64
|
| 4 |
+
import asyncio
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import qdrant_client
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from typing import List, Dict, Tuple
|
| 10 |
+
|
| 11 |
+
import llamaindex_utils
|
| 12 |
+
from rag_pipeline import async_indexDocument
|
| 13 |
+
from models import get_lora_model, enable_lora, ColPali, ColPaliProcessor
|
| 14 |
+
from utils import load_tokenizer
|
| 15 |
+
|
| 16 |
+
from llama_index.llms.gemini import Gemini
|
| 17 |
+
from llama_index.core.tools import RetrieverTool
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
GEMINI_API_KEY = os.getenv(key="GEMINI_API_KEY")
|
| 21 |
+
QDRANT_API_KEY = os.getenv(key="QDRANT_API_KEY")
|
| 22 |
+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
| 23 |
+
|
| 24 |
+
async def initialize_model() -> Dict:
|
| 25 |
+
"""Initialize models
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
model_dict: Dict: Dictionary stores neccessary models
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
model = ColPali.from_pretrained(model_dir='./pretrained/colpaligemma-3b-mix-448-base', torch_dtype=torch.bfloat16)
|
| 32 |
+
tokenizer = load_tokenizer(tokenizer_dir='./pretrained/colpaligemma-3b-mix-448-base')
|
| 33 |
+
processor = ColPaliProcessor(tokenizer=tokenizer).from_pretrained(pretrained_dir='./pretrained/colpaligemma-3b-mix-448-base')
|
| 34 |
+
|
| 35 |
+
model.model.language_model.model = get_lora_model(model.model.language_model.model,
|
| 36 |
+
rank=32,
|
| 37 |
+
alphas=32,
|
| 38 |
+
lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'],
|
| 39 |
+
training=False,
|
| 40 |
+
dropout_p=0.1,
|
| 41 |
+
torch_dtype=torch.bfloat16)
|
| 42 |
+
model.model.language_model.model = enable_lora(model.model.language_model.model, lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'], enabled=True)
|
| 43 |
+
|
| 44 |
+
model = get_lora_model(model,
|
| 45 |
+
rank=32,
|
| 46 |
+
alphas=32,
|
| 47 |
+
lora_modules=['custom_text_proj'],
|
| 48 |
+
training=False,
|
| 49 |
+
dropout_p=0.1,
|
| 50 |
+
torch_dtype=torch.bfloat16)
|
| 51 |
+
|
| 52 |
+
model = enable_lora(model, lora_modules=['custom_text_proj'], enabled=True)
|
| 53 |
+
|
| 54 |
+
model.load_lora('./pretrained/colpaligemma-3b-mix-448-base')
|
| 55 |
+
|
| 56 |
+
# Initialize LLM
|
| 57 |
+
generation_config = {
|
| 58 |
+
"temperature": 0.0,
|
| 59 |
+
"top_p": 0.95,
|
| 60 |
+
"top_k": 64,
|
| 61 |
+
"max_output_tokens": 1024,
|
| 62 |
+
"response_mime_type": "text/plain",
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
llm = Gemini(api_key=GEMINI_API_KEY, generation_config=generation_config)
|
| 66 |
+
|
| 67 |
+
# Setup Qdrant
|
| 68 |
+
# Creating Qdrant Client
|
| 69 |
+
vector_store_client = qdrant_client.AsyncQdrantClient(location="https://b3878645-ec71-426c-8afa-b8b3b7589e40.us-east4-0.gcp.cloud.qdrant.io",
|
| 70 |
+
api_key=QDRANT_API_KEY,
|
| 71 |
+
timeout=100)
|
| 72 |
+
|
| 73 |
+
embed_model = llamaindex_utils.ColPaliGemmaEmbedding(model=model,
|
| 74 |
+
processor=processor,
|
| 75 |
+
device=device)
|
| 76 |
+
|
| 77 |
+
collections = await get_collection_names(vector_store_client)
|
| 78 |
+
retrievers_dict = {}
|
| 79 |
+
for name in collections:
|
| 80 |
+
if name not in retrievers_dict:
|
| 81 |
+
retrievers_dict[name] = llamaindex_utils.ColPaliRetriever(vector_store_client=vector_store_client,
|
| 82 |
+
target_collection=name,
|
| 83 |
+
embed_model=embed_model,
|
| 84 |
+
similarity_top_k=3)
|
| 85 |
+
return {"llm": llm,
|
| 86 |
+
"vector_store_client": vector_store_client,
|
| 87 |
+
"model": model,
|
| 88 |
+
"processor": processor,
|
| 89 |
+
"embed_model": embed_model,
|
| 90 |
+
"collections": collections,
|
| 91 |
+
"retrievers_dict": retrievers_dict}
|
| 92 |
+
|
| 93 |
+
async def get_collection_names(vector_store_client):
|
| 94 |
+
collections = await vector_store_client.get_collections()
|
| 95 |
+
return [collection.name for collection in collections.collections]
|
| 96 |
+
|
| 97 |
+
async def index(files: List[str],
|
| 98 |
+
target_collection: str
|
| 99 |
+
) -> Tuple[str, gr.Dropdown, List[str], Dict[str, llamaindex_utils.ColPaliRetriever]]:
|
| 100 |
+
"""
|
| 101 |
+
Insert all image pages from files to speicified target collection to the vector store
|
| 102 |
+
and return the mapping from retriever's name to its object instance.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
files (List[str]): List of file path
|
| 106 |
+
target_collection (str): Target collection to insert into the vector store
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
Tuple[str, gr.Dropdown, List[str], Dict[str, llamaindex_utils.ColPaliRetriever]]: Return message, dropdown component, collections' names, dictionary mapping retriever to its object instance
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
for file in files:
|
| 113 |
+
await async_indexDocument(file_path=file,
|
| 114 |
+
vector_store_client=model_dict["vector_store_client"],
|
| 115 |
+
target_collection=target_collection,
|
| 116 |
+
model=model_dict["model"],
|
| 117 |
+
processor=model_dict["processor"],
|
| 118 |
+
device=device)
|
| 119 |
+
|
| 120 |
+
if target_collection not in retrievers:
|
| 121 |
+
retrievers[target_collection] = llamaindex_utils.ColPaliRetriever(vector_store_client=model_dict["vector_store_client"],
|
| 122 |
+
target_collection=target_collection,
|
| 123 |
+
embed_model=model_dict["embed_model"],
|
| 124 |
+
similarity_top_k=3)
|
| 125 |
+
collection_names = await get_collection_names(model_dict["vector_store_client"])
|
| 126 |
+
return (f"Uploaded and index {len(files)} files.",
|
| 127 |
+
gr.Dropdown(choices=collection_names),
|
| 128 |
+
collection_names)
|
| 129 |
+
|
| 130 |
+
async def search_with_llm(query: str,
|
| 131 |
+
similarity_top_k: int,
|
| 132 |
+
num_children: int) -> Tuple[str, List[Image.Image]]:
|
| 133 |
+
"""Search the result given query and list of retrievers.
|
| 134 |
+
Returns the search's response and list of images support for that response.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
query (str): Query question
|
| 138 |
+
retrievers (Dict[str, llamaindex_utils.ColPaliRetriever]): Dictionary mapping between retrievers' names and their object instances
|
| 139 |
+
similarity_top_k (int): top K similarity results retrieved from the retriever
|
| 140 |
+
num_children (int): number of children for tree summarization
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
Tuple[str, List[Image.Image]]: Returns the search's response and list of images support for that response.
|
| 144 |
+
"""
|
| 145 |
+
retriever_tools = [RetrieverTool.from_defaults(
|
| 146 |
+
name=key,
|
| 147 |
+
retriever=value,
|
| 148 |
+
description=f"Useful for retrieving information about {key} financials") for key, value in retrievers.items()]
|
| 149 |
+
|
| 150 |
+
retriever_mappings = {retriever_tool.metadata.name: retriever_tool.retriever for retriever_tool in retriever_tools}
|
| 151 |
+
|
| 152 |
+
fusion_retriever = llamaindex_utils.CustomFusionRetriever(llm=model_dict["llm"],
|
| 153 |
+
retriever_mappings=retriever_mappings,
|
| 154 |
+
similarity_top_k=similarity_top_k)
|
| 155 |
+
|
| 156 |
+
query_engine = llamaindex_utils.CustomQueryEngine(retriever_tools=[retriever_tool.metadata for retriever_tool in retriever_tools],
|
| 157 |
+
fusion_retriever=fusion_retriever,
|
| 158 |
+
llm=model_dict["llm"],
|
| 159 |
+
num_children=num_children)
|
| 160 |
+
response = await query_engine.aquery(query_str=query)
|
| 161 |
+
|
| 162 |
+
return response.response, [Image.open(BytesIO(base64.b64decode(image))) for image in response.source_images]
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def build_gui():
|
| 166 |
+
with gr.Blocks() as demo:
|
| 167 |
+
gr.Markdown("# Image Based RAG System using ColPali 📚🔍")
|
| 168 |
+
with gr.Row(equal_height=True):
|
| 169 |
+
with gr.Column():
|
| 170 |
+
gr.Markdown("## 1️. Upload PDFs")
|
| 171 |
+
files = gr.File(file_types=["pdf"],
|
| 172 |
+
file_count="multiple",
|
| 173 |
+
interactive=True)
|
| 174 |
+
|
| 175 |
+
choices = gr.State(value=model_dict["collections"])
|
| 176 |
+
gr.Markdown("## 2️. Index the PDFs and upload")
|
| 177 |
+
target_collection = gr.Dropdown(choices=choices.value,
|
| 178 |
+
allow_custom_value=True,
|
| 179 |
+
label="Collection name",
|
| 180 |
+
show_label=True,
|
| 181 |
+
interactive=True)
|
| 182 |
+
|
| 183 |
+
message_box = gr.Textbox(value="File not yet uploaded",
|
| 184 |
+
show_label=False,
|
| 185 |
+
interactive=False)
|
| 186 |
+
convert_button = gr.Button("🔄 Convert and upload")
|
| 187 |
+
|
| 188 |
+
# Define the actions for conversion
|
| 189 |
+
convert_button.click(index, inputs=[files, target_collection], outputs=[message_box, target_collection, choices])
|
| 190 |
+
|
| 191 |
+
with gr.Column():
|
| 192 |
+
gr.Markdown("## 3️. Enter your question")
|
| 193 |
+
query = gr.Textbox(placeholder="Enter your query to match",
|
| 194 |
+
lines=15,
|
| 195 |
+
max_lines=20,
|
| 196 |
+
autoscroll=True)
|
| 197 |
+
with gr.Accordion(label="Additional Settings", open=False):
|
| 198 |
+
similarity_top_k = gr.Slider(minimum=1,
|
| 199 |
+
maximum=10,
|
| 200 |
+
value=3,
|
| 201 |
+
step=1.0,
|
| 202 |
+
label="Top K similarity retrieved from the retriever")
|
| 203 |
+
|
| 204 |
+
num_children = gr.Slider(minimum=1,
|
| 205 |
+
maximum=10,
|
| 206 |
+
value=3,
|
| 207 |
+
step=1.0,
|
| 208 |
+
label="Set number of children for Tree Summarization")
|
| 209 |
+
search_button = gr.Button("🔍 Search")
|
| 210 |
+
|
| 211 |
+
gr.Markdown("## 4️. ColPali Retrieval")
|
| 212 |
+
with gr.Row(equal_height=True):
|
| 213 |
+
output_text = gr.Textbox(label="Query result",
|
| 214 |
+
show_label=True,
|
| 215 |
+
placeholder="Response from query",
|
| 216 |
+
lines=8,
|
| 217 |
+
max_lines=20,
|
| 218 |
+
interactive=False)
|
| 219 |
+
output_imgs = gr.Gallery(label="Most relevant images is...",
|
| 220 |
+
show_fullscreen_button=True,
|
| 221 |
+
show_label=True,
|
| 222 |
+
show_download_button=True,
|
| 223 |
+
interactive=False)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# Action for search button
|
| 227 |
+
search_button.click(
|
| 228 |
+
search_with_llm,
|
| 229 |
+
inputs=[query, similarity_top_k, num_children],
|
| 230 |
+
outputs=[output_text, output_imgs])
|
| 231 |
+
return demo
|
| 232 |
+
|
| 233 |
+
async def amain():
|
| 234 |
+
global model_dict, retrievers
|
| 235 |
+
model_dict = await initialize_model()
|
| 236 |
+
retrievers = model_dict["retrievers_dict"]
|
| 237 |
+
|
| 238 |
+
demo = build_gui()
|
| 239 |
+
demo.queue().launch(debug=True, share=False)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
if __name__ == "__main__":
|
| 243 |
+
asyncio.run(amain())
|
| 244 |
+
|
| 245 |
+
|
env.yaml
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
channels:
|
| 3 |
+
- defaults
|
| 4 |
+
dependencies:
|
| 5 |
+
- bzip2=1.0.8=h80987f9_6
|
| 6 |
+
- ca-certificates=2024.7.2=hca03da5_0
|
| 7 |
+
- libffi=3.4.4=hca03da5_1
|
| 8 |
+
- ncurses=6.4=h313beb8_0
|
| 9 |
+
- openssl=3.0.15=h80987f9_0
|
| 10 |
+
- pip=24.2=py311hca03da5_0
|
| 11 |
+
- python=3.11.9=hb885b13_0
|
| 12 |
+
- readline=8.2=h1a28f6b_0
|
| 13 |
+
- setuptools=75.1.0=py311hca03da5_0
|
| 14 |
+
- sqlite=3.45.3=h80987f9_0
|
| 15 |
+
- tk=8.6.14=h6ba3021_0
|
| 16 |
+
- wheel=0.44.0=py311hca03da5_0
|
| 17 |
+
- xz=5.4.6=h80987f9_1
|
| 18 |
+
- zlib=1.2.13=h18a0788_1
|
| 19 |
+
- pip:
|
| 20 |
+
- accelerate==1.1.0
|
| 21 |
+
- aiofiles==23.2.1
|
| 22 |
+
- aiohappyeyeballs==2.4.3
|
| 23 |
+
- aiohttp==3.10.10
|
| 24 |
+
- aiosignal==1.3.1
|
| 25 |
+
- annotated-types==0.7.0
|
| 26 |
+
- anyio==4.6.2.post1
|
| 27 |
+
- appnope==0.1.4
|
| 28 |
+
- argon2-cffi==23.1.0
|
| 29 |
+
- argon2-cffi-bindings==21.2.0
|
| 30 |
+
- arrow==1.3.0
|
| 31 |
+
- asttokens==2.4.1
|
| 32 |
+
- async-lru==2.0.4
|
| 33 |
+
- attrs==24.2.0
|
| 34 |
+
- babel==2.16.0
|
| 35 |
+
- beautifulsoup4==4.12.3
|
| 36 |
+
- bleach==6.2.0
|
| 37 |
+
- cachetools==5.5.0
|
| 38 |
+
- certifi==2024.8.30
|
| 39 |
+
- cffi==1.17.1
|
| 40 |
+
- charset-normalizer==3.4.0
|
| 41 |
+
- click==8.1.7
|
| 42 |
+
- comm==0.2.2
|
| 43 |
+
- contourpy==1.3.0
|
| 44 |
+
- cycler==0.12.1
|
| 45 |
+
- dataclasses-json==0.6.7
|
| 46 |
+
- datasets==3.0.1
|
| 47 |
+
- debugpy==1.8.7
|
| 48 |
+
- decorator==5.1.1
|
| 49 |
+
- defusedxml==0.7.1
|
| 50 |
+
- deprecated==1.2.14
|
| 51 |
+
- dill==0.3.8
|
| 52 |
+
- dirtyjson==1.0.8
|
| 53 |
+
- distro==1.9.0
|
| 54 |
+
- executing==2.1.0
|
| 55 |
+
- fastapi==0.115.4
|
| 56 |
+
- fastjsonschema==2.20.0
|
| 57 |
+
- ffmpy==0.4.0
|
| 58 |
+
- filelock==3.16.1
|
| 59 |
+
- fonttools==4.54.1
|
| 60 |
+
- fqdn==1.5.1
|
| 61 |
+
- frozenlist==1.5.0
|
| 62 |
+
- fsspec==2024.6.1
|
| 63 |
+
- google-ai-generativelanguage==0.6.4
|
| 64 |
+
- google-api-core==2.20.0
|
| 65 |
+
- google-api-python-client==2.147.0
|
| 66 |
+
- google-auth==2.35.0
|
| 67 |
+
- google-auth-httplib2==0.2.0
|
| 68 |
+
- google-generativeai==0.5.4
|
| 69 |
+
- googleapis-common-protos==1.65.0
|
| 70 |
+
- gradio==4.44.1
|
| 71 |
+
- gradio-client==1.3.0
|
| 72 |
+
- greenlet==3.1.1
|
| 73 |
+
- grpcio==1.67.1
|
| 74 |
+
- grpcio-status==1.62.3
|
| 75 |
+
- grpcio-tools==1.62.3
|
| 76 |
+
- h11==0.14.0
|
| 77 |
+
- h2==4.1.0
|
| 78 |
+
- hpack==4.0.0
|
| 79 |
+
- httpcore==1.0.6
|
| 80 |
+
- httplib2==0.22.0
|
| 81 |
+
- httpx==0.27.2
|
| 82 |
+
- huggingface-hub==0.26.2
|
| 83 |
+
- hyperframe==6.0.1
|
| 84 |
+
- idna==3.10
|
| 85 |
+
- importlib-resources==6.4.5
|
| 86 |
+
- instructorembedding==1.0.1
|
| 87 |
+
- ipykernel==6.29.5
|
| 88 |
+
- ipython==8.29.0
|
| 89 |
+
- isoduration==20.11.0
|
| 90 |
+
- jedi==0.19.1
|
| 91 |
+
- jinja2==3.1.4
|
| 92 |
+
- jiter==0.7.0
|
| 93 |
+
- joblib==1.4.2
|
| 94 |
+
- json5==0.9.25
|
| 95 |
+
- jsonpointer==3.0.0
|
| 96 |
+
- jsonschema==4.23.0
|
| 97 |
+
- jsonschema-specifications==2024.10.1
|
| 98 |
+
- jupyter-client==8.6.3
|
| 99 |
+
- jupyter-core==5.7.2
|
| 100 |
+
- jupyter-events==0.10.0
|
| 101 |
+
- jupyter-lsp==2.2.5
|
| 102 |
+
- jupyter-server==2.14.2
|
| 103 |
+
- jupyter-server-terminals==0.5.3
|
| 104 |
+
- jupyterlab==4.2.5
|
| 105 |
+
- jupyterlab-pygments==0.3.0
|
| 106 |
+
- jupyterlab-server==2.27.3
|
| 107 |
+
- kiwisolver==1.4.7
|
| 108 |
+
- llama-cloud==0.1.2
|
| 109 |
+
- llama-index==0.11.17
|
| 110 |
+
- llama-index-agent-openai==0.3.4
|
| 111 |
+
- llama-index-cli==0.3.1
|
| 112 |
+
- llama-index-core==0.11.17
|
| 113 |
+
- llama-index-embeddings-huggingface==0.3.1
|
| 114 |
+
- llama-index-embeddings-instructor==0.2.1
|
| 115 |
+
- llama-index-embeddings-openai==0.2.5
|
| 116 |
+
- llama-index-indices-managed-llama-cloud==0.4.0
|
| 117 |
+
- llama-index-legacy==0.9.48.post3
|
| 118 |
+
- llama-index-llms-gemini==0.3.7
|
| 119 |
+
- llama-index-llms-openai==0.2.13
|
| 120 |
+
- llama-index-multi-modal-llms-gemini==0.3.1
|
| 121 |
+
- llama-index-multi-modal-llms-openai==0.2.2
|
| 122 |
+
- llama-index-postprocessor-colbert-rerank==0.2.1
|
| 123 |
+
- llama-index-program-openai==0.2.0
|
| 124 |
+
- llama-index-question-gen-openai==0.2.0
|
| 125 |
+
- llama-index-readers-file==0.2.2
|
| 126 |
+
- llama-index-readers-llama-parse==0.3.0
|
| 127 |
+
- llama-index-vector-stores-qdrant==0.3.1
|
| 128 |
+
- llama-parse==0.5.7
|
| 129 |
+
- markdown-it-py==3.0.0
|
| 130 |
+
- markupsafe==2.1.5
|
| 131 |
+
- marshmallow==3.23.1
|
| 132 |
+
- matplotlib==3.9.2
|
| 133 |
+
- matplotlib-inline==0.1.7
|
| 134 |
+
- mdurl==0.1.2
|
| 135 |
+
- mistune==3.0.2
|
| 136 |
+
- mpmath==1.3.0
|
| 137 |
+
- multidict==6.1.0
|
| 138 |
+
- multiprocess==0.70.16
|
| 139 |
+
- mypy-extensions==1.0.0
|
| 140 |
+
- nbclient==0.10.0
|
| 141 |
+
- nbconvert==7.16.4
|
| 142 |
+
- nbformat==5.10.4
|
| 143 |
+
- nest-asyncio==1.6.0
|
| 144 |
+
- networkx==3.4.2
|
| 145 |
+
- nltk==3.9.1
|
| 146 |
+
- notebook==7.2.2
|
| 147 |
+
- notebook-shim==0.2.4
|
| 148 |
+
- numpy==1.26.4
|
| 149 |
+
- openai==1.53.0
|
| 150 |
+
- orjson==3.10.11
|
| 151 |
+
- overrides==7.7.0
|
| 152 |
+
- packaging==24.1
|
| 153 |
+
- pandas==2.2.3
|
| 154 |
+
- pandocfilters==1.5.1
|
| 155 |
+
- parso==0.8.4
|
| 156 |
+
- pdf2image==1.17.0
|
| 157 |
+
- peft==0.11.1
|
| 158 |
+
- pexpect==4.9.0
|
| 159 |
+
- pillow==10.4.0
|
| 160 |
+
- platformdirs==4.3.6
|
| 161 |
+
- portalocker==2.10.1
|
| 162 |
+
- prometheus-client==0.21.0
|
| 163 |
+
- prompt-toolkit==3.0.48
|
| 164 |
+
- propcache==0.2.0
|
| 165 |
+
- proto-plus==1.24.0
|
| 166 |
+
- protobuf==4.25.5
|
| 167 |
+
- psutil==6.0.0
|
| 168 |
+
- ptyprocess==0.7.0
|
| 169 |
+
- pure-eval==0.2.3
|
| 170 |
+
- pyarrow==17.0.0
|
| 171 |
+
- pyasn1==0.6.1
|
| 172 |
+
- pyasn1-modules==0.4.1
|
| 173 |
+
- pycparser==2.22
|
| 174 |
+
- pydantic==2.9.2
|
| 175 |
+
- pydantic-core==2.23.4
|
| 176 |
+
- pydub==0.25.1
|
| 177 |
+
- pygments==2.18.0
|
| 178 |
+
- pyparsing==3.1.4
|
| 179 |
+
- pypdf==4.3.1
|
| 180 |
+
- python-dateutil==2.9.0.post0
|
| 181 |
+
- python-json-logger==2.0.7
|
| 182 |
+
- python-multipart==0.0.12
|
| 183 |
+
- pytz==2024.2
|
| 184 |
+
- pyyaml==6.0.2
|
| 185 |
+
- pyzmq==26.2.0
|
| 186 |
+
- qdrant-client==1.12.0
|
| 187 |
+
- referencing==0.35.1
|
| 188 |
+
- regex==2024.9.11
|
| 189 |
+
- requests==2.32.3
|
| 190 |
+
- rfc3339-validator==0.1.4
|
| 191 |
+
- rfc3986-validator==0.1.1
|
| 192 |
+
- rich==13.9.4
|
| 193 |
+
- rpds-py==0.20.1
|
| 194 |
+
- rsa==4.9
|
| 195 |
+
- ruff==0.7.2
|
| 196 |
+
- safetensors==0.4.5
|
| 197 |
+
- scikit-learn==1.5.2
|
| 198 |
+
- scipy==1.14.1
|
| 199 |
+
- semantic-version==2.10.0
|
| 200 |
+
- send2trash==1.8.3
|
| 201 |
+
- sentence-transformers==2.7.0
|
| 202 |
+
- shellingham==1.5.4
|
| 203 |
+
- six==1.16.0
|
| 204 |
+
- sniffio==1.3.1
|
| 205 |
+
- soupsieve==2.6
|
| 206 |
+
- sqlalchemy==2.0.36
|
| 207 |
+
- stack-data==0.6.3
|
| 208 |
+
- starlette==0.41.2
|
| 209 |
+
- striprtf==0.0.26
|
| 210 |
+
- sympy==1.13.3
|
| 211 |
+
- tenacity==8.5.0
|
| 212 |
+
- terminado==0.18.1
|
| 213 |
+
- threadpoolctl==3.5.0
|
| 214 |
+
- tiktoken==0.8.0
|
| 215 |
+
- tinycss2==1.4.0
|
| 216 |
+
- tokenizers==0.20.1
|
| 217 |
+
- tomlkit==0.12.0
|
| 218 |
+
- torch==2.4.1
|
| 219 |
+
- torchinfo==1.8.0
|
| 220 |
+
- torchvision==0.19.1
|
| 221 |
+
- tornado==6.4.1
|
| 222 |
+
- tqdm==4.66.5
|
| 223 |
+
- traitlets==5.14.3
|
| 224 |
+
- transformers==4.45.1
|
| 225 |
+
- typer==0.12.5
|
| 226 |
+
- types-python-dateutil==2.9.0.20241003
|
| 227 |
+
- typing-extensions==4.12.2
|
| 228 |
+
- typing-inspect==0.9.0
|
| 229 |
+
- tzdata==2024.2
|
| 230 |
+
- uri-template==1.3.0
|
| 231 |
+
- uritemplate==4.1.1
|
| 232 |
+
- urllib3==2.2.3
|
| 233 |
+
- uvicorn==0.32.0
|
| 234 |
+
- wcwidth==0.2.13
|
| 235 |
+
- webcolors==24.8.0
|
| 236 |
+
- webencodings==0.5.1
|
| 237 |
+
- websocket-client==1.8.0
|
| 238 |
+
- websockets==12.0
|
| 239 |
+
- wrapt==1.16.0
|
| 240 |
+
- xxhash==3.5.0
|
| 241 |
+
- yarl==1.17.1
|
llamaindex_utils.py
ADDED
|
@@ -0,0 +1,558 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import json
|
| 3 |
+
import asyncio
|
| 4 |
+
import qdrant_client
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from pydantic import PrivateAttr, Field
|
| 7 |
+
from typing import Union, Optional, List, Any, Dict, Set
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
|
| 10 |
+
from llama_index.core.vector_stores.types import VectorStoreQueryResult
|
| 11 |
+
from llama_index.core.vector_stores.utils import (
|
| 12 |
+
legacy_metadata_dict_to_node,
|
| 13 |
+
metadata_dict_to_node,
|
| 14 |
+
)
|
| 15 |
+
from llama_index.core.embeddings import BaseEmbedding
|
| 16 |
+
from llama_index.core.retrievers import BaseRetriever
|
| 17 |
+
from llama_index.core import QueryBundle, PromptTemplate
|
| 18 |
+
from llama_index.core.schema import NodeWithScore, TextNode
|
| 19 |
+
from llama_index.core.llms import LLM
|
| 20 |
+
from llama_index.core.question_gen import LLMQuestionGenerator
|
| 21 |
+
from llama_index.core.tools import ToolMetadata
|
| 22 |
+
from llama_index.core.output_parsers.utils import parse_json_markdown
|
| 23 |
+
from llama_index.core.question_gen.types import SubQuestion
|
| 24 |
+
|
| 25 |
+
from models import ColPali, ColPaliProcessor
|
| 26 |
+
from prompt_templates import (DEFAULT_GEN_PROMPT_TMPL,
|
| 27 |
+
DEFAULT_FINAL_ANSWER_PROMPT_TMPL,
|
| 28 |
+
DEFAULT_SUB_QUESTION_PROMPT_TMPL,
|
| 29 |
+
DEFAULT_SYNTHESIZE_PROMPT_TMPL)
|
| 30 |
+
from typing import Any, List, Optional, Tuple, cast
|
| 31 |
+
from qdrant_client.http.models import Payload
|
| 32 |
+
|
| 33 |
+
from collections import defaultdict
|
| 34 |
+
|
| 35 |
+
def parse_to_query_result(response: List[Any]) -> VectorStoreQueryResult:
|
| 36 |
+
"""
|
| 37 |
+
Convert vector store response to VectorStoreQueryResult.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
response: List[Any]: List of results returned from the vector store.
|
| 41 |
+
"""
|
| 42 |
+
nodes = []
|
| 43 |
+
similarities = []
|
| 44 |
+
ids = []
|
| 45 |
+
|
| 46 |
+
for point in response:
|
| 47 |
+
payload = cast(Payload, point.payload)
|
| 48 |
+
try:
|
| 49 |
+
node = metadata_dict_to_node(payload)
|
| 50 |
+
except Exception:
|
| 51 |
+
metadata, node_info, relationships = legacy_metadata_dict_to_node(
|
| 52 |
+
payload
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
node = TextNode(
|
| 56 |
+
id_=str(point.id),
|
| 57 |
+
text=payload.get("text"),
|
| 58 |
+
metadata=metadata,
|
| 59 |
+
start_char_idx=node_info.get("start", None),
|
| 60 |
+
end_char_idx=node_info.get("end", None),
|
| 61 |
+
relationships=relationships,
|
| 62 |
+
)
|
| 63 |
+
nodes.append(node)
|
| 64 |
+
ids.append(str(point.id))
|
| 65 |
+
try:
|
| 66 |
+
similarities.append(point.score)
|
| 67 |
+
except AttributeError:
|
| 68 |
+
# certain requests do not return a score
|
| 69 |
+
similarities.append(1.0)
|
| 70 |
+
|
| 71 |
+
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class ColPaliGemmaEmbedding(BaseEmbedding):
|
| 75 |
+
_model: ColPali = PrivateAttr()
|
| 76 |
+
_processor: ColPaliProcessor = PrivateAttr()
|
| 77 |
+
|
| 78 |
+
device: Union[torch.device | str] = Field(default="cpu",
|
| 79 |
+
description="Device to use")
|
| 80 |
+
def __init__(self,
|
| 81 |
+
model: ColPali,
|
| 82 |
+
processor: ColPaliProcessor,
|
| 83 |
+
device: Optional[str] = 'cpu',
|
| 84 |
+
**kwargs):
|
| 85 |
+
super().__init__(device=device,
|
| 86 |
+
**kwargs)
|
| 87 |
+
self._model = model.to(device).eval()
|
| 88 |
+
self._processor = processor
|
| 89 |
+
|
| 90 |
+
@classmethod
|
| 91 |
+
def class_name(cls) -> str:
|
| 92 |
+
return "ColPaliGemmaEmbedding"
|
| 93 |
+
|
| 94 |
+
def _get_query_embedding(self, query: str) -> List[float]:
|
| 95 |
+
"""Get query embedding.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
query (str): Query String
|
| 99 |
+
"""
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
processed_query = self._processor.process_queries([query])
|
| 102 |
+
processed_query = {k: v.to(self.device) for k, v in processed_query.items()}
|
| 103 |
+
query_embeddings = self._model(**processed_query)
|
| 104 |
+
return query_embeddings.to('cpu')[0]
|
| 105 |
+
|
| 106 |
+
def _get_text_embedding(self, text: str) -> List[float]:
|
| 107 |
+
"""Get text embedding.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
text (str): Text String
|
| 111 |
+
"""
|
| 112 |
+
with torch.no_grad():
|
| 113 |
+
processed_query = self._processor.process_queries([text])
|
| 114 |
+
processed_query = {k: v.to(self.device) for k, v in processed_query.items()}
|
| 115 |
+
query_embeddings = self._model(**processed_query)
|
| 116 |
+
return query_embeddings.to('cpu')[0]
|
| 117 |
+
|
| 118 |
+
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
| 119 |
+
"""Get text embeddings.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
texts (List[str]): List of text string
|
| 123 |
+
"""
|
| 124 |
+
with torch.no_grad():
|
| 125 |
+
processed_queries = self._processor.process_queries(texts)
|
| 126 |
+
processed_query = {k: v.to(self.device) for k, v in processed_query.items()}
|
| 127 |
+
query_embeddings = self._model(**processed_queries)
|
| 128 |
+
return query_embeddings.to('cpu')
|
| 129 |
+
|
| 130 |
+
async def _aget_query_embedding(self, query: str) -> List[float]:
|
| 131 |
+
return self._get_query_embedding(query)
|
| 132 |
+
|
| 133 |
+
async def _aget_text_embedding(self, text: str) -> List[float]:
|
| 134 |
+
return self._get_text_embedding(text)
|
| 135 |
+
|
| 136 |
+
class ColPaliRetriever(BaseRetriever):
|
| 137 |
+
def __init__(self,
|
| 138 |
+
vector_store_client: Union[qdrant_client.QdrantClient | qdrant_client.AsyncQdrantClient],
|
| 139 |
+
target_collection: str,
|
| 140 |
+
embed_model: ColPaliGemmaEmbedding,
|
| 141 |
+
query_mode: str = 'default',
|
| 142 |
+
similarity_top_k: int = 3,
|
| 143 |
+
) -> None:
|
| 144 |
+
self._vector_store_client = vector_store_client
|
| 145 |
+
self._target_collection = target_collection
|
| 146 |
+
self._embed_model = embed_model
|
| 147 |
+
self._query_mode = query_mode
|
| 148 |
+
self._similarity_top_k = similarity_top_k
|
| 149 |
+
super().__init__()
|
| 150 |
+
|
| 151 |
+
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
| 152 |
+
"""Get retrived nodes from the vector store by retriever given query string.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
query_bundle (QueryBundle): QueryBundle class includes query string
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
List[NodeWithScore]: List of retrieved nodes.
|
| 159 |
+
"""
|
| 160 |
+
if query_bundle.embedding is None:
|
| 161 |
+
query_embedding = self._embed_model._get_query_embedding(query_bundle.query_str)
|
| 162 |
+
else:
|
| 163 |
+
query_embedding = query_bundle.embedding
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
query_embedding = query_embedding.cpu().float().numpy().tolist()
|
| 167 |
+
|
| 168 |
+
# Get nodes from vector store
|
| 169 |
+
response = self._vector_store_client.query_points(collection_name=self._target_collection,
|
| 170 |
+
query=query_embedding,
|
| 171 |
+
limit=self._similarity_top_k).points
|
| 172 |
+
# Parse to structured output nodes
|
| 173 |
+
query_result = parse_to_query_result(response)
|
| 174 |
+
nodes_with_scores = []
|
| 175 |
+
for idx, node in enumerate(query_result.nodes):
|
| 176 |
+
score = None
|
| 177 |
+
if query_result.similarities is not None:
|
| 178 |
+
score = query_result.similarities[idx]
|
| 179 |
+
nodes_with_scores.append(NodeWithScore(node=node, score=score))
|
| 180 |
+
return nodes_with_scores
|
| 181 |
+
|
| 182 |
+
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
| 183 |
+
"""Asynchronously get retrived nodes from the vector store by retriever given query string.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
query_bundle (QueryBundle): QueryBundle class includes query string
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
List[NodeWithScore]: List of retrieved nodes.
|
| 190 |
+
"""
|
| 191 |
+
if query_bundle.embedding is None:
|
| 192 |
+
query_embedding = await self._embed_model._aget_query_embedding(query_bundle.query_str)
|
| 193 |
+
else:
|
| 194 |
+
query_embedding = query_bundle.embedding
|
| 195 |
+
|
| 196 |
+
query_embedding = query_embedding.cpu().float().numpy().tolist()
|
| 197 |
+
|
| 198 |
+
# Get nodes from vector store
|
| 199 |
+
responses = await self._vector_store_client.query_points(collection_name=self._target_collection,
|
| 200 |
+
query=query_embedding,
|
| 201 |
+
limit=self._similarity_top_k)
|
| 202 |
+
|
| 203 |
+
responses = responses.points
|
| 204 |
+
# Parse to structured output nodes
|
| 205 |
+
query_result = parse_to_query_result(responses)
|
| 206 |
+
nodes_with_scores = []
|
| 207 |
+
for idx, node in enumerate(query_result.nodes):
|
| 208 |
+
score = None
|
| 209 |
+
if query_result.similarities is not None:
|
| 210 |
+
score = query_result.similarities[idx]
|
| 211 |
+
nodes_with_scores.append(NodeWithScore(node=node, score=score))
|
| 212 |
+
return nodes_with_scores
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def fuse_results(retrieved_nodes: List[NodeWithScore], similarity_top_k: int) -> List[NodeWithScore]:
|
| 216 |
+
"""Fuse retrieved nodes using Reciprocal Rank
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
retrieved_nodes (List[NodeWithScore]): List of nodes.
|
| 220 |
+
similarity_top_k (int): get top K nodes.
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
List[NodeWithScore]: List of nodes after fused
|
| 224 |
+
"""
|
| 225 |
+
k = 60.0
|
| 226 |
+
fused_scores = {}
|
| 227 |
+
text_to_node = {}
|
| 228 |
+
for rank, node_with_score in enumerate(sorted(retrieved_nodes, key=lambda x: x.score or 0.0, reverse=True)):
|
| 229 |
+
text = node_with_score.node.get_content(metadata_mode='all')
|
| 230 |
+
text_to_node[text] = node_with_score
|
| 231 |
+
fused_scores[text] = fused_scores.get(text, 0.0) + 1.0 / (rank + k)
|
| 232 |
+
|
| 233 |
+
# Sort results by calculated score
|
| 234 |
+
reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True))
|
| 235 |
+
reranked_nodes: List[NodeWithScore] = []
|
| 236 |
+
for text, score in reranked_results.items():
|
| 237 |
+
reranked_nodes.append(text_to_node[text])
|
| 238 |
+
reranked_nodes[-1].score = score
|
| 239 |
+
return reranked_nodes[:similarity_top_k]
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def generate_queries(llm: LLM, query: str, num_queries: int) -> List[str]:
|
| 243 |
+
"""Generate num_queries queries
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
llm (LLM): LLM model
|
| 247 |
+
query (str): query string
|
| 248 |
+
num_queries (int): Number of queries to generate
|
| 249 |
+
|
| 250 |
+
Returns:
|
| 251 |
+
generate_queries List[str]: List of generated queries
|
| 252 |
+
"""
|
| 253 |
+
query_prompt = PromptTemplate(DEFAULT_GEN_PROMPT_TMPL)
|
| 254 |
+
generate_queries = llm.predict(query_prompt,
|
| 255 |
+
num_queries=num_queries,
|
| 256 |
+
query=query)
|
| 257 |
+
generate_queries = generate_queries.split('\n')
|
| 258 |
+
return generate_queries
|
| 259 |
+
|
| 260 |
+
async def agenerate_queries(llm: LLM, query: str, num_queries: int):
|
| 261 |
+
"""Asynchronously generate num_queries queries
|
| 262 |
+
|
| 263 |
+
Args:
|
| 264 |
+
llm (LLM): LLM model
|
| 265 |
+
query (str): query string
|
| 266 |
+
num_queries (int): Number of queries to generate
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
generate_queries List[str]: List of generated queries
|
| 270 |
+
"""
|
| 271 |
+
query_prompt = PromptTemplate(DEFAULT_GEN_PROMPT_TMPL)
|
| 272 |
+
generate_queries = await llm.apredict(query_prompt,
|
| 273 |
+
num_queries=num_queries,
|
| 274 |
+
query=query)
|
| 275 |
+
generate_queries = generate_queries.split('\n')
|
| 276 |
+
return generate_queries
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# Tree Summarization
|
| 280 |
+
def synthesize_results(queries: List[SubQuestion], contexts: Dict[str, Set[str]], llm: LLM, num_children: int) -> Tuple[str, List[str]]:
|
| 281 |
+
"""Summarize the results generated from LLM.
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
queries (List[SubQuestion]): Generated results
|
| 285 |
+
contexts (Dict[str, Set[str]]): Dictionary maps context information string to its set of source images
|
| 286 |
+
llm (LLM): LLM Model
|
| 287 |
+
num_children (int): Number of children for Tree Summarization
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
Tuple[str, List[str]]: Synthesized text, set of source images.
|
| 291 |
+
"""
|
| 292 |
+
qa_prompt = PromptTemplate(DEFAULT_SYNTHESIZE_PROMPT_TMPL)
|
| 293 |
+
|
| 294 |
+
new_contexts = defaultdict(set)
|
| 295 |
+
keys = list(contexts.keys())
|
| 296 |
+
for idx in range(0, len(keys), num_children):
|
| 297 |
+
contexts_batch = keys[idx: idx + num_children]
|
| 298 |
+
context_str = '\n\n'.join([f"{i + 1}. {text}" for i, text in enumerate(contexts_batch)])
|
| 299 |
+
|
| 300 |
+
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str="\n".join([query.sub_question for query in queries]))
|
| 301 |
+
combined_result = llm.complete(fmt_qa_prompt)
|
| 302 |
+
|
| 303 |
+
# Parse json string to dictionary
|
| 304 |
+
json_dict = parse_json_markdown(str(combined_result))
|
| 305 |
+
if len(json_dict['choices']) > 0:
|
| 306 |
+
for choice in json_dict['choices']:
|
| 307 |
+
new_contexts[json_dict['summarized_text']] = new_contexts[json_dict['summarized_text']].union(contexts[contexts_batch[choice - 1]])
|
| 308 |
+
else:
|
| 309 |
+
new_contexts[json_dict['summarized_text']] = set()
|
| 310 |
+
|
| 311 |
+
if len(new_contexts) == 1:
|
| 312 |
+
synthesized_text = list(new_contexts.keys())[0]
|
| 313 |
+
return synthesized_text, list(new_contexts[synthesized_text])
|
| 314 |
+
else:
|
| 315 |
+
return synthesize_results(queries, new_contexts, llm, num_children=num_children)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
async def asynthesize_results(queries: List[SubQuestion], contexts: Dict[str, Set[str]], llm: LLM, num_children: int) -> Union[str, List[str]]:
|
| 319 |
+
"""Asynchronously sumamarize the results generated from LLM.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
queries (List[SubQuestion]): Generated results
|
| 323 |
+
contexts (Dict[str, Set[str]]): Dictionary maps context information string to its set of source images
|
| 324 |
+
llm (LLM): LLM Model
|
| 325 |
+
num_children (int): Number of children for Tree Summarization
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
Tuple[str, List[str]]: Synthesized text, set of source images.
|
| 329 |
+
"""
|
| 330 |
+
qa_prompt = PromptTemplate(DEFAULT_SYNTHESIZE_PROMPT_TMPL)
|
| 331 |
+
fmt_qa_prompts = []
|
| 332 |
+
keys = list(contexts.keys())
|
| 333 |
+
contexts_batches = []
|
| 334 |
+
for idx in range(0, len(keys), num_children):
|
| 335 |
+
contexts_batch = keys[idx: idx + num_children]
|
| 336 |
+
|
| 337 |
+
context_str = '\n\n'.join([f"{idx + 1}. {text}" for idx, text in enumerate(contexts_batch)])
|
| 338 |
+
|
| 339 |
+
fmt_qa_prompt = qa_prompt.format(context_str=context_str, query_str="\n".join([query.sub_question for query in queries]))
|
| 340 |
+
fmt_qa_prompts.append(fmt_qa_prompt)
|
| 341 |
+
contexts_batches.append(contexts_batch)
|
| 342 |
+
|
| 343 |
+
tasks = []
|
| 344 |
+
async with asyncio.TaskGroup() as tg:
|
| 345 |
+
for fmt_qa_prompt in fmt_qa_prompts:
|
| 346 |
+
task = tg.create_task(llm.acomplete(fmt_qa_prompt))
|
| 347 |
+
tasks.append(task)
|
| 348 |
+
|
| 349 |
+
responses = [str(task.result()) for task in tasks]
|
| 350 |
+
new_contexts = defaultdict(set)
|
| 351 |
+
for idx, response in enumerate(responses):
|
| 352 |
+
# Parse json string to dictionary
|
| 353 |
+
json_dict = parse_json_markdown(response)
|
| 354 |
+
|
| 355 |
+
if len(json_dict["choices"]) > 1:
|
| 356 |
+
for choice in json_dict["choices"]:
|
| 357 |
+
new_contexts[json_dict["summarized_text"]] = new_contexts[json_dict["summarized_text"]].union(contexts[contexts_batches[idx][choice - 1]])
|
| 358 |
+
else:
|
| 359 |
+
new_contexts[json_dict["summarized_text"]] = set()
|
| 360 |
+
|
| 361 |
+
if len(new_contexts) == 1:
|
| 362 |
+
synthesized_text = list(new_contexts.keys())[0]
|
| 363 |
+
return synthesized_text, list(new_contexts[synthesized_text])
|
| 364 |
+
else:
|
| 365 |
+
return await asynthesize_results(queries, new_contexts, llm, num_children=num_children)
|
| 366 |
+
|
| 367 |
+
class CustomFusionRetriever(BaseRetriever):
|
| 368 |
+
def __init__(self,
|
| 369 |
+
llm,
|
| 370 |
+
retriever_mappings: Dict[str, BaseRetriever],
|
| 371 |
+
similarity_top_k: int = 3,
|
| 372 |
+
num_generated_queries = 3,
|
| 373 |
+
) -> None:
|
| 374 |
+
self._retriever_mappings = retriever_mappings
|
| 375 |
+
self._similarity_top_k = similarity_top_k
|
| 376 |
+
self._num_generated_queries = num_generated_queries
|
| 377 |
+
self._llm = llm
|
| 378 |
+
super().__init__()
|
| 379 |
+
|
| 380 |
+
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
| 381 |
+
"""Retrieve self._similarity_top_k content nodes given query
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
query_bundle (QueryBundle): query bundle include query string
|
| 385 |
+
"""
|
| 386 |
+
# Get data from query bundle
|
| 387 |
+
query_dict = json.loads(query_bundle.query_str)
|
| 388 |
+
original_query = query_dict['sub_question']
|
| 389 |
+
tool_name = query_dict['tool_name']
|
| 390 |
+
|
| 391 |
+
# Rewrite original query to n queries
|
| 392 |
+
generated_queries = generate_queries(self._llm, original_query, num_queries=self._num_generated_queries)
|
| 393 |
+
|
| 394 |
+
# For each generated query, retrieve relevant nodes
|
| 395 |
+
retrieved_nodes = []
|
| 396 |
+
for query in generated_queries:
|
| 397 |
+
if len(query) == 0:
|
| 398 |
+
continue
|
| 399 |
+
retrieved_nodes.extend(self._retriever_mappings[tool_name].retrieve(query))
|
| 400 |
+
|
| 401 |
+
# Fuse retrieved nodes using reciprocal rank
|
| 402 |
+
fused_results = fuse_results(retrieved_nodes,
|
| 403 |
+
similarity_top_k=self._similarity_top_k)
|
| 404 |
+
return fused_results
|
| 405 |
+
|
| 406 |
+
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
| 407 |
+
"""Asynchronously retrieve self._similarity_top_k content nodes given query
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
query_bundle (QueryBundle): query bundle include query string
|
| 411 |
+
"""
|
| 412 |
+
# Get data from query bundle
|
| 413 |
+
query_dict = json.loads(query_bundle.query_str)
|
| 414 |
+
original_query = query_dict['sub_question']
|
| 415 |
+
tool_name = query_dict['tool_name']
|
| 416 |
+
|
| 417 |
+
# Rewrite original query to n queries
|
| 418 |
+
generated_queries = await agenerate_queries(llm=self._llm, query=original_query, num_queries=self._num_generated_queries)
|
| 419 |
+
|
| 420 |
+
# For each generated query, retrieve relevant nodes
|
| 421 |
+
tasks = []
|
| 422 |
+
async with asyncio.TaskGroup() as tg:
|
| 423 |
+
for query in generated_queries:
|
| 424 |
+
if len(query) == 0:
|
| 425 |
+
continue
|
| 426 |
+
task = tg.create_task(self._retriever_mappings[tool_name].aretrieve(query))
|
| 427 |
+
tasks.append(task)
|
| 428 |
+
|
| 429 |
+
retrieved_nodes = [node for task in tasks for node in task.result()]
|
| 430 |
+
|
| 431 |
+
# Fuse retrieved nodes using reciprocal rank
|
| 432 |
+
fused_results = fuse_results(retrieved_nodes,
|
| 433 |
+
similarity_top_k=self._similarity_top_k)
|
| 434 |
+
return fused_results
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
@dataclass
|
| 438 |
+
class Response:
|
| 439 |
+
response: str
|
| 440 |
+
source_images: Optional[List] = None
|
| 441 |
+
|
| 442 |
+
def __str__(self):
|
| 443 |
+
return self.response
|
| 444 |
+
|
| 445 |
+
class CustomQueryEngine:
|
| 446 |
+
def __init__(self,
|
| 447 |
+
retriever_tools: List[ToolMetadata],
|
| 448 |
+
fusion_retriever: BaseRetriever,
|
| 449 |
+
qa_prompt: PromptTemplate = None,
|
| 450 |
+
llm: LLM = None,
|
| 451 |
+
num_children: int = 3):
|
| 452 |
+
self._qa_prompt = qa_prompt if qa_prompt else PromptTemplate(DEFAULT_FINAL_ANSWER_PROMPT_TMPL)
|
| 453 |
+
self._llm = llm
|
| 454 |
+
self._num_children = num_children
|
| 455 |
+
self._sub_question_generator = LLMQuestionGenerator.from_defaults(llm=self._llm,
|
| 456 |
+
prompt_template_str=DEFAULT_SUB_QUESTION_PROMPT_TMPL)
|
| 457 |
+
self._fusion_retriever = fusion_retriever
|
| 458 |
+
self._retriever_tools = retriever_tools
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def query(self, query_str: str) -> Response:
|
| 462 |
+
# Generate sub queries
|
| 463 |
+
sub_queries = self._sub_question_generator.generate(tools=self._retriever_tools,
|
| 464 |
+
query=QueryBundle(query_str=query_str))
|
| 465 |
+
|
| 466 |
+
if len(sub_queries) == 0:
|
| 467 |
+
response_template = PromptTemplate("Cannot answer the query: {query_str}")
|
| 468 |
+
return Response(response=response_template.format(query_str=query_str), source_images=[])
|
| 469 |
+
else:
|
| 470 |
+
# Dictionary to map response -> source_images
|
| 471 |
+
response2images_mapping = defaultdict(set)
|
| 472 |
+
|
| 473 |
+
# For each sub queries retrieve relevant image nodes
|
| 474 |
+
# With fusion retriever, each sub query is rewritten to n queries -> retrieve relevant nodes for each generated query
|
| 475 |
+
# -> fuse all nodes retrieved from multiple generated queries using reciprocal rank -> get top k results
|
| 476 |
+
for sub_query in sub_queries:
|
| 477 |
+
retrieved_nodes = self._fusion_retriever.retrieve(QueryBundle(query_str=sub_query.model_dump_json()))
|
| 478 |
+
# Using LLM to get the answer for sub query from retrieved nodes
|
| 479 |
+
for retrieved_node in retrieved_nodes:
|
| 480 |
+
response2images_mapping[str(self._llm.complete([sub_query.sub_question, Image.open(retrieved_node.node.resolve_image())]))].add(retrieved_node.node.image)
|
| 481 |
+
|
| 482 |
+
# Synthesize results
|
| 483 |
+
synthesized_text, source_images = synthesize_results(queries=sub_queries,
|
| 484 |
+
contexts=response2images_mapping,
|
| 485 |
+
llm=self._llm,
|
| 486 |
+
num_children=self._num_children)
|
| 487 |
+
|
| 488 |
+
final_answer = self._llm.predict(self._qa_prompt,
|
| 489 |
+
context_str=synthesized_text,
|
| 490 |
+
query_str=query_str)
|
| 491 |
+
|
| 492 |
+
response_template = PromptTemplate("Retrieved Information:\n"
|
| 493 |
+
"------------------------\n"
|
| 494 |
+
"{retrieved_information}\n"
|
| 495 |
+
"-------------------------\n\n"
|
| 496 |
+
"Answer:\n"
|
| 497 |
+
"{final_answer}")
|
| 498 |
+
|
| 499 |
+
return Response(response=response_template.format(retrieved_information=synthesized_text, final_answer=final_answer), source_images=source_images)
|
| 500 |
+
|
| 501 |
+
async def aquery(self, query_str: str):
|
| 502 |
+
sub_queries = await self._sub_question_generator.agenerate(tools=self._retriever_tools,
|
| 503 |
+
query=QueryBundle(query_str=query_str))
|
| 504 |
+
if len(sub_queries) == 0:
|
| 505 |
+
response_template = PromptTemplate("Cannot answer the query: {query_str}")
|
| 506 |
+
return Response(response=response_template.format(query_str=query_str), source_images=[])
|
| 507 |
+
else:
|
| 508 |
+
retrieved_subquestion_nodes = []
|
| 509 |
+
async with asyncio.TaskGroup() as tg:
|
| 510 |
+
for sub_query in sub_queries:
|
| 511 |
+
task = tg.create_task(self._fusion_retriever.aretrieve(QueryBundle(query_str=sub_query.model_dump_json())))
|
| 512 |
+
retrieved_subquestion_nodes.append([sub_query.sub_question, task])
|
| 513 |
+
|
| 514 |
+
retrieved_subquestion_nodes = [[sub_question, task.result()] for sub_question, task in retrieved_subquestion_nodes]
|
| 515 |
+
|
| 516 |
+
answers = []
|
| 517 |
+
# For each sub queries retrieve relevant image nodes
|
| 518 |
+
# With fusion retriever, each sub query is rewritten to n queries -> retrieve relevant nodes for each generated query
|
| 519 |
+
# -> fuse all nodes retrieved from multiple generated queries using reciprocal rank -> get top k results
|
| 520 |
+
async with asyncio.TaskGroup() as tg:
|
| 521 |
+
for sub_question, retrieved_nodes in retrieved_subquestion_nodes:
|
| 522 |
+
for retrieved_node in retrieved_nodes:
|
| 523 |
+
task = tg.create_task(self._llm.acomplete([sub_question, Image.open(retrieved_node.node.resolve_image())]))
|
| 524 |
+
answers.append([task, retrieved_node.node.image])
|
| 525 |
+
|
| 526 |
+
# Dictionary to map response -> source_images
|
| 527 |
+
response2images_mapping = defaultdict(set)
|
| 528 |
+
|
| 529 |
+
for task, image in answers:
|
| 530 |
+
response2images_mapping[str(task.result())].add(image)
|
| 531 |
+
|
| 532 |
+
# Synthesize results
|
| 533 |
+
synthesized_text, source_images = await asynthesize_results(queries=sub_queries,
|
| 534 |
+
contexts=response2images_mapping,
|
| 535 |
+
llm=self._llm,
|
| 536 |
+
num_children=self._num_children)
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
final_answer = await self._llm.apredict(self._qa_prompt,
|
| 540 |
+
context_str=synthesized_text,
|
| 541 |
+
query_str=query_str)
|
| 542 |
+
|
| 543 |
+
response_template = PromptTemplate("Retrieved Information:\n"
|
| 544 |
+
"------------------------\n"
|
| 545 |
+
"{retrieved_information}\n"
|
| 546 |
+
"-------------------------\n\n"
|
| 547 |
+
"Answer:\n"
|
| 548 |
+
"{final_answer}")
|
| 549 |
+
|
| 550 |
+
return Response(response=response_template.format(retrieved_information=synthesized_text, final_answer=final_answer), source_images=source_images)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
|
models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .colpali import ColPali, KVCache
|
| 2 |
+
from .paligemma_processor import PaliGemmaProcessor
|
| 3 |
+
from .colpali_processor import ColPaliProcessor
|
| 4 |
+
from .paligemma import PaliGemma
|
| 5 |
+
from .lora import *
|
models/colpali.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from .gemma import KVCache
|
| 8 |
+
from .paligemma import PaliGemma, PaliGemmaConfig
|
| 9 |
+
from typing import Optional
|
| 10 |
+
from utils import *
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from safetensors import safe_open
|
| 13 |
+
|
| 14 |
+
def convert_weights_dict(original_weights):
|
| 15 |
+
converted_weights = {}
|
| 16 |
+
converted_weights['custom_text_proj.lora_A.weight'] = original_weights['base_model.model.custom_text_proj.lora_A.weight']
|
| 17 |
+
converted_weights['custom_text_proj.lora_B.weight'] = original_weights['base_model.model.custom_text_proj.lora_B.weight']
|
| 18 |
+
for i in range(18):
|
| 19 |
+
converted_weights[f'model.language_model.model.layers.{i}.mlp.down_proj.lora_A.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.mlp.down_proj.lora_A.weight']
|
| 20 |
+
converted_weights[f'model.language_model.model.layers.{i}.mlp.down_proj.lora_B.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.mlp.down_proj.lora_B.weight']
|
| 21 |
+
converted_weights[f'model.language_model.model.layers.{i}.mlp.gate_proj.lora_A.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.mlp.gate_proj.lora_A.weight']
|
| 22 |
+
converted_weights[f'model.language_model.model.layers.{i}.mlp.gate_proj.lora_B.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.mlp.gate_proj.lora_B.weight']
|
| 23 |
+
converted_weights[f'model.language_model.model.layers.{i}.mlp.up_proj.lora_A.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.mlp.up_proj.lora_A.weight']
|
| 24 |
+
converted_weights[f'model.language_model.model.layers.{i}.mlp.up_proj.lora_B.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.mlp.up_proj.lora_B.weight']
|
| 25 |
+
converted_weights[f'model.language_model.model.layers.{i}.self_attn.q_proj.lora_A.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.q_proj.lora_A.weight']
|
| 26 |
+
converted_weights[f'model.language_model.model.layers.{i}.self_attn.q_proj.lora_B.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.q_proj.lora_B.weight']
|
| 27 |
+
converted_weights[f'model.language_model.model.layers.{i}.self_attn.k_proj.lora_A.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.k_proj.lora_A.weight']
|
| 28 |
+
converted_weights[f'model.language_model.model.layers.{i}.self_attn.k_proj.lora_B.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.k_proj.lora_B.weight']
|
| 29 |
+
converted_weights[f'model.language_model.model.layers.{i}.self_attn.v_proj.lora_A.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.v_proj.lora_A.weight']
|
| 30 |
+
converted_weights[f'model.language_model.model.layers.{i}.self_attn.v_proj.lora_B.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.v_proj.lora_B.weight']
|
| 31 |
+
converted_weights[f'model.language_model.model.layers.{i}.self_attn.o_proj.lora_A.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.o_proj.lora_A.weight']
|
| 32 |
+
converted_weights[f'model.language_model.model.layers.{i}.self_attn.o_proj.lora_B.weight'] = original_weights[f'base_model.model.model.language_model.model.layers.{i}.self_attn.o_proj.lora_B.weight']
|
| 33 |
+
|
| 34 |
+
return converted_weights
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ColPali(nn.Module):
|
| 38 |
+
def __init__(self, cfg: PaliGemmaConfig):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.model = PaliGemma(cfg=cfg)
|
| 41 |
+
self.dim = 128
|
| 42 |
+
self.custom_text_proj = nn.Linear(self.model.cfg.text_config.hidden_size, self.dim, bias=False)
|
| 43 |
+
|
| 44 |
+
@staticmethod
|
| 45 |
+
def from_pretrained(model_dir, torch_dtype: torch.dtype = torch.float32):
|
| 46 |
+
torch.set_default_dtype(torch_dtype)
|
| 47 |
+
with open(os.path.join(model_dir, 'config.json'), "r") as f:
|
| 48 |
+
model_config = json.loads(f.read())
|
| 49 |
+
config = PaliGemmaConfig.from_dict(model_config)
|
| 50 |
+
|
| 51 |
+
safetensor_files = Path(model_dir).glob("*.safetensors")
|
| 52 |
+
|
| 53 |
+
weights = {}
|
| 54 |
+
for file in safetensor_files:
|
| 55 |
+
with safe_open(file, framework='pt', device="cpu") as f:
|
| 56 |
+
for key in f.keys():
|
| 57 |
+
weights[key] = f.get_tensor(key)
|
| 58 |
+
model = ColPali(config)
|
| 59 |
+
model.load_state_dict(weights, strict=False)
|
| 60 |
+
model.tie_weights()
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
def load_lora(self, model_dir):
|
| 64 |
+
weights = {}
|
| 65 |
+
with safe_open(os.path.join(model_dir, "adapter_model.safetensors"), framework="pt", device="cpu") as f:
|
| 66 |
+
for key in f.keys():
|
| 67 |
+
weights[key] = f.get_tensor(key)
|
| 68 |
+
|
| 69 |
+
converted_weights = convert_weights_dict(weights)
|
| 70 |
+
self.load_state_dict(converted_weights, strict=False)
|
| 71 |
+
|
| 72 |
+
def tie_weights(self):
|
| 73 |
+
self.model.language_model.tie_weights()
|
| 74 |
+
|
| 75 |
+
def forward(self, *args, **kwargs) -> torch.Tensor:
|
| 76 |
+
outputs = self.model(*args, **kwargs)
|
| 77 |
+
last_hidden_states = outputs[0]
|
| 78 |
+
proj = self.custom_text_proj(last_hidden_states)
|
| 79 |
+
# L2 normalization
|
| 80 |
+
proj = proj / proj.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
| 81 |
+
|
| 82 |
+
proj = proj * kwargs['attention_mask'].unsqueeze(-1) # (batch_size, sequence_length, dim)
|
| 83 |
+
|
| 84 |
+
return proj
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
models/colpali_processor.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from typing import Tuple, List
|
| 4 |
+
import numpy as np
|
| 5 |
+
from transformers import GemmaTokenizerFast
|
| 6 |
+
from .paligemma_processor import PaliGemmaProcessor
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
def process_imgs(imgs: List[Image.Image],
|
| 10 |
+
img_size: Tuple[int, int],
|
| 11 |
+
rescale: float,
|
| 12 |
+
mean: Tuple[float, float, float],
|
| 13 |
+
std: Tuple[float, float, float]):
|
| 14 |
+
|
| 15 |
+
def normalize(img, mean, std):
|
| 16 |
+
img = (img - np.array(mean, dtype=img.dtype)) / np.array(std, dtype=img.dtype)
|
| 17 |
+
return img
|
| 18 |
+
|
| 19 |
+
resized_imgs = [img.resize((img_size[0], img_size[1]), resample=Image.Resampling.BICUBIC) for img in imgs]
|
| 20 |
+
|
| 21 |
+
rescaled_imgs = [np.array(img, dtype=np.float32) * rescale for img in resized_imgs]
|
| 22 |
+
|
| 23 |
+
normalized_imgs = [normalize(img, mean, std) for img in rescaled_imgs]
|
| 24 |
+
|
| 25 |
+
transposed_imgs = [img.transpose(2, 0, 1) for img in normalized_imgs]
|
| 26 |
+
|
| 27 |
+
tensor_imgs = torch.tensor(np.stack(transposed_imgs, axis=0), dtype=torch.float32)
|
| 28 |
+
return tensor_imgs
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def process_prompts(prompt, image_token, max_num_image_token, bos_token):
|
| 32 |
+
return f"{image_token * max_num_image_token}{bos_token}{prompt}\n"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ColPaliProcessor(PaliGemmaProcessor):
|
| 36 |
+
def __init__(self,
|
| 37 |
+
tokenizer: GemmaTokenizerFast) -> None:
|
| 38 |
+
super().__init__(tokenizer=tokenizer)
|
| 39 |
+
self.mock_image = Image.new(mode='RGB', size=(16, 16), color='black')
|
| 40 |
+
|
| 41 |
+
def process_images(self, images: List[Image.Image]):
|
| 42 |
+
input_prompts = ["Describe the image."] * len(images)
|
| 43 |
+
|
| 44 |
+
images = [image.convert("RGB") for image in images]
|
| 45 |
+
|
| 46 |
+
return_data = self(images,
|
| 47 |
+
input_prompts,
|
| 48 |
+
padding="longest",
|
| 49 |
+
truncation=False)
|
| 50 |
+
|
| 51 |
+
return return_data
|
| 52 |
+
|
| 53 |
+
def process_queries(self,
|
| 54 |
+
queries: List[str],
|
| 55 |
+
max_length: int = 50,
|
| 56 |
+
suffix: Optional[str] = None):
|
| 57 |
+
|
| 58 |
+
if suffix is None:
|
| 59 |
+
suffix = "<pad>" * 10
|
| 60 |
+
|
| 61 |
+
texts_query: List[str] = []
|
| 62 |
+
|
| 63 |
+
for query in queries:
|
| 64 |
+
query = f"Question: {query}"
|
| 65 |
+
query += suffix
|
| 66 |
+
texts_query.append(query)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
batch_query = self(imgs=[self.mock_image] * len(texts_query),
|
| 70 |
+
prompts=texts_query,
|
| 71 |
+
padding="longest",
|
| 72 |
+
max_length=max_length + self.image_seq_length,
|
| 73 |
+
truncation=True)
|
| 74 |
+
|
| 75 |
+
del batch_query["pixel_values"]
|
| 76 |
+
|
| 77 |
+
batch_query["input_ids"] = batch_query["input_ids"][..., self.image_seq_length:]
|
| 78 |
+
batch_query["attention_mask"] = batch_query["attention_mask"][..., self.image_seq_length:]
|
| 79 |
+
|
| 80 |
+
return batch_query
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
models/gemma.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.nn.utils.parametrize as parametrize
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Optional, List
|
| 7 |
+
import math
|
| 8 |
+
import torch.utils.checkpoint as checkpoint
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class GemmaConfig:
|
| 12 |
+
hidden_size: int = 2048
|
| 13 |
+
intermediate_size: int = 16384
|
| 14 |
+
num_attention_heads: int = 8
|
| 15 |
+
num_hidden_layers: int = 18
|
| 16 |
+
num_image_tokens: int = 256
|
| 17 |
+
num_key_value_heads: int = 1
|
| 18 |
+
vocab_size: int = 257216
|
| 19 |
+
norm_eps: float = 1e-6
|
| 20 |
+
max_seq_len: int = 8192
|
| 21 |
+
attention_dropout: float = 0.0
|
| 22 |
+
use_lora: bool = False
|
| 23 |
+
training: bool = False
|
| 24 |
+
|
| 25 |
+
@classmethod
|
| 26 |
+
def from_dict(cls, data):
|
| 27 |
+
return cls(
|
| 28 |
+
hidden_size = data['hidden_size'],
|
| 29 |
+
intermediate_size = data['intermediate_size'],
|
| 30 |
+
num_attention_heads = data['num_attention_heads'],
|
| 31 |
+
num_hidden_layers = data['num_hidden_layers'],
|
| 32 |
+
num_image_tokens = data['num_image_tokens'],
|
| 33 |
+
num_key_value_heads = data['num_key_value_heads'],
|
| 34 |
+
vocab_size = data['vocab_size'],
|
| 35 |
+
training = data['training'])
|
| 36 |
+
|
| 37 |
+
class RMSNorm(nn.Module):
|
| 38 |
+
def __init__(self, dim: int, norm_eps: float = 1e-6):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.weight = nn.Parameter(torch.zeros(dim))
|
| 41 |
+
self.norm_eps = norm_eps
|
| 42 |
+
|
| 43 |
+
def _norm(self, x):
|
| 44 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.norm_eps)
|
| 45 |
+
|
| 46 |
+
def forward(self, x: torch.Tensor):
|
| 47 |
+
output = self._norm(x.float())
|
| 48 |
+
output = output * (1.0 + self.weight.float())
|
| 49 |
+
return output.type_as(x)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def precompute_freqs(head_dim: int, max_seq_len: int, theta: int = 10000):
|
| 53 |
+
thetas = 1 / (theta ** (torch.arange(0, head_dim, 2, dtype=torch.int64).float() / head_dim))
|
| 54 |
+
m = torch.arange(max_seq_len, dtype=torch.long)
|
| 55 |
+
|
| 56 |
+
# (max_seq_len, head_dim // 2)
|
| 57 |
+
freqs = torch.outer(m, thetas)
|
| 58 |
+
|
| 59 |
+
# (max_seq_len, head_dim // 2) -> (max_seq_len, head_dim)
|
| 60 |
+
freqs = torch.cat((freqs, freqs), dim=-1)
|
| 61 |
+
return freqs
|
| 62 |
+
|
| 63 |
+
def roate_half(x: torch.Tensor):
|
| 64 |
+
x1 = x[..., :x.shape[-1] // 2]
|
| 65 |
+
x2 = x[..., x.shape[-1] // 2:]
|
| 66 |
+
|
| 67 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 68 |
+
|
| 69 |
+
def apply_rotary_embed(x: torch.Tensor,
|
| 70 |
+
freqs: torch.Tensor):
|
| 71 |
+
# x: (n, n_heads, seq_len, head_dim)
|
| 72 |
+
# freqs: (n, seq_len, head_dim)
|
| 73 |
+
device_type = x.device.type
|
| 74 |
+
device_type = device_type if device_type != 'mps' else 'cpu'
|
| 75 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 76 |
+
cos = freqs.cos()
|
| 77 |
+
sin = freqs.sin()
|
| 78 |
+
while len(cos.shape) < len(x.shape):
|
| 79 |
+
cos = cos.unsqueeze(1)
|
| 80 |
+
sin = sin.unsqueeze(1)
|
| 81 |
+
cos = cos.to(x.dtype)
|
| 82 |
+
sin = sin.to(x.dtype)
|
| 83 |
+
x = (x * cos) + (roate_half(x) * sin)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
class KVCache:
|
| 87 |
+
def __init__(self):
|
| 88 |
+
self.cache_k: List[torch.Tensor] = []
|
| 89 |
+
self.cache_v: List[torch.Tensor] = []
|
| 90 |
+
|
| 91 |
+
def num_items(self):
|
| 92 |
+
if len(self.cache_k) == 0:
|
| 93 |
+
return 0
|
| 94 |
+
else:
|
| 95 |
+
# (n, num_heads, seq_len, head_dim)
|
| 96 |
+
return self.cache_k[0].shape[-2]
|
| 97 |
+
|
| 98 |
+
def update(self, xk, xv, layer_idx):
|
| 99 |
+
if layer_idx < len(self.cache_k):
|
| 100 |
+
self.cache_k[layer_idx] = torch.cat((self.cache_k[layer_idx], xk), dim=-2)
|
| 101 |
+
self.cache_v[layer_idx] = torch.cat((self.cache_v[layer_idx], xv), dim=-2)
|
| 102 |
+
else:
|
| 103 |
+
self.cache_k.append(xk)
|
| 104 |
+
self.cache_v.append(xv)
|
| 105 |
+
|
| 106 |
+
return self.cache_k[layer_idx], self.cache_v[layer_idx]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class GemmaTransformerAttention(nn.Module):
|
| 110 |
+
def __init__(self, cfg: GemmaConfig, layer_idx: int):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.cfg = cfg
|
| 113 |
+
self.layer_idx = layer_idx
|
| 114 |
+
self.vocab_size = cfg.vocab_size
|
| 115 |
+
self.hidden_size = cfg.hidden_size
|
| 116 |
+
self.num_attention_heads = cfg.num_attention_heads
|
| 117 |
+
self.num_key_value_heads = cfg.num_key_value_heads
|
| 118 |
+
self.max_seq_len = cfg.max_seq_len
|
| 119 |
+
|
| 120 |
+
assert self.hidden_size % self.num_attention_heads == 0
|
| 121 |
+
|
| 122 |
+
self.n_rep =self.num_attention_heads // self.num_key_value_heads
|
| 123 |
+
self.head_dim = self.hidden_size // self.num_attention_heads
|
| 124 |
+
|
| 125 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=False)
|
| 126 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 127 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
|
| 128 |
+
|
| 129 |
+
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
| 130 |
+
|
| 131 |
+
self.attn_dropout = cfg.attention_dropout
|
| 132 |
+
self.training = cfg.training
|
| 133 |
+
|
| 134 |
+
self.register_buffer('freqs',
|
| 135 |
+
precompute_freqs(self.head_dim, cfg.max_seq_len),
|
| 136 |
+
persistent=False)
|
| 137 |
+
|
| 138 |
+
def forward(self, x: torch.Tensor,
|
| 139 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 140 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 141 |
+
kv_cache: Optional[KVCache] = None):
|
| 142 |
+
batch_size, seq_len, embed_dim = x.shape
|
| 143 |
+
|
| 144 |
+
xq = self.q_proj(x)
|
| 145 |
+
xk = self.k_proj(x)
|
| 146 |
+
xv = self.v_proj(x)
|
| 147 |
+
|
| 148 |
+
# (n, seq_len, hidden_size) -> (n, seq_len, num_heads, head_dim) -> (n, num_heads, seq_len, head_dim)
|
| 149 |
+
xq = xq.view(batch_size, seq_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
| 150 |
+
# (n, seq_len, hidden_size) -> (n, seq_len, num_kv_heads, head_dim) -> (n, num_kv_heads, seq_len, head_dim)
|
| 151 |
+
xk = xk.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 152 |
+
xv = xv.view(batch_size, seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 153 |
+
|
| 154 |
+
xq = apply_rotary_embed(xq, self.freqs[position_ids, :])
|
| 155 |
+
xk = apply_rotary_embed(xk, self.freqs[position_ids, :])
|
| 156 |
+
|
| 157 |
+
if kv_cache is not None:
|
| 158 |
+
keys, values = kv_cache.update(xk, xv, self.layer_idx)
|
| 159 |
+
else:
|
| 160 |
+
keys, values = xk, xv
|
| 161 |
+
|
| 162 |
+
# (n, num_kv_heads, seq_len, head_dim) -> (n, num_kv_heads * n_rep, seq_len, head_dim) -> (n, num_heads, seq_len, head_dim)
|
| 163 |
+
keys = keys[:, :, None, :, :].expand(-1, -1, self.n_rep, -1, -1).view(batch_size, -1, keys.shape[-2], self.head_dim)
|
| 164 |
+
values = values[:, :, None, :, :].expand(-1, -1, self.n_rep, -1, -1).view(batch_size, -1, keys.shape[-2], self.head_dim)
|
| 165 |
+
|
| 166 |
+
assert attention_mask is not None
|
| 167 |
+
# (n, num_heads, seq_len, head_dim) @ (n, num_heads, head_dim, seq_len) -> (n, num_heads, seq_len, seq_len)
|
| 168 |
+
attn_weights = torch.softmax(xq @ keys.transpose(2, 3) / math.sqrt(self.head_dim) + attention_mask, dim=-1)
|
| 169 |
+
|
| 170 |
+
# dropout when training
|
| 171 |
+
attn_weights = F.dropout(attn_weights, p=self.attn_dropout, training=self.training)
|
| 172 |
+
# (n, num_heads, seq_len, seq_len) @ (n, num_heads, seq_len, head_dim) -> (n, num_heads, seq_len, head_dim)
|
| 173 |
+
attn_output = attn_weights @ values
|
| 174 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 175 |
+
attn_output = attn_output.view(*x.shape)
|
| 176 |
+
|
| 177 |
+
attn_output = self.o_proj(attn_output)
|
| 178 |
+
return attn_output, attn_weights
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class GemmaTransformerMLP(nn.Module):
|
| 182 |
+
def __init__(self, cfg: GemmaConfig):
|
| 183 |
+
super().__init__()
|
| 184 |
+
self.cfg = cfg
|
| 185 |
+
|
| 186 |
+
self.down_proj = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False)
|
| 187 |
+
self.gate_proj = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
|
| 188 |
+
self.up_proj = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
|
| 189 |
+
|
| 190 |
+
def forward(self, x: torch.Tensor):
|
| 191 |
+
return self.down_proj(F.gelu(self.gate_proj(x), approximate="tanh") * self.up_proj(x))
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class GemmaTransformerDecoder(nn.Module):
|
| 196 |
+
def __init__(self, cfg: GemmaConfig, layer_idx: int) -> None:
|
| 197 |
+
super().__init__()
|
| 198 |
+
self.cfg = cfg
|
| 199 |
+
|
| 200 |
+
self.input_layernorm = RMSNorm(cfg.hidden_size, cfg.norm_eps)
|
| 201 |
+
self.self_attn = GemmaTransformerAttention(cfg, layer_idx)
|
| 202 |
+
self.mlp = GemmaTransformerMLP(cfg)
|
| 203 |
+
self.post_attention_layernorm = RMSNorm(cfg.hidden_size, cfg.norm_eps)
|
| 204 |
+
self.gradient_checking = False
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def forward(self, x: torch.Tensor,
|
| 208 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 209 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 210 |
+
kv_cache: Optional[KVCache] = None):
|
| 211 |
+
|
| 212 |
+
residual = x
|
| 213 |
+
x = self.input_layernorm(x)
|
| 214 |
+
|
| 215 |
+
if self.gradient_checking:
|
| 216 |
+
x = checkpoint.checkpoint(self.self_attn, x, position_ids, attention_mask, kv_cache)
|
| 217 |
+
else:
|
| 218 |
+
x = self.self_attn(x,
|
| 219 |
+
position_ids,
|
| 220 |
+
attention_mask,
|
| 221 |
+
kv_cache)[0]
|
| 222 |
+
x += residual
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
residual = x
|
| 226 |
+
x = self.post_attention_layernorm(x)
|
| 227 |
+
x = residual + self.mlp(x)
|
| 228 |
+
return x
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class GemmaModel(nn.Module):
|
| 232 |
+
def __init__(self, cfg: GemmaConfig) -> None:
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.cfg = cfg
|
| 235 |
+
self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
|
| 236 |
+
|
| 237 |
+
self.layers = nn.ModuleList(
|
| 238 |
+
[GemmaTransformerDecoder(cfg, layer_idx) for layer_idx in range(cfg.num_hidden_layers)]
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
self.norm = RMSNorm(cfg.hidden_size, cfg.norm_eps)
|
| 242 |
+
|
| 243 |
+
def forward(self, x: torch.Tensor,
|
| 244 |
+
position_ids: Optional[torch.Tensor],
|
| 245 |
+
attention_mask: Optional[torch.Tensor],
|
| 246 |
+
kv_cache: Optional[KVCache]) -> torch.Tensor:
|
| 247 |
+
|
| 248 |
+
output = x * torch.tensor(self.cfg.hidden_size ** 0.5, dtype=x.dtype)
|
| 249 |
+
for layer in self.layers:
|
| 250 |
+
output = layer(output,
|
| 251 |
+
position_ids,
|
| 252 |
+
attention_mask,
|
| 253 |
+
kv_cache)
|
| 254 |
+
output = self.norm(output)
|
| 255 |
+
return output
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
class Gemma(nn.Module):
|
| 259 |
+
def __init__(self, cfg: GemmaConfig) -> None:
|
| 260 |
+
super().__init__()
|
| 261 |
+
self.cfg = cfg
|
| 262 |
+
self.model = GemmaModel(cfg)
|
| 263 |
+
self.vocab_size = cfg.vocab_size
|
| 264 |
+
self.lm_head = nn.Linear(cfg.hidden_size, cfg.vocab_size, bias=False)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def gradient_checkpointing_enabled(self, enabled=False):
|
| 268 |
+
for name, module in self.model.named_modules():
|
| 269 |
+
if isinstance(module, GemmaTransformerDecoder):
|
| 270 |
+
module.gradient_checking = enabled
|
| 271 |
+
|
| 272 |
+
def tie_weights(self):
|
| 273 |
+
self.lm_head.weight = self.model.embed_tokens.weight
|
| 274 |
+
|
| 275 |
+
def forward(self,
|
| 276 |
+
input_embeds: torch.Tensor,
|
| 277 |
+
position_ids: Optional[torch.Tensor],
|
| 278 |
+
attention_mask: Optional[torch.Tensor],
|
| 279 |
+
kv_cache: Optional[KVCache]):
|
| 280 |
+
|
| 281 |
+
output = self.model(input_embeds,
|
| 282 |
+
position_ids,
|
| 283 |
+
attention_mask,
|
| 284 |
+
kv_cache)
|
| 285 |
+
return output, kv_cache
|
models/lora.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.nn.utils.parametrize as parametrize
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
class LoRALayer:
|
| 8 |
+
def __init__(self, features_in: int, features_out: int, rank: int=1, alphas: int=1):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.lora_A = nn.Linear(features_in, rank, bias=False)
|
| 11 |
+
self.lora_B = nn.Linear(rank, features_out, bias=False)
|
| 12 |
+
nn.init.normal_(self.lora_A.weight, mean=0, std=1/rank)
|
| 13 |
+
|
| 14 |
+
self.scale = alphas / rank
|
| 15 |
+
|
| 16 |
+
class LoRALinear(nn.Module, LoRALayer):
|
| 17 |
+
def __init__(self, base_layer: nn.Module, rank: int=1, alphas: int=1, dropout_p: float=0.0):
|
| 18 |
+
features_out, features_in = base_layer.weight.shape
|
| 19 |
+
super().__init__()
|
| 20 |
+
LoRALayer.__init__(self, features_in=features_in, features_out=features_out, rank=rank, alphas=alphas)
|
| 21 |
+
|
| 22 |
+
self.base_layer = nn.Linear(features_in, features_out, bias=False)
|
| 23 |
+
self.base_layer.weight = base_layer.weight
|
| 24 |
+
|
| 25 |
+
if dropout_p > 0.0:
|
| 26 |
+
self.lora_dropout = nn.Dropout(p=dropout_p, inplace=False)
|
| 27 |
+
else:
|
| 28 |
+
self.lora_dropout = nn.Identity()
|
| 29 |
+
|
| 30 |
+
self.enabled = False
|
| 31 |
+
|
| 32 |
+
def forward(self, x: torch.Tensor):
|
| 33 |
+
result = self.base_layer(x)
|
| 34 |
+
if self.enabled:
|
| 35 |
+
result = result + self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scale
|
| 36 |
+
return result
|
| 37 |
+
|
| 38 |
+
def enable_lora(model: nn.Module, lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'], enabled=True):
|
| 39 |
+
for name, module in model.named_modules():
|
| 40 |
+
if name.split('.')[-1] in lora_modules:
|
| 41 |
+
module.enabled = enabled
|
| 42 |
+
return model
|
| 43 |
+
|
| 44 |
+
def replace_module(module: nn.Module, target_modules: List[str], torch_dtype: torch.dtype, **kwargs):
|
| 45 |
+
for child_name, child_module in module.named_children():
|
| 46 |
+
if child_name in target_modules:
|
| 47 |
+
new_module = LoRALinear(child_module, **kwargs).to(torch_dtype)
|
| 48 |
+
setattr(module, child_name, new_module)
|
| 49 |
+
else:
|
| 50 |
+
replace_module(child_module, target_modules, torch_dtype, **kwargs)
|
| 51 |
+
|
| 52 |
+
def get_lora_model(model: nn.Module, rank: float, alphas: float, lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'], dropout_p: float = 0.0, training: bool = False, torch_dtype: torch.dtype = torch.bfloat16):
|
| 53 |
+
lora_config = {'rank': rank,
|
| 54 |
+
'alphas': alphas,
|
| 55 |
+
'dropout_p': dropout_p}
|
| 56 |
+
replace_module(model, lora_modules, torch_dtype, **lora_config)
|
| 57 |
+
|
| 58 |
+
for name, param in model.named_parameters():
|
| 59 |
+
if 'lora' not in name:
|
| 60 |
+
param.requires_grad = False
|
| 61 |
+
else:
|
| 62 |
+
if training:
|
| 63 |
+
param.requires_grad = True
|
| 64 |
+
else:
|
| 65 |
+
param.requires_grad = False
|
| 66 |
+
|
| 67 |
+
return model
|
| 68 |
+
|
models/paligemma.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from .gemma import GemmaConfig, Gemma, KVCache
|
| 6 |
+
from .siglip import SigLIPConfig, SigLIPVisionTower
|
| 7 |
+
from typing import Optional
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from safetensors import safe_open
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class PaliGemmaConfig:
|
| 15 |
+
bos_token_id: int = 2
|
| 16 |
+
eos_token_id: int = 1
|
| 17 |
+
hidden_size: int = 2048
|
| 18 |
+
ignore_index: int = -100
|
| 19 |
+
image_token_index: int = 257152
|
| 20 |
+
pad_token_id: int = 0
|
| 21 |
+
projection_dim: int = 2048
|
| 22 |
+
text_config: GemmaConfig = None
|
| 23 |
+
vision_config: SigLIPConfig = None
|
| 24 |
+
vocab_size: int = 257216
|
| 25 |
+
@classmethod
|
| 26 |
+
def from_dict(cls, data):
|
| 27 |
+
return cls(
|
| 28 |
+
bos_token_id = data['bos_token_id'],
|
| 29 |
+
eos_token_id = data['eos_token_id'],
|
| 30 |
+
hidden_size = data['hidden_size'],
|
| 31 |
+
ignore_index = data['ignore_index'],
|
| 32 |
+
image_token_index = data['image_token_index'],
|
| 33 |
+
pad_token_id = data['pad_token_id'],
|
| 34 |
+
projection_dim = data['projection_dim'],
|
| 35 |
+
text_config = GemmaConfig.from_dict(data['text_config']),
|
| 36 |
+
vision_config = SigLIPConfig.from_dict(data['vision_config'])
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
class PaliGemmaMultimodalProjector(nn.Module):
|
| 40 |
+
def __init__(self, cfg: PaliGemmaConfig):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.linear = nn.Linear(cfg.vision_config.hidden_size, cfg.vision_config.projection_dim)
|
| 43 |
+
|
| 44 |
+
def forward(self, x: torch.Tensor):
|
| 45 |
+
x = self.linear(x)
|
| 46 |
+
return x
|
| 47 |
+
|
| 48 |
+
class PaliGemma(nn.Module):
|
| 49 |
+
def __init__(self, cfg: PaliGemmaConfig):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.cfg = cfg
|
| 52 |
+
self.language_model = Gemma(cfg.text_config)
|
| 53 |
+
|
| 54 |
+
self.vision_tower = SigLIPVisionTower(cfg.vision_config)
|
| 55 |
+
|
| 56 |
+
self.multi_modal_projector = PaliGemmaMultimodalProjector(cfg)
|
| 57 |
+
|
| 58 |
+
def tie_weights(self):
|
| 59 |
+
self.language_model.tie_weights()
|
| 60 |
+
|
| 61 |
+
def _merge_img_embeds_and_input_embeds(self, img_embeds: torch.Tensor,
|
| 62 |
+
input_embeds: torch.Tensor,
|
| 63 |
+
input_tokens: torch.Tensor):
|
| 64 |
+
batch_size, seq_len, embed_dim = input_embeds.shape
|
| 65 |
+
scaled_img = img_embeds / (self.cfg.hidden_size ** 0.5)
|
| 66 |
+
|
| 67 |
+
final_embeddings = torch.zeros((batch_size, seq_len, embed_dim), dtype=img_embeds.dtype, device=img_embeds.device)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# (n, seq_len)
|
| 71 |
+
text_mask = (input_tokens != self.cfg.pad_token_id) & (input_tokens != self.cfg.image_token_index)
|
| 72 |
+
img_mask = input_tokens == self.cfg.image_token_index
|
| 73 |
+
pad_mask = input_tokens == self.cfg.pad_token_id
|
| 74 |
+
|
| 75 |
+
text_mask = text_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
| 76 |
+
img_mask = img_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
| 77 |
+
pad_mask = pad_mask.unsqueeze(-1).expand(-1, -1, embed_dim)
|
| 78 |
+
|
| 79 |
+
# (n, seq_len, embed_dim)
|
| 80 |
+
final_embeddings = torch.where(text_mask, input_embeds, final_embeddings)
|
| 81 |
+
final_embeddings = final_embeddings.masked_scatter(img_mask, scaled_img)
|
| 82 |
+
final_embeddings = torch.where(pad_mask, torch.zeros_like(final_embeddings), final_embeddings)
|
| 83 |
+
|
| 84 |
+
return final_embeddings
|
| 85 |
+
|
| 86 |
+
def _create_position_ids_and_attention_mask(self,
|
| 87 |
+
device: str = '',
|
| 88 |
+
dtype: torch.dtype = torch.float32,
|
| 89 |
+
batch_size: int = 32,
|
| 90 |
+
seq_len: int = 1,
|
| 91 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 92 |
+
kv_cache: Optional[KVCache] = None):
|
| 93 |
+
# Create Attention Mask
|
| 94 |
+
if kv_cache is None or kv_cache.num_items() == 0:
|
| 95 |
+
causal_mask = torch.full((batch_size, seq_len, seq_len), 0, dtype=dtype, device=device)
|
| 96 |
+
position_ids = attention_mask.cumsum(dim=-1).masked_fill_((attention_mask == 0), 1).to(device)
|
| 97 |
+
|
| 98 |
+
else:
|
| 99 |
+
assert seq_len == 1
|
| 100 |
+
kv_len = kv_cache.num_items() + 1
|
| 101 |
+
causal_mask = torch.full((batch_size, 1, kv_len), 0, dtype=dtype, device=device)
|
| 102 |
+
position_ids = attention_mask.cumsum(dim=-1)[:, -1].to(device)
|
| 103 |
+
|
| 104 |
+
# (n, seq_len, kv_len) -> (n, 1, seq_len, kv_len)
|
| 105 |
+
causal_mask = causal_mask.unsqueeze(1)
|
| 106 |
+
|
| 107 |
+
return position_ids, causal_mask
|
| 108 |
+
|
| 109 |
+
@staticmethod
|
| 110 |
+
def from_pretrained(model_dir):
|
| 111 |
+
with open(os.path.join(model_dir, 'config.json'), "r") as f:
|
| 112 |
+
model_config = json.loads(f.read())
|
| 113 |
+
config = PaliGemmaConfig.from_dict(model_config)
|
| 114 |
+
|
| 115 |
+
safetensor_files = Path(model_dir).glob("*.safetensors")
|
| 116 |
+
|
| 117 |
+
weights = {}
|
| 118 |
+
for file in safetensor_files:
|
| 119 |
+
with safe_open(file, framework='pt', device="cpu") as f:
|
| 120 |
+
for key in f.keys():
|
| 121 |
+
weights[key] = f.get_tensor(key)
|
| 122 |
+
|
| 123 |
+
model = PaliGemma(config)
|
| 124 |
+
model.load_state_dict(weights, strict=False)
|
| 125 |
+
model.tie_weights()
|
| 126 |
+
return model
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def forward(self, *args, **kwargs):
|
| 130 |
+
|
| 131 |
+
# input_tokens: (n, seq_len)
|
| 132 |
+
|
| 133 |
+
# -> (n, seq_len, embed_dim)
|
| 134 |
+
kv_cache = kwargs['kv_cache'] if 'kv_cache' in kwargs else None
|
| 135 |
+
input_tokens = kwargs['input_ids']
|
| 136 |
+
pixel_values = kwargs['pixel_values'] if 'pixel_values' in kwargs else None
|
| 137 |
+
attention_mask = kwargs['attention_mask']
|
| 138 |
+
input_embeds = self.language_model.model.embed_tokens(input_tokens)
|
| 139 |
+
if pixel_values is not None:
|
| 140 |
+
img_embeds = self.vision_tower(pixel_values.to(input_embeds.dtype))
|
| 141 |
+
img_embeds = self.multi_modal_projector(img_embeds)
|
| 142 |
+
final_embeddings = self._merge_img_embeds_and_input_embeds(img_embeds=img_embeds,
|
| 143 |
+
input_embeds=input_embeds,
|
| 144 |
+
input_tokens=input_tokens)
|
| 145 |
+
else:
|
| 146 |
+
final_embeddings = input_embeds
|
| 147 |
+
|
| 148 |
+
position_ids, causal_mask = self._create_position_ids_and_attention_mask(device=final_embeddings.device.type,
|
| 149 |
+
dtype=final_embeddings.dtype,
|
| 150 |
+
batch_size=final_embeddings.shape[0],
|
| 151 |
+
seq_len=final_embeddings.shape[1],
|
| 152 |
+
attention_mask=attention_mask,
|
| 153 |
+
kv_cache=kv_cache)
|
| 154 |
+
|
| 155 |
+
outputs, kv_cache = self.language_model(
|
| 156 |
+
input_embeds=final_embeddings,
|
| 157 |
+
position_ids=position_ids,
|
| 158 |
+
attention_mask=causal_mask,
|
| 159 |
+
kv_cache=kv_cache
|
| 160 |
+
)
|
| 161 |
+
return outputs, kv_cache
|
| 162 |
+
|
models/paligemma_processor.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from typing import Tuple, List
|
| 4 |
+
import numpy as np
|
| 5 |
+
from transformers import GemmaTokenizerFast, BatchFeature
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
def preprocess_imgs(imgs: List[Image.Image],
|
| 10 |
+
img_size: Tuple[int, int],
|
| 11 |
+
rescale: float,
|
| 12 |
+
mean: Tuple[float, float, float],
|
| 13 |
+
std: Tuple[float, float, float]):
|
| 14 |
+
|
| 15 |
+
def normalize(img, mean, std):
|
| 16 |
+
img = (img - np.array(mean, dtype=img.dtype)) / np.array(std, dtype=img.dtype)
|
| 17 |
+
return img
|
| 18 |
+
|
| 19 |
+
resized_imgs = [np.array(img.resize((img_size[0], img_size[1]), resample=3)) for img in imgs]
|
| 20 |
+
|
| 21 |
+
rescaled_imgs = [(img * rescale).astype(np.float32) for img in resized_imgs]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
normalized_imgs = [normalize(img, mean, std) for img in rescaled_imgs]
|
| 25 |
+
transposed_imgs = [img.transpose(2, 0, 1) for img in normalized_imgs]
|
| 26 |
+
|
| 27 |
+
tensor_imgs = torch.tensor(np.stack(transposed_imgs, axis=0), dtype=torch.float32)
|
| 28 |
+
return tensor_imgs
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def preprocess_prompts(prompt, image_token, max_num_image_token, bos_token):
|
| 32 |
+
return f"{image_token * max_num_image_token}{bos_token}{prompt}\n"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class PaliGemmaProcessor:
|
| 36 |
+
IMAGE_TOKEN = "<image>"
|
| 37 |
+
def __init__(self,
|
| 38 |
+
tokenizer: GemmaTokenizerFast) -> None:
|
| 39 |
+
|
| 40 |
+
additional_special_tokens = {"additional_special_tokens": [self.IMAGE_TOKEN]}
|
| 41 |
+
tokenizer.add_special_tokens(additional_special_tokens)
|
| 42 |
+
|
| 43 |
+
EXTRA_TOKENS = [
|
| 44 |
+
f"<loc{i:04d}>" for i in range(1024)
|
| 45 |
+
] # These tokens are used for object detection (bounding boxes)
|
| 46 |
+
EXTRA_TOKENS += [
|
| 47 |
+
f"<seg{i:03d}>" for i in range(128)
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
tokenizer.add_tokens(EXTRA_TOKENS)
|
| 51 |
+
|
| 52 |
+
tokenizer.add_bos_token = False
|
| 53 |
+
tokenizer.add_eos_token = False
|
| 54 |
+
|
| 55 |
+
self.tokenizer = tokenizer
|
| 56 |
+
|
| 57 |
+
def from_pretrained(self, pretrained_dir):
|
| 58 |
+
|
| 59 |
+
with open(os.path.join(pretrained_dir, "preprocessor_config.json"), "r") as f:
|
| 60 |
+
config = json.loads(f.read())
|
| 61 |
+
|
| 62 |
+
self.image_seq_length = config['image_seq_length']
|
| 63 |
+
self.image_mean = config['image_mean']
|
| 64 |
+
self.image_std = config['image_std']
|
| 65 |
+
self.resample = config['resample']
|
| 66 |
+
self.rescale_factor = config['rescale_factor']
|
| 67 |
+
self.size = (config['size']['height'], config['size']['width'])
|
| 68 |
+
return self
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def __call__(self,
|
| 72 |
+
imgs: List[Image.Image],
|
| 73 |
+
prompts: List[str],
|
| 74 |
+
padding: str = "longest",
|
| 75 |
+
truncation: bool = True,
|
| 76 |
+
max_length: int = None):
|
| 77 |
+
|
| 78 |
+
processed_imgs = preprocess_imgs(imgs,
|
| 79 |
+
img_size=self.size,
|
| 80 |
+
rescale=self.rescale_factor,
|
| 81 |
+
mean=self.image_mean,
|
| 82 |
+
std=self.image_mean)
|
| 83 |
+
|
| 84 |
+
processed_prompts = [preprocess_prompts(prompt,
|
| 85 |
+
image_token=self.IMAGE_TOKEN,
|
| 86 |
+
max_num_image_token=self.image_seq_length,
|
| 87 |
+
bos_token=self.tokenizer.bos_token) for prompt in prompts]
|
| 88 |
+
|
| 89 |
+
model_inputs = self.tokenizer(processed_prompts,
|
| 90 |
+
return_tensors='pt',
|
| 91 |
+
padding=padding,
|
| 92 |
+
truncation=truncation,
|
| 93 |
+
max_length=max_length)
|
| 94 |
+
|
| 95 |
+
return {**model_inputs, "pixel_values": processed_imgs}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
models/siglip.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class SigLIPConfig:
|
| 9 |
+
hidden_size: int = 1152
|
| 10 |
+
intermediate_size: int = 4304
|
| 11 |
+
num_attention_heads: int = 16
|
| 12 |
+
num_hidden_layers: int = 27
|
| 13 |
+
num_image_tokens: int = 256
|
| 14 |
+
patch_size: int = 14
|
| 15 |
+
projection_dim: int = 2048
|
| 16 |
+
n_channels: int = 3
|
| 17 |
+
img_size: int = 224
|
| 18 |
+
norm_eps: float = 1e-6
|
| 19 |
+
attention_dropout: float = 0.0
|
| 20 |
+
|
| 21 |
+
@classmethod
|
| 22 |
+
def from_dict(cls, data):
|
| 23 |
+
return cls(
|
| 24 |
+
hidden_size = data['hidden_size'],
|
| 25 |
+
intermediate_size = data['intermediate_size'],
|
| 26 |
+
num_attention_heads = data['num_attention_heads'],
|
| 27 |
+
num_hidden_layers = data['num_hidden_layers'],
|
| 28 |
+
num_image_tokens = data['num_image_tokens'],
|
| 29 |
+
patch_size = data['patch_size'],
|
| 30 |
+
projection_dim = data['projection_dim']
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
class SigLIPEmbedding(nn.Module):
|
| 34 |
+
def __init__(self, cfg: SigLIPConfig):
|
| 35 |
+
super().__init__()
|
| 36 |
+
self.patch_embedding = nn.Conv2d(cfg.n_channels, cfg.hidden_size, kernel_size=cfg.patch_size, stride=cfg.patch_size, padding='valid')
|
| 37 |
+
|
| 38 |
+
self.num_patches = (cfg.img_size // cfg.patch_size) ** 2
|
| 39 |
+
self.position_embedding = nn.Embedding(cfg.num_image_tokens, cfg.hidden_size)
|
| 40 |
+
|
| 41 |
+
self.register_buffer('position_ids',
|
| 42 |
+
torch.arange(cfg.num_image_tokens).expand(1, -1),
|
| 43 |
+
persistent=False)
|
| 44 |
+
|
| 45 |
+
def forward(self, x: torch.FloatTensor):
|
| 46 |
+
# x: (n, c, h, w) -> (n, c, num_patch_h, num_patch_w)
|
| 47 |
+
img_embeds = self.patch_embedding(x)
|
| 48 |
+
# (n, c, num_patch_h, num_patch_w) -> (n, c, num_patches) -> (n, num_patches, c)
|
| 49 |
+
img_embeds = img_embeds.reshape(*img_embeds.shape[:2], -1).transpose(1, 2)
|
| 50 |
+
return img_embeds + self.position_embedding(self.position_ids.to(torch.int64))
|
| 51 |
+
|
| 52 |
+
class SigLIPTransformerAttention(nn.Module):
|
| 53 |
+
def __init__(self, cfg: SigLIPConfig):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.cfg = cfg
|
| 56 |
+
self.num_attention_heads = cfg.num_attention_heads
|
| 57 |
+
self.head_dim = cfg.hidden_size // self.num_attention_heads
|
| 58 |
+
|
| 59 |
+
self.q_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size)
|
| 60 |
+
self.k_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size)
|
| 61 |
+
self.v_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size)
|
| 62 |
+
|
| 63 |
+
self.out_proj = nn.Linear(cfg.hidden_size, cfg.hidden_size)
|
| 64 |
+
self.dropout_p = self.cfg.attention_dropout
|
| 65 |
+
|
| 66 |
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor):
|
| 67 |
+
batch_size, num_patches, _ = x.shape
|
| 68 |
+
|
| 69 |
+
xq = self.q_proj(x)
|
| 70 |
+
xk = self.k_proj(x)
|
| 71 |
+
xv = self.v_proj(x)
|
| 72 |
+
|
| 73 |
+
xq = xq.view(batch_size, num_patches, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
| 74 |
+
xk = xk.view(batch_size, num_patches, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
| 75 |
+
xv = xv.view(batch_size, num_patches, self.num_attention_heads, self.head_dim).transpose(1, 2)
|
| 76 |
+
|
| 77 |
+
# attn_weights = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 78 |
+
|
| 79 |
+
# attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(xq.dtype)
|
| 80 |
+
|
| 81 |
+
# attn_output = torch.matmul(attn_weights, xv)
|
| 82 |
+
# attn_output = attn_output.transpose(1, 2).contiguous()
|
| 83 |
+
# attn_output = attn_output.view(batch_size, num_patches, -1)
|
| 84 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 85 |
+
query=xq,
|
| 86 |
+
key=xk,
|
| 87 |
+
value=xv,
|
| 88 |
+
attn_mask=attention_mask,
|
| 89 |
+
dropout_p=self.dropout_p,
|
| 90 |
+
is_causal=False
|
| 91 |
+
)
|
| 92 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 93 |
+
attn_output = attn_output.view(batch_size, num_patches, -1)
|
| 94 |
+
attn_output = self.out_proj(attn_output)
|
| 95 |
+
return attn_output, None
|
| 96 |
+
|
| 97 |
+
class SigLIPTransformerMLP(nn.Module):
|
| 98 |
+
def __init__(self, cfg: SigLIPConfig):
|
| 99 |
+
super().__init__()
|
| 100 |
+
self.cfg = cfg
|
| 101 |
+
|
| 102 |
+
self.fc1 = nn.Linear(cfg.hidden_size, cfg.intermediate_size)
|
| 103 |
+
self.fc2 = nn.Linear(cfg.intermediate_size, cfg.hidden_size)
|
| 104 |
+
|
| 105 |
+
def forward(self, x: torch.Tensor):
|
| 106 |
+
|
| 107 |
+
x = self.fc1(x)
|
| 108 |
+
x = F.gelu(x, approximate='tanh')
|
| 109 |
+
x = self.fc2(x)
|
| 110 |
+
return x
|
| 111 |
+
|
| 112 |
+
class SigLIPTransformerBlock(nn.Module):
|
| 113 |
+
def __init__(self, cfg: SigLIPConfig):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.layer_norm1 = nn.LayerNorm(cfg.hidden_size, eps=cfg.norm_eps)
|
| 116 |
+
self.layer_norm2 = nn.LayerNorm(cfg.hidden_size, eps=cfg.norm_eps)
|
| 117 |
+
|
| 118 |
+
self.self_attn = SigLIPTransformerAttention(cfg)
|
| 119 |
+
self.mlp = SigLIPTransformerMLP(cfg)
|
| 120 |
+
|
| 121 |
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
residual = x
|
| 123 |
+
x = self.layer_norm1(x)
|
| 124 |
+
x = residual + self.self_attn(x, attention_mask)[0]
|
| 125 |
+
residual = x
|
| 126 |
+
x = self.layer_norm2(x)
|
| 127 |
+
x = residual + self.mlp(x)
|
| 128 |
+
return x
|
| 129 |
+
|
| 130 |
+
class SigLIPTransformerEncoder(nn.Module):
|
| 131 |
+
def __init__(self, cfg: SigLIPConfig):
|
| 132 |
+
super().__init__()
|
| 133 |
+
|
| 134 |
+
self.cfg = cfg
|
| 135 |
+
self.layers = nn.ModuleList(
|
| 136 |
+
[SigLIPTransformerBlock(cfg) for _ in range(cfg.num_hidden_layers)]
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 140 |
+
for layer in self.layers:
|
| 141 |
+
x = layer(x, attention_mask)
|
| 142 |
+
return x
|
| 143 |
+
class SigLIPModel(nn.Module):
|
| 144 |
+
def __init__(self, cfg: SigLIPConfig):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.embeddings = SigLIPEmbedding(cfg)
|
| 147 |
+
self.encoder = SigLIPTransformerEncoder(cfg)
|
| 148 |
+
self.post_layernorm = nn.LayerNorm(cfg.hidden_size, eps=cfg.norm_eps)
|
| 149 |
+
|
| 150 |
+
def forward(self, x: torch.Tensor):
|
| 151 |
+
img_embed = self.embeddings(x)
|
| 152 |
+
output = self.encoder(img_embed)
|
| 153 |
+
output = self.post_layernorm(output)
|
| 154 |
+
return output
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class SigLIPVisionTower(nn.Module):
|
| 159 |
+
def __init__(self, cfg: SigLIPConfig):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.cfg = cfg
|
| 162 |
+
self.vision_model = SigLIPModel(cfg)
|
| 163 |
+
|
| 164 |
+
def forward(self, x: torch.Tensor):
|
| 165 |
+
return self.vision_model(x)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
|
pretrained/colpaligemma-3b-mix-448-base/adapter_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:caed65068cae6d50e572d984914324a7d8a9360cdd7f4263ea82f1792614391f
|
| 3 |
+
size 78625112
|
pretrained/colpaligemma-3b-mix-448-base/config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:012239f7d70c76d7f85bfca5e23f6afcde455f9ed23fa3f2ec9057b6028f6a5b
|
| 3 |
+
size 1047
|
pretrained/colpaligemma-3b-mix-448-base/model-00001-of-00002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c128f5670d7a66942a194be6e2d324dc329c0de19e99c6f047513878e14f988e
|
| 3 |
+
size 4986817288
|
pretrained/colpaligemma-3b-mix-448-base/model-00002-of-00002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8352c38e4d1785c4a35547d13f4d8d5562faab6fe8e9a30b1f5d8039d355a409
|
| 3 |
+
size 862495528
|
pretrained/colpaligemma-3b-mix-448-base/preprocessor_config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5fc342baea95529a5eb9746a0232fb88941d759812d7b616c382f2f87ba6123f
|
| 3 |
+
size 700
|
pretrained/colpaligemma-3b-mix-448-base/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ffd310e50986db7a039948ab83441d612689e7f989198e31b5c8984ca458adf6
|
| 3 |
+
size 17763459
|
pretrained/colpaligemma-3b-mix-448-base/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8986bb4f423f07f8c7f70d0dbe3526fb2316056c17bae71b1ea975e77a168fc6
|
| 3 |
+
size 4264023
|
pretrained/colpaligemma-3b-mix-448-base/tokenizer_config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d5e95b5ab863693113e65e4899e1db28c09d892fa84243c7dfe6ce7f727f1888
|
| 3 |
+
size 242696
|
prompt_templates.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from llama_index.core.question_gen.types import SubQuestion
|
| 3 |
+
from llama_index.core.tools.types import ToolMetadata
|
| 4 |
+
from llama_index.core.question_gen.prompts import build_tools_text
|
| 5 |
+
|
| 6 |
+
PREFIX = """\
|
| 7 |
+
Given a user question, and a list of tools, output a list of relevant sub-questions \
|
| 8 |
+
in json markdown that when composed can help answer the full user question:
|
| 9 |
+
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
example_query_str = (
|
| 13 |
+
"Compare and contrast the revenue growth and EBITDA of Uber and Lyft for year 2021"
|
| 14 |
+
)
|
| 15 |
+
example_tools = [
|
| 16 |
+
ToolMetadata(
|
| 17 |
+
name="uber_10k",
|
| 18 |
+
description="Provides information about Uber financials for year 2021",
|
| 19 |
+
),
|
| 20 |
+
ToolMetadata(
|
| 21 |
+
name="lyft_10k",
|
| 22 |
+
description="Provides information about Lyft financials for year 2021",
|
| 23 |
+
),
|
| 24 |
+
]
|
| 25 |
+
example_tools_str = build_tools_text(example_tools)
|
| 26 |
+
example_output = [
|
| 27 |
+
SubQuestion(
|
| 28 |
+
sub_question="What is the revenue growth of Uber", tool_name="uber_10k"
|
| 29 |
+
),
|
| 30 |
+
SubQuestion(sub_question="What is the EBITDA of Uber", tool_name="uber_10k"),
|
| 31 |
+
SubQuestion(
|
| 32 |
+
sub_question="What is the revenue growth of Lyft", tool_name="lyft_10k"
|
| 33 |
+
),
|
| 34 |
+
SubQuestion(sub_question="What is the EBITDA of Lyft", tool_name="lyft_10k"),
|
| 35 |
+
]
|
| 36 |
+
example_output_str = json.dumps(
|
| 37 |
+
{"items": [x.model_dump() for x in example_output]}, indent=4
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
EXAMPLES = f"""\
|
| 41 |
+
# Example 1
|
| 42 |
+
<Tools>
|
| 43 |
+
```json
|
| 44 |
+
{example_tools_str}
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
<User Question>
|
| 48 |
+
{example_query_str}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
<Output>
|
| 52 |
+
```json
|
| 53 |
+
{example_output_str}
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
SUFFIX = """\
|
| 59 |
+
# Example 2
|
| 60 |
+
<Tools>
|
| 61 |
+
```json
|
| 62 |
+
{tools_str}
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
<User Question>
|
| 66 |
+
{query_str}
|
| 67 |
+
|
| 68 |
+
<Output>
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
DEFAULT_SUB_QUESTION_PROMPT_TMPL = PREFIX + EXAMPLES + SUFFIX
|
| 72 |
+
|
| 73 |
+
DEFAULT_GEN_PROMPT_TMPL = """\
|
| 74 |
+
You are a helpful assistant that generates multiple search queries based on a \
|
| 75 |
+
single input query. Generate {num_queries} search queries, one on each line, \
|
| 76 |
+
related to the following input query:
|
| 77 |
+
Query: {query}
|
| 78 |
+
Queries:
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
DEFAULT_FINAL_ANSWER_PROMPT_TMPL = """\
|
| 82 |
+
Context information is below.
|
| 83 |
+
---------------------
|
| 84 |
+
{context_str}
|
| 85 |
+
---------------------
|
| 86 |
+
Given the context information and not prior knowledge, answer the query.
|
| 87 |
+
Query: {query_str}
|
| 88 |
+
Answer: \
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
SYNTHESIZE_PROMPT = """\
|
| 93 |
+
Context information is below.
|
| 94 |
+
---------------------
|
| 95 |
+
{context_str}
|
| 96 |
+
---------------------
|
| 97 |
+
Given the information from multiple sources and not prior knowledge,
|
| 98 |
+
Summarize the information that are most relevant to the queries and return index of choices chosen to summarize.
|
| 99 |
+
|
| 100 |
+
Query: {query_str}\n
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
SYNTHESIZE_OUTPUT_FORMAT = """Return the output that conforms to the JSON schema below.
|
| 105 |
+
Here is the output schema.
|
| 106 |
+
|
| 107 |
+
{
|
| 108 |
+
"properties": {
|
| 109 |
+
"summarized_text": {
|
| 110 |
+
"title": "Summarized Text",
|
| 111 |
+
"type": "string"
|
| 112 |
+
},
|
| 113 |
+
"choices": {
|
| 114 |
+
"items": {
|
| 115 |
+
"type": "integer"
|
| 116 |
+
},
|
| 117 |
+
"title": "Choices",
|
| 118 |
+
"type": "array"
|
| 119 |
+
}
|
| 120 |
+
},
|
| 121 |
+
"required": [
|
| 122 |
+
"summarized_text",
|
| 123 |
+
"choices"
|
| 124 |
+
],
|
| 125 |
+
"title": "SummarizeAnswer",
|
| 126 |
+
"type": "object"
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
Answer: \
|
| 130 |
+
""".replace("{", "{{").replace("}", "}}")
|
| 131 |
+
|
| 132 |
+
DEFAULT_SYNTHESIZE_PROMPT_TMPL = SYNTHESIZE_PROMPT + SYNTHESIZE_OUTPUT_FORMAT
|
rag_pipeline.py
ADDED
|
@@ -0,0 +1,531 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import asyncio
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
import os
|
| 5 |
+
import uuid
|
| 6 |
+
import base64
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from pdf2image import pdf2image
|
| 10 |
+
from typing import List, Union
|
| 11 |
+
from tqdm.auto import tqdm
|
| 12 |
+
|
| 13 |
+
from utils import *
|
| 14 |
+
from models import ColPali, ColPaliProcessor, get_lora_model, enable_lora
|
| 15 |
+
|
| 16 |
+
import qdrant_client
|
| 17 |
+
from qdrant_client.http import models as rest
|
| 18 |
+
from llamaindex_utils import ColPaliGemmaEmbedding, ColPaliRetriever, CustomFusionRetriever, CustomQueryEngine
|
| 19 |
+
from llama_index.llms.gemini import Gemini
|
| 20 |
+
from llama_index.core.tools import RetrieverTool
|
| 21 |
+
|
| 22 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 23 |
+
|
| 24 |
+
def embed_imgs(model: ColPali,
|
| 25 |
+
processor: ColPaliProcessor,
|
| 26 |
+
input_imgs: List[Image.Image],
|
| 27 |
+
device: str = 'cpu') -> List[torch.Tensor]:
|
| 28 |
+
"""Generates embeddings given images.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
model (ColPali): Main model
|
| 32 |
+
processor (ColPaliProcessor): Data Processor
|
| 33 |
+
input_imgs (List[Image.Image]): List of input images
|
| 34 |
+
device (str, optional): device to run model. Defaults to 'cpu'.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
List[torch.Tensor]: List of output embedings.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
colpali_model = model.to(device=device).eval()
|
| 41 |
+
|
| 42 |
+
dataloader = DataLoader(input_imgs,
|
| 43 |
+
batch_size=8,
|
| 44 |
+
shuffle=False,
|
| 45 |
+
num_workers=0,
|
| 46 |
+
collate_fn=lambda x: processor.process_images(x))
|
| 47 |
+
|
| 48 |
+
document_embeddings = []
|
| 49 |
+
with torch.no_grad():
|
| 50 |
+
for batch, model_inputs in tqdm(enumerate(dataloader)):
|
| 51 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
| 52 |
+
# Encode images
|
| 53 |
+
img_embeds = colpali_model(**model_inputs, kv_cache=None)
|
| 54 |
+
document_embeddings.extend(list(torch.unbind(img_embeds.to('cpu').to(torch.float32))))
|
| 55 |
+
return document_embeddings
|
| 56 |
+
|
| 57 |
+
def embed_queries(model: ColPali,
|
| 58 |
+
processor: ColPaliProcessor,
|
| 59 |
+
queries: List[str],
|
| 60 |
+
device: str = 'cpu') -> List[torch.Tensor]:
|
| 61 |
+
"""Generate embeddings given queries.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
model (ColPali): Embedding model
|
| 65 |
+
processor (ColPaliProcessor): Data Processor
|
| 66 |
+
queries (List[str]): List of query strings
|
| 67 |
+
device (str, optional): Device to run model. Defaults to 'cpu'.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
List[torch.Tensor]: List of embeddings
|
| 71 |
+
"""
|
| 72 |
+
colpali_model = model.to(device=device).eval()
|
| 73 |
+
|
| 74 |
+
dataloader = DataLoader(queries,
|
| 75 |
+
batch_size=8,
|
| 76 |
+
shuffle=False,
|
| 77 |
+
num_workers=0,
|
| 78 |
+
collate_fn=lambda x: processor.process_queries(x))
|
| 79 |
+
|
| 80 |
+
queries_embeddings = []
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
for batch, model_inputs in tqdm(enumerate(dataloader)):
|
| 83 |
+
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
|
| 84 |
+
# Encode Queries
|
| 85 |
+
query_embeds = colpali_model(**model_inputs, kv_cache=None)
|
| 86 |
+
queries_embeddings.extend(torch.unbind(query_embeds.to('cpu').type(torch.float32)))
|
| 87 |
+
|
| 88 |
+
return queries_embeddings
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def score_single_vectors(qs: List[torch.Tensor],
|
| 92 |
+
ps: List[torch.Tensor]) -> torch.FloatTensor:
|
| 93 |
+
"""Calculate similarity between 2 single vectors
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
qs (List[torch.Tensor]): First Embeddings
|
| 97 |
+
ps (List[torch.Tensor]): Second Embeddings
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
torch.FloatTensor: Score Tensor
|
| 101 |
+
"""
|
| 102 |
+
assert len(qs) != 0 and len(ps) != 0
|
| 103 |
+
|
| 104 |
+
qs_stacked = torch.stack(qs)
|
| 105 |
+
ps_stacked = torch.stack(ps)
|
| 106 |
+
|
| 107 |
+
scores = torch.einsum("bd,cd->bc", qs_stacked, ps_stacked)
|
| 108 |
+
assert scores.shape[0] == len(qs), f"Expected {len(qs)} scores, got {scores.shape[0]}"
|
| 109 |
+
scores = scores.to(torch.float32)
|
| 110 |
+
return scores
|
| 111 |
+
|
| 112 |
+
def score_multi_vectors(qs: List[torch.Tensor],
|
| 113 |
+
ps: List[torch.Tensor],
|
| 114 |
+
batch_size: int = 8,
|
| 115 |
+
device: Union[torch.device|str] = "cpu") -> torch.FloatTensor:
|
| 116 |
+
"""Calculate MaxSim between 2 list of vectors.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
qs (List[torch.Tensor]): List of query embeddings
|
| 120 |
+
ps (List[torch.Tensor]): List of document embeddings
|
| 121 |
+
batch_size (int, optional): Batch Size. Defaults to 8.
|
| 122 |
+
device (Union[torch.device | str], optional): Device to cast tensor to. Defaults to "cpu".
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
torch.FloatTensor: Score tensors.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
assert len(qs) != 0 and len(ps) != 0
|
| 129 |
+
scores_list = []
|
| 130 |
+
for i in range(0, len(qs), batch_size):
|
| 131 |
+
scores_batch = []
|
| 132 |
+
qs_batch = torch.nn.utils.rnn.pad_sequence(qs[i:i+batch_size], batch_first=True, padding_value=0).to(device)
|
| 133 |
+
for j in range(0, len(ps), batch_size):
|
| 134 |
+
ps_batch = torch.nn.utils.rnn.pad_sequence(ps[j:j+batch_size], batch_first=True, padding_value=0).to(device)
|
| 135 |
+
tmp = torch.einsum("abd,ced->acbe", qs_batch, ps_batch).max(dim=-1)[0].sum(dim=2)
|
| 136 |
+
scores_batch.append(tmp)
|
| 137 |
+
|
| 138 |
+
scores_batch = torch.cat(scores_batch, dim=1).cpu()
|
| 139 |
+
scores_list.append(scores_batch)
|
| 140 |
+
|
| 141 |
+
scores = torch.cat(scores_list, dim=0)
|
| 142 |
+
return scores.to(torch.float32)
|
| 143 |
+
|
| 144 |
+
def indexDocument(file_path: str,
|
| 145 |
+
vector_store_client,
|
| 146 |
+
target_collection: str,
|
| 147 |
+
model: nn.Module,
|
| 148 |
+
processor: ColPaliProcessor,
|
| 149 |
+
device: Union[str|torch.device]) -> None:
|
| 150 |
+
"""Index document given file_path.
|
| 151 |
+
Each page in document is embedded by ColPaliGemma Model, then insert into Qdrant vector store given target collection.
|
| 152 |
+
Creates taret collection if it is not created in the vector store yet.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
file_path (str): _description_
|
| 156 |
+
vector_store_client (_type_): _description_
|
| 157 |
+
target_collection (str): _description_
|
| 158 |
+
model (nn.Module): _description_
|
| 159 |
+
processor (ColPaliProcessor): _description_
|
| 160 |
+
device (Union[str | torch.device]): _description_
|
| 161 |
+
"""
|
| 162 |
+
document_images = []
|
| 163 |
+
document_embeddings = []
|
| 164 |
+
document_images.extend(pdf2image.convert_from_path(file_path))
|
| 165 |
+
|
| 166 |
+
document_embeddings = embed_imgs(model=model,
|
| 167 |
+
processor=processor,
|
| 168 |
+
input_imgs=document_images,
|
| 169 |
+
device=device)
|
| 170 |
+
|
| 171 |
+
# Create Qdrant Collectioon
|
| 172 |
+
if not vector_store_client.collection_exists(collection_name=target_collection):
|
| 173 |
+
# Specify vectors_config
|
| 174 |
+
scalar_quant = rest.ScalarQuantizationConfig(
|
| 175 |
+
type=rest.ScalarType.INT8,
|
| 176 |
+
quantile=0.99,
|
| 177 |
+
always_ram=False
|
| 178 |
+
)
|
| 179 |
+
vector_params = rest.VectorParams(
|
| 180 |
+
size=128,
|
| 181 |
+
distance=rest.Distance.COSINE,
|
| 182 |
+
multivector_config=rest.MultiVectorConfig(
|
| 183 |
+
comparator=rest.MultiVectorComparator.MAX_SIM
|
| 184 |
+
),
|
| 185 |
+
quantization_config=rest.ScalarQuantization(
|
| 186 |
+
scalar=scalar_quant
|
| 187 |
+
),
|
| 188 |
+
)
|
| 189 |
+
vector_store_client.create_collection(
|
| 190 |
+
collection_name=target_collection,
|
| 191 |
+
on_disk_payload=True,
|
| 192 |
+
optimizers_config=rest.OptimizersConfigDiff(
|
| 193 |
+
indexing_threshold=100
|
| 194 |
+
),
|
| 195 |
+
vectors_config=vector_params
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Add embedding to Qdrant Collection
|
| 199 |
+
points = []
|
| 200 |
+
for i, embedding in enumerate(document_embeddings):
|
| 201 |
+
multivector = embedding.cpu().float().numpy().tolist()
|
| 202 |
+
|
| 203 |
+
buffer = BytesIO()
|
| 204 |
+
document_images[i].save(buffer, format='JPEG')
|
| 205 |
+
image_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 206 |
+
# Define payload
|
| 207 |
+
payload = {}
|
| 208 |
+
node_metadata = {"file_name": file_path,
|
| 209 |
+
"page_id": i + 1}
|
| 210 |
+
node_content = {'id_': abs(hash(file_path + str(i + 1))),
|
| 211 |
+
'image': image_str,
|
| 212 |
+
"metadata": node_metadata}
|
| 213 |
+
|
| 214 |
+
payload["_node_content"] = json.dumps(node_content)
|
| 215 |
+
payload["_node_type"] = "ImageNode"
|
| 216 |
+
|
| 217 |
+
# store ref doc id at top level to allow metadata filtering
|
| 218 |
+
# kept for backwards compatibility, will consolidate in future
|
| 219 |
+
payload["document_id"] = "None" # for Chroma
|
| 220 |
+
payload["doc_id"] = "None" # for Pinecone, Qdrant, Redis
|
| 221 |
+
payload["ref_doc_id"] = "None" # for Weaviate
|
| 222 |
+
|
| 223 |
+
points.append(rest.PointStruct(
|
| 224 |
+
id=node_content['id_'],
|
| 225 |
+
vector=multivector,
|
| 226 |
+
payload=payload,
|
| 227 |
+
))
|
| 228 |
+
|
| 229 |
+
step = 8
|
| 230 |
+
for i in range(0, len(points), step):
|
| 231 |
+
points_batch = points[i: i + step]
|
| 232 |
+
vector_store_client.upsert(collection_name=target_collection,
|
| 233 |
+
points=points_batch,
|
| 234 |
+
wait=False)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
async def async_indexDocument(file_path: str,
|
| 238 |
+
vector_store_client: qdrant_client.AsyncQdrantClient,
|
| 239 |
+
target_collection: str,
|
| 240 |
+
model: nn.Module,
|
| 241 |
+
processor: ColPaliProcessor,
|
| 242 |
+
device: Union[str|torch.device]) -> None:
|
| 243 |
+
"""Asynchrously index document given file_path.
|
| 244 |
+
Each page in document is embedded by ColPaliGemma Model, then insert into Qdrant vector store given target collection.
|
| 245 |
+
Creates taret collection if it is not created in the vector store yet.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
file_path (str): _description_
|
| 249 |
+
vector_store_client (_type_): _description_
|
| 250 |
+
target_collection (str): _description_
|
| 251 |
+
model (nn.Module): _description_
|
| 252 |
+
processor (ColPaliProcessor): _description_
|
| 253 |
+
device (Union[str | torch.device]): _description_
|
| 254 |
+
"""
|
| 255 |
+
document_images = []
|
| 256 |
+
document_embeddings = []
|
| 257 |
+
document_images.extend(pdf2image.convert_from_path(file_path))
|
| 258 |
+
|
| 259 |
+
document_embeddings = embed_imgs(model=model,
|
| 260 |
+
processor=processor,
|
| 261 |
+
input_imgs=document_images,
|
| 262 |
+
device=device)
|
| 263 |
+
|
| 264 |
+
# Create Qdrant Collectioon
|
| 265 |
+
if not await vector_store_client.collection_exists(collection_name=target_collection):
|
| 266 |
+
# Specify vectors_config
|
| 267 |
+
scalar_quant = rest.ScalarQuantizationConfig(
|
| 268 |
+
type=rest.ScalarType.INT8,
|
| 269 |
+
quantile=0.99,
|
| 270 |
+
always_ram=False
|
| 271 |
+
)
|
| 272 |
+
vector_params = rest.VectorParams(
|
| 273 |
+
size=128,
|
| 274 |
+
distance=rest.Distance.COSINE,
|
| 275 |
+
multivector_config=rest.MultiVectorConfig(
|
| 276 |
+
comparator=rest.MultiVectorComparator.MAX_SIM
|
| 277 |
+
),
|
| 278 |
+
quantization_config=rest.ScalarQuantization(
|
| 279 |
+
scalar=scalar_quant
|
| 280 |
+
),
|
| 281 |
+
)
|
| 282 |
+
await vector_store_client.create_collection(
|
| 283 |
+
collection_name=target_collection,
|
| 284 |
+
on_disk_payload=True,
|
| 285 |
+
optimizers_config=rest.OptimizersConfigDiff(
|
| 286 |
+
indexing_threshold=100
|
| 287 |
+
),
|
| 288 |
+
vectors_config=vector_params
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# Add embedding to Qdrant Collection
|
| 292 |
+
points = []
|
| 293 |
+
for i, embedding in enumerate(document_embeddings):
|
| 294 |
+
multivector = embedding.cpu().float().numpy().tolist()
|
| 295 |
+
|
| 296 |
+
buffer = BytesIO()
|
| 297 |
+
document_images[i].save(buffer, format='JPEG')
|
| 298 |
+
image_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
| 299 |
+
# Define payload
|
| 300 |
+
payload = {}
|
| 301 |
+
node_metadata = {"file_name": file_path,
|
| 302 |
+
"page_id": i + 1}
|
| 303 |
+
node_content = {'id_': abs(hash(file_path + str(i + 1))),
|
| 304 |
+
'image': image_str,
|
| 305 |
+
"metadata": node_metadata}
|
| 306 |
+
|
| 307 |
+
payload["_node_content"] = json.dumps(node_content)
|
| 308 |
+
payload["_node_type"] = "ImageNode"
|
| 309 |
+
|
| 310 |
+
# store ref doc id at top level to allow metadata filtering
|
| 311 |
+
# kept for backwards compatibility, will consolidate in future
|
| 312 |
+
payload["document_id"] = "None" # for Chroma
|
| 313 |
+
payload["doc_id"] = "None" # for Pinecone, Qdrant, Redis
|
| 314 |
+
payload["ref_doc_id"] = "None" # for Weaviate
|
| 315 |
+
|
| 316 |
+
points.append(rest.PointStruct(
|
| 317 |
+
id=node_content['id_'],
|
| 318 |
+
vector=multivector,
|
| 319 |
+
payload=payload,
|
| 320 |
+
))
|
| 321 |
+
|
| 322 |
+
step = 8
|
| 323 |
+
for i in range(0, len(points), step):
|
| 324 |
+
points_batch = points[i: i + step]
|
| 325 |
+
await vector_store_client.upsert(collection_name=target_collection,
|
| 326 |
+
points=points_batch,
|
| 327 |
+
wait=False)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
GEMINI_API_KEY = os.getenv(key="GEMINI_API_KEY")
|
| 331 |
+
|
| 332 |
+
def main():
|
| 333 |
+
model = ColPali.from_pretrained(model_dir='./pretrained/colpaligemma-3b-mix-448-base', torch_dtype=torch.bfloat16)
|
| 334 |
+
tokenizer = load_tokenizer(tokenizer_dir='./pretrained/colpaligemma-3b-mix-448-base')
|
| 335 |
+
processor = ColPaliProcessor(tokenizer=tokenizer).from_pretrained(pretrained_dir='./pretrained/colpaligemma-3b-mix-448-base')
|
| 336 |
+
|
| 337 |
+
model.model.language_model.model = get_lora_model(model.model.language_model.model,
|
| 338 |
+
rank=32,
|
| 339 |
+
alphas=32,
|
| 340 |
+
lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'],
|
| 341 |
+
training=False,
|
| 342 |
+
dropout_p=0.1,
|
| 343 |
+
torch_dtype=torch.bfloat16)
|
| 344 |
+
model.model.language_model.model = enable_lora(model.model.language_model.model, lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'], enabled=True)
|
| 345 |
+
|
| 346 |
+
model = get_lora_model(model,
|
| 347 |
+
rank=32,
|
| 348 |
+
alphas=32,
|
| 349 |
+
lora_modules=['custom_text_proj'],
|
| 350 |
+
training=False,
|
| 351 |
+
dropout_p=0.1,
|
| 352 |
+
torch_dtype=torch.bfloat16)
|
| 353 |
+
model = enable_lora(model, lora_modules=['custom_text_proj'], enabled=True)
|
| 354 |
+
|
| 355 |
+
model.load_lora('./pretrained/colpaligemma-3b-mix-448-base')
|
| 356 |
+
|
| 357 |
+
# Initialize LLM
|
| 358 |
+
generation_config = {
|
| 359 |
+
"temperature": 0.0,
|
| 360 |
+
"top_p": 0.95,
|
| 361 |
+
"top_k": 64,
|
| 362 |
+
"max_output_tokens": 1024,
|
| 363 |
+
"response_mime_type": "text/plain",
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
llm = Gemini(api_key=GEMINI_API_KEY, generation_config=generation_config)
|
| 367 |
+
|
| 368 |
+
# Setup Qdrant
|
| 369 |
+
# Creating Qdrant Client
|
| 370 |
+
vector_store_client = qdrant_client.QdrantClient(location="http://localhost:6333", timeout=100)
|
| 371 |
+
|
| 372 |
+
indexDocument('./data/pdfs-financial/Alphabet_Inc_goog-10-q-q1-2024.pdf',
|
| 373 |
+
vector_store_client=vector_store_client,
|
| 374 |
+
target_collection="Alphabet",
|
| 375 |
+
model=model,
|
| 376 |
+
processor=processor,
|
| 377 |
+
device='mps')
|
| 378 |
+
|
| 379 |
+
indexDocument('./data/pdfs-financial/Nvidia_ecefb2b2-efcb-45f3-b72b-212d90fcd873.pdf',
|
| 380 |
+
vector_store_client=vector_store_client,
|
| 381 |
+
target_collection="Nvidia",
|
| 382 |
+
model=model,
|
| 383 |
+
processor=processor,
|
| 384 |
+
device='mps')
|
| 385 |
+
|
| 386 |
+
# RAG using LLamaIndex
|
| 387 |
+
|
| 388 |
+
embed_model = ColPaliGemmaEmbedding(model=model, processor=processor, device="mps")
|
| 389 |
+
|
| 390 |
+
alphabet_retriever = ColPaliRetriever(vector_store_client=vector_store_client,
|
| 391 |
+
target_collection="Alphabet",
|
| 392 |
+
embed_model=embed_model,
|
| 393 |
+
query_mode='default',
|
| 394 |
+
similarity_top_k=3)
|
| 395 |
+
|
| 396 |
+
nvidia_retriever = ColPaliRetriever(vector_store_client=vector_store_client,
|
| 397 |
+
target_collection="Nvidia",
|
| 398 |
+
embed_model=embed_model,
|
| 399 |
+
query_mode='default',
|
| 400 |
+
similarity_top_k=3)
|
| 401 |
+
|
| 402 |
+
# Query Router Among Multiple Retrievers
|
| 403 |
+
retriever_tools = [
|
| 404 |
+
RetrieverTool.from_defaults(
|
| 405 |
+
name="alphabet",
|
| 406 |
+
retriever=alphabet_retriever,
|
| 407 |
+
description="Useful for retrieving information about Alphabet Inc financials"
|
| 408 |
+
),
|
| 409 |
+
RetrieverTool.from_defaults(
|
| 410 |
+
name="nvidia",
|
| 411 |
+
retriever=nvidia_retriever,
|
| 412 |
+
description="Useful for retrieving information about Nvidia financials"
|
| 413 |
+
)
|
| 414 |
+
]
|
| 415 |
+
|
| 416 |
+
retriever_mappings = {retriever_tool.metadata.name: retriever_tool.retriever for retriever_tool in retriever_tools}
|
| 417 |
+
|
| 418 |
+
fusion_retriever = CustomFusionRetriever(llm=llm,
|
| 419 |
+
retriever_mappings=retriever_mappings,
|
| 420 |
+
num_generated_queries=3,
|
| 421 |
+
similarity_top_k=3)
|
| 422 |
+
|
| 423 |
+
query_engine = CustomQueryEngine(retriever_tools=[retriever_tool.metadata for retriever_tool in retriever_tools],
|
| 424 |
+
fusion_retriever=fusion_retriever,
|
| 425 |
+
llm=llm,
|
| 426 |
+
num_children=3)
|
| 427 |
+
|
| 428 |
+
query_str = "Compare the net income between Nvidia and Alphabet"
|
| 429 |
+
response = query_engine.query(query_str=query_str)
|
| 430 |
+
print(response.response)
|
| 431 |
+
|
| 432 |
+
async def amain():
|
| 433 |
+
model = ColPali.from_pretrained(model_dir='./pretrained/colpaligemma-3b-mix-448-base', torch_dtype=torch.bfloat16)
|
| 434 |
+
tokenizer = load_tokenizer(tokenizer_dir='./pretrained/colpaligemma-3b-mix-448-base')
|
| 435 |
+
processor = ColPaliProcessor(tokenizer=tokenizer).from_pretrained(pretrained_dir='./pretrained/colpaligemma-3b-mix-448-base')
|
| 436 |
+
|
| 437 |
+
model.model.language_model.model = get_lora_model(model.model.language_model.model,
|
| 438 |
+
rank=32,
|
| 439 |
+
alphas=32,
|
| 440 |
+
lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'],
|
| 441 |
+
training=False,
|
| 442 |
+
dropout_p=0.1,
|
| 443 |
+
torch_dtype=torch.bfloat16)
|
| 444 |
+
model.model.language_model.model = enable_lora(model.model.language_model.model, lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'], enabled=True)
|
| 445 |
+
|
| 446 |
+
model = get_lora_model(model,
|
| 447 |
+
rank=32,
|
| 448 |
+
alphas=32,
|
| 449 |
+
lora_modules=['custom_text_proj'],
|
| 450 |
+
training=False,
|
| 451 |
+
dropout_p=0.1,
|
| 452 |
+
torch_dtype=torch.bfloat16)
|
| 453 |
+
model = enable_lora(model, lora_modules=['custom_text_proj'], enabled=True)
|
| 454 |
+
|
| 455 |
+
model.load_lora('./pretrained/colpaligemma-3b-mix-448-base')
|
| 456 |
+
|
| 457 |
+
# Initialize LLM
|
| 458 |
+
generation_config = {
|
| 459 |
+
"temperature": 0.0,
|
| 460 |
+
"top_p": 0.95,
|
| 461 |
+
"top_k": 64,
|
| 462 |
+
"max_output_tokens": 1024,
|
| 463 |
+
"response_mime_type": "text/plain",
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
llm = Gemini(api_key=GEMINI_API_KEY, generation_config=generation_config)
|
| 467 |
+
|
| 468 |
+
# Setup Qdrant
|
| 469 |
+
# Creating Qdrant Client
|
| 470 |
+
vector_store_client = qdrant_client.AsyncQdrantClient(location="http://localhost:6333", timeout=100)
|
| 471 |
+
|
| 472 |
+
await async_indexDocument('./data/pdfs-financial/Alphabet_Inc_goog-10-q-q1-2024.pdf',
|
| 473 |
+
vector_store_client=vector_store_client,
|
| 474 |
+
target_collection="Alphabet",
|
| 475 |
+
model=model,
|
| 476 |
+
processor=processor,
|
| 477 |
+
device='mps')
|
| 478 |
+
|
| 479 |
+
await async_indexDocument('./data/pdfs-financial/Nvidia_ecefb2b2-efcb-45f3-b72b-212d90fcd873.pdf',
|
| 480 |
+
vector_store_client=vector_store_client,
|
| 481 |
+
target_collection="Nvidia",
|
| 482 |
+
model=model,
|
| 483 |
+
processor=processor,
|
| 484 |
+
device='mps')
|
| 485 |
+
|
| 486 |
+
embed_model = ColPaliGemmaEmbedding(model=model, processor=processor, device="mps")
|
| 487 |
+
|
| 488 |
+
alphabet_retriever = ColPaliRetriever(vector_store_client=vector_store_client,
|
| 489 |
+
target_collection="Alphabet",
|
| 490 |
+
embed_model=embed_model,
|
| 491 |
+
query_mode='default',
|
| 492 |
+
similarity_top_k=3)
|
| 493 |
+
|
| 494 |
+
nvidia_retriever = ColPaliRetriever(vector_store_client=vector_store_client,
|
| 495 |
+
target_collection="Nvidia",
|
| 496 |
+
embed_model=embed_model,
|
| 497 |
+
query_mode='default',
|
| 498 |
+
similarity_top_k=3)
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
# Query Router Among Multiple Retrievers
|
| 502 |
+
retriever_tools = [
|
| 503 |
+
RetrieverTool.from_defaults(
|
| 504 |
+
name="alphabet",
|
| 505 |
+
retriever=alphabet_retriever,
|
| 506 |
+
description="Useful for retrieving information about Alphabet Inc financials"
|
| 507 |
+
),
|
| 508 |
+
RetrieverTool.from_defaults(
|
| 509 |
+
name="nvidia",
|
| 510 |
+
retriever=nvidia_retriever,
|
| 511 |
+
description="Useful for retrieving information about Nvidia financials"
|
| 512 |
+
)
|
| 513 |
+
]
|
| 514 |
+
|
| 515 |
+
retriever_mappings = {retriever_tool.metadata.name: retriever_tool.retriever for retriever_tool in retriever_tools}
|
| 516 |
+
|
| 517 |
+
fusion_retriever = CustomFusionRetriever(llm=llm,
|
| 518 |
+
retriever_mappings=retriever_mappings,
|
| 519 |
+
similarity_top_k=3)
|
| 520 |
+
|
| 521 |
+
query_engine = CustomQueryEngine(retriever_tools=[retriever_tool.metadata for retriever_tool in retriever_tools],
|
| 522 |
+
fusion_retriever=fusion_retriever,
|
| 523 |
+
llm=llm,
|
| 524 |
+
num_children=3)
|
| 525 |
+
|
| 526 |
+
query_str = "Compare the net income between Nvidia and Alphabet"
|
| 527 |
+
response = await query_engine.aquery(query_str=query_str)
|
| 528 |
+
print(str(response))
|
| 529 |
+
|
| 530 |
+
if __name__ == "__main__":
|
| 531 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.1.0
|
| 2 |
+
aiofiles==23.2.1
|
| 3 |
+
aiohappyeyeballs==2.4.3
|
| 4 |
+
aiohttp==3.10.10
|
| 5 |
+
aiosignal==1.3.1
|
| 6 |
+
annotated-types==0.7.0
|
| 7 |
+
anyio==4.6.2.post1
|
| 8 |
+
appnope==0.1.4
|
| 9 |
+
argon2-cffi==23.1.0
|
| 10 |
+
argon2-cffi-bindings==21.2.0
|
| 11 |
+
arrow==1.3.0
|
| 12 |
+
asttokens==2.4.1
|
| 13 |
+
async-lru==2.0.4
|
| 14 |
+
attrs==24.2.0
|
| 15 |
+
babel==2.16.0
|
| 16 |
+
beautifulsoup4==4.12.3
|
| 17 |
+
bleach==6.2.0
|
| 18 |
+
cachetools==5.5.0
|
| 19 |
+
certifi==2024.8.30
|
| 20 |
+
cffi==1.17.1
|
| 21 |
+
charset-normalizer==3.4.0
|
| 22 |
+
click==8.1.7
|
| 23 |
+
comm==0.2.2
|
| 24 |
+
contourpy==1.3.0
|
| 25 |
+
cycler==0.12.1
|
| 26 |
+
dataclasses-json==0.6.7
|
| 27 |
+
datasets==3.0.1
|
| 28 |
+
debugpy==1.8.7
|
| 29 |
+
decorator==5.1.1
|
| 30 |
+
defusedxml==0.7.1
|
| 31 |
+
Deprecated==1.2.14
|
| 32 |
+
dill==0.3.8
|
| 33 |
+
dirtyjson==1.0.8
|
| 34 |
+
distro==1.9.0
|
| 35 |
+
executing==2.1.0
|
| 36 |
+
fastapi==0.115.4
|
| 37 |
+
fastjsonschema==2.20.0
|
| 38 |
+
ffmpy==0.4.0
|
| 39 |
+
filelock==3.16.1
|
| 40 |
+
fonttools==4.54.1
|
| 41 |
+
fqdn==1.5.1
|
| 42 |
+
frozenlist==1.5.0
|
| 43 |
+
fsspec==2024.6.1
|
| 44 |
+
google-ai-generativelanguage==0.6.4
|
| 45 |
+
google-api-core==2.20.0
|
| 46 |
+
google-api-python-client==2.147.0
|
| 47 |
+
google-auth==2.35.0
|
| 48 |
+
google-auth-httplib2==0.2.0
|
| 49 |
+
google-generativeai==0.5.4
|
| 50 |
+
googleapis-common-protos==1.65.0
|
| 51 |
+
gradio==4.44.1
|
| 52 |
+
gradio_client==1.3.0
|
| 53 |
+
greenlet==3.1.1
|
| 54 |
+
grpcio==1.67.1
|
| 55 |
+
grpcio-status==1.62.3
|
| 56 |
+
grpcio-tools==1.62.3
|
| 57 |
+
h11==0.14.0
|
| 58 |
+
h2==4.1.0
|
| 59 |
+
hpack==4.0.0
|
| 60 |
+
httpcore==1.0.6
|
| 61 |
+
httplib2==0.22.0
|
| 62 |
+
httpx==0.27.2
|
| 63 |
+
huggingface-hub==0.26.2
|
| 64 |
+
hyperframe==6.0.1
|
| 65 |
+
idna==3.10
|
| 66 |
+
importlib_resources==6.4.5
|
| 67 |
+
InstructorEmbedding==1.0.1
|
| 68 |
+
ipykernel==6.29.5
|
| 69 |
+
ipython==8.29.0
|
| 70 |
+
isoduration==20.11.0
|
| 71 |
+
jedi==0.19.1
|
| 72 |
+
Jinja2==3.1.4
|
| 73 |
+
jiter==0.7.0
|
| 74 |
+
joblib==1.4.2
|
| 75 |
+
json5==0.9.25
|
| 76 |
+
jsonpointer==3.0.0
|
| 77 |
+
jsonschema==4.23.0
|
| 78 |
+
jsonschema-specifications==2024.10.1
|
| 79 |
+
jupyter_client==8.6.3
|
| 80 |
+
jupyter_core==5.7.2
|
| 81 |
+
jupyter-events==0.10.0
|
| 82 |
+
jupyter-lsp==2.2.5
|
| 83 |
+
jupyter_server==2.14.2
|
| 84 |
+
jupyter_server_terminals==0.5.3
|
| 85 |
+
jupyterlab==4.2.5
|
| 86 |
+
jupyterlab_pygments==0.3.0
|
| 87 |
+
jupyterlab_server==2.27.3
|
| 88 |
+
kiwisolver==1.4.7
|
| 89 |
+
llama-cloud==0.1.2
|
| 90 |
+
llama-index==0.11.17
|
| 91 |
+
llama-index-agent-openai==0.3.4
|
| 92 |
+
llama-index-cli==0.3.1
|
| 93 |
+
llama-index-core==0.11.17
|
| 94 |
+
llama-index-embeddings-huggingface==0.3.1
|
| 95 |
+
llama-index-embeddings-instructor==0.2.1
|
| 96 |
+
llama-index-embeddings-openai==0.2.5
|
| 97 |
+
llama-index-indices-managed-llama-cloud==0.4.0
|
| 98 |
+
llama-index-legacy==0.9.48.post3
|
| 99 |
+
llama-index-llms-gemini==0.3.7
|
| 100 |
+
llama-index-llms-openai==0.2.13
|
| 101 |
+
llama-index-multi-modal-llms-gemini==0.3.1
|
| 102 |
+
llama-index-multi-modal-llms-openai==0.2.2
|
| 103 |
+
llama-index-postprocessor-colbert-rerank==0.2.1
|
| 104 |
+
llama-index-program-openai==0.2.0
|
| 105 |
+
llama-index-question-gen-openai==0.2.0
|
| 106 |
+
llama-index-readers-file==0.2.2
|
| 107 |
+
llama-index-readers-llama-parse==0.3.0
|
| 108 |
+
llama-index-vector-stores-qdrant==0.3.1
|
| 109 |
+
llama-parse==0.5.7
|
| 110 |
+
markdown-it-py==3.0.0
|
| 111 |
+
MarkupSafe==2.1.5
|
| 112 |
+
marshmallow==3.23.1
|
| 113 |
+
matplotlib==3.9.2
|
| 114 |
+
matplotlib-inline==0.1.7
|
| 115 |
+
mdurl==0.1.2
|
| 116 |
+
mistune==3.0.2
|
| 117 |
+
mpmath==1.3.0
|
| 118 |
+
multidict==6.1.0
|
| 119 |
+
multiprocess==0.70.16
|
| 120 |
+
mypy-extensions==1.0.0
|
| 121 |
+
nbclient==0.10.0
|
| 122 |
+
nbconvert==7.16.4
|
| 123 |
+
nbformat==5.10.4
|
| 124 |
+
nest-asyncio==1.6.0
|
| 125 |
+
networkx==3.4.2
|
| 126 |
+
nltk==3.9.1
|
| 127 |
+
notebook==7.2.2
|
| 128 |
+
notebook_shim==0.2.4
|
| 129 |
+
numpy==1.26.4
|
| 130 |
+
openai==1.53.0
|
| 131 |
+
orjson==3.10.11
|
| 132 |
+
overrides==7.7.0
|
| 133 |
+
packaging==24.1
|
| 134 |
+
pandas==2.2.3
|
| 135 |
+
pandocfilters==1.5.1
|
| 136 |
+
parso==0.8.4
|
| 137 |
+
pdf2image==1.17.0
|
| 138 |
+
peft==0.11.1
|
| 139 |
+
pexpect==4.9.0
|
| 140 |
+
pillow==10.4.0
|
| 141 |
+
pip==24.2
|
| 142 |
+
platformdirs==4.3.6
|
| 143 |
+
portalocker==2.10.1
|
| 144 |
+
prometheus_client==0.21.0
|
| 145 |
+
prompt_toolkit==3.0.48
|
| 146 |
+
propcache==0.2.0
|
| 147 |
+
proto-plus==1.24.0
|
| 148 |
+
protobuf==4.25.5
|
| 149 |
+
psutil==6.0.0
|
| 150 |
+
ptyprocess==0.7.0
|
| 151 |
+
pure_eval==0.2.3
|
| 152 |
+
pyarrow==17.0.0
|
| 153 |
+
pyasn1==0.6.1
|
| 154 |
+
pyasn1_modules==0.4.1
|
| 155 |
+
pycparser==2.22
|
| 156 |
+
pydantic==2.9.2
|
| 157 |
+
pydantic_core==2.23.4
|
| 158 |
+
pydub==0.25.1
|
| 159 |
+
Pygments==2.18.0
|
| 160 |
+
pyparsing==3.1.4
|
| 161 |
+
pypdf==4.3.1
|
| 162 |
+
python-dateutil==2.9.0.post0
|
| 163 |
+
python-json-logger==2.0.7
|
| 164 |
+
python-multipart==0.0.12
|
| 165 |
+
pytz==2024.2
|
| 166 |
+
PyYAML==6.0.2
|
| 167 |
+
pyzmq==26.2.0
|
| 168 |
+
qdrant-client==1.12.0
|
| 169 |
+
referencing==0.35.1
|
| 170 |
+
regex==2024.9.11
|
| 171 |
+
requests==2.32.3
|
| 172 |
+
rfc3339-validator==0.1.4
|
| 173 |
+
rfc3986-validator==0.1.1
|
| 174 |
+
rich==13.9.4
|
| 175 |
+
rpds-py==0.20.1
|
| 176 |
+
rsa==4.9
|
| 177 |
+
ruff==0.7.2
|
| 178 |
+
safetensors==0.4.5
|
| 179 |
+
scikit-learn==1.5.2
|
| 180 |
+
scipy==1.14.1
|
| 181 |
+
semantic-version==2.10.0
|
| 182 |
+
Send2Trash==1.8.3
|
| 183 |
+
sentence-transformers==2.7.0
|
| 184 |
+
setuptools==75.1.0
|
| 185 |
+
shellingham==1.5.4
|
| 186 |
+
six==1.16.0
|
| 187 |
+
sniffio==1.3.1
|
| 188 |
+
soupsieve==2.6
|
| 189 |
+
SQLAlchemy==2.0.36
|
| 190 |
+
stack-data==0.6.3
|
| 191 |
+
starlette==0.41.2
|
| 192 |
+
striprtf==0.0.26
|
| 193 |
+
sympy==1.13.3
|
| 194 |
+
tenacity==8.5.0
|
| 195 |
+
terminado==0.18.1
|
| 196 |
+
threadpoolctl==3.5.0
|
| 197 |
+
tiktoken==0.8.0
|
| 198 |
+
tinycss2==1.4.0
|
| 199 |
+
tokenizers==0.20.1
|
| 200 |
+
tomlkit==0.12.0
|
| 201 |
+
torch==2.4.1
|
| 202 |
+
torchinfo==1.8.0
|
| 203 |
+
torchvision==0.19.1
|
| 204 |
+
tornado==6.4.1
|
| 205 |
+
tqdm==4.66.5
|
| 206 |
+
traitlets==5.14.3
|
| 207 |
+
transformers==4.45.1
|
| 208 |
+
typer==0.12.5
|
| 209 |
+
types-python-dateutil==2.9.0.20241003
|
| 210 |
+
typing_extensions==4.12.2
|
| 211 |
+
typing-inspect==0.9.0
|
| 212 |
+
tzdata==2024.2
|
| 213 |
+
uri-template==1.3.0
|
| 214 |
+
uritemplate==4.1.1
|
| 215 |
+
urllib3==2.2.3
|
| 216 |
+
uvicorn==0.32.0
|
| 217 |
+
wcwidth==0.2.13
|
| 218 |
+
webcolors==24.8.0
|
| 219 |
+
webencodings==0.5.1
|
| 220 |
+
websocket-client==1.8.0
|
| 221 |
+
websockets==12.0
|
| 222 |
+
wheel==0.44.0
|
| 223 |
+
wrapt==1.16.0
|
| 224 |
+
xxhash==3.5.0
|
| 225 |
+
yarl==1.17.1
|
utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .utils import *
|
| 2 |
+
IMAGE_TOKEN = "<image>"
|
utils/utils.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from PIL import Image
|
| 3 |
+
from typing import Tuple, List
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import os
|
| 7 |
+
from transformers import AutoTokenizer, GemmaTokenizerFast
|
| 8 |
+
from safetensors import safe_open
|
| 9 |
+
import json
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from models.paligemma import PaliGemmaConfig, PaliGemma
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def load_model(model_dir: str):
|
| 15 |
+
|
| 16 |
+
with open(os.path.join(model_dir, 'config.json'), "r") as f:
|
| 17 |
+
model_config = json.loads(f.read())
|
| 18 |
+
config = PaliGemmaConfig.from_dict(model_config)
|
| 19 |
+
|
| 20 |
+
safetensor_files = Path(model_dir).glob("*.safetensors")
|
| 21 |
+
|
| 22 |
+
weights = {}
|
| 23 |
+
for file in safetensor_files:
|
| 24 |
+
with safe_open(file, framework='pt', device="cpu") as f:
|
| 25 |
+
for key in f.keys():
|
| 26 |
+
weights[key] = f.get_tensor(key)
|
| 27 |
+
|
| 28 |
+
model = PaliGemma(config)
|
| 29 |
+
model.load_state_dict(weights, strict=False)
|
| 30 |
+
model.tie_weights()
|
| 31 |
+
|
| 32 |
+
return model
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_tokenizer(tokenizer_dir: str):
|
| 36 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, padding_side='right')
|
| 37 |
+
return tokenizer
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def freeze_model(model: nn.Module):
|
| 41 |
+
for param in model.parameters():
|
| 42 |
+
param.requires_grad = False
|
| 43 |
+
|
| 44 |
+
return model
|