Spaces:
Sleeping
Sleeping
| """Main Mulitmodal-RAG pipeline script.""" | |
| import os | |
| import torch | |
| import fitz #PyMuPDF | |
| import faiss | |
| import re | |
| import gc | |
| import numpy as np | |
| from typing import List, Dict, Tuple | |
| from PIL import Image | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from transformers import TextIteratorStreamer | |
| from utils import ( | |
| save_cache, load_cache, | |
| init_faiss_indexflatip, add_embeddings_to_index, | |
| search_faiss_index, save_faiss_index, load_faiss_index, cleanup_images, clear_gpu_cache | |
| ) | |
| from model_setup import embedding_model, model, processor | |
| torch.set_num_threads(4) # Just being agnostic | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Function to extract text and images from each page of the PDF | |
| # This function uses PyMuPDF (fitz) to extract text and images from each page | |
| def extract_pages_text_and_images(pdf_path, image_dir): | |
| """Extract text and images page-wise""" | |
| doc = fitz.open(pdf_path) | |
| os.makedirs(image_dir, exist_ok=True) | |
| page_texts = [] | |
| page_images = [] | |
| for page_num in range(len(doc)): | |
| page = doc.load_page(page_num) | |
| text = page.get_text() | |
| # Store all images on this page, store their paths | |
| images = [] | |
| for img_index, img in enumerate(page.get_images(full=True)): | |
| xref = img[0] | |
| base_image = doc.extract_image(xref) | |
| image_bytes = base_image["image"] | |
| image_ext = base_image["ext"] | |
| image_filename = f"page_{page_num + 1}_img_{img_index}.{image_ext}" | |
| image_path = os.path.join(image_dir, image_filename) | |
| with open(image_path, "wb") as img_file: | |
| img_file.write(image_bytes) | |
| images.append(image_path) | |
| page_texts.append(text) | |
| page_images.append(images) | |
| if doc: doc.close() | |
| return page_texts, page_images | |
| # Generate image descriptions using the Gemma3 model | |
| # This function will be called in parallel for each page's images | |
| def generate_image_descriptions(image_paths): | |
| """Generate images and tables descriptions as texts.""" | |
| captions = [] | |
| if not processor or not model: | |
| print("[ERROR] Model or Processor not loaded. Cannot generate image descriptions.") | |
| return [] | |
| for image_path in image_paths: | |
| raw_image = Image.open(image_path) | |
| if raw_image.mode != "RGB": | |
| image = raw_image.convert("RGB") | |
| else: | |
| image = raw_image | |
| width, height = image.size | |
| if width < 32 or height < 32: # Filtering out smaller images that may disrupt the process | |
| continue | |
| messages = [ | |
| {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, | |
| {"role": "user", "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": "Describe the factual content visible in the image. Be concise and accurate as the descriptions will be used for retrieval."} | |
| ]} | |
| ] | |
| try: | |
| inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, | |
| return_dict=True, return_tensors="pt").to("cpu", dtype=torch.bfloat16) | |
| input_len = inputs["input_ids"].shape[-1] # To get rid of the prompt echo | |
| with torch.inference_mode(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=False, | |
| cache_implementation="offloaded_static" | |
| ) | |
| raw = processor.decode(generated_ids[0], skip_special_tokens=True) | |
| caption = clean_caption(raw) | |
| captions.append({"image_path": image_path, "caption": caption}) | |
| except Exception as e: | |
| print(f"[โ ๏ธ ERROR]: Failed to generate caption for image {image_path}: {e}") | |
| captions.append({"image_path": image_path, "caption": "<---image---> (Captioning failed)"}) # Add a placeholder caption | |
| continue | |
| finally: | |
| gc.collect() | |
| clear_gpu_cache() | |
| return captions | |
| # Cleaning the captions from the extracted images | |
| # Regex: match everything from "model\n" up through first double-newline after "sent:" | |
| prefix_re = re.compile( | |
| r"model\s*\n.*?\bsent:\s*\n\n", | |
| flags=re.IGNORECASE | re.DOTALL, | |
| ) | |
| def clean_caption(raw: str) -> str: | |
| # 1. Strip off the prompt/header by splitting once. | |
| parts = prefix_re.split(raw.strip(), maxsplit=1) | |
| if len(parts) == 2: | |
| return parts[1].strip() | |
| # 2. Fallback: if the caption begins with ** (bold header), return from there. | |
| bold_index = raw.find("**") | |
| if bold_index >= 0: | |
| return raw[bold_index:].strip() | |
| # 3. Last resort: return everything except the first paragraph. | |
| paras = raw.strip().split('\n\n', 1) | |
| return paras[-1].strip() # might still include some leading noise | |
| # Generate captions for all images on each page | |
| def generate_captions_per_page(page_image_paths_list): | |
| """"Generate captions per page's images""" | |
| page_captions = [] | |
| for image_paths in page_image_paths_list: | |
| captions = generate_image_descriptions(image_paths) | |
| # Extract the 'caption' strings only | |
| captions_texts = [cap['caption'] for cap in captions] | |
| page_captions.append(captions_texts) | |
| return page_captions | |
| # Merge text and captions for each page | |
| # This function combines the text and captions for each page into a single string | |
| def merge_text_and_captions(page_texts, page_captions): | |
| """Merge text, image captions and table descriptions per page""" | |
| combined_pages = [] | |
| for page_num, (text, captions) in enumerate(zip(page_texts, page_captions), 1): | |
| page_content = text.strip() + "\n\n" | |
| for cap in captions: | |
| page_content += f"[Image Description]: {cap}\n\n" | |
| combined_pages.append(page_content) | |
| return combined_pages | |
| # Chunk the merged pages into smaller text chunks with metadata | |
| # This function splits the combined text of each page into smaller chunks | |
| def chunk_text_with_metadata(merged_pages): | |
| """ | |
| Given a list of pages (strings) with combined text and image captions, | |
| split each page's content into chunks, attach metadata, and collect all chunks. | |
| Args: | |
| merged_pages (List[str]): List where each item is the content (text + captions) of a single page. | |
| Returns: | |
| List[dict]: List of chunked dicts with keys: content, page, chunk_id, type | |
| """ | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| separators=["\n\n", "\n", ".", " ", ""], # Recursive splitting separators, from paragraphs to words | |
| chunk_size =1000, | |
| chunk_overlap =200, | |
| add_start_index=True | |
| ) | |
| all_chunks = [] | |
| chunk_global_id = 0 | |
| for page_num, page_content in enumerate(merged_pages, start=1): | |
| # Split page content into chunks | |
| page_chunks = text_splitter.split_text(page_content) | |
| # Tag metadata on each chunk | |
| for chunk_num, chunk_text in enumerate(page_chunks, start=1): | |
| chunk_dict = { | |
| "content": chunk_text, | |
| "page": page_num, | |
| "chunk_id": chunk_global_id, | |
| "chunk_number_on_page": chunk_num, | |
| "type": "extracted_texts_and_captions_descriptions" | |
| } | |
| all_chunks.append(chunk_dict) | |
| chunk_global_id += 1 | |
| return all_chunks | |
| # Preprocess the uploaded PDF | |
| def preprocess_pdf(file_path: str, image_dir: str, embedding_model, | |
| index_file: str = "index.faiss", | |
| chunks_file: str = "chunks.json", | |
| use_cache: bool = True) -> Tuple[faiss.IndexFlatIP, List[Dict]]: | |
| if not os.path.exists(file_path): | |
| raise FileNotFoundError(f"PDF not found at {file_path}") | |
| # Loading cache to save on compute time and resources everytime a query is made | |
| if use_cache and os.path.exists(index_file) and os.path.exists(chunks_file): | |
| print("[INFO] Loading cached FAISS index and chunks...") | |
| index = load_faiss_index(index_file) | |
| chunks = load_cache(chunks_file) | |
| return index, chunks | |
| # Cleanup stale cache if not using it or if missing | |
| if not use_cache or not (os.path.exists(index_file) and os.path.exists(chunks_file)): | |
| if os.path.exists(index_file): | |
| os.remove(index_file) | |
| if os.path.exists(chunks_file): | |
| os.remove(chunks_file) | |
| # Otherwise run full processing | |
| try: | |
| page_texts, page_images = extract_pages_text_and_images(file_path, image_dir) | |
| except Exception as e: | |
| print(f"Error reading PDF: {e}") | |
| raise e | |
| page_captions = generate_captions_per_page(page_images) | |
| merged_pages = merge_text_and_captions(page_texts, page_captions) | |
| # Delete extracted images after captioning | |
| cleanup_images(image_dir) | |
| # Chunk the merged pages | |
| chunks = chunk_text_with_metadata(merged_pages) | |
| texts = [chunk['content'] for chunk in chunks] | |
| # Geenrate embeddings and initialize faiss index with the dimensions of the embeddings | |
| embeddings = embedding_model.encode(texts, normalize_embeddings=True) | |
| embeddings = embeddings.astype(np.float32) # Making sure embeddings are in float32 format for FAISS | |
| embedding_dim = embeddings.shape[1] | |
| index = init_faiss_indexflatip(embedding_dim=embedding_dim) | |
| # Add embeddings to index | |
| add_embeddings_to_index(index=index, embeddings=embeddings) | |
| # Save index and chunks | |
| if use_cache: | |
| save_faiss_index(index, index_file) | |
| save_cache(chunks, chunks_file) | |
| return index, chunks | |
| # Semantic search funtion that uses preprocessed data | |
| def semantic_search(query, embedding_model, index, chunks, top_k=10): | |
| # Embed user query | |
| query_embedding = embedding_model.encode([query], normalize_embeddings=True) | |
| # Retrieve top matches from FAISS | |
| distances, indices = search_faiss_index(index, query_embedding, k=top_k) | |
| # Retrieve matched chunks | |
| retrieved_chunks = [chunks[i] for i in indices[0]] | |
| return retrieved_chunks | |
| # Generate answer for Gradio interface | |
| def generate_answer_stream(query, retrieved_chunks, model, processor): | |
| """Feeds tokens gradually from LLM.""" | |
| context_texts = [chunk['content'] for chunk in retrieved_chunks] | |
| # Combine system instruction, context, and query into a single string for the user role | |
| system_instruction = """You are a helpful and precise assistant for question-answering tasks. | |
| Use only the following pieces of retrieved context to answer the question. | |
| You may provide the response in a structured markdown response if necessary. | |
| If the answer is not found in the provided context, state that the information is not available in the document. Do not use any external knowledge or make assumptions. | |
| """ | |
| # Build the core prompt string, excluding specific turn markers | |
| # The processor.apply_chat_template will handle the proper formatting | |
| rag_prompt_content = "" | |
| if system_instruction: | |
| rag_prompt_content += f"{system_instruction.strip()}\n\n" | |
| if context_texts: | |
| rag_prompt_content += "Context:\n" +"-"+ "\n-".join(context_texts).strip() + "\n\n" | |
| rag_prompt_content += f"Question: {query.strip()}\nAnswer:" | |
| # Robust format for multimodal processor | |
| messages = [ | |
| {"role": "user", "content": [{"type": "text", "text": rag_prompt_content}]} | |
| ] | |
| # Prepare model inputs using apply_chat_template | |
| # This will correctly format the prompt for Gemma 3 | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, # Tell the model it is the start of its turn | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=4096 # Apply max_length here if needed, truncation will handle it | |
| ).to("cpu") | |
| streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, decode_kwargs={"skip_special_tokens": True}) | |
| with torch.inference_mode(): | |
| model.generate(**inputs, streamer=streamer, use_cache=True, max_new_tokens=512) | |
| gc.collect() # Free memory after model generation | |
| accumulated = "" | |
| for new_text in streamer: | |
| # time.sleep(0.2) | |
| accumulated += new_text | |
| yield accumulated | |
| # Free memory after streaming is complete | |
| clear_gpu_cache() | |
| gc.collect() | |