Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import torch | |
| from llama_parse import LlamaParse | |
| from llama_index.core import StorageContext, load_index_from_storage | |
| from llama_index.core.indices import MultiModalVectorStoreIndex | |
| from llama_index.core.schema import Document, ImageDocument | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| example_indexes = { | |
| "IONIQ 2024": "./iconiq_report_index", | |
| "Uber 10k 2021": "./uber_index", | |
| } | |
| # device = "cpu" | |
| # if torch.cuda.is_available(): | |
| # device = "cuda" | |
| # elif torch.backends.mps.is_available(): | |
| # device = "mps" | |
| image_embed_model = HuggingFaceEmbedding( | |
| model_name="llamaindex/vdr-2b-multi-v1", | |
| device="cpu", | |
| trust_remote_code=True, | |
| token=os.getenv("HUGGINGFACE_TOKEN"), | |
| model_kwargs={"torch_dtype": torch.float16}, | |
| embed_batch_size=2, | |
| ) | |
| text_embed_model = HuggingFaceEmbedding( | |
| model_name="BAAI/bge-small-en", | |
| device="cpu", | |
| trust_remote_code=True, | |
| token=os.getenv("HUGGINGFACE_TOKEN"), | |
| embed_batch_size=2, | |
| ) | |
| def load_index(index_path: str) -> MultiModalVectorStoreIndex: | |
| storage_context = StorageContext.from_defaults(persist_dir=index_path) | |
| return load_index_from_storage( | |
| storage_context, | |
| embed_model=text_embed_model, | |
| image_embed_model=image_embed_model, | |
| ) | |
| def create_index(file, llama_parse_key, progress=gr.Progress()): | |
| if not file or not llama_parse_key: | |
| return None, "Please provide both a file and LlamaParse API key" | |
| try: | |
| progress(0, desc="Initializing LlamaParse...") | |
| parser = LlamaParse( | |
| api_key=llama_parse_key, | |
| take_screenshot=True, | |
| ) | |
| # Process document | |
| progress(0.2, desc="Processing document with LlamaParse...") | |
| md_json_obj = parser.get_json_result(file.name)[0] | |
| progress(0.4, desc="Downloading and processing images...") | |
| image_dicts = parser.get_images( | |
| [md_json_obj], | |
| download_path=os.path.join(os.path.dirname(file.name), f"{file.name}_images") | |
| ) | |
| # Create text document | |
| progress(0.6, desc="Creating text documents...") | |
| text = "" | |
| for page in md_json_obj["pages"]: | |
| text += page["md"] + "\n\n" | |
| text_docs = [Document(text=text.strip())] | |
| # Create image documents | |
| progress(0.8, desc="Creating image documents...") | |
| image_docs = [] | |
| for image_dict in image_dicts: | |
| image_docs.append(ImageDocument(text=image_dict["name"], image_path=image_dict["path"])) | |
| # Create index | |
| progress(0.9, desc="Creating final index...") | |
| index = MultiModalVectorStoreIndex.from_documents( | |
| text_docs + image_docs, | |
| embed_model=text_embed_model, | |
| image_embed_model=image_embed_model, | |
| ) | |
| progress(1.0, desc="Complete!") | |
| return index, "Index created successfully!" | |
| except Exception as e: | |
| return None, f"Error creating index: {str(e)}" | |
| def run_search(index, query, text_top_k, image_top_k): | |
| if not index: | |
| return "Please create or select an index first.", [], [] | |
| retriever = index.as_retriever( | |
| similarity_top_k=text_top_k, | |
| image_similarity_top_k=image_top_k, | |
| ) | |
| image_nodes = retriever.text_to_image_retrieve(query) | |
| text_nodes = retriever.text_retrieve(query) | |
| # Extract text and scores from nodes | |
| text_results = [{"text": node.text, "score": f"{node.score:.3f}"} for node in text_nodes] | |
| # Load images and scores | |
| image_results = [] | |
| for node in image_nodes: | |
| if hasattr(node.node, 'image_path') and os.path.exists(node.node.image_path): | |
| try: | |
| image_results.append(( | |
| node.node.image_path, | |
| f"Similarity: {node.score:.3f}", | |
| )) | |
| except Exception as e: | |
| print(f"Error loading image {node.node.image_path}: {e}") | |
| return "Search completed!", text_results, image_results | |
| # Create the Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Multi-Modal Retrieval with LlamaIndex and llamaindex/vdr-2b-multi-v1") | |
| gr.Markdown(""" | |
| This demo shows how to use the new `llamaindex/vdr-2b-multi-v1` model for multi-modal document search. | |
| Using this model, we can index images and perform text-to-image retrieval. | |
| This demo compares to pure text retrieval using the `BAAI/bge-small-en` model. Is this a fair comparison? Not really, | |
| but it's the easiest to run in a limited huggingface space, and shows the strengths of screenshot-based retrieval. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Index selection/creation | |
| with gr.Tab("Use Existing Index"): | |
| existing_index_dropdown = gr.Dropdown( | |
| choices=list(example_indexes.keys()), | |
| label="Select Pre-made Index", | |
| value=list(example_indexes.keys())[0] | |
| ) | |
| with gr.Tab("Create New Index"): | |
| gr.Markdown( | |
| """ | |
| To create a new index, enter your LlamaParse API key and upload a PDF. | |
| You can get a free API key by signing up [here](https://cloud.llamaindex.ai). | |
| Processing will take a few minutes when creating a new index, depending on the size of the document. | |
| """ | |
| ) | |
| file_upload = gr.File(label="Upload PDF") | |
| llama_parse_key = gr.Textbox( | |
| label="LlamaParse API Key", | |
| type="password" | |
| ) | |
| create_btn = gr.Button("Create Index") | |
| create_status = gr.Textbox(label="Status", interactive=False) | |
| # Search controls | |
| query_input = gr.Textbox(label="Search Query", value="What is the Executive Summary?") | |
| with gr.Row(): | |
| text_top_k = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=2, | |
| step=1, | |
| label="Text Top-K" | |
| ) | |
| image_top_k = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=2, | |
| step=1, | |
| label="Image Top-K" | |
| ) | |
| search_btn = gr.Button("Search") | |
| with gr.Column(): | |
| # Results display | |
| status_output = gr.Textbox(label="Search Status") | |
| image_output = gr.Gallery( | |
| label="Retrieved Images", | |
| show_label=True, # This will show the similarity score captions | |
| elem_id="gallery" | |
| ) | |
| text_output = gr.JSON( | |
| label="Retrieved Text with Similarity Scores", | |
| elem_id="text_results" | |
| ) | |
| # State | |
| index_state = gr.State() | |
| # Load default index on startup | |
| default_index = load_index(example_indexes["IONIQ 2024"]) | |
| index_state.value = default_index | |
| # Event handlers | |
| def load_existing_index(index_name): | |
| if index_name: | |
| try: | |
| index = load_index(example_indexes[index_name]) | |
| return index, f"Loaded index: {index_name}" | |
| except Exception as e: | |
| return None, f"Error loading index: {str(e)}" | |
| return None, "No index selected" | |
| existing_index_dropdown.change( | |
| fn=load_existing_index, | |
| inputs=[existing_index_dropdown], | |
| outputs=[index_state, create_status], | |
| api_name=False | |
| ) | |
| create_btn.click( | |
| fn=create_index, | |
| inputs=[file_upload, llama_parse_key], | |
| outputs=[index_state, create_status], | |
| api_name=False, | |
| show_progress=True # Enable progress bar | |
| ) | |
| search_btn.click( | |
| fn=run_search, | |
| inputs=[index_state, query_input, text_top_k, image_top_k], | |
| outputs=[status_output, text_output, image_output], | |
| api_name=False | |
| ) | |
| gr.Markdown(""" | |
| This demo was built with [LlamaIndex](https://docs.llamaindex.ai) and [LlamaParse](https://cloud.llamaindex.ai). To see more multi-modal demos, check out the [llama parse examples](https://github.com/run-llama/llama_parse/tree/main/examples/multimodal). | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| # Running locally | |
| demo.launch() | |