import glob import os import gc import time import re import hashlib from pathlib import Path from typing import List, Dict, Tuple, Optional import fitz # PyMuPDF import torch import numpy as np from PIL import Image from transformers import AutoModel, AutoProcessor, AutoTokenizer # Changed from AutoModelForCausalLM from langchain_core.documents import Document import pickle from numpy.linalg import norm import camelot import base64 import pytesseract from pdf2image import convert_from_path import faiss from rank_bm25 import BM25Okapi # ======================================== # ๐Ÿ“‚ CONFIGURATION # ======================================== PDF_DIR = r"D:\BeRU\testing" FAISS_INDEX_PATH = "VLM2Vec-V2rag2" MODEL_CACHE_DIR = ".cache" IMAGE_OUTPUT_DIR = "extracted_images2" # Chunking configuration CHUNK_SIZE = 450 # words OVERLAP = 100 # words MIN_CHUNK_SIZE = 50 MAX_CHUNK_SIZE = 800 # Instruction prefixes for better embeddings DOCUMENT_INSTRUCTION = "Represent this technical document for semantic search: " QUERY_INSTRUCTION = "Represent this question for finding relevant technical information: " # Hybrid search weights DENSE_WEIGHT = 0.4 # Weight for semantic search SPARSE_WEIGHT = 0.6 # Weight for keyword search # Create directories os.makedirs(PDF_DIR, exist_ok=True) os.makedirs(FAISS_INDEX_PATH, exist_ok=True) os.makedirs(MODEL_CACHE_DIR, exist_ok=True) os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True) # ======================================== # ๐Ÿค– VLM2Vec-V2 WRAPPER (ENHANCED) # ======================================== class VLM2VecEmbeddings: """VLM2Vec-V2 embedding class with instruction prefixes.""" def __init__(self, model_name: str = "TIGER-Lab/VLM2Vec-Qwen2VL-2B", cache_dir: str = None): print(f"๐Ÿค– Loading VLM2Vec-V2 model: {model_name}") self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f" Device: {self.device}") try: self.model = AutoModel.from_pretrained( model_name, cache_dir=cache_dir, trust_remote_code=True, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ).to(self.device) self.processor = AutoProcessor.from_pretrained( model_name, cache_dir=cache_dir, trust_remote_code=True ) self.tokenizer = AutoTokenizer.from_pretrained( model_name, cache_dir=cache_dir, trust_remote_code=True ) self.model.eval() # Get actual embedding dimension test_input = self.tokenizer("test", return_tensors="pt").to(self.device) with torch.no_grad(): test_output = self.model(**test_input, output_hidden_states=True) self.embedding_dim = test_output.hidden_states[-1].shape[-1] print(f" Embedding dimension: {self.embedding_dim}") print("โœ… VLM2Vec-V2 loaded successfully\n") except Exception as e: print(f"โŒ Error loading VLM2Vec-V2: {e}") raise def normalize_text(self, text: str) -> str: """Normalize text for better embeddings.""" # Remove excessive whitespace text = re.sub(r'\s+', ' ', text) # Remove page numbers text = re.sub(r'Page \d+', '', text, flags=re.IGNORECASE) # Normalize unicode text = text.strip() return text def embed_documents(self, texts: List[str], add_instruction: bool = True) -> List[List[float]]: """Embed documents with instruction prefix and weighted mean pooling.""" embeddings = [] with torch.no_grad(): for text in texts: try: # โœ… NORMALIZE TEXT clean_text = self.normalize_text(text) # โœ… ADD INSTRUCTION PREFIX if add_instruction: prefixed_text = DOCUMENT_INSTRUCTION + clean_text else: prefixed_text = clean_text inputs = self.tokenizer( prefixed_text, return_tensors="pt", padding=True, truncation=True, max_length=min(self.tokenizer.model_max_length or 512, 2048) ).to(self.device) outputs = self.model(**inputs, output_hidden_states=True) if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None: # โœ… WEIGHTED MEAN POOLING (ignores padding) hidden_states = outputs.hidden_states[-1] attention_mask = inputs['attention_mask'].unsqueeze(-1).float() # Apply attention mask as weights weighted_hidden_states = hidden_states * attention_mask sum_embeddings = weighted_hidden_states.sum(dim=1) sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9) # Weighted mean embedding = (sum_embeddings / sum_mask).squeeze() else: # Fallback to logits attention_mask = inputs['attention_mask'].unsqueeze(-1).float() weighted_logits = outputs.logits * attention_mask sum_embeddings = weighted_logits.sum(dim=1) sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9) embedding = (sum_embeddings / sum_mask).squeeze() embeddings.append(embedding.cpu().numpy().tolist()) except Exception as e: print(f" โŒ CRITICAL: Failed to embed text: {e}") print(f" Text preview: {text[:100]}") raise RuntimeError(f"Embedding failed for text: {text[:50]}...") from e return embeddings def embed_query(self, text: str) -> List[float]: """Embed query with query-specific instruction.""" # โœ… DIFFERENT INSTRUCTION FOR QUERIES clean_text = self.normalize_text(text) prefixed_text = QUERY_INSTRUCTION + clean_text # Don't add document instruction again return self.embed_documents([prefixed_text], add_instruction=False)[0] def embed_image(self, image_path: str, prompt: str = "Technical diagram") -> Optional[List[float]]: """Embed image with Qwen2-VL proper format.""" try: with torch.no_grad(): image = Image.open(image_path).convert('RGB') # โœ… QWEN2-VL CORRECT FORMAT messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt} ] } ] # Apply chat template text = self.processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Process with both text and images inputs = self.processor( text=[text], images=[image], return_tensors="pt", padding=True ).to(self.device) outputs = self.model(**inputs, output_hidden_states=True) if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None: hidden_states = outputs.hidden_states[-1] # Use weighted mean pooling if 'attention_mask' in inputs: attention_mask = inputs['attention_mask'].unsqueeze(-1).float() weighted_hidden_states = hidden_states * attention_mask sum_embeddings = weighted_hidden_states.sum(dim=1) sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9) embedding = (sum_embeddings / sum_mask).squeeze() else: embedding = hidden_states.mean(dim=1).squeeze() else: # Fallback to pooler output if available if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: embedding = outputs.pooler_output.squeeze() else: return None return embedding.cpu().numpy().tolist() except Exception as e: print(f" โš ๏ธ Failed to embed image {Path(image_path).name}: {str(e)[:100]}") return None # ======================================== # ๐Ÿ” QUERY PREPROCESSING # ======================================== def preprocess_query(query: str) -> str: """Preprocess query by expanding abbreviations.""" abbreviations = { r'\bh2s\b': 'hydrogen sulfide', r'\bppm\b': 'parts per million', r'\bppe\b': 'personal protective equipment', r'\bscba\b': 'self contained breathing apparatus', r'\blel\b': 'lower explosive limit', r'\bhel\b': 'higher explosive limit', r'\buel\b': 'upper explosive limit' } query_lower = query.lower() for abbr, full in abbreviations.items(): query_lower = re.sub(abbr, full, query_lower) # Remove excessive punctuation query_lower = re.sub(r'[?!]+$', '', query_lower) # Clean extra spaces query_lower = re.sub(r'\s+', ' ', query_lower).strip() return query_lower # ======================================== # ๐Ÿ“Š TABLE EXTRACTION # ======================================== def is_table_of_contents_header(df, page_num): """Detect TOC by checking first row for keywords.""" if len(df) == 0 or page_num > 15: return False # Check first row (headers) first_row = ' '.join(df.iloc[0].astype(str)).lower() # TOC keywords in your images toc_keywords = ['section', 'subsection', 'description', 'page no', 'page number', 'contents'] # If at least 2 keywords match, it's TOC keyword_count = sum(1 for keyword in toc_keywords if keyword in first_row) return keyword_count >= 2 def looks_like_toc_data(df): """Check if table data looks like TOC (section numbers + page numbers).""" if len(df) < 2 or len(df.columns) < 2: return False # Check last column: should be mostly page numbers (182-246 range in your case) last_col = df.iloc[1:, -1].astype(str) # Skip header row numeric_count = sum(val.strip().isdigit() and 50 < int(val.strip()) < 300 for val in last_col if val.strip().isdigit()) if len(last_col) > 0 and numeric_count / len(last_col) > 0.7: # Check first column: should have section numbers like "10.1", "10.2" first_col = df.iloc[1:, 0].astype(str) section_pattern = sum(1 for val in first_col if re.match(r'^\d+\.?\d*$', val.strip())) if section_pattern / len(first_col) > 0.5: return True return False def extract_tables_from_pdf(pdf_path: str) -> List[Document]: """Extract bordered tables with smart TOC detection.""" chunks = [] try: lattice_tables = camelot.read_pdf( pdf_path, pages='all', flavor='lattice', # Only bordered tables suppress_stdout=True ) all_tables = list(lattice_tables) seen_tables = set() # Track TOC state in_toc_section = False toc_start_page = None print(f" ๐Ÿ“Š Found {len(all_tables)} bordered tables") for table in all_tables: df = table.df current_page = table.page # Unique ID table_id = (current_page, tuple(df.iloc[0].tolist()) if len(df) > 0 else ()) if table_id in seen_tables: continue seen_tables.add(table_id) # Skip first 5 pages (title pages) if current_page <= 5: continue # Basic validation if len(df.columns) < 2 or len(df) < 3 or table.accuracy < 80: continue # โœ… Detect TOC start (page with header row) if not in_toc_section and is_table_of_contents_header(df, current_page): in_toc_section = True toc_start_page = current_page print(f" ๐Ÿ” TOC detected at page {current_page}") continue # โœ… If we're in TOC section, check if this continues the pattern if in_toc_section: if looks_like_toc_data(df): print(f" โญ๏ธ Skipping TOC continuation on page {current_page}") continue else: # TOC ended, resume normal extraction print(f" โœ… TOC ended, found real table on page {current_page}") in_toc_section = False # Extract valid table table_text = table_to_natural_language_enhanced(table) if table_text.strip(): chunks.append(Document( page_content=table_text, metadata={ "source": os.path.basename(pdf_path), "page": current_page, "heading": "Table Data", "type": "table", "table_accuracy": table.accuracy } )) print(f" โœ… Extracted {len(chunks)} valid tables (after TOC filtering)") except Exception as e: print(f"โš ๏ธ Table extraction failed: {e}") finally: try: del lattice_tables del all_tables gc.collect() time.sleep(0.1) except: pass return chunks def table_to_natural_language_enhanced(table) -> str: """Enhanced table-to-natural-language conversion.""" df = table.df if len(df) < 2: return "" headers = [str(h).strip() for h in df.iloc[0].astype(str).tolist()] headers = [h if h and h.lower() not in ['', 'nan', 'none'] else f"Column_{i}" for i, h in enumerate(headers)] descriptions = [] for idx in range(1, len(df)): row = [str(cell).strip() for cell in df.iloc[idx].astype(str).tolist()] if not any(cell and cell.lower() not in ['', 'nan', 'none'] for cell in row): continue if len(row) > 0 and row[0] and row[0].lower() not in ['', 'nan', 'none']: sentence_parts = [] for i in range(1, min(len(row), len(headers))): if row[i] and row[i].lower() not in ['', 'nan', 'none']: sentence_parts.append(f"{headers[i]}: {row[i]}") if sentence_parts: descriptions.append(f"{row[0]} has {', '.join(sentence_parts)}.") else: descriptions.append(f"{row[0]}.") return "\n".join(descriptions) def extract_tables_with_ocr(pdf_path: str, page_num: int) -> List[Dict]: """OCR fallback for image-based PDFs.""" try: images = convert_from_path(pdf_path, first_page=page_num, last_page=page_num) if not images: return [] ocr_text = pytesseract.image_to_string(images[0]) lines = ocr_text.split('\n') table_lines = [] for line in lines: if re.search(r'\s{2,}', line) or '\t' in line: table_lines.append(line) if len(table_lines) > 2: return [{ "text": "\n".join(table_lines), "page": page_num, "method": "ocr" }] return [] except Exception as e: return [] def get_table_regions(pdf_path: str) -> Dict[int, List[tuple]]: """Get bounding boxes using BOTH lattice and stream methods.""" table_regions = {} try: lattice_tables = camelot.read_pdf(pdf_path, pages='all', flavor='lattice', suppress_stdout=True) stream_tables = camelot.read_pdf(pdf_path, pages='all', flavor='stream', suppress_stdout=True) all_tables = list(lattice_tables) + list(stream_tables) for table in all_tables: page = table.page if is_table_of_contents_header(table.df, page): continue bbox = table._bbox if page not in table_regions: table_regions[page] = [] if bbox not in table_regions[page]: table_regions[page].append(bbox) except Exception as e: pass return table_regions # ======================================== # ๐Ÿ–ผ๏ธ IMAGE EXTRACTION # ======================================== def extract_images_from_pdf(pdf_path: str, output_dir: str) -> List[Dict]: """Extract images from PDF.""" doc = fitz.open(pdf_path) image_data = [] for page_num in range(len(doc)): page = doc[page_num] images = page.get_images() for img_index, img in enumerate(images): try: xref = img[0] base_image = doc.extract_image(xref) image_bytes = base_image["image"] if len(image_bytes) < 10000: continue image_filename = f"{Path(pdf_path).stem}_p{page_num+1}_img{img_index+1}.png" image_path = os.path.join(output_dir, image_filename) with open(image_path, "wb") as img_file: img_file.write(image_bytes) image_data.append({ "path": image_path, "page": page_num + 1, "source": os.path.basename(pdf_path), "type": "image" }) except Exception as e: continue doc.close() return image_data # ======================================== # ๐Ÿ“„ TEXT EXTRACTION WITH OVERLAPPING CHUNKS # ======================================== def is_bold_text(span): return "bold" in span['font'].lower() or (span['flags'] & 2**4) def is_likely_heading(text, font_size, is_bold, avg_font_size): if not is_bold: return False text = text.strip() if len(text) > 100 or len(text) < 3: return False if font_size > avg_font_size * 1.1: return True if text.isupper() or re.match(r'^\d+\.?\d*\s+[A-Z]', text): return True return False def is_inside_table(block_bbox, table_bboxes): """Check if text block overlaps with table region.""" bx1, by1, bx2, by2 = block_bbox for table_bbox in table_bboxes: tx1, ty1, tx2, ty2 = table_bbox if not (bx2 < tx1 or bx1 > tx2 or by2 < ty1 or by1 > ty2): return True return False def split_text_with_overlap(text: str, heading: str, source: str, page: int, chunk_size: int = CHUNK_SIZE, overlap: int = OVERLAP) -> List[Document]: """Split text with overlap and heading context.""" words = text.split() if len(words) <= chunk_size: # โœ… ADD HEADING CONTEXT content_with_context = f"Section: {heading}\n\n{text}" return [Document( page_content=content_with_context, metadata={ "source": source, "page": page, "heading": heading, "type": "text", "parent_text": text, "chunk_index": 0, "total_chunks": 1 } )] chunks = [] chunk_index = 0 for i in range(0, len(words), chunk_size - overlap): chunk_words = words[i:i + chunk_size] if len(chunk_words) < MIN_CHUNK_SIZE and len(chunks) > 0: break chunk_text = " ".join(chunk_words) # โœ… ADD HEADING CONTEXT TO EACH CHUNK content_with_context = f"Section: {heading}\n\n{chunk_text}" chunks.append(Document( page_content=content_with_context, metadata={ "source": source, "page": page, "heading": heading, "type": "text", "parent_text": text, "chunk_index": chunk_index, "start_word": i, "end_word": i + len(chunk_words) } )) chunk_index += 1 for chunk in chunks: chunk.metadata["total_chunks"] = len(chunks) return chunks def extract_text_chunks_with_overlap(pdf_path: str, table_regions: Dict[int, List[tuple]]) -> List[Document]: """Extract text with overlapping chunks.""" doc = fitz.open(pdf_path) all_font_sizes = [] for page_num in range(len(doc)): page = doc[page_num] blocks = page.get_text("dict")["blocks"] for block in blocks: if "lines" in block: for line in block["lines"]: for span in line["spans"]: all_font_sizes.append(span["size"]) avg_font_size = sum(all_font_sizes) / len(all_font_sizes) if all_font_sizes else 12 sections = [] current_section = "" current_heading = "Introduction" current_page = 1 for page_num in range(len(doc)): page = doc[page_num] blocks = page.get_text("dict")["blocks"] page_tables = table_regions.get(page_num + 1, []) for block in blocks: if "lines" not in block: continue block_bbox = block.get("bbox", (0, 0, 0, 0)) if is_inside_table(block_bbox, page_tables): continue for line in block["lines"]: line_text = "" line_is_bold = False line_font_size = 0 for span in line["spans"]: line_text += span["text"] if is_bold_text(span): line_is_bold = True line_font_size = max(line_font_size, span["size"]) line_text = line_text.strip() if not line_text: continue if is_likely_heading(line_text, line_font_size, line_is_bold, avg_font_size): if current_section.strip(): sections.append({ "text": current_section.strip(), "heading": current_heading, "page": current_page, "source": os.path.basename(pdf_path) }) current_heading = line_text current_section = "" current_page = page_num + 1 else: current_section += line_text + " " if current_section.strip(): sections.append({ "text": current_section.strip(), "heading": current_heading, "page": current_page, "source": os.path.basename(pdf_path) }) doc.close() all_chunks = [] for section in sections: chunks = split_text_with_overlap( text=section['text'], heading=section['heading'], source=section['source'], page=section['page'], chunk_size=CHUNK_SIZE, overlap=OVERLAP ) all_chunks.extend(chunks) return all_chunks # ======================================== # ๐Ÿ”„ COMBINED EXTRACTION # ======================================== def extract_all_content_from_pdf(pdf_path: str) -> Tuple[List[Document], List[Dict]]: """Extract text, tables, and images.""" print(f" ๐Ÿ“Š Extracting tables...") table_regions = get_table_regions(pdf_path) table_chunks = extract_tables_from_pdf(pdf_path) print(f" โœ… {len(table_chunks)} table chunks") print(f" ๐Ÿ“„ Extracting text...") text_chunks = extract_text_chunks_with_overlap(pdf_path, table_regions) print(f" โœ… {len(text_chunks)} text chunks") print(f" ๐Ÿ–ผ๏ธ Extracting images...") images = extract_images_from_pdf(pdf_path, IMAGE_OUTPUT_DIR) print(f" โœ… {len(images)} images") all_chunks = text_chunks + table_chunks return all_chunks, images # ======================================== # ๐Ÿ—๏ธ BUILD FAISS INDEX WITH STREAMING # ======================================== # Replace the HybridRetriever class and related functions with this optimized version: # ======================================== # ๐Ÿ—๏ธ BUILD FAISS INDEX WITH BM25 # ======================================== def build_multimodal_faiss_streaming(pdf_files: List[str], embedding_model: VLM2VecEmbeddings): """Build FAISS index with streaming and BM25.""" index_hash_file = f"{FAISS_INDEX_PATH}/index_hash.txt" current_hash = hashlib.md5("".join(sorted(pdf_files)).encode()).hexdigest() if os.path.exists(index_hash_file): with open(index_hash_file, 'r') as f: existing_hash = f.read().strip() if existing_hash == current_hash: print("โš ๏ธ Index already exists for these PDFs!") response = input(" Rebuild anyway? (yes/no): ").strip().lower() if response != 'yes': return None, [] all_texts = [] all_image_paths = [] print("\n๐Ÿ“„ Processing PDFs...\n") for pdf_file in pdf_files: print(f"๐Ÿ“– Processing: {Path(pdf_file).name}") try: text_chunks, images = extract_all_content_from_pdf(pdf_file) all_texts.extend(text_chunks) all_image_paths.extend(images) except Exception as e: print(f" โŒ Error: {e}") continue print() print(f"โœ… Total chunks: {len(all_texts)}") print(f"โœ… Total images: {len(all_image_paths)}\n") if len(all_texts) == 0: print("โŒ No content extracted!") return None, [] # Build text index print("๐Ÿ”— Generating text embeddings...\n") text_index = None batch_size = 10 for i in range(0, len(all_texts), batch_size): batch = all_texts[i:i+batch_size] batch_contents = [doc.page_content for doc in batch] try: batch_embeddings = embedding_model.embed_documents(batch_contents, add_instruction=True) batch_embeddings_np = np.array(batch_embeddings).astype('float32') if text_index is None: dimension = batch_embeddings_np.shape[1] text_index = faiss.IndexFlatIP(dimension) print(f" Text embedding dimension: {dimension}") faiss.normalize_L2(batch_embeddings_np) text_index.add(batch_embeddings_np) if (i // batch_size + 1) % 5 == 0: print(f" Progress: {i + len(batch)}/{len(all_texts)}") except Exception as e: print(f" โŒ Error: {e}") raise print(f" โœ… Complete") # Save FAISS index faiss.write_index(text_index, f"{FAISS_INDEX_PATH}/text_index.faiss") # Save documents with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "wb") as f: pickle.dump(all_texts, f) # โœ… BUILD AND SAVE BM25 INDEX print("\n๐Ÿ” Building BM25 index for keyword search...") tokenized_docs = [doc.page_content.lower().split() for doc in all_texts] bm25_index = BM25Okapi(tokenized_docs,k1=1.3, b=0.65) with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "wb") as f: pickle.dump(bm25_index, f) print(" โœ… BM25 index saved") # Build image index if len(all_image_paths) > 0: print(f"\n๐Ÿ–ผ๏ธ Embedding images...") image_index = None successful_images = [] for idx, img_data in enumerate(all_image_paths): img_embedding = embedding_model.embed_image(img_data["path"]) if img_embedding is None: continue img_embedding_np = np.array([img_embedding]).astype('float32') if image_index is None: dimension = img_embedding_np.shape[1] image_index = faiss.IndexFlatIP(dimension) print(f" Image dimension: {dimension}") faiss.normalize_L2(img_embedding_np) image_index.add(img_embedding_np) successful_images.append(img_data) if (len(successful_images)) % 10 == 0: print(f" Progress: {len(successful_images)}/{len(all_image_paths)}") print(f" โœ… {len(successful_images)} images embedded") if image_index is not None and len(successful_images) > 0: faiss.write_index(image_index, f"{FAISS_INDEX_PATH}/image_index.faiss") with open(f"{FAISS_INDEX_PATH}/image_documents.pkl", "wb") as f: pickle.dump(successful_images, f) # Save hash with open(index_hash_file, 'w') as f: f.write(current_hash) print(f"\nโœ… Index saved: {FAISS_INDEX_PATH}\n") return text_index, all_texts # ======================================== # ๐Ÿ” OPTIMIZED HYBRID SEARCH # ======================================== # ======================================== # ๐Ÿ“Š QUERY WITH BM25 ONLY # ======================================== def query_with_bm25(query: str, k_text: int = 5, k_images: int = 3): """Query using BM25 keyword search only.""" # โœ… PREPROCESS QUERY processed_query = preprocess_query(query) print(f" ๐Ÿ” Processed: {processed_query}") # Load documents with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f: text_docs = pickle.load(f) # โœ… LOAD BM25 INDEX try: with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "rb") as f: bm25_index = pickle.load(f) except FileNotFoundError: print(" โš ๏ธ BM25 index not found, building on-the-fly...") tokenized_docs = [doc.page_content.lower().split() for doc in text_docs] bm25_index = BM25Okapi(tokenized_docs) # BM25 SEARCH ONLY tokenized_query = processed_query.lower().split() bm25_scores = bm25_index.get_scores(tokenized_query) # Get top k results top_indices = np.argsort(bm25_scores)[::-1][:k_text] text_results = [] relevant_pages = set() for rank, idx in enumerate(top_indices, 1): doc = text_docs[idx] score = float(bm25_scores[idx]) text_results.append({ "document": doc, "score": score, "rank": rank, "type": doc.metadata.get('type', 'text') }) relevant_pages.add((doc.metadata.get('source'), doc.metadata.get('page'))) # Get images from relevant pages (not semantic search) relevant_images = [] try: image_docs_path = f"{FAISS_INDEX_PATH}/image_documents.pkl" if os.path.exists(image_docs_path): with open(image_docs_path, "rb") as f: image_docs = pickle.load(f) # Get images from same pages as top text results for img_doc in image_docs: img_page = (img_doc['source'], img_doc['page']) if img_page in relevant_pages and len(relevant_images) < k_images: relevant_images.append({ "path": img_doc['path'], "source": img_doc['source'], "page": img_doc['page'], "type": "image", "score": 0.0, "rank": len(relevant_images) + 1, "from_page": True }) except Exception as e: pass return { "text_results": text_results, "images": relevant_images, "query": query, "processed_query": processed_query } # ======================================== # ๐Ÿ“Š DISPLAY RESULTS (BM25 ONLY) # ======================================== def display_results_bm25(results: Dict): """Display BM25 results.""" print("\n๐Ÿ“š TOP RESULTS (BM25 Keyword Search):\n") for result in results['text_results']: doc = result["document"] print(f"[{result['rank']}] BM25 Score: {result['score']:.4f} | {doc.metadata.get('type', 'N/A')}") print(f" ๐Ÿ“„ {doc.metadata.get('source')} - Page {doc.metadata.get('page')}") print(f" ๐Ÿ“Œ {doc.metadata.get('heading', 'N/A')[:60]}") if 'total_chunks' in doc.metadata and doc.metadata.get('total_chunks', 1) > 1: print(f" ๐Ÿ”— Chunk {doc.metadata.get('chunk_index', 0)+1}/{doc.metadata.get('total_chunks')}") print(f" ๐Ÿ“ {doc.page_content[:200]}...") print() print("\n๐Ÿ–ผ๏ธ IMAGES:\n") if results['images']: for img in results['images']: print(f"[{img['rank']}] {img['source']} - Page {img['page']}") print(f" {img['path']}\n") else: print(" No images found\n") # ======================================== # ๐Ÿ” HYBRID SEARCH IMPLEMENTATION # ======================================== def normalize_scores(scores: np.ndarray) -> np.ndarray: """Min-max normalization to 0-1 range.""" if len(scores) == 0: return scores min_score = np.min(scores) max_score = np.max(scores) if max_score == min_score: return np.ones_like(scores) return (scores - min_score) / (max_score - min_score) def query_with_hybrid(query: str, embedding_model: VLM2VecEmbeddings, k_text: int = 5, k_images: int = 3, dense_weight: float = DENSE_WEIGHT, sparse_weight: float = SPARSE_WEIGHT): """ Hybrid search combining semantic (FAISS) and keyword (BM25) retrieval. """ processed_query = preprocess_query(query) print(f" ๐Ÿ” Processed: {processed_query}") with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f: text_docs = pickle.load(f) # SEMANTIC SEARCH print(f" ๐Ÿง  Running semantic search...") try: text_index = faiss.read_index(f"{FAISS_INDEX_PATH}/text_index.faiss") query_embedding = embedding_model.embed_query(processed_query) query_np = np.array([query_embedding]).astype('float32') faiss.normalize_L2(query_np) k_retrieve = min(k_text * 3, len(text_docs)) distances, indices = text_index.search(query_np, k_retrieve) semantic_scores = distances[0] semantic_indices = indices[0] print(f" โœ… Retrieved {len(semantic_indices)} semantic results") except Exception as e: print(f" โš ๏ธ Semantic search failed: {e}") semantic_scores = np.array([]) semantic_indices = np.array([]) # BM25 SEARCH print(f" ๐Ÿ”ค Running BM25 keyword search...") try: with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "rb") as f: bm25_index = pickle.load(f) except FileNotFoundError: tokenized_docs = [doc.page_content.lower().split() for doc in text_docs] bm25_index = BM25Okapi(tokenized_docs, k1=1.3, b=0.65) tokenized_query = processed_query.lower().split() bm25_scores_all = bm25_index.get_scores(tokenized_query) print(f" โœ… Scored {len(bm25_scores_all)} documents") # SCORE FUSION print(f" โš–๏ธ Fusing scores (semantic: {dense_weight}, BM25: {sparse_weight})...") combined_scores = {} if len(semantic_scores) > 0: semantic_scores_norm = normalize_scores(semantic_scores) for idx, score in zip(semantic_indices, semantic_scores_norm): if idx < len(text_docs): combined_scores[idx] = dense_weight * score bm25_scores_norm = normalize_scores(bm25_scores_all) for idx, score in enumerate(bm25_scores_norm): if idx in combined_scores: combined_scores[idx] += sparse_weight * score else: combined_scores[idx] = sparse_weight * score sorted_indices = sorted(combined_scores.keys(), key=lambda x: combined_scores[x], reverse=True) top_indices = sorted_indices[:k_text] print(f" โœ… Top {len(top_indices)} results selected") # PREPARE RESULTS text_results = [] relevant_pages = set() for rank, idx in enumerate(top_indices, 1): doc = text_docs[idx] semantic_score = semantic_scores_norm[np.where(semantic_indices == idx)[0][0]] if idx in semantic_indices else 0.0 bm25_score = bm25_scores_norm[idx] combined_score = combined_scores[idx] text_results.append({ "document": doc, "score": combined_score, "semantic_score": float(semantic_score), "bm25_score": float(bm25_score), "rank": rank, "type": doc.metadata.get('type', 'text') }) relevant_pages.add((doc.metadata.get('source'), doc.metadata.get('page'))) # GET IMAGES relevant_images = [] try: image_docs_path = f"{FAISS_INDEX_PATH}/image_documents.pkl" if os.path.exists(image_docs_path): with open(image_docs_path, "rb") as f: image_docs = pickle.load(f) for img_doc in image_docs: img_page = (img_doc['source'], img_doc['page']) if img_page in relevant_pages and len(relevant_images) < k_images: relevant_images.append({ "path": img_doc['path'], "source": img_doc['source'], "page": img_doc['page'], "type": "image", "score": 0.0, "rank": len(relevant_images) + 1, "from_page": True }) except Exception as e: pass return { "text_results": text_results, "images": relevant_images, "query": query, "processed_query": processed_query, "method": "hybrid" } def display_results_hybrid(results: Dict): """Display hybrid search results.""" print("\n๐Ÿ“š TOP RESULTS (Hybrid Search: Semantic + BM25):\n") for result in results['text_results']: doc = result["document"] print(f"[{result['rank']}] Combined: {result['score']:.4f} " f"(Semantic: {result['semantic_score']:.4f}, BM25: {result['bm25_score']:.4f}) " f"| {doc.metadata.get('type', 'N/A')}") print(f" ๐Ÿ“„ {doc.metadata.get('source')} - Page {doc.metadata.get('page')}") print(f" ๐Ÿ“Œ {doc.metadata.get('heading', 'N/A')[:60]}") if 'total_chunks' in doc.metadata and doc.metadata.get('total_chunks', 1) > 1: print(f" ๐Ÿ”— Chunk {doc.metadata.get('chunk_index', 0)+1}/{doc.metadata.get('total_chunks')}") print(f" ๐Ÿ“ {doc.page_content[:200]}...") print() print("\n๐Ÿ–ผ๏ธ IMAGES:\n") if results['images']: for img in results['images']: print(f"[{img['rank']}] {img['source']} - Page {img['page']}") print(f" {img['path']}\n") else: print(" No images found\n") # ======================================== # ๐Ÿ“– GET CONTEXT WITH PARENTS # ======================================== def get_context_with_parents(results: Dict) -> List[Dict]: """Extract full parent contexts.""" seen_parents = set() contexts = [] for result in results['text_results']: doc = result['document'] parent = doc.metadata.get('parent_text') if parent and parent not in seen_parents: contexts.append({ "text": parent, "source": doc.metadata['source'], "page": doc.metadata['page'], "heading": doc.metadata['heading'], "type": doc.metadata.get('type', 'text'), "is_parent": True }) seen_parents.add(parent) elif not parent: contexts.append({ "text": doc.page_content, "source": doc.metadata['source'], "page": doc.metadata['page'], "heading": doc.metadata['heading'], "type": doc.metadata.get('type', 'text'), "is_parent": False }) return contexts # ======================================== # ๐Ÿš€ MAIN EXECUTION (UPDATED FOR HYBRID) # ======================================== if __name__ == "__main__": print("="*70) print("๐Ÿš€ RAG with HYBRID SEARCH (Semantic + BM25)") print("="*70 + "\n") pdf_files = glob.glob(f"{PDF_DIR}/*.pdf") print(f"๐Ÿ“‚ Found {len(pdf_files)} PDF files\n") if len(pdf_files) == 0: print("โŒ No PDFs found!") exit(1) print("\n๐Ÿค– Loading VLM2Vec model...") embedding_model = VLM2VecEmbeddings( model_name="TIGER-Lab/VLM2Vec-Qwen2VL-2B", cache_dir=MODEL_CACHE_DIR ) # Load or build index if os.path.exists(f"{FAISS_INDEX_PATH}/text_index.faiss"): print(f"โœ… Loading existing index\n") if not os.path.exists(f"{FAISS_INDEX_PATH}/bm25_index.pkl"): print("โš ๏ธ BM25 index missing, building now...") with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f: all_texts = pickle.load(f) print(" Building BM25 index...") tokenized_docs = [doc.page_content.lower().split() for doc in all_texts] bm25_index = BM25Okapi(tokenized_docs, k1=1.3, b=0.65) with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "wb") as f: pickle.dump(bm25_index, f) print(" โœ… BM25 index saved\n") else: print("๐Ÿ”จ Building new index...\n") embedding_model = VLM2VecEmbeddings( model_name="TIGER-Lab/VLM2Vec-Qwen2VL-2B", cache_dir=MODEL_CACHE_DIR ) index, documents = build_multimodal_faiss_streaming(pdf_files, embedding_model) if index is None: exit(0) # Interactive testing print("="*70) print("๐Ÿงช TESTING MODE - HYBRID SEARCH") print(f" Weights: Semantic {DENSE_WEIGHT} | BM25 {SPARSE_WEIGHT}") print("="*70 + "\n") test_queries = [ "What is the higher and lower explosive limit of butane?", "What are the precautions taken while handling H2S?", "What are the Personal Protection used for Sulfolane?", "What is the Composition of Platforming Feed and Product?", "Explain Dual function platforming catalyst chemistry.", "Steps to be followed in Amine Regeneration Unit for normal shutdown process.", "Could you tell me what De-greasing of Amine System in pre startup wash", ] print("๐Ÿ“‹ SUGGESTED QUERIES:") for i, q in enumerate(test_queries, 1): print(f" {i}. {q}") print() print("๐Ÿ’ก Type 'mode' to switch between hybrid/bm25/semantic") print() current_mode = "hybrid" while True: user_query = input(f"๐Ÿ’ฌ Query [{current_mode}] (or 1-5, 'mode', or 'exit'): ").strip() if user_query.lower() == 'exit': print("\nโœ… Done!") break if user_query.lower() == 'mode': print("\n๐Ÿ”„ Select mode:") print(" 1. Hybrid (Semantic + BM25)") print(" 2. BM25 only") print(" 3. Semantic only") mode_choice = input(" Choice (1-3): ").strip() if mode_choice == '1': current_mode = "hybrid" elif mode_choice == '2': current_mode = "bm25" elif mode_choice == '3': current_mode = "semantic" print(f" โœ… Mode set to: {current_mode}\n") continue if user_query.isdigit() and 1 <= int(user_query) <= len(test_queries): user_query = test_queries[int(user_query) - 1] if not user_query: continue print(f"\n{'='*60}") print(f"๐Ÿ” Query: {user_query}") print(f"๐Ÿ”ง Mode: {current_mode.upper()}") print(f"{'='*60}\n") try: if current_mode == "hybrid": results = query_with_hybrid(user_query, embedding_model, k_text=5, k_images=3) display_results_hybrid(results) elif current_mode == "bm25": results = query_with_bm25(user_query, k_text=5, k_images=3) display_results_bm25(results) else: # semantic only results = query_with_hybrid(user_query, embedding_model, k_text=5, k_images=3, dense_weight=1.0, sparse_weight=0.0) display_results_hybrid(results) print("\n๐Ÿ“– FULL CONTEXT:\n") contexts = get_context_with_parents(results) for i, ctx in enumerate(contexts[:3], 1): print(f"[{i}] {ctx['heading'][:50]}") if ctx['is_parent']: print(f" โœ… Full section") print(f" {ctx['text'][:300]}...\n") print("="*60 + "\n") except Exception as e: print(f"\nโŒ Error: {e}\n") import traceback traceback.print_exc()