Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| Gradio Demo for Document Retrieval - Hugging Face Spaces with ZeroGPU | |
| This script creates a Gradio interface for testing both BiGemma3 and ColGemma3 models | |
| with PDF document upload, automatic conversion to images, and query-based retrieval. | |
| Features: | |
| - PDF upload with automatic conversion to images | |
| - Model selection: NetraEmbed (BiGemma3), ColNetraEmbed (ColGemma3), or Both | |
| - Query input with top-k selection (default: 5) | |
| - Similarity score display | |
| - Side-by-side comparison when both models are selected | |
| - Progressive loading with real-time updates | |
| - Proper error handling | |
| - ZeroGPU integration for efficient GPU usage | |
| """ | |
| import io | |
| import gc | |
| import math | |
| from typing import Iterator, List, Optional, Tuple | |
| import gradio as gr | |
| import torch | |
| import spaces | |
| from pdf2image import convert_from_path | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import seaborn as sns | |
| from einops import rearrange | |
| # Import from colpali_engine | |
| from colpali_engine.models import BiGemma3, BiGemmaProcessor3, ColGemma3, ColGemmaProcessor3 | |
| from colpali_engine.interpretability import get_similarity_maps_from_embeddings | |
| from colpali_engine.interpretability.similarity_map_utils import normalize_similarity_map | |
| # Configuration | |
| MAX_BATCH_SIZE = 32 # Maximum pages to process at once | |
| DEFAULT_DURATION = 120 # Default GPU duration in seconds | |
| # Global state for models and indexed documents | |
| class DocumentIndex: | |
| def __init__(self): | |
| self.images: List[Image.Image] = [] | |
| self.bigemma_embeddings = None | |
| self.colgemma_embeddings = None | |
| self.bigemma_model = None | |
| self.bigemma_processor = None | |
| self.colgemma_model = None | |
| self.colgemma_processor = None | |
| self.models_loaded = {"bigemma": False, "colgemma": False} | |
| doc_index = DocumentIndex() | |
| # Helper functions | |
| def get_loaded_models() -> List[str]: | |
| """Get list of currently loaded models.""" | |
| loaded = [] | |
| if doc_index.bigemma_model is not None: | |
| loaded.append("BiGemma3") | |
| if doc_index.colgemma_model is not None: | |
| loaded.append("ColGemma3") | |
| return loaded | |
| def get_model_choice_from_loaded() -> str: | |
| """Determine model choice string based on what's loaded.""" | |
| loaded = get_loaded_models() | |
| if "BiGemma3" in loaded and "ColGemma3" in loaded: | |
| return "Both" | |
| elif "BiGemma3" in loaded: | |
| return "NetraEmbed (BiGemma3)" | |
| elif "ColGemma3" in loaded: | |
| return "ColNetraEmbed (ColGemma3)" | |
| else: | |
| return "" | |
| def load_bigemma_model(): | |
| """Load BiGemma3 model and processor.""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if doc_index.bigemma_model is None: | |
| print("Loading BiGemma3 (NetraEmbed)...") | |
| try: | |
| doc_index.bigemma_processor = BiGemmaProcessor3.from_pretrained( | |
| "Cognitive-Lab/NetraEmbed", | |
| use_fast=True, | |
| ) | |
| doc_index.bigemma_model = BiGemma3.from_pretrained( | |
| "Cognitive-Lab/NetraEmbed", | |
| torch_dtype=torch.bfloat16, | |
| device_map=device, | |
| ) | |
| doc_index.bigemma_model.eval() | |
| doc_index.models_loaded["bigemma"] = True | |
| print("β BiGemma3 loaded successfully") | |
| except Exception as e: | |
| print(f"β Failed to load BiGemma3: {str(e)}") | |
| raise | |
| return doc_index.bigemma_model, doc_index.bigemma_processor | |
| def load_colgemma_model(): | |
| """Load ColGemma3 model and processor.""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if doc_index.colgemma_model is None: | |
| print("Loading ColGemma3 (ColNetraEmbed)...") | |
| try: | |
| doc_index.colgemma_model = ColGemma3.from_pretrained( | |
| "Cognitive-Lab/ColNetraEmbed", | |
| dtype=torch.bfloat16, | |
| device_map=device, | |
| ) | |
| doc_index.colgemma_model.eval() | |
| doc_index.colgemma_processor = ColGemmaProcessor3.from_pretrained( | |
| "Cognitive-Lab/ColNetraEmbed", | |
| use_fast=True, | |
| ) | |
| doc_index.models_loaded["colgemma"] = True | |
| print("β ColGemma3 loaded successfully") | |
| except Exception as e: | |
| print(f"β Failed to load ColGemma3: {str(e)}") | |
| raise | |
| return doc_index.colgemma_model, doc_index.colgemma_processor | |
| def unload_models(): | |
| """Unload models and free GPU memory.""" | |
| try: | |
| if doc_index.bigemma_model is not None: | |
| del doc_index.bigemma_model | |
| del doc_index.bigemma_processor | |
| doc_index.bigemma_model = None | |
| doc_index.bigemma_processor = None | |
| doc_index.models_loaded["bigemma"] = False | |
| if doc_index.colgemma_model is not None: | |
| del doc_index.colgemma_model | |
| del doc_index.colgemma_processor | |
| doc_index.colgemma_model = None | |
| doc_index.colgemma_processor = None | |
| doc_index.models_loaded["colgemma"] = False | |
| # Clear embeddings and images | |
| doc_index.bigemma_embeddings = None | |
| doc_index.colgemma_embeddings = None | |
| doc_index.images = [] | |
| # Force garbage collection | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| return "β Models unloaded and GPU memory cleared" | |
| except Exception as e: | |
| return f"β Error unloading models: {str(e)}" | |
| def clear_incompatible_embeddings(model_choice: str) -> str: | |
| """Clear embeddings that are incompatible with currently loading models.""" | |
| cleared = [] | |
| # If loading only BiGemma3, clear ColGemma3 embeddings | |
| if model_choice == "NetraEmbed (BiGemma3)": | |
| if doc_index.colgemma_embeddings is not None: | |
| doc_index.colgemma_embeddings = None | |
| doc_index.images = [] | |
| cleared.append("ColGemma3") | |
| print("Cleared ColGemma3 embeddings") | |
| # If loading only ColGemma3, clear BiGemma3 embeddings | |
| elif model_choice == "ColNetraEmbed (ColGemma3)": | |
| if doc_index.bigemma_embeddings is not None: | |
| doc_index.bigemma_embeddings = None | |
| doc_index.images = [] | |
| cleared.append("BiGemma3") | |
| print("Cleared BiGemma3 embeddings") | |
| if cleared: | |
| return f"Cleared {', '.join(cleared)} embeddings - please re-index" | |
| return "" | |
| def pdf_to_images(pdf_path: str) -> List[Image.Image]: | |
| """Convert PDF to list of PIL Images with error handling.""" | |
| try: | |
| print(f"Converting PDF to images: {pdf_path}") | |
| images = convert_from_path(pdf_path, dpi=200) | |
| print(f"Converted {len(images)} pages") | |
| return images | |
| except Exception as e: | |
| print(f"β PDF conversion error: {str(e)}") | |
| raise Exception(f"Failed to convert PDF: {str(e)}") | |
| def generate_colgemma_heatmap( | |
| image: Image.Image, | |
| query: str, | |
| query_embedding: torch.Tensor, | |
| image_embedding: torch.Tensor, | |
| model, | |
| processor, | |
| ) -> Image.Image: | |
| """Generate heatmap overlay for ColGemma3 results.""" | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Re-process the single image to get the proper batch_images dict for image mask | |
| batch_images = processor.process_images([image]).to(device) | |
| # Create image mask manually (ColGemmaProcessor3 doesn't have get_image_mask) | |
| if "input_ids" in batch_images and hasattr(model.config, "image_token_id"): | |
| image_token_id = model.config.image_token_id | |
| image_mask = batch_images["input_ids"] == image_token_id | |
| else: | |
| # Fallback: all tokens are image tokens | |
| image_mask = torch.ones( | |
| image_embedding.shape[0], image_embedding.shape[1], dtype=torch.bool, device=device | |
| ) | |
| # Calculate n_patches from actual number of image tokens | |
| num_image_tokens = image_mask.sum().item() | |
| n_side = int(math.sqrt(num_image_tokens)) | |
| if n_side * n_side == num_image_tokens: | |
| n_patches = (n_side, n_side) | |
| else: | |
| # Fallback: use default calculation | |
| n_patches = (16, 16) | |
| # Generate similarity maps (returns a list of tensors) | |
| similarity_maps_list = get_similarity_maps_from_embeddings( | |
| image_embeddings=image_embedding, | |
| query_embeddings=query_embedding, | |
| n_patches=n_patches, | |
| image_mask=image_mask, | |
| ) | |
| # Get the similarity map for our image (returns a list, get first element) | |
| similarity_map = similarity_maps_list[0] # (query_length, n_patches_x, n_patches_y) | |
| # Aggregate across all query tokens (mean) | |
| if similarity_map.dtype == torch.bfloat16: | |
| similarity_map = similarity_map.float() | |
| aggregated_map = torch.mean(similarity_map, dim=0) | |
| # Convert the image to an array | |
| img_array = np.array(image.convert("RGBA")) | |
| # Normalize the similarity map and convert to numpy | |
| similarity_map_array = normalize_similarity_map(aggregated_map).to(torch.float32).cpu().numpy() | |
| # Reshape to match PIL convention | |
| similarity_map_array = rearrange(similarity_map_array, "h w -> w h") | |
| # Create PIL image from similarity map | |
| similarity_map_image = Image.fromarray((similarity_map_array * 255).astype("uint8")).resize( | |
| image.size, Image.Resampling.BICUBIC | |
| ) | |
| # Create matplotlib figure | |
| fig, ax = plt.subplots(figsize=(10, 10)) | |
| ax.imshow(img_array) | |
| ax.imshow( | |
| similarity_map_image, | |
| cmap=sns.color_palette("mako", as_cmap=True), | |
| alpha=0.5, | |
| ) | |
| ax.set_axis_off() | |
| plt.tight_layout() | |
| # Convert to PIL Image | |
| buffer = io.BytesIO() | |
| plt.savefig(buffer, format="png", dpi=150, bbox_inches="tight", pad_inches=0) | |
| buffer.seek(0) | |
| heatmap_image = Image.open(buffer).copy() | |
| plt.close() | |
| return heatmap_image | |
| except Exception as e: | |
| print(f"β Heatmap generation error: {str(e)}") | |
| # Return original image if heatmap generation fails | |
| return image | |
| def index_bigemma_images(images: List[Image.Image]) -> torch.Tensor: | |
| """Index images with BiGemma3 model.""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, processor = doc_index.bigemma_model, doc_index.bigemma_processor | |
| batch_images = processor.process_images(images).to(device) | |
| embeddings = model(**batch_images, embedding_dim=768) | |
| return embeddings | |
| def index_colgemma_images(images: List[Image.Image]) -> torch.Tensor: | |
| """Index images with ColGemma3 model.""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, processor = doc_index.colgemma_model, doc_index.colgemma_processor | |
| batch_images = processor.process_images(images).to(device) | |
| embeddings = model(**batch_images) | |
| return embeddings | |
| def index_document(pdf_file, model_choice: str) -> Iterator[str]: | |
| """Upload and index a PDF document with progress updates.""" | |
| if pdf_file is None: | |
| yield "β οΈ Please upload a PDF document first." | |
| return | |
| try: | |
| status_messages = [] | |
| # Convert PDF to images | |
| status_messages.append("β³ Converting PDF to images...") | |
| yield "\n".join(status_messages) | |
| doc_index.images = pdf_to_images(pdf_file.name) | |
| num_pages = len(doc_index.images) | |
| status_messages.append(f"β Converted PDF to {num_pages} images") | |
| # Check if we need to batch process | |
| if num_pages > MAX_BATCH_SIZE: | |
| status_messages.append(f"β οΈ Large PDF ({num_pages} pages). Processing in batches of {MAX_BATCH_SIZE}...") | |
| yield "\n".join(status_messages) | |
| # Index with BiGemma3 | |
| if model_choice in ["NetraEmbed (BiGemma3)", "Both"]: | |
| if doc_index.bigemma_model is None: | |
| status_messages.append("β³ Loading BiGemma3 model...") | |
| yield "\n".join(status_messages) | |
| load_bigemma_model() | |
| status_messages.append("β BiGemma3 loaded") | |
| else: | |
| status_messages.append("β Using cached BiGemma3 model") | |
| yield "\n".join(status_messages) | |
| status_messages.append("β³ Encoding images with BiGemma3...") | |
| yield "\n".join(status_messages) | |
| doc_index.bigemma_embeddings = index_bigemma_images(doc_index.images) | |
| status_messages.append("β Indexed with BiGemma3 (shape: {})".format(doc_index.bigemma_embeddings.shape)) | |
| yield "\n".join(status_messages) | |
| # Index with ColGemma3 | |
| if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]: | |
| if doc_index.colgemma_model is None: | |
| status_messages.append("β³ Loading ColGemma3 model...") | |
| yield "\n".join(status_messages) | |
| load_colgemma_model() | |
| status_messages.append("β ColGemma3 loaded") | |
| else: | |
| status_messages.append("β Using cached ColGemma3 model") | |
| yield "\n".join(status_messages) | |
| status_messages.append("β³ Encoding images with ColGemma3...") | |
| yield "\n".join(status_messages) | |
| doc_index.colgemma_embeddings = index_colgemma_images(doc_index.images) | |
| status_messages.append( | |
| "β Indexed with ColGemma3 (shape: {})".format(doc_index.colgemma_embeddings.shape) | |
| ) | |
| yield "\n".join(status_messages) | |
| final_status = "\n".join(status_messages) + "\n\nβ Document ready for querying!" | |
| yield final_status | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"Indexing error: {error_details}") | |
| yield f"β Error indexing document: {str(e)}" | |
| def query_bigemma(query: str, top_k: int) -> Tuple[str, List]: | |
| """Query indexed documents with BiGemma3.""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, processor = doc_index.bigemma_model, doc_index.bigemma_processor | |
| # Encode query | |
| batch_query = processor.process_texts([query]).to(device) | |
| query_embedding = model(**batch_query, embedding_dim=768) | |
| # Compute scores (cosine similarity) | |
| scores = processor.score( | |
| qs=query_embedding, | |
| ps=doc_index.bigemma_embeddings, | |
| ) | |
| # Get top-k results | |
| top_k_actual = min(top_k, len(doc_index.images)) | |
| top_indices = scores[0].argsort(descending=True)[:top_k_actual] | |
| # Format results | |
| results_text = "### BiGemma3 (NetraEmbed) Results\n\n" | |
| gallery_images = [] | |
| for rank, idx in enumerate(top_indices): | |
| score = scores[0, idx].item() | |
| results_text += f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.4f}\n" | |
| gallery_images.append( | |
| (doc_index.images[idx.item()], f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.4f})") | |
| ) | |
| return results_text, gallery_images | |
| def query_colgemma(query: str, top_k: int, show_heatmap: bool = False) -> Tuple[str, List]: | |
| """Query indexed documents with ColGemma3.""" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model, processor = doc_index.colgemma_model, doc_index.colgemma_processor | |
| # Encode query | |
| batch_query = processor.process_queries([query]).to(device) | |
| query_embedding = model(**batch_query) | |
| # Compute scores (MaxSim) | |
| scores = processor.score_multi_vector( | |
| qs=query_embedding, | |
| ps=doc_index.colgemma_embeddings, | |
| ) | |
| # Get top-k results | |
| top_k_actual = min(top_k, len(doc_index.images)) | |
| top_indices = scores[0].argsort(descending=True)[:top_k_actual] | |
| # Format results | |
| results_text = "### ColGemma3 (ColNetraEmbed) Results\n\n" | |
| gallery_images = [] | |
| for rank, idx in enumerate(top_indices): | |
| score = scores[0, idx].item() | |
| results_text += f"**Rank {rank + 1}:** Page {idx.item() + 1} - Score: {score:.2f}\n" | |
| # Generate heatmap if requested | |
| if show_heatmap: | |
| heatmap_image = generate_colgemma_heatmap( | |
| image=doc_index.images[idx.item()], | |
| query=query, | |
| query_embedding=query_embedding, | |
| image_embedding=doc_index.colgemma_embeddings[idx.item()].unsqueeze(0), | |
| model=model, | |
| processor=processor, | |
| ) | |
| gallery_images.append( | |
| (heatmap_image, f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})") | |
| ) | |
| else: | |
| gallery_images.append( | |
| ( | |
| doc_index.images[idx.item()], | |
| f"Rank {rank + 1} - Page {idx.item() + 1} (Score: {score:.2f})", | |
| ) | |
| ) | |
| return results_text, gallery_images | |
| def query_documents( | |
| query: str, model_choice: str, top_k: int, show_heatmap: bool = False | |
| ) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[List]]: | |
| """Query the indexed documents.""" | |
| if not doc_index.images: | |
| return "β οΈ Please upload and index a document first.", None, None, None | |
| if not query.strip(): | |
| return "β οΈ Please enter a query.", None, None, None | |
| try: | |
| results_bi = None | |
| results_col = None | |
| gallery_images_bi = [] | |
| gallery_images_col = [] | |
| # Query with BiGemma3 | |
| if model_choice in ["NetraEmbed (BiGemma3)", "Both"]: | |
| if doc_index.bigemma_embeddings is None: | |
| return "β οΈ Please index the document with BiGemma3 first.", None, None, None | |
| results_bi, gallery_images_bi = query_bigemma(query, top_k) | |
| # Query with ColGemma3 | |
| if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]: | |
| if doc_index.colgemma_embeddings is None: | |
| return "β οΈ Please index the document with ColGemma3 first.", None, None, None | |
| results_col, gallery_images_col = query_colgemma(query, top_k, show_heatmap) | |
| # Return results based on model choice | |
| if model_choice == "NetraEmbed (BiGemma3)": | |
| return results_bi, None, gallery_images_bi, None | |
| elif model_choice == "ColNetraEmbed (ColGemma3)": | |
| return results_col, None, None, gallery_images_col | |
| else: # Both | |
| return results_bi, results_col, gallery_images_bi, gallery_images_col | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"Query error: {error_details}") | |
| return f"β Error during query: {str(e)}", None, None, None | |
| def load_models_with_progress(model_choice: str) -> Iterator[Tuple]: | |
| """Load models with progress updates.""" | |
| if not model_choice: | |
| yield ( | |
| "β Please select a model first.", | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(value="Load model first"), | |
| ) | |
| return | |
| try: | |
| status_messages = [] | |
| # Clear incompatible embeddings | |
| clear_msg = clear_incompatible_embeddings(model_choice) | |
| if clear_msg: | |
| status_messages.append(f"β οΈ {clear_msg}") | |
| # Load BiGemma3 | |
| if model_choice in ["NetraEmbed (BiGemma3)", "Both"]: | |
| status_messages.append("β³ Loading BiGemma3 (NetraEmbed)...") | |
| yield ( | |
| "\n".join(status_messages), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(value="Loading models..."), | |
| ) | |
| load_bigemma_model() | |
| status_messages[-1] = "β BiGemma3 loaded successfully" | |
| yield ( | |
| "\n".join(status_messages), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(value="Loading models..."), | |
| ) | |
| # Load ColGemma3 | |
| if model_choice in ["ColNetraEmbed (ColGemma3)", "Both"]: | |
| status_messages.append("β³ Loading ColGemma3 (ColNetraEmbed)...") | |
| yield ( | |
| "\n".join(status_messages), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(value="Loading models..."), | |
| ) | |
| load_colgemma_model() | |
| status_messages[-1] = "β ColGemma3 loaded successfully" | |
| yield ( | |
| "\n".join(status_messages), | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(value="Loading models..."), | |
| ) | |
| # Determine column visibility based on loaded models | |
| show_bigemma = model_choice in ["NetraEmbed (BiGemma3)", "Both"] | |
| show_colgemma = model_choice in ["ColNetraEmbed (ColGemma3)", "Both"] | |
| show_heatmap_checkbox = model_choice in ["ColNetraEmbed (ColGemma3)", "Both"] | |
| final_status = "\n".join(status_messages) + "\n\nβ Ready!" | |
| yield ( | |
| final_status, | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| gr.update(visible=show_bigemma), | |
| gr.update(visible=show_colgemma), | |
| gr.update(visible=show_heatmap_checkbox), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(value="Ready to index"), | |
| ) | |
| except Exception as e: | |
| import traceback | |
| error_details = traceback.format_exc() | |
| print(f"Model loading error: {error_details}") | |
| yield ( | |
| f"β Failed to load models: {str(e)}", | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(value="Load model first"), | |
| ) | |
| def unload_models_and_hide_ui(): | |
| """Unload models and hide main UI.""" | |
| status = unload_models() | |
| return ( | |
| status, | |
| gr.update(visible=True), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(visible=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(value="Load model first"), | |
| ) | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| title="NetraEmbed Demo", | |
| ) as demo: | |
| # Header section with model info and banner | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("# NetraEmbed") | |
| gr.HTML( | |
| """ | |
| <div style="display: flex; gap: 8px; flex-wrap: wrap; margin-bottom: 15px;"> | |
| <a href="https://arxiv.org/abs/2512.03514" target="_blank"> | |
| <img src="https://img.shields.io/badge/arXiv-2512.03514-b31b1b.svg" alt="Paper"> | |
| </a> | |
| <a href="https://github.com/adithya-s-k/colpali" target="_blank"> | |
| <img src="https://img.shields.io/badge/GitHub-colpali-181717?logo=github" alt="GitHub"> | |
| </a> | |
| <a href="https://huggingface.co/Cognitive-Lab/ColNetraEmbed" target="_blank"> | |
| <img src="https://img.shields.io/badge/π€%20HuggingFace-Model-yellow" alt="Model"> | |
| </a> | |
| <a href="https://www.cognitivelab.in/blog/introducing-netraembed" target="_blank"> | |
| <img src="https://img.shields.io/badge/Blog-CognitiveLab-blue" alt="Blog"> | |
| </a> | |
| <a href="https://cloud.cognitivelab.in" target="_blank"> | |
| <img src="https://img.shields.io/badge/Demo-Try%20it%20out-green" alt="Demo"> | |
| </a> | |
| </div> | |
| """ | |
| ) | |
| gr.Markdown( | |
| """ | |
| **π Universal Multilingual Multimodal Document Retrieval** | |
| Upload a PDF document, select your model(s), and query using semantic search. | |
| **Available Models:** | |
| - **NetraEmbed (BiGemma3)**: Single-vector embedding with Matryoshka representation | |
| Fast retrieval with cosine similarity | |
| - **ColNetraEmbed (ColGemma3)**: Multi-vector embedding with late interaction | |
| High-quality retrieval with MaxSim scoring and attention heatmaps | |
| """ | |
| ) | |
| with gr.Column(scale=1): | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center;"> | |
| <img src="https://cdn-uploads.huggingface.co/production/uploads/6442d975ad54813badc1ddf7/-fYMikXhSuqRqm-UIdulK.png" | |
| alt="NetraEmbed Banner" | |
| style="width: 100%; height: auto; border-radius: 8px;"> | |
| </div> | |
| """ | |
| ) | |
| gr.Markdown("---") | |
| # Compact 3-column layout | |
| with gr.Row(): | |
| # Column 1: Model Management | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π€ Model Management") | |
| model_select = gr.Radio( | |
| choices=["NetraEmbed (BiGemma3)", "ColNetraEmbed (ColGemma3)", "Both"], | |
| value="Both", | |
| label="Select Model(s)", | |
| ) | |
| load_model_btn = gr.Button("π Load Model", variant="primary", size="sm") | |
| unload_model_btn = gr.Button("ποΈ Unload", variant="secondary", size="sm") | |
| model_status = gr.Textbox( | |
| label="Status", | |
| lines=6, | |
| interactive=False, | |
| value="Select and load a model", | |
| ) | |
| loading_info = gr.Markdown( | |
| """ | |
| **First load:** 2-3 min | |
| **Cached:** ~30 sec | |
| """, | |
| visible=True, | |
| ) | |
| # Column 2: Document Upload & Indexing | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Upload & Index") | |
| pdf_upload = gr.File(label="Upload PDF", file_types=[".pdf"], interactive=False) | |
| index_btn = gr.Button("π₯ Index Document", variant="primary", size="sm", interactive=False) | |
| index_status = gr.Textbox( | |
| label="Indexing Status", | |
| lines=6, | |
| interactive=False, | |
| value="Load model first", | |
| ) | |
| # Column 3: Query | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Query Document") | |
| query_input = gr.Textbox( | |
| label="Enter Query", | |
| placeholder="e.g., financial report, organizational structure...", | |
| lines=2, | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| top_k_slider = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=5, | |
| step=1, | |
| label="Top K", | |
| scale=2, | |
| interactive=False, | |
| ) | |
| heatmap_checkbox = gr.Checkbox( | |
| label="Heatmaps", | |
| value=False, | |
| visible=False, | |
| scale=1, | |
| ) | |
| query_btn = gr.Button("π Search", variant="primary", size="sm", interactive=False) | |
| gr.Markdown("---") | |
| # Results section (always visible after model load) | |
| with gr.Column(visible=False) as main_interface: | |
| gr.Markdown("### π Results") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1, visible=False) as bigemma_column: | |
| bigemma_results = gr.Markdown( | |
| value="*BiGemma3 results will appear here...*", | |
| ) | |
| bigemma_gallery = gr.Gallery( | |
| label="BiGemma3 - Top Retrieved Pages", | |
| show_label=True, | |
| columns=2, | |
| height="auto", | |
| object_fit="contain", | |
| ) | |
| with gr.Column(scale=1, visible=False) as colgemma_column: | |
| colgemma_results = gr.Markdown( | |
| value="*ColGemma3 results will appear here...*", | |
| ) | |
| colgemma_gallery = gr.Gallery( | |
| label="ColGemma3 - Top Retrieved Pages", | |
| show_label=True, | |
| columns=2, | |
| height="auto", | |
| object_fit="contain", | |
| ) | |
| # Tips | |
| with gr.Accordion("π‘ Tips", open=False): | |
| gr.Markdown( | |
| """ | |
| - **Both models**: Compare results side-by-side | |
| - **Scores**: BiGemma3 uses cosine similarity (-1 to 1), ColGemma3 uses MaxSim (higher is better) | |
| - **Heatmaps**: Enable to visualize ColGemma3 attention patterns (brighter = higher attention) | |
| """ | |
| ) | |
| # Event handlers - Model Management | |
| load_model_btn.click( | |
| fn=load_models_with_progress, | |
| inputs=[model_select], | |
| outputs=[ | |
| model_status, | |
| loading_info, | |
| main_interface, | |
| bigemma_column, | |
| colgemma_column, | |
| heatmap_checkbox, | |
| pdf_upload, | |
| index_btn, | |
| query_input, | |
| top_k_slider, | |
| query_btn, | |
| index_status, | |
| ], | |
| ) | |
| unload_model_btn.click( | |
| fn=unload_models_and_hide_ui, | |
| outputs=[ | |
| model_status, | |
| loading_info, | |
| main_interface, | |
| bigemma_column, | |
| colgemma_column, | |
| heatmap_checkbox, | |
| pdf_upload, | |
| index_btn, | |
| query_input, | |
| top_k_slider, | |
| query_btn, | |
| index_status, | |
| ], | |
| ) | |
| # Event handlers - Main Interface | |
| def index_with_current_models(pdf_file): | |
| """Index document with currently loaded models.""" | |
| if pdf_file is None: | |
| yield "β οΈ Please upload a PDF document first." | |
| return | |
| model_choice = get_model_choice_from_loaded() | |
| if not model_choice: | |
| yield "β οΈ No models loaded. Please load a model first." | |
| return | |
| # Use generator from index_document | |
| for status in index_document(pdf_file, model_choice): | |
| yield status | |
| def query_with_current_models(query, top_k, show_heatmap): | |
| """Query with currently loaded models.""" | |
| model_choice = get_model_choice_from_loaded() | |
| if not model_choice: | |
| return "β οΈ No models loaded. Please load a model first.", None, None, None | |
| return query_documents(query, model_choice, top_k, show_heatmap) | |
| index_btn.click( | |
| fn=index_with_current_models, | |
| inputs=[pdf_upload], | |
| outputs=[index_status], | |
| ) | |
| query_btn.click( | |
| fn=query_with_current_models, | |
| inputs=[query_input, top_k_slider, heatmap_checkbox], | |
| outputs=[bigemma_results, colgemma_results, bigemma_gallery, colgemma_gallery], | |
| ) | |
| # Enable queue for handling multiple requests | |
| demo.queue(max_size=20) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() | |