| |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| import os |
| import base64 |
| import subprocess |
| from io import BytesIO |
| from tqdm import tqdm |
| from pdf2image import convert_from_path |
| import torch |
| from torch.utils.data import DataLoader |
| from transformers.utils.import_utils import is_flash_attn_2_available |
| from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor |
| from openai import OpenAI |
| import spaces |
| import gradio as gr |
|
|
|
|
| |
| |
|
|
| |
| model = ColQwen2_5.from_pretrained( |
| "vidore/colqwen2.5-v0.2", |
| torch_dtype=torch.bfloat16, |
| device_map="cuda:0", |
| attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, |
| ).eval() |
| processor = ColQwen2_5_Processor.from_pretrained("vidore/colqwen2.5-v0.2") |
|
|
|
|
| |
| |
| |
| def encode_image_to_base64(image): |
| """Encodes a PIL image to a base64 string.""" |
| buffered = BytesIO() |
| image.save(buffered, format="JPEG") |
| return base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
| def convert_files(files): |
| """Converts a list of PDF files to a list of images.""" |
| images = [] |
| for f in files: |
| images.extend(convert_from_path(f, thread_count=4)) |
|
|
| |
| if len(images) >= 150: |
| raise gr.Error("The number of images in the dataset should be less than 150.") |
| return images |
|
|
|
|
| |
| |
| |
| @spaces.GPU |
| def index_gpu(images, ds): |
| """Runs inference on the GPU for the given images with the visual document retrieval model.""" |
| |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| if device != model.device: |
| model.to(device) |
|
|
| |
| dataloader = DataLoader( |
| images, |
| batch_size=4, |
| |
| shuffle=False, |
| collate_fn=lambda x: processor.process_images(x).to(model.device), |
| ) |
|
|
| |
| for batch_doc in tqdm(dataloader): |
| with torch.no_grad(): |
| batch_doc = {k: v.to(device) for k, v in batch_doc.items()} |
| embeddings_doc = model(**batch_doc) |
| ds.extend(list(torch.unbind(embeddings_doc.to("cpu")))) |
| return f"Uploaded and converted {len(images)} pages", ds, images |
|
|
|
|
| def query_gemini(query, images, api_key): |
| """Calls Google's Gemini model with the query and image data.""" |
| if api_key: |
| try: |
| |
| base64_images = [encode_image_to_base64(image[0]) for image in images] |
|
|
| |
| client = OpenAI( |
| api_key=api_key.strip(), |
| base_url="https://generativelanguage.googleapis.com/v1beta/openai/", |
| ) |
| PROMPT = """ |
| You are a smart assistant designed to answer questions about a PDF document. |
| You are given relevant information in the form of PDF pages. Use them to construct a short response to the question, and cite your sources (page numbers, etc). |
| If it is not possible to answer using the provided pages, do not attempt to provide an answer and simply say the answer is not present within the documents. |
| Give detailed and extensive answers, only containing info in the pages you are given. |
| You can answer using information contained in plots and figures if necessary. |
| Answer in the same language as the query. |
| |
| Query: {query} |
| PDF pages: |
| """ |
|
|
| |
| response = client.chat.completions.create( |
| model="gemini-2.5-flash-lite", |
| reasoning_effort="none", |
| messages=[ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": PROMPT.format(query=query)} |
| ] |
| + [ |
| { |
| "type": "image_url", |
| "image_url": {"url": f"data:image/jpeg;base64,{im}"}, |
| } |
| for im in base64_images |
| ], |
| } |
| ], |
| max_tokens=500, |
| ) |
| |
| |
| return response.choices[0].message.content |
| |
| |
| except Exception as e: |
| return "API connection error! Please check your API key and try again." |
|
|
| |
| return "Enter your Gemini API key to get a custom response." |
|
|
|
|
| |
| |
| |
| def index(files, ds): |
| """Convert files to images and index them.""" |
| images = convert_files(files) |
| return index_gpu(images, ds) |
|
|
|
|
| @spaces.GPU |
| def search(query: str, ds, images, k, api_key): |
| """Search for the most relevant pages based on the query.""" |
| k = min(k, len(ds)) |
|
|
| |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| if device != model.device: |
| model.to(device) |
|
|
| |
| qs = [] |
| with torch.no_grad(): |
| batch_query = processor.process_queries([query]).to(model.device) |
| embeddings_query = model(**batch_query) |
| qs.extend(list(torch.unbind(embeddings_query.to("cpu")))) |
|
|
| |
| scores = processor.score(qs, ds, device=device) |
| top_k_indices = scores[0].topk(k).indices.tolist() |
|
|
| |
| results = [] |
| for idx in top_k_indices: |
| img = images[idx] |
| img_copy = img.copy() |
| results.append((img_copy, f"Page {idx}")) |
|
|
| |
| ai_response = query_gemini(query, results, api_key) |
|
|
| return results, ai_response |
|
|
|
|
| |
| |
| |
| with gr.Blocks(theme=gr.themes.Ocean()) as demo: |
| gr.Markdown( |
| "# Multimodal RAG with ColVision & Gemini 📚" |
| ) |
| gr.Markdown( |
| """Demo to test ColQwen2.5 (ColPali) on PDF documents. |
| ColPali is a model implemented from the paper [ColPali: Efficient Document Retrieval with Vision Language Models](https://arxiv.org/abs/2407.01449). |
| This demo allows you to upload PDF files and search for the most relevant pages based on your query. |
| Refresh the page if you change documents! |
| ⚠️ This demo uses a model trained exclusively on A4 PDFs in portrait mode, containing English text. Performance is expected to drop for other page formats and languages. |
| Other models will be released with better robustness towards different languages and document formats! |
| """ |
| ) |
| with gr.Row(): |
| with gr.Column(scale=2): |
| gr.Markdown("## 1️⃣ Upload PDFs") |
| file = gr.File( |
| file_types=[".pdf"], file_count="multiple", label="Upload PDFs" |
| ) |
|
|
| gr.Markdown("## 2️⃣ Index the PDFs") |
| message = gr.Textbox("Files not yet uploaded", label="Status") |
| convert_button = gr.Button("🔄 Index documents") |
| embeds = gr.State(value=[]) |
| imgs = gr.State(value=[]) |
|
|
| with gr.Column(scale=3): |
| gr.Markdown("## 3️⃣ Search") |
| api_key = gr.Textbox( |
| placeholder="Enter your Gemini API key here (must be valid)", |
| label="API key", |
| ) |
| query = gr.Textbox(placeholder="Enter your query here", label="Query") |
| k = gr.Slider( |
| minimum=1, |
| maximum=10, |
| step=1, |
| label="Number of results", |
| value=3, |
| info="Number of pages to retrieve", |
| ) |
| search_button = gr.Button("🔍 Search", variant="primary") |
|
|
| |
| gr.Markdown("## 4️⃣ Retrieved Image") |
| output_gallery = gr.Gallery( |
| label="Retrieved Documents", height=600, show_label=True |
| ) |
|
|
| gr.Markdown("## 5️⃣ Gemini Response") |
| output_text = gr.Textbox( |
| label="AI Response", |
| placeholder="Generated response based on retrieved documents", |
| show_copy_button=True, |
| ) |
|
|
| |
| convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs]) |
| search_button.click( |
| search, |
| inputs=[query, embeds, imgs, k, api_key], |
| outputs=[output_gallery, output_text], |
| ) |
|
|
|
|
| |
| if __name__ == "__main__": |
| demo.queue(max_size=10).launch() |
|
|