Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import base64 | |
| import asyncio | |
| from io import BytesIO | |
| import gradio as gr | |
| import qdrant_client | |
| from PIL import Image | |
| from typing import List, Dict, Tuple | |
| import llamaindex_utils | |
| from rag_pipeline import async_indexDocument | |
| from models import get_lora_model, enable_lora, ColPali, ColPaliProcessor | |
| from utils import load_tokenizer | |
| from llama_index.llms.gemini import Gemini | |
| from llama_index.core.tools import RetrieverTool | |
| from huggingface_hub import hf_hub_download | |
| GEMINI_API_KEY = os.getenv(key="GEMINI_API_KEY") | |
| QDRANT_API_KEY = os.getenv(key="QDRANT_API_KEY") | |
| HF_TOKEN_KEY = os.getenv(key="HF_TOKEN_KEY") | |
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
| async def initialize_model() -> Dict: | |
| """Initialize models | |
| Returns: | |
| model_dict: Dict: Dictionary stores neccessary models | |
| """ | |
| if not os.path.exists("./pretrained/colpaligemma-3b-mix-448-base"): | |
| os.makedirs("./pretrained/colpaligemma-3b-mix-448-base", exist_ok=True) | |
| files_to_download = ["adapter_model.safetensors", | |
| "config.json", | |
| "model-00001-of-00002.safetensors", | |
| "model-00002-of-00002.safetensors", | |
| "preprocessor_config.json", | |
| "tokenizer.json", | |
| "tokenizer.model", | |
| "tokenizer_config.json"] | |
| for file in files_to_download: | |
| hf_hub_download(repo_id="dnnhhuy/colpaligemma-3b-mix-448-base", | |
| filename=file, | |
| token=HF_TOKEN_KEY, | |
| local_dir="./pretrained/colpaligemma-3b-mix-448-base") | |
| model = ColPali.from_pretrained(model_dir='./pretrained/colpaligemma-3b-mix-448-base', torch_dtype=torch.bfloat16) | |
| tokenizer = load_tokenizer(tokenizer_dir='./pretrained/colpaligemma-3b-mix-448-base') | |
| processor = ColPaliProcessor(tokenizer=tokenizer).from_pretrained(pretrained_dir='./pretrained/colpaligemma-3b-mix-448-base') | |
| model.model.language_model.model = get_lora_model(model.model.language_model.model, | |
| rank=32, | |
| alphas=32, | |
| lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'], | |
| training=False, | |
| dropout_p=0.1, | |
| torch_dtype=torch.bfloat16) | |
| model.model.language_model.model = enable_lora(model.model.language_model.model, lora_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj', 'down_proj', 'gate_proj', 'up_proj'], enabled=True) | |
| model = get_lora_model(model, | |
| rank=32, | |
| alphas=32, | |
| lora_modules=['custom_text_proj'], | |
| training=False, | |
| dropout_p=0.1, | |
| torch_dtype=torch.bfloat16) | |
| model = enable_lora(model, lora_modules=['custom_text_proj'], enabled=True) | |
| model.load_lora('./pretrained/colpaligemma-3b-mix-448-base') | |
| # Initialize LLM | |
| generation_config = { | |
| "temperature": 0.0, | |
| "top_p": 0.95, | |
| "top_k": 64, | |
| "max_output_tokens": 1024, | |
| "response_mime_type": "text/plain", | |
| } | |
| llm = Gemini(api_key=GEMINI_API_KEY, generation_config=generation_config) | |
| # Setup Qdrant | |
| # Creating Qdrant Client | |
| vector_store_client = qdrant_client.AsyncQdrantClient(location="https://b3878645-ec71-426c-8afa-b8b3b7589e40.us-east4-0.gcp.cloud.qdrant.io", | |
| api_key=QDRANT_API_KEY, | |
| timeout=100) | |
| embed_model = llamaindex_utils.ColPaliGemmaEmbedding(model=model, | |
| processor=processor, | |
| device=device) | |
| collections = await get_collection_names(vector_store_client) | |
| retrievers_dict = {} | |
| for name in collections: | |
| if name not in retrievers_dict: | |
| retrievers_dict[name] = llamaindex_utils.ColPaliRetriever(vector_store_client=vector_store_client, | |
| target_collection=name, | |
| embed_model=embed_model, | |
| similarity_top_k=3) | |
| return {"llm": llm, | |
| "vector_store_client": vector_store_client, | |
| "model": model, | |
| "processor": processor, | |
| "embed_model": embed_model, | |
| "collections": collections, | |
| "retrievers_dict": retrievers_dict} | |
| async def get_collection_names(vector_store_client): | |
| collections = await vector_store_client.get_collections() | |
| return [collection.name for collection in collections.collections] | |
| async def index(files: List[str], | |
| target_collection: str | |
| ) -> Tuple[str, gr.Dropdown, List[str], Dict[str, llamaindex_utils.ColPaliRetriever]]: | |
| """ | |
| Insert all image pages from files to speicified target collection to the vector store | |
| and return the mapping from retriever's name to its object instance. | |
| Args: | |
| files (List[str]): List of file path | |
| target_collection (str): Target collection to insert into the vector store | |
| Returns: | |
| Tuple[str, gr.Dropdown, List[str], Dict[str, llamaindex_utils.ColPaliRetriever]]: Return message, dropdown component, collections' names, dictionary mapping retriever to its object instance | |
| """ | |
| for file in files: | |
| await async_indexDocument(file_path=file, | |
| vector_store_client=model_dict["vector_store_client"], | |
| target_collection=target_collection, | |
| model=model_dict["model"], | |
| processor=model_dict["processor"], | |
| device=device) | |
| if target_collection not in retrievers: | |
| retrievers[target_collection] = llamaindex_utils.ColPaliRetriever(vector_store_client=model_dict["vector_store_client"], | |
| target_collection=target_collection, | |
| embed_model=model_dict["embed_model"], | |
| similarity_top_k=3) | |
| collection_names = await get_collection_names(model_dict["vector_store_client"]) | |
| return (f"Uploaded and index {len(files)} files.", | |
| gr.Dropdown(choices=collection_names), | |
| collection_names) | |
| async def search_with_llm(query: str, | |
| similarity_top_k: int, | |
| num_children: int) -> Tuple[str, List[Image.Image]]: | |
| """Search the result given query and list of retrievers. | |
| Returns the search's response and list of images support for that response. | |
| Args: | |
| query (str): Query question | |
| retrievers (Dict[str, llamaindex_utils.ColPaliRetriever]): Dictionary mapping between retrievers' names and their object instances | |
| similarity_top_k (int): top K similarity results retrieved from the retriever | |
| num_children (int): number of children for tree summarization | |
| Returns: | |
| Tuple[str, List[Image.Image]]: Returns the search's response and list of images support for that response. | |
| """ | |
| retriever_tools = [RetrieverTool.from_defaults( | |
| name=key, | |
| retriever=value, | |
| description=f"Useful for retrieving information about {key}.") for key, value in retrievers.items()] | |
| retriever_mappings = {retriever_tool.metadata.name: retriever_tool.retriever for retriever_tool in retriever_tools} | |
| fusion_retriever = llamaindex_utils.CustomFusionRetriever(llm=model_dict["llm"], | |
| retriever_mappings=retriever_mappings, | |
| similarity_top_k=similarity_top_k) | |
| query_engine = llamaindex_utils.CustomQueryEngine(retriever_tools=[retriever_tool.metadata for retriever_tool in retriever_tools], | |
| fusion_retriever=fusion_retriever, | |
| llm=model_dict["llm"], | |
| num_children=num_children) | |
| response = await query_engine.aquery(query_str=query) | |
| return response.response, [Image.open(BytesIO(base64.b64decode(image))) for image in response.source_images] | |
| async def delete_collection(target_collection): | |
| if await model_dict["vector_store_client"].collection_exists(collection_name=target_collection): | |
| await model_dict["vector_store_client"].delete_collection(collection_name=target_collection, timeout=100) | |
| choices = await get_collection_names(model_dict["vector_store_client"]) | |
| return (f"Deleted collection {target_collection}", gr.Dropdown(choices=choices), choices) | |
| else: | |
| choices = await get_collection_names(model_dict["vector_store_client"]) | |
| return (f"Collection {target_collection} is not found.", gr.Dropdown(choices=choices), choices) | |
| def build_gui(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Image Based RAG System using ColPali ๐๐") | |
| with gr.Row(equal_height=True): | |
| with gr.Column(): | |
| gr.Markdown("## 1๏ธ. Upload PDFs") | |
| files = gr.File(file_types=["pdf"], | |
| file_count="multiple", | |
| interactive=True) | |
| choices = gr.State(value=model_dict["collections"]) | |
| gr.Markdown("## 2๏ธ. Index the PDFs and upload") | |
| target_collection = gr.Dropdown(choices=choices.value, | |
| allow_custom_value=True, | |
| label="Collection name", | |
| show_label=True, | |
| interactive=True) | |
| message_box = gr.Textbox(value="File not yet uploaded", | |
| show_label=False, | |
| interactive=False) | |
| with gr.Row(equal_height=True): | |
| delete_button = gr.Button("๐๏ธ Delete collection") | |
| convert_button = gr.Button("๐ Convert and upload") | |
| # Define the actions for conversion | |
| convert_button.click(index, inputs=[files, target_collection], outputs=[message_box, target_collection, choices]) | |
| # Define the actions for delete collection | |
| delete_button.click(delete_collection, inputs=[target_collection], outputs=[message_box, target_collection, choices]) | |
| with gr.Column(): | |
| gr.Markdown("## 3๏ธ. Enter your question") | |
| query = gr.Textbox(placeholder="Enter your query to match", | |
| lines=15, | |
| max_lines=20, | |
| autoscroll=True) | |
| with gr.Accordion(label="Additional Settings", open=False): | |
| similarity_top_k = gr.Slider(minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=1.0, | |
| label="Top K similarity retrieved from the retriever") | |
| num_children = gr.Slider(minimum=1, | |
| maximum=10, | |
| value=3, | |
| step=1.0, | |
| label="Set number of children for Tree Summarization") | |
| search_button = gr.Button("๐ Search") | |
| gr.Markdown("## 4๏ธ. ColPali Retrieval") | |
| with gr.Row(equal_height=True): | |
| output_text = gr.Textbox(label="Query result", | |
| show_label=True, | |
| placeholder="Response from query", | |
| lines=8, | |
| max_lines=20, | |
| interactive=False) | |
| output_imgs = gr.Gallery(label="Most relevant images is...", | |
| show_fullscreen_button=True, | |
| show_label=True, | |
| show_download_button=True, | |
| interactive=False) | |
| # Action for search button | |
| search_button.click( | |
| search_with_llm, | |
| inputs=[query, similarity_top_k, num_children], | |
| outputs=[output_text, output_imgs]) | |
| return demo | |
| async def amain(): | |
| global model_dict, retrievers | |
| model_dict = await initialize_model() | |
| retrievers = model_dict["retrievers_dict"] | |
| demo = build_gui() | |
| demo.queue().launch(debug=True, share=False) | |
| if __name__ == "__main__": | |
| asyncio.run(amain()) | |