Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| import fitz # PyMuPDF | |
| from transformers.utils.import_utils import is_flash_attn_2_available | |
| from colpali_engine.models import ColQwen2, ColQwen2Processor | |
| # ----------------------------- | |
| # Load ColPali Model | |
| # ----------------------------- | |
| def load_colpali(): | |
| model_name = "vidore/colqwen2-v1.0" | |
| model = ColQwen2.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda:0" if torch.cuda.is_available() else "cpu", | |
| attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None, | |
| ).eval() | |
| processor = ColQwen2Processor.from_pretrained(model_name) | |
| return model, processor | |
| colpali_model, colpali_processor = load_colpali() | |
| st.title("π Visual PDF Search with ColPali") | |
| pdf_file = st.file_uploader("Upload a PDF", type="pdf") | |
| # ----------------------------- | |
| # Convert PDF to image | |
| # ----------------------------- | |
| def render_pdf_page_as_image(doc, zoom=2.0): | |
| images = [] | |
| for page in doc: | |
| mat = fitz.Matrix(zoom, zoom) | |
| pix = page.get_pixmap(matrix=mat) | |
| img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) | |
| images.append(img) | |
| return images | |
| # ----------------------------- | |
| # Chunk image into pieces | |
| # ----------------------------- | |
| def chunk_image(image, rows=3, cols=3): | |
| width, height = image.size | |
| chunk_width = width // cols | |
| chunk_height = height // rows | |
| chunks = [] | |
| for r in range(rows): | |
| for c in range(cols): | |
| left = c * chunk_width | |
| top = r * chunk_height | |
| right = left + chunk_width | |
| bottom = top + chunk_height | |
| chunk = image.crop((left, top, right, bottom)).resize((512, 512)) | |
| chunks.append(chunk) | |
| return chunks | |
| if pdf_file: | |
| doc = fitz.open(stream=pdf_file.read(), filetype="pdf") | |
| images = render_pdf_page_as_image(doc) | |
| if not images: | |
| st.warning("Failed to read content from the PDF.") | |
| else: | |
| all_chunks = [] | |
| for image in images: | |
| all_chunks.extend(chunk_image(image, rows=2, cols=2)) | |
| user_query = st.text_input("What are you looking for in the document?") | |
| if user_query: | |
| batch_images = colpali_processor.process_images(all_chunks).to(colpali_model.device) | |
| batch_queries = colpali_processor.process_queries([user_query]).to(colpali_model.device) | |
| with torch.no_grad(): | |
| image_embeddings = colpali_model(**batch_images) | |
| query_embeddings = colpali_model(**batch_queries) | |
| scores = colpali_processor.score_multi_vector(query_embeddings, image_embeddings) | |
| best_idx = torch.argmax(scores).item() | |
| best_image = all_chunks[best_idx] | |
| best_score = scores[0, best_idx].item() | |
| st.markdown("### π Most Relevant Image Chunk") | |
| st.image(best_image, caption=f"Score: {best_score:.4f}", use_column_width=True) | |