Spaces:
Configuration error
Configuration error
| from pathlib import Path | |
| from urllib.request import urlopen | |
| from uuid import uuid4 | |
| import modal | |
| MINUTES = 60 | |
| app = modal.App("chat-with-pdf") | |
| CACHE_DIR = "/hf-cache" | |
| model_image = ( | |
| modal.Image.debian_slim(python_version="3.12") | |
| .apt_install("git") | |
| .pip_install( | |
| [ | |
| "transformers>=4.45.0", | |
| "torch==2.4.1", | |
| "torchvision==0.19.1", | |
| "git+https://github.com/illuin-tech/colpali.git@782edcd50108d1842d154730ad3ce72476a2d17d", | |
| "hf_transfer==0.1.8", | |
| "qwen-vl-utils==0.0.8", | |
| ] | |
| ) | |
| .env({"HF_HUB_ENABLE_HF_TRANSFER": "1", "HF_HUB_CACHE": CACHE_DIR}) | |
| ) | |
| with model_image.imports(): | |
| import torch | |
| from colpali_engine.models import ColQwen2, ColQwen2Processor | |
| from qwen_vl_utils import process_vision_info | |
| from transformers import AutoProcessor, Qwen2VLForConditionalGeneration | |
| MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct" | |
| MODEL_REVISION = "aca78372505e6cb469c4fa6a35c60265b00ff5a4" | |
| sessions = modal.Dict.from_name("colqwen-chat-sessions", create_if_missing=True) | |
| class Session: | |
| def __init__(self): | |
| self.images = None | |
| self.messages = [] | |
| self.pdf_embeddings = None | |
| pdf_volume = modal.Volume.from_name("colqwen-chat-pdfs", create_if_missing=True) | |
| PDF_ROOT = Path("/vol/pdfs/") | |
| cache_volume = modal.Volume.from_name("hf-hub-cache", create_if_missing=True) | |
| def download_model(): | |
| from huggingface_hub import snapshot_download | |
| result = snapshot_download( | |
| MODEL_NAME, | |
| revision=MODEL_REVISION, | |
| ignore_patterns=["*.pt", "*.bin"], | |
| ) | |
| print(f"Downloaded model weights to {result}") | |
| class Model: | |
| def load_models(self): | |
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.environ["TRANSFORMERS_OFFLINE"] = "0" | |
| # Load ColQwen2 with explicit configuration | |
| try: | |
| self.colqwen2_model = ColQwen2.from_pretrained( | |
| "vidore/colqwen2-v0.1", | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| ) | |
| except Exception as e: | |
| print(f"Error loading ColQwen2: {e}") | |
| # Fallback to CPU loading then move to GPU | |
| self.colqwen2_model = ColQwen2.from_pretrained( | |
| "vidore/colqwen2-v0.1", | |
| torch_dtype=torch.bfloat16, | |
| device_map=None, | |
| trust_remote_code=True, | |
| ) | |
| self.colqwen2_model = self.colqwen2_model.to("cuda:0") | |
| self.colqwen2_processor = ColQwen2Processor.from_pretrained( | |
| "vidore/colqwen2-v0.1" | |
| ) | |
| # Load Qwen2-VL with explicit configuration | |
| try: | |
| self.qwen2_vl_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| MODEL_NAME, | |
| revision=MODEL_REVISION, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| ) | |
| except Exception as e: | |
| print(f"Error loading Qwen2VL: {e}") | |
| # Fallback approach | |
| self.qwen2_vl_model = Qwen2VLForConditionalGeneration.from_pretrained( | |
| MODEL_NAME, | |
| revision=MODEL_REVISION, | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True, | |
| device_map=None, | |
| ) | |
| self.qwen2_vl_model = self.qwen2_vl_model.to("cuda:0") | |
| self.qwen2_vl_processor = AutoProcessor.from_pretrained( | |
| MODEL_NAME, | |
| revision=MODEL_REVISION, | |
| trust_remote_code=True | |
| ) | |
| def index_pdf(self, session_id, target: bytes | list): | |
| session = sessions.get(session_id) | |
| if session is None: | |
| session = Session() | |
| if isinstance(target, bytes): | |
| images = convert_pdf_to_images.remote(target) | |
| else: | |
| images = target | |
| session_dir = PDF_ROOT / f"{session_id}" | |
| session_dir.mkdir(exist_ok=True, parents=True) | |
| for ii, image in enumerate(images): | |
| filename = session_dir / f"{str(ii).zfill(3)}.jpg" | |
| image.save(filename) | |
| BATCH_SZ = 4 | |
| pdf_embeddings = [] | |
| batches = [images[i : i + BATCH_SZ] for i in range(0, len(images), BATCH_SZ)] | |
| for batch in batches: | |
| batch_images = self.colqwen2_processor.process_images(batch).to( | |
| self.colqwen2_model.device | |
| ) | |
| pdf_embeddings += list(self.colqwen2_model(**batch_images).to("cpu")) | |
| session.pdf_embeddings = pdf_embeddings | |
| sessions[session_id] = session | |
| def respond_to_message(self, session_id, message): | |
| session = sessions.get(session_id) | |
| if session is None: | |
| session = Session() | |
| pdf_volume.reload() | |
| images = (PDF_ROOT / str(session_id)).glob("*.jpg") | |
| images = list(sorted(images, key=lambda p: int(p.stem))) | |
| if not images: | |
| return "Please upload a PDF first" | |
| elif session.pdf_embeddings is None: | |
| return "Indexing PDF..." | |
| relevant_image = self.get_relevant_image(message, session, images) | |
| output_text = self.generate_response(message, session, relevant_image) | |
| append_to_messages(message, session, user_type="user") | |
| append_to_messages(output_text, session, user_type="assistant") | |
| sessions[session_id] = session | |
| return output_text | |
| def get_relevant_image(self, message, session, images): | |
| import PIL | |
| batch_queries = self.colqwen2_processor.process_queries([message]).to( | |
| self.colqwen2_model.device | |
| ) | |
| query_embeddings = self.colqwen2_model(**batch_queries) | |
| scores = self.colqwen2_processor.score_multi_vector( | |
| query_embeddings, session.pdf_embeddings | |
| )[0] | |
| max_index = max(range(len(scores)), key=lambda index: scores[index]) | |
| return PIL.Image.open(images[max_index]) | |
| def generate_response(self, message, session, image): | |
| chatbot_message = get_chatbot_message_with_image(message, image) | |
| query = self.qwen2_vl_processor.apply_chat_template( | |
| [*session.messages, chatbot_message], | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| image_inputs, _ = process_vision_info([chatbot_message]) | |
| inputs = self.qwen2_vl_processor( | |
| text=[query], | |
| images=image_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = inputs.to("cuda:0") | |
| generated_ids = self.qwen2_vl_model.generate(**inputs, max_new_tokens=512) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = self.qwen2_vl_processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| )[0] | |
| return output_text | |
| pdf_image = ( | |
| modal.Image.debian_slim(python_version="3.12") | |
| .apt_install("poppler-utils") | |
| .pip_install("pdf2image==1.17.0", "pillow==10.4.0") | |
| ) | |
| def convert_pdf_to_images(pdf_bytes): | |
| from pdf2image import convert_from_bytes | |
| images = convert_from_bytes(pdf_bytes, fmt="jpeg") | |
| return images | |
| def main(question: str = None, pdf_path: str = None, session_id: str = None): | |
| model = Model() | |
| if session_id is None: | |
| session_id = str(uuid4()) | |
| print("Starting a new session with id", session_id) | |
| if pdf_path is None: | |
| pdf_path = "https://arxiv.org/pdf/1706.03762" | |
| if pdf_path.startswith("http"): | |
| pdf_bytes = urlopen(pdf_path).read() | |
| else: | |
| pdf_path = Path(pdf_path) | |
| pdf_bytes = pdf_path.read_bytes() | |
| print("Indexing PDF from", pdf_path) | |
| model.index_pdf.remote(session_id, pdf_bytes) | |
| else: | |
| if pdf_path is not None: | |
| raise ValueError("Start a new session to chat with a new PDF") | |
| print("Resuming session with id", session_id) | |
| if question is None: | |
| question = "What is this document about?" | |
| print("QUESTION:", question) | |
| print(model.respond_to_message.remote(session_id, question)) | |
| web_image = pdf_image.pip_install( | |
| "fastapi[standard]==0.115.4", | |
| "pydantic==2.9.2", | |
| "starlette==0.41.2", | |
| "gradio==4.44.1", | |
| "pillow==10.4.0", | |
| "gradio-pdf==0.0.15", | |
| "pdf2image==1.17.0", | |
| ) | |
| def ui(): | |
| import uuid | |
| import gradio as gr | |
| from fastapi import FastAPI | |
| from gradio.routes import mount_gradio_app | |
| from gradio_pdf import PDF | |
| from pdf2image import convert_from_path | |
| web_app = FastAPI() | |
| model = Model() | |
| def upload_pdf(path, session_id): | |
| if session_id == "" or session_id is None: | |
| session_id = str(uuid.uuid4()) | |
| images = convert_from_path(path) | |
| model.index_pdf.remote(session_id, images) | |
| return session_id | |
| def respond_to_message(message, _, session_id): | |
| return model.respond_to_message.remote(session_id, message) | |
| with gr.Blocks(theme="soft") as demo: | |
| session_id = gr.State("") | |
| gr.Markdown("# Chat with PDF") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.ChatInterface( | |
| fn=respond_to_message, | |
| additional_inputs=[session_id], | |
| retry_btn=None, | |
| undo_btn=None, | |
| clear_btn=None, | |
| ) | |
| with gr.Column(scale=1): | |
| pdf = PDF( | |
| label="Upload a PDF", | |
| ) | |
| pdf.upload(upload_pdf, [pdf, session_id], session_id) | |
| return mount_gradio_app(app=web_app, blocks=demo, path="/") | |
| def get_chatbot_message_with_image(message, image): | |
| return { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": message}, | |
| ], | |
| } | |
| def append_to_messages(message, session, user_type="user"): | |
| session.messages.append( | |
| { | |
| "role": user_type, | |
| "content": {"type": "text", "text": message}, | |
| }, | |
| ) |