Update app.py
Browse files
app.py
CHANGED
|
@@ -10,6 +10,8 @@ from PIL import Image
|
|
| 10 |
from torch.utils.data import DataLoader
|
| 11 |
from tqdm import tqdm
|
| 12 |
|
|
|
|
|
|
|
| 13 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
| 14 |
|
| 15 |
|
|
@@ -30,42 +32,50 @@ def encode_image_to_base64(image):
|
|
| 30 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 31 |
|
| 32 |
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
"""Calls OpenAI's GPT-4o-mini with the query and image data."""
|
| 35 |
|
| 36 |
if api_key and api_key.startswith("sk"):
|
| 37 |
try:
|
| 38 |
from openai import OpenAI
|
| 39 |
-
|
| 40 |
-
base64_images = [encode_image_to_base64(image[0]) for image in images]
|
| 41 |
client = OpenAI(api_key=api_key.strip())
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
Give detailed and extensive answers, only containing info in the pages you are given.
|
| 47 |
-
You can answer using information contained in plots and figures if necessary.
|
| 48 |
-
Answer in the same language as the query.
|
| 49 |
-
|
| 50 |
-
Query: {query}
|
| 51 |
-
PDF pages:
|
| 52 |
"""
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
response = client.chat.completions.create(
|
| 55 |
model="gpt-4o-mini",
|
| 56 |
messages=[
|
| 57 |
{
|
| 58 |
"role": "user",
|
| 59 |
-
"content":
|
| 60 |
-
{
|
| 61 |
-
"type": "text",
|
| 62 |
-
"text": PROMPT.format(query=query)
|
| 63 |
-
}] + [{
|
| 64 |
-
"type": "image_url",
|
| 65 |
-
"image_url": {
|
| 66 |
-
"url": f"data:image/jpeg;base64,{im}"
|
| 67 |
-
},
|
| 68 |
-
} for im in base64_images]
|
| 69 |
}
|
| 70 |
],
|
| 71 |
max_tokens=500,
|
|
@@ -77,7 +87,7 @@ def query_gpt4o_mini(query, images, api_key):
|
|
| 77 |
return "Enter your OpenAI API key to get a custom response"
|
| 78 |
|
| 79 |
|
| 80 |
-
def search(query: str, ds, images, k, api_key):
|
| 81 |
k = min(k, len(ds))
|
| 82 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 83 |
if device != model.device:
|
|
@@ -95,7 +105,9 @@ def search(query: str, ds, images, k, api_key):
|
|
| 95 |
|
| 96 |
results = []
|
| 97 |
for idx in top_k_indices:
|
| 98 |
-
|
|
|
|
|
|
|
| 99 |
|
| 100 |
# Generate response from GPT-4o-mini
|
| 101 |
ai_response = query_gpt4o_mini(query, results, api_key)
|
|
@@ -103,22 +115,62 @@ def search(query: str, ds, images, k, api_key):
|
|
| 103 |
return results, ai_response
|
| 104 |
|
| 105 |
|
| 106 |
-
def index(files, ds):
|
| 107 |
print("Converting files")
|
| 108 |
-
images = convert_files(files)
|
| 109 |
print(f"Files converted with {len(images)} images.")
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
images = []
|
|
|
|
|
|
|
| 116 |
for f in files:
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
if len(images) >= 500:
|
| 120 |
raise gr.Error("The number of images in the dataset should be less than 500.")
|
| 121 |
-
return images
|
| 122 |
|
| 123 |
|
| 124 |
def index_gpu(images, ds):
|
|
@@ -141,7 +193,7 @@ def index_gpu(images, ds):
|
|
| 141 |
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
|
| 142 |
embeddings_doc = model(**batch_doc)
|
| 143 |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
| 144 |
-
return
|
| 145 |
|
| 146 |
|
| 147 |
|
|
@@ -166,6 +218,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 166 |
api_key = gr.Textbox(placeholder="Enter your OpenAI KEY here (optional)", label="API key")
|
| 167 |
embeds = gr.State(value=[])
|
| 168 |
imgs = gr.State(value=[])
|
|
|
|
| 169 |
|
| 170 |
with gr.Column(scale=3):
|
| 171 |
gr.Markdown("## 2️⃣ Search")
|
|
@@ -178,8 +231,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 178 |
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
|
| 179 |
output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents")
|
| 180 |
|
| 181 |
-
convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])
|
| 182 |
-
search_button.click(search, inputs=[query, embeds, imgs, k, api_key], outputs=[output_gallery, output_text])
|
| 183 |
|
| 184 |
if __name__ == "__main__":
|
| 185 |
-
demo.queue(max_size=5).launch(debug=True)
|
|
|
|
| 10 |
from torch.utils.data import DataLoader
|
| 11 |
from tqdm import tqdm
|
| 12 |
|
| 13 |
+
from pqdm.processes import pqdm
|
| 14 |
+
|
| 15 |
from colpali_engine.models import ColQwen2, ColQwen2Processor
|
| 16 |
|
| 17 |
|
|
|
|
| 32 |
return base64.b64encode(buffered.getvalue()).decode("utf-8")
|
| 33 |
|
| 34 |
|
| 35 |
+
DEFAULT_SYSTEM_PROMPT = """
|
| 36 |
+
You are a smart assistant designed to answer questions about a PDF document.
|
| 37 |
+
You are given relevant information in the form of PDF pages preceded by their metadata (PDF title, page number, surrounding context).
|
| 38 |
+
Use them to construct a short response to the question, and cite your sources (page number, pdf title).
|
| 39 |
+
If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents.
|
| 40 |
+
Give detailed and extensive answers, only containing info in the pages you are given.
|
| 41 |
+
You can answer using information contained in plots and figures if necessary.
|
| 42 |
+
Answer in the same language as the query.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def query_gpt4o_mini(query, images, api_key, system_prompt=DEFAULT_SYSTEM_PROMPT):
|
| 46 |
"""Calls OpenAI's GPT-4o-mini with the query and image data."""
|
| 47 |
|
| 48 |
if api_key and api_key.startswith("sk"):
|
| 49 |
try:
|
| 50 |
from openai import OpenAI
|
| 51 |
+
|
|
|
|
| 52 |
client = OpenAI(api_key=api_key.strip())
|
| 53 |
+
prompt = f"""
|
| 54 |
+
{system_prompt}
|
| 55 |
+
Query: {query}
|
| 56 |
+
PDF pages:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
"""
|
| 58 |
+
|
| 59 |
+
messages = [{"type": "text", "text": prompt}]
|
| 60 |
+
for im, capt in images:
|
| 61 |
+
if capt is not None:
|
| 62 |
+
messages.append({
|
| 63 |
+
"type": "text",
|
| 64 |
+
"text": capt
|
| 65 |
+
})
|
| 66 |
+
messages.append({
|
| 67 |
+
"type": "image_url",
|
| 68 |
+
"image_url": {
|
| 69 |
+
"url": f"data:image/jpeg;base64,{encode_image_to_base64(im)}"
|
| 70 |
+
},
|
| 71 |
+
})
|
| 72 |
+
|
| 73 |
response = client.chat.completions.create(
|
| 74 |
model="gpt-4o-mini",
|
| 75 |
messages=[
|
| 76 |
{
|
| 77 |
"role": "user",
|
| 78 |
+
"content": messages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
}
|
| 80 |
],
|
| 81 |
max_tokens=500,
|
|
|
|
| 87 |
return "Enter your OpenAI API key to get a custom response"
|
| 88 |
|
| 89 |
|
| 90 |
+
def search(query: str, ds, images, metadatas, k, api_key):
|
| 91 |
k = min(k, len(ds))
|
| 92 |
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
| 93 |
if device != model.device:
|
|
|
|
| 105 |
|
| 106 |
results = []
|
| 107 |
for idx in top_k_indices:
|
| 108 |
+
img = images[idx]
|
| 109 |
+
meta = metadatas[idx]
|
| 110 |
+
results.append((img, f"Document: {meta['title']}, Page: {meta['page']}, Context: {meta['context']}"))
|
| 111 |
|
| 112 |
# Generate response from GPT-4o-mini
|
| 113 |
ai_response = query_gpt4o_mini(query, results, api_key)
|
|
|
|
| 115 |
return results, ai_response
|
| 116 |
|
| 117 |
|
| 118 |
+
def index(files, ds, api_key):
|
| 119 |
print("Converting files")
|
| 120 |
+
images, metadatas = convert_files(files, api_key)
|
| 121 |
print(f"Files converted with {len(images)} images.")
|
| 122 |
+
ds = index_gpu(images, ds)
|
| 123 |
+
print(f"Indexed {len(ds)} images.")
|
| 124 |
+
return f"Uploaded and converted {len(images)} pages", ds, images, metadatas
|
| 125 |
+
|
| 126 |
+
DEFAULT_CONTEXT_PROMPT = """
|
| 127 |
+
You are a smart assistant designed to extract context of PDF pages.
|
| 128 |
+
Give detailed and extensive answers, only containing info in the pages you are given.
|
| 129 |
+
You can answer using information contained in plots and figures if necessary.
|
| 130 |
+
Answer in the same language as the query.
|
| 131 |
+
"""
|
| 132 |
+
|
| 133 |
+
def extract_context(images, api_key, window=10):
|
| 134 |
+
"""Extracts context from images."""
|
| 135 |
+
prompt = "Give the general context about these pages."
|
| 136 |
+
window_contexts = []
|
| 137 |
+
|
| 138 |
+
args = [(prompt, (images[max(i-window+1, 0):i+1], None), api_key, DEFAULT_CONTEXT_PROMPT)
|
| 139 |
+
for i in range(0, len(images), window)]
|
| 140 |
+
window_contexts = pqdm(args, query_gpt4o_mini, n_jobs=8)
|
| 141 |
+
|
| 142 |
+
# for i in tqdm(range(0, len(images), window), desc="Extracting context", total=len(images)//window):
|
| 143 |
+
# window_images = images[max(i-window+1, 0):i+1]
|
| 144 |
+
# window_images = [(image, None) for image in window_images]
|
| 145 |
+
# window_contexts.append(query_gpt4o_mini(prompt, window_images, api_key, system_prompt=DEFAULT_CONTEXT_PROMPT))
|
| 146 |
+
|
| 147 |
+
contexts = []
|
| 148 |
+
for i in range(len(images)):
|
| 149 |
+
context = window_contexts[i//window]
|
| 150 |
+
contexts.append(context)
|
| 151 |
+
|
| 152 |
+
assert len(contexts) == len(images)
|
| 153 |
+
return contexts
|
| 154 |
+
|
| 155 |
+
def extract_metadata(file, images, api_key, window=10):
|
| 156 |
+
"""Extracts metadata from pdfs. Extract page number, file name, and authors."""
|
| 157 |
+
title = file.split("/")[-1]
|
| 158 |
+
contexts = extract_context(images, api_key, window=window)
|
| 159 |
+
return [{"title": title, "page": i+1, "context": contexts[i]} for i in range(len(images))]
|
| 160 |
+
|
| 161 |
+
def convert_files(files, api_key):
|
| 162 |
images = []
|
| 163 |
+
metadatas = []
|
| 164 |
+
|
| 165 |
for f in files:
|
| 166 |
+
file_images = convert_from_path(f, thread_count=4)
|
| 167 |
+
file_metadatas = extract_metadata(f, file_images, api_key)
|
| 168 |
+
images.extend(file_images)
|
| 169 |
+
metadatas.extend(file_metadatas)
|
| 170 |
|
| 171 |
if len(images) >= 500:
|
| 172 |
raise gr.Error("The number of images in the dataset should be less than 500.")
|
| 173 |
+
return images, metadatas
|
| 174 |
|
| 175 |
|
| 176 |
def index_gpu(images, ds):
|
|
|
|
| 193 |
batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
|
| 194 |
embeddings_doc = model(**batch_doc)
|
| 195 |
ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
|
| 196 |
+
return ds
|
| 197 |
|
| 198 |
|
| 199 |
|
|
|
|
| 218 |
api_key = gr.Textbox(placeholder="Enter your OpenAI KEY here (optional)", label="API key")
|
| 219 |
embeds = gr.State(value=[])
|
| 220 |
imgs = gr.State(value=[])
|
| 221 |
+
metadatas = gr.State(value=[])
|
| 222 |
|
| 223 |
with gr.Column(scale=3):
|
| 224 |
gr.Markdown("## 2️⃣ Search")
|
|
|
|
| 231 |
output_gallery = gr.Gallery(label="Retrieved Documents", height=600, show_label=True)
|
| 232 |
output_text = gr.Textbox(label="AI Response", placeholder="Generated response based on retrieved documents")
|
| 233 |
|
| 234 |
+
convert_button.click(index, inputs=[file, embeds, api_key], outputs=[message, embeds, imgs, metadatas])
|
| 235 |
+
search_button.click(search, inputs=[query, embeds, imgs, metadatas, k, api_key], outputs=[output_gallery, output_text])
|
| 236 |
|
| 237 |
if __name__ == "__main__":
|
| 238 |
+
demo.queue(max_size=5).launch(debug=True, server_name="0.0.0.0")
|