BeRU Deployer
Deploy BeRU Streamlit RAG System - Add app, models logic, configs, and optimizations for HF Spaces
dec533d | import glob | |
| import os | |
| import gc | |
| import time | |
| import re | |
| import hashlib | |
| from pathlib import Path | |
| from typing import List, Dict, Tuple, Optional | |
| import fitz # PyMuPDF | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from transformers import AutoModel, AutoProcessor, AutoTokenizer # Changed from AutoModelForCausalLM | |
| from langchain_core.documents import Document | |
| import pickle | |
| from numpy.linalg import norm | |
| import camelot | |
| import base64 | |
| import pytesseract | |
| from pdf2image import convert_from_path | |
| import faiss | |
| from rank_bm25 import BM25Okapi | |
| # ======================================== | |
| # π CONFIGURATION | |
| # ======================================== | |
| PDF_DIR = r"D:\BeRU\testing" | |
| FAISS_INDEX_PATH = "VLM2Vec-V2rag2" | |
| MODEL_CACHE_DIR = ".cache" | |
| IMAGE_OUTPUT_DIR = "extracted_images2" | |
| # Chunking configuration | |
| CHUNK_SIZE = 450 # words | |
| OVERLAP = 100 # words | |
| MIN_CHUNK_SIZE = 50 | |
| MAX_CHUNK_SIZE = 800 | |
| # Instruction prefixes for better embeddings | |
| DOCUMENT_INSTRUCTION = "Represent this technical document for semantic search: " | |
| QUERY_INSTRUCTION = "Represent this question for finding relevant technical information: " | |
| # Hybrid search weights | |
| DENSE_WEIGHT = 0.4 # Weight for semantic search | |
| SPARSE_WEIGHT = 0.6 # Weight for keyword search | |
| # Create directories | |
| os.makedirs(PDF_DIR, exist_ok=True) | |
| os.makedirs(FAISS_INDEX_PATH, exist_ok=True) | |
| os.makedirs(MODEL_CACHE_DIR, exist_ok=True) | |
| os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True) | |
| # ======================================== | |
| # π€ VLM2Vec-V2 WRAPPER (ENHANCED) | |
| # ======================================== | |
| class VLM2VecEmbeddings: | |
| """VLM2Vec-V2 embedding class with instruction prefixes.""" | |
| def __init__(self, model_name: str = "TIGER-Lab/VLM2Vec-Qwen2VL-2B", cache_dir: str = None): | |
| print(f"π€ Loading VLM2Vec-V2 model: {model_name}") | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f" Device: {self.device}") | |
| try: | |
| self.model = AutoModel.from_pretrained( | |
| model_name, | |
| cache_dir=cache_dir, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ).to(self.device) | |
| self.processor = AutoProcessor.from_pretrained( | |
| model_name, | |
| cache_dir=cache_dir, | |
| trust_remote_code=True | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| cache_dir=cache_dir, | |
| trust_remote_code=True | |
| ) | |
| self.model.eval() | |
| # Get actual embedding dimension | |
| test_input = self.tokenizer("test", return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| test_output = self.model(**test_input, output_hidden_states=True) | |
| self.embedding_dim = test_output.hidden_states[-1].shape[-1] | |
| print(f" Embedding dimension: {self.embedding_dim}") | |
| print("β VLM2Vec-V2 loaded successfully\n") | |
| except Exception as e: | |
| print(f"β Error loading VLM2Vec-V2: {e}") | |
| raise | |
| def normalize_text(self, text: str) -> str: | |
| """Normalize text for better embeddings.""" | |
| # Remove excessive whitespace | |
| text = re.sub(r'\s+', ' ', text) | |
| # Remove page numbers | |
| text = re.sub(r'Page \d+', '', text, flags=re.IGNORECASE) | |
| # Normalize unicode | |
| text = text.strip() | |
| return text | |
| def embed_documents(self, texts: List[str], add_instruction: bool = True) -> List[List[float]]: | |
| """Embed documents with instruction prefix and weighted mean pooling.""" | |
| embeddings = [] | |
| with torch.no_grad(): | |
| for text in texts: | |
| try: | |
| # β NORMALIZE TEXT | |
| clean_text = self.normalize_text(text) | |
| # β ADD INSTRUCTION PREFIX | |
| if add_instruction: | |
| prefixed_text = DOCUMENT_INSTRUCTION + clean_text | |
| else: | |
| prefixed_text = clean_text | |
| inputs = self.tokenizer( | |
| prefixed_text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=min(self.tokenizer.model_max_length or 512, 2048) | |
| ).to(self.device) | |
| outputs = self.model(**inputs, output_hidden_states=True) | |
| if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None: | |
| # β WEIGHTED MEAN POOLING (ignores padding) | |
| hidden_states = outputs.hidden_states[-1] | |
| attention_mask = inputs['attention_mask'].unsqueeze(-1).float() | |
| # Apply attention mask as weights | |
| weighted_hidden_states = hidden_states * attention_mask | |
| sum_embeddings = weighted_hidden_states.sum(dim=1) | |
| sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9) | |
| # Weighted mean | |
| embedding = (sum_embeddings / sum_mask).squeeze() | |
| else: | |
| # Fallback to logits | |
| attention_mask = inputs['attention_mask'].unsqueeze(-1).float() | |
| weighted_logits = outputs.logits * attention_mask | |
| sum_embeddings = weighted_logits.sum(dim=1) | |
| sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9) | |
| embedding = (sum_embeddings / sum_mask).squeeze() | |
| embeddings.append(embedding.cpu().numpy().tolist()) | |
| except Exception as e: | |
| print(f" β CRITICAL: Failed to embed text: {e}") | |
| print(f" Text preview: {text[:100]}") | |
| raise RuntimeError(f"Embedding failed for text: {text[:50]}...") from e | |
| return embeddings | |
| def embed_query(self, text: str) -> List[float]: | |
| """Embed query with query-specific instruction.""" | |
| # β DIFFERENT INSTRUCTION FOR QUERIES | |
| clean_text = self.normalize_text(text) | |
| prefixed_text = QUERY_INSTRUCTION + clean_text | |
| # Don't add document instruction again | |
| return self.embed_documents([prefixed_text], add_instruction=False)[0] | |
| def embed_image(self, image_path: str, prompt: str = "Technical diagram") -> Optional[List[float]]: | |
| """Embed image with Qwen2-VL proper format.""" | |
| try: | |
| with torch.no_grad(): | |
| image = Image.open(image_path).convert('RGB') | |
| # β QWEN2-VL CORRECT FORMAT | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt} | |
| ] | |
| } | |
| ] | |
| # Apply chat template | |
| text = self.processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| # Process with both text and images | |
| inputs = self.processor( | |
| text=[text], | |
| images=[image], | |
| return_tensors="pt", | |
| padding=True | |
| ).to(self.device) | |
| outputs = self.model(**inputs, output_hidden_states=True) | |
| if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None: | |
| hidden_states = outputs.hidden_states[-1] | |
| # Use weighted mean pooling | |
| if 'attention_mask' in inputs: | |
| attention_mask = inputs['attention_mask'].unsqueeze(-1).float() | |
| weighted_hidden_states = hidden_states * attention_mask | |
| sum_embeddings = weighted_hidden_states.sum(dim=1) | |
| sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9) | |
| embedding = (sum_embeddings / sum_mask).squeeze() | |
| else: | |
| embedding = hidden_states.mean(dim=1).squeeze() | |
| else: | |
| # Fallback to pooler output if available | |
| if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: | |
| embedding = outputs.pooler_output.squeeze() | |
| else: | |
| return None | |
| return embedding.cpu().numpy().tolist() | |
| except Exception as e: | |
| print(f" β οΈ Failed to embed image {Path(image_path).name}: {str(e)[:100]}") | |
| return None | |
| # ======================================== | |
| # π QUERY PREPROCESSING | |
| # ======================================== | |
| def preprocess_query(query: str) -> str: | |
| """Preprocess query by expanding abbreviations.""" | |
| abbreviations = { | |
| r'\bh2s\b': 'hydrogen sulfide', | |
| r'\bppm\b': 'parts per million', | |
| r'\bppe\b': 'personal protective equipment', | |
| r'\bscba\b': 'self contained breathing apparatus', | |
| r'\blel\b': 'lower explosive limit', | |
| r'\bhel\b': 'higher explosive limit', | |
| r'\buel\b': 'upper explosive limit' | |
| } | |
| query_lower = query.lower() | |
| for abbr, full in abbreviations.items(): | |
| query_lower = re.sub(abbr, full, query_lower) | |
| # Remove excessive punctuation | |
| query_lower = re.sub(r'[?!]+$', '', query_lower) | |
| # Clean extra spaces | |
| query_lower = re.sub(r'\s+', ' ', query_lower).strip() | |
| return query_lower | |
| # ======================================== | |
| # π TABLE EXTRACTION | |
| # ======================================== | |
| def is_table_of_contents_header(df, page_num): | |
| """Detect TOC by checking first row for keywords.""" | |
| if len(df) == 0 or page_num > 15: | |
| return False | |
| # Check first row (headers) | |
| first_row = ' '.join(df.iloc[0].astype(str)).lower() | |
| # TOC keywords in your images | |
| toc_keywords = ['section', 'subsection', 'description', 'page no', 'page number', 'contents'] | |
| # If at least 2 keywords match, it's TOC | |
| keyword_count = sum(1 for keyword in toc_keywords if keyword in first_row) | |
| return keyword_count >= 2 | |
| def looks_like_toc_data(df): | |
| """Check if table data looks like TOC (section numbers + page numbers).""" | |
| if len(df) < 2 or len(df.columns) < 2: | |
| return False | |
| # Check last column: should be mostly page numbers (182-246 range in your case) | |
| last_col = df.iloc[1:, -1].astype(str) # Skip header row | |
| numeric_count = sum(val.strip().isdigit() and 50 < int(val.strip()) < 300 | |
| for val in last_col if val.strip().isdigit()) | |
| if len(last_col) > 0 and numeric_count / len(last_col) > 0.7: | |
| # Check first column: should have section numbers like "10.1", "10.2" | |
| first_col = df.iloc[1:, 0].astype(str) | |
| section_pattern = sum(1 for val in first_col | |
| if re.match(r'^\d+\.?\d*$', val.strip())) | |
| if section_pattern / len(first_col) > 0.5: | |
| return True | |
| return False | |
| def extract_tables_from_pdf(pdf_path: str) -> List[Document]: | |
| """Extract bordered tables with smart TOC detection.""" | |
| chunks = [] | |
| try: | |
| lattice_tables = camelot.read_pdf( | |
| pdf_path, | |
| pages='all', | |
| flavor='lattice', # Only bordered tables | |
| suppress_stdout=True | |
| ) | |
| all_tables = list(lattice_tables) | |
| seen_tables = set() | |
| # Track TOC state | |
| in_toc_section = False | |
| toc_start_page = None | |
| print(f" π Found {len(all_tables)} bordered tables") | |
| for table in all_tables: | |
| df = table.df | |
| current_page = table.page | |
| # Unique ID | |
| table_id = (current_page, tuple(df.iloc[0].tolist()) if len(df) > 0 else ()) | |
| if table_id in seen_tables: | |
| continue | |
| seen_tables.add(table_id) | |
| # Skip first 5 pages (title pages) | |
| if current_page <= 5: | |
| continue | |
| # Basic validation | |
| if len(df.columns) < 2 or len(df) < 3 or table.accuracy < 80: | |
| continue | |
| # β Detect TOC start (page with header row) | |
| if not in_toc_section and is_table_of_contents_header(df, current_page): | |
| in_toc_section = True | |
| toc_start_page = current_page | |
| print(f" π TOC detected at page {current_page}") | |
| continue | |
| # β If we're in TOC section, check if this continues the pattern | |
| if in_toc_section: | |
| if looks_like_toc_data(df): | |
| print(f" βοΈ Skipping TOC continuation on page {current_page}") | |
| continue | |
| else: | |
| # TOC ended, resume normal extraction | |
| print(f" β TOC ended, found real table on page {current_page}") | |
| in_toc_section = False | |
| # Extract valid table | |
| table_text = table_to_natural_language_enhanced(table) | |
| if table_text.strip(): | |
| chunks.append(Document( | |
| page_content=table_text, | |
| metadata={ | |
| "source": os.path.basename(pdf_path), | |
| "page": current_page, | |
| "heading": "Table Data", | |
| "type": "table", | |
| "table_accuracy": table.accuracy | |
| } | |
| )) | |
| print(f" β Extracted {len(chunks)} valid tables (after TOC filtering)") | |
| except Exception as e: | |
| print(f"β οΈ Table extraction failed: {e}") | |
| finally: | |
| try: | |
| del lattice_tables | |
| del all_tables | |
| gc.collect() | |
| time.sleep(0.1) | |
| except: | |
| pass | |
| return chunks | |
| def table_to_natural_language_enhanced(table) -> str: | |
| """Enhanced table-to-natural-language conversion.""" | |
| df = table.df | |
| if len(df) < 2: | |
| return "" | |
| headers = [str(h).strip() for h in df.iloc[0].astype(str).tolist()] | |
| headers = [h if h and h.lower() not in ['', 'nan', 'none'] else f"Column_{i}" | |
| for i, h in enumerate(headers)] | |
| descriptions = [] | |
| for idx in range(1, len(df)): | |
| row = [str(cell).strip() for cell in df.iloc[idx].astype(str).tolist()] | |
| if not any(cell and cell.lower() not in ['', 'nan', 'none'] for cell in row): | |
| continue | |
| if len(row) > 0 and row[0] and row[0].lower() not in ['', 'nan', 'none']: | |
| sentence_parts = [] | |
| for i in range(1, min(len(row), len(headers))): | |
| if row[i] and row[i].lower() not in ['', 'nan', 'none']: | |
| sentence_parts.append(f"{headers[i]}: {row[i]}") | |
| if sentence_parts: | |
| descriptions.append(f"{row[0]} has {', '.join(sentence_parts)}.") | |
| else: | |
| descriptions.append(f"{row[0]}.") | |
| return "\n".join(descriptions) | |
| def extract_tables_with_ocr(pdf_path: str, page_num: int) -> List[Dict]: | |
| """OCR fallback for image-based PDFs.""" | |
| try: | |
| images = convert_from_path(pdf_path, first_page=page_num, last_page=page_num) | |
| if not images: | |
| return [] | |
| ocr_text = pytesseract.image_to_string(images[0]) | |
| lines = ocr_text.split('\n') | |
| table_lines = [] | |
| for line in lines: | |
| if re.search(r'\s{2,}', line) or '\t' in line: | |
| table_lines.append(line) | |
| if len(table_lines) > 2: | |
| return [{ | |
| "text": "\n".join(table_lines), | |
| "page": page_num, | |
| "method": "ocr" | |
| }] | |
| return [] | |
| except Exception as e: | |
| return [] | |
| def get_table_regions(pdf_path: str) -> Dict[int, List[tuple]]: | |
| """Get bounding boxes using BOTH lattice and stream methods.""" | |
| table_regions = {} | |
| try: | |
| lattice_tables = camelot.read_pdf(pdf_path, pages='all', flavor='lattice', suppress_stdout=True) | |
| stream_tables = camelot.read_pdf(pdf_path, pages='all', flavor='stream', suppress_stdout=True) | |
| all_tables = list(lattice_tables) + list(stream_tables) | |
| for table in all_tables: | |
| page = table.page | |
| if is_table_of_contents_header(table.df, page): | |
| continue | |
| bbox = table._bbox | |
| if page not in table_regions: | |
| table_regions[page] = [] | |
| if bbox not in table_regions[page]: | |
| table_regions[page].append(bbox) | |
| except Exception as e: | |
| pass | |
| return table_regions | |
| # ======================================== | |
| # πΌοΈ IMAGE EXTRACTION | |
| # ======================================== | |
| def extract_images_from_pdf(pdf_path: str, output_dir: str) -> List[Dict]: | |
| """Extract images from PDF.""" | |
| doc = fitz.open(pdf_path) | |
| image_data = [] | |
| for page_num in range(len(doc)): | |
| page = doc[page_num] | |
| images = page.get_images() | |
| for img_index, img in enumerate(images): | |
| try: | |
| xref = img[0] | |
| base_image = doc.extract_image(xref) | |
| image_bytes = base_image["image"] | |
| if len(image_bytes) < 10000: | |
| continue | |
| image_filename = f"{Path(pdf_path).stem}_p{page_num+1}_img{img_index+1}.png" | |
| image_path = os.path.join(output_dir, image_filename) | |
| with open(image_path, "wb") as img_file: | |
| img_file.write(image_bytes) | |
| image_data.append({ | |
| "path": image_path, | |
| "page": page_num + 1, | |
| "source": os.path.basename(pdf_path), | |
| "type": "image" | |
| }) | |
| except Exception as e: | |
| continue | |
| doc.close() | |
| return image_data | |
| # ======================================== | |
| # π TEXT EXTRACTION WITH OVERLAPPING CHUNKS | |
| # ======================================== | |
| def is_bold_text(span): | |
| return "bold" in span['font'].lower() or (span['flags'] & 2**4) | |
| def is_likely_heading(text, font_size, is_bold, avg_font_size): | |
| if not is_bold: | |
| return False | |
| text = text.strip() | |
| if len(text) > 100 or len(text) < 3: | |
| return False | |
| if font_size > avg_font_size * 1.1: | |
| return True | |
| if text.isupper() or re.match(r'^\d+\.?\d*\s+[A-Z]', text): | |
| return True | |
| return False | |
| def is_inside_table(block_bbox, table_bboxes): | |
| """Check if text block overlaps with table region.""" | |
| bx1, by1, bx2, by2 = block_bbox | |
| for table_bbox in table_bboxes: | |
| tx1, ty1, tx2, ty2 = table_bbox | |
| if not (bx2 < tx1 or bx1 > tx2 or by2 < ty1 or by1 > ty2): | |
| return True | |
| return False | |
| def split_text_with_overlap(text: str, heading: str, source: str, page: int, | |
| chunk_size: int = CHUNK_SIZE, overlap: int = OVERLAP) -> List[Document]: | |
| """Split text with overlap and heading context.""" | |
| words = text.split() | |
| if len(words) <= chunk_size: | |
| # β ADD HEADING CONTEXT | |
| content_with_context = f"Section: {heading}\n\n{text}" | |
| return [Document( | |
| page_content=content_with_context, | |
| metadata={ | |
| "source": source, | |
| "page": page, | |
| "heading": heading, | |
| "type": "text", | |
| "parent_text": text, | |
| "chunk_index": 0, | |
| "total_chunks": 1 | |
| } | |
| )] | |
| chunks = [] | |
| chunk_index = 0 | |
| for i in range(0, len(words), chunk_size - overlap): | |
| chunk_words = words[i:i + chunk_size] | |
| if len(chunk_words) < MIN_CHUNK_SIZE and len(chunks) > 0: | |
| break | |
| chunk_text = " ".join(chunk_words) | |
| # β ADD HEADING CONTEXT TO EACH CHUNK | |
| content_with_context = f"Section: {heading}\n\n{chunk_text}" | |
| chunks.append(Document( | |
| page_content=content_with_context, | |
| metadata={ | |
| "source": source, | |
| "page": page, | |
| "heading": heading, | |
| "type": "text", | |
| "parent_text": text, | |
| "chunk_index": chunk_index, | |
| "start_word": i, | |
| "end_word": i + len(chunk_words) | |
| } | |
| )) | |
| chunk_index += 1 | |
| for chunk in chunks: | |
| chunk.metadata["total_chunks"] = len(chunks) | |
| return chunks | |
| def extract_text_chunks_with_overlap(pdf_path: str, table_regions: Dict[int, List[tuple]]) -> List[Document]: | |
| """Extract text with overlapping chunks.""" | |
| doc = fitz.open(pdf_path) | |
| all_font_sizes = [] | |
| for page_num in range(len(doc)): | |
| page = doc[page_num] | |
| blocks = page.get_text("dict")["blocks"] | |
| for block in blocks: | |
| if "lines" in block: | |
| for line in block["lines"]: | |
| for span in line["spans"]: | |
| all_font_sizes.append(span["size"]) | |
| avg_font_size = sum(all_font_sizes) / len(all_font_sizes) if all_font_sizes else 12 | |
| sections = [] | |
| current_section = "" | |
| current_heading = "Introduction" | |
| current_page = 1 | |
| for page_num in range(len(doc)): | |
| page = doc[page_num] | |
| blocks = page.get_text("dict")["blocks"] | |
| page_tables = table_regions.get(page_num + 1, []) | |
| for block in blocks: | |
| if "lines" not in block: | |
| continue | |
| block_bbox = block.get("bbox", (0, 0, 0, 0)) | |
| if is_inside_table(block_bbox, page_tables): | |
| continue | |
| for line in block["lines"]: | |
| line_text = "" | |
| line_is_bold = False | |
| line_font_size = 0 | |
| for span in line["spans"]: | |
| line_text += span["text"] | |
| if is_bold_text(span): | |
| line_is_bold = True | |
| line_font_size = max(line_font_size, span["size"]) | |
| line_text = line_text.strip() | |
| if not line_text: | |
| continue | |
| if is_likely_heading(line_text, line_font_size, line_is_bold, avg_font_size): | |
| if current_section.strip(): | |
| sections.append({ | |
| "text": current_section.strip(), | |
| "heading": current_heading, | |
| "page": current_page, | |
| "source": os.path.basename(pdf_path) | |
| }) | |
| current_heading = line_text | |
| current_section = "" | |
| current_page = page_num + 1 | |
| else: | |
| current_section += line_text + " " | |
| if current_section.strip(): | |
| sections.append({ | |
| "text": current_section.strip(), | |
| "heading": current_heading, | |
| "page": current_page, | |
| "source": os.path.basename(pdf_path) | |
| }) | |
| doc.close() | |
| all_chunks = [] | |
| for section in sections: | |
| chunks = split_text_with_overlap( | |
| text=section['text'], | |
| heading=section['heading'], | |
| source=section['source'], | |
| page=section['page'], | |
| chunk_size=CHUNK_SIZE, | |
| overlap=OVERLAP | |
| ) | |
| all_chunks.extend(chunks) | |
| return all_chunks | |
| # ======================================== | |
| # π COMBINED EXTRACTION | |
| # ======================================== | |
| def extract_all_content_from_pdf(pdf_path: str) -> Tuple[List[Document], List[Dict]]: | |
| """Extract text, tables, and images.""" | |
| print(f" π Extracting tables...") | |
| table_regions = get_table_regions(pdf_path) | |
| table_chunks = extract_tables_from_pdf(pdf_path) | |
| print(f" β {len(table_chunks)} table chunks") | |
| print(f" π Extracting text...") | |
| text_chunks = extract_text_chunks_with_overlap(pdf_path, table_regions) | |
| print(f" β {len(text_chunks)} text chunks") | |
| print(f" πΌοΈ Extracting images...") | |
| images = extract_images_from_pdf(pdf_path, IMAGE_OUTPUT_DIR) | |
| print(f" β {len(images)} images") | |
| all_chunks = text_chunks + table_chunks | |
| return all_chunks, images | |
| # ======================================== | |
| # ποΈ BUILD FAISS INDEX WITH STREAMING | |
| # ======================================== | |
| # Replace the HybridRetriever class and related functions with this optimized version: | |
| # ======================================== | |
| # ποΈ BUILD FAISS INDEX WITH BM25 | |
| # ======================================== | |
| def build_multimodal_faiss_streaming(pdf_files: List[str], embedding_model: VLM2VecEmbeddings): | |
| """Build FAISS index with streaming and BM25.""" | |
| index_hash_file = f"{FAISS_INDEX_PATH}/index_hash.txt" | |
| current_hash = hashlib.md5("".join(sorted(pdf_files)).encode()).hexdigest() | |
| if os.path.exists(index_hash_file): | |
| with open(index_hash_file, 'r') as f: | |
| existing_hash = f.read().strip() | |
| if existing_hash == current_hash: | |
| print("β οΈ Index already exists for these PDFs!") | |
| response = input(" Rebuild anyway? (yes/no): ").strip().lower() | |
| if response != 'yes': | |
| return None, [] | |
| all_texts = [] | |
| all_image_paths = [] | |
| print("\nπ Processing PDFs...\n") | |
| for pdf_file in pdf_files: | |
| print(f"π Processing: {Path(pdf_file).name}") | |
| try: | |
| text_chunks, images = extract_all_content_from_pdf(pdf_file) | |
| all_texts.extend(text_chunks) | |
| all_image_paths.extend(images) | |
| except Exception as e: | |
| print(f" β Error: {e}") | |
| continue | |
| print() | |
| print(f"β Total chunks: {len(all_texts)}") | |
| print(f"β Total images: {len(all_image_paths)}\n") | |
| if len(all_texts) == 0: | |
| print("β No content extracted!") | |
| return None, [] | |
| # Build text index | |
| print("π Generating text embeddings...\n") | |
| text_index = None | |
| batch_size = 10 | |
| for i in range(0, len(all_texts), batch_size): | |
| batch = all_texts[i:i+batch_size] | |
| batch_contents = [doc.page_content for doc in batch] | |
| try: | |
| batch_embeddings = embedding_model.embed_documents(batch_contents, add_instruction=True) | |
| batch_embeddings_np = np.array(batch_embeddings).astype('float32') | |
| if text_index is None: | |
| dimension = batch_embeddings_np.shape[1] | |
| text_index = faiss.IndexFlatIP(dimension) | |
| print(f" Text embedding dimension: {dimension}") | |
| faiss.normalize_L2(batch_embeddings_np) | |
| text_index.add(batch_embeddings_np) | |
| if (i // batch_size + 1) % 5 == 0: | |
| print(f" Progress: {i + len(batch)}/{len(all_texts)}") | |
| except Exception as e: | |
| print(f" β Error: {e}") | |
| raise | |
| print(f" β Complete") | |
| # Save FAISS index | |
| faiss.write_index(text_index, f"{FAISS_INDEX_PATH}/text_index.faiss") | |
| # Save documents | |
| with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "wb") as f: | |
| pickle.dump(all_texts, f) | |
| # β BUILD AND SAVE BM25 INDEX | |
| print("\nπ Building BM25 index for keyword search...") | |
| tokenized_docs = [doc.page_content.lower().split() for doc in all_texts] | |
| bm25_index = BM25Okapi(tokenized_docs,k1=1.3, b=0.65) | |
| with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "wb") as f: | |
| pickle.dump(bm25_index, f) | |
| print(" β BM25 index saved") | |
| # Build image index | |
| if len(all_image_paths) > 0: | |
| print(f"\nπΌοΈ Embedding images...") | |
| image_index = None | |
| successful_images = [] | |
| for idx, img_data in enumerate(all_image_paths): | |
| img_embedding = embedding_model.embed_image(img_data["path"]) | |
| if img_embedding is None: | |
| continue | |
| img_embedding_np = np.array([img_embedding]).astype('float32') | |
| if image_index is None: | |
| dimension = img_embedding_np.shape[1] | |
| image_index = faiss.IndexFlatIP(dimension) | |
| print(f" Image dimension: {dimension}") | |
| faiss.normalize_L2(img_embedding_np) | |
| image_index.add(img_embedding_np) | |
| successful_images.append(img_data) | |
| if (len(successful_images)) % 10 == 0: | |
| print(f" Progress: {len(successful_images)}/{len(all_image_paths)}") | |
| print(f" β {len(successful_images)} images embedded") | |
| if image_index is not None and len(successful_images) > 0: | |
| faiss.write_index(image_index, f"{FAISS_INDEX_PATH}/image_index.faiss") | |
| with open(f"{FAISS_INDEX_PATH}/image_documents.pkl", "wb") as f: | |
| pickle.dump(successful_images, f) | |
| # Save hash | |
| with open(index_hash_file, 'w') as f: | |
| f.write(current_hash) | |
| print(f"\nβ Index saved: {FAISS_INDEX_PATH}\n") | |
| return text_index, all_texts | |
| # ======================================== | |
| # π OPTIMIZED HYBRID SEARCH | |
| # ======================================== | |
| # ======================================== | |
| # π QUERY WITH BM25 ONLY | |
| # ======================================== | |
| def query_with_bm25(query: str, k_text: int = 5, k_images: int = 3): | |
| """Query using BM25 keyword search only.""" | |
| # β PREPROCESS QUERY | |
| processed_query = preprocess_query(query) | |
| print(f" π Processed: {processed_query}") | |
| # Load documents | |
| with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f: | |
| text_docs = pickle.load(f) | |
| # β LOAD BM25 INDEX | |
| try: | |
| with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "rb") as f: | |
| bm25_index = pickle.load(f) | |
| except FileNotFoundError: | |
| print(" β οΈ BM25 index not found, building on-the-fly...") | |
| tokenized_docs = [doc.page_content.lower().split() for doc in text_docs] | |
| bm25_index = BM25Okapi(tokenized_docs) | |
| # BM25 SEARCH ONLY | |
| tokenized_query = processed_query.lower().split() | |
| bm25_scores = bm25_index.get_scores(tokenized_query) | |
| # Get top k results | |
| top_indices = np.argsort(bm25_scores)[::-1][:k_text] | |
| text_results = [] | |
| relevant_pages = set() | |
| for rank, idx in enumerate(top_indices, 1): | |
| doc = text_docs[idx] | |
| score = float(bm25_scores[idx]) | |
| text_results.append({ | |
| "document": doc, | |
| "score": score, | |
| "rank": rank, | |
| "type": doc.metadata.get('type', 'text') | |
| }) | |
| relevant_pages.add((doc.metadata.get('source'), doc.metadata.get('page'))) | |
| # Get images from relevant pages (not semantic search) | |
| relevant_images = [] | |
| try: | |
| image_docs_path = f"{FAISS_INDEX_PATH}/image_documents.pkl" | |
| if os.path.exists(image_docs_path): | |
| with open(image_docs_path, "rb") as f: | |
| image_docs = pickle.load(f) | |
| # Get images from same pages as top text results | |
| for img_doc in image_docs: | |
| img_page = (img_doc['source'], img_doc['page']) | |
| if img_page in relevant_pages and len(relevant_images) < k_images: | |
| relevant_images.append({ | |
| "path": img_doc['path'], | |
| "source": img_doc['source'], | |
| "page": img_doc['page'], | |
| "type": "image", | |
| "score": 0.0, | |
| "rank": len(relevant_images) + 1, | |
| "from_page": True | |
| }) | |
| except Exception as e: | |
| pass | |
| return { | |
| "text_results": text_results, | |
| "images": relevant_images, | |
| "query": query, | |
| "processed_query": processed_query | |
| } | |
| # ======================================== | |
| # π DISPLAY RESULTS (BM25 ONLY) | |
| # ======================================== | |
| def display_results_bm25(results: Dict): | |
| """Display BM25 results.""" | |
| print("\nπ TOP RESULTS (BM25 Keyword Search):\n") | |
| for result in results['text_results']: | |
| doc = result["document"] | |
| print(f"[{result['rank']}] BM25 Score: {result['score']:.4f} | {doc.metadata.get('type', 'N/A')}") | |
| print(f" π {doc.metadata.get('source')} - Page {doc.metadata.get('page')}") | |
| print(f" π {doc.metadata.get('heading', 'N/A')[:60]}") | |
| if 'total_chunks' in doc.metadata and doc.metadata.get('total_chunks', 1) > 1: | |
| print(f" π Chunk {doc.metadata.get('chunk_index', 0)+1}/{doc.metadata.get('total_chunks')}") | |
| print(f" π {doc.page_content[:200]}...") | |
| print() | |
| print("\nπΌοΈ IMAGES:\n") | |
| if results['images']: | |
| for img in results['images']: | |
| print(f"[{img['rank']}] {img['source']} - Page {img['page']}") | |
| print(f" {img['path']}\n") | |
| else: | |
| print(" No images found\n") | |
| # ======================================== | |
| # π HYBRID SEARCH IMPLEMENTATION | |
| # ======================================== | |
| def normalize_scores(scores: np.ndarray) -> np.ndarray: | |
| """Min-max normalization to 0-1 range.""" | |
| if len(scores) == 0: | |
| return scores | |
| min_score = np.min(scores) | |
| max_score = np.max(scores) | |
| if max_score == min_score: | |
| return np.ones_like(scores) | |
| return (scores - min_score) / (max_score - min_score) | |
| def query_with_hybrid(query: str, embedding_model: VLM2VecEmbeddings, | |
| k_text: int = 5, k_images: int = 3, | |
| dense_weight: float = DENSE_WEIGHT, | |
| sparse_weight: float = SPARSE_WEIGHT): | |
| """ | |
| Hybrid search combining semantic (FAISS) and keyword (BM25) retrieval. | |
| """ | |
| processed_query = preprocess_query(query) | |
| print(f" π Processed: {processed_query}") | |
| with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f: | |
| text_docs = pickle.load(f) | |
| # SEMANTIC SEARCH | |
| print(f" π§ Running semantic search...") | |
| try: | |
| text_index = faiss.read_index(f"{FAISS_INDEX_PATH}/text_index.faiss") | |
| query_embedding = embedding_model.embed_query(processed_query) | |
| query_np = np.array([query_embedding]).astype('float32') | |
| faiss.normalize_L2(query_np) | |
| k_retrieve = min(k_text * 3, len(text_docs)) | |
| distances, indices = text_index.search(query_np, k_retrieve) | |
| semantic_scores = distances[0] | |
| semantic_indices = indices[0] | |
| print(f" β Retrieved {len(semantic_indices)} semantic results") | |
| except Exception as e: | |
| print(f" β οΈ Semantic search failed: {e}") | |
| semantic_scores = np.array([]) | |
| semantic_indices = np.array([]) | |
| # BM25 SEARCH | |
| print(f" π€ Running BM25 keyword search...") | |
| try: | |
| with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "rb") as f: | |
| bm25_index = pickle.load(f) | |
| except FileNotFoundError: | |
| tokenized_docs = [doc.page_content.lower().split() for doc in text_docs] | |
| bm25_index = BM25Okapi(tokenized_docs, k1=1.3, b=0.65) | |
| tokenized_query = processed_query.lower().split() | |
| bm25_scores_all = bm25_index.get_scores(tokenized_query) | |
| print(f" β Scored {len(bm25_scores_all)} documents") | |
| # SCORE FUSION | |
| print(f" βοΈ Fusing scores (semantic: {dense_weight}, BM25: {sparse_weight})...") | |
| combined_scores = {} | |
| if len(semantic_scores) > 0: | |
| semantic_scores_norm = normalize_scores(semantic_scores) | |
| for idx, score in zip(semantic_indices, semantic_scores_norm): | |
| if idx < len(text_docs): | |
| combined_scores[idx] = dense_weight * score | |
| bm25_scores_norm = normalize_scores(bm25_scores_all) | |
| for idx, score in enumerate(bm25_scores_norm): | |
| if idx in combined_scores: | |
| combined_scores[idx] += sparse_weight * score | |
| else: | |
| combined_scores[idx] = sparse_weight * score | |
| sorted_indices = sorted(combined_scores.keys(), | |
| key=lambda x: combined_scores[x], | |
| reverse=True) | |
| top_indices = sorted_indices[:k_text] | |
| print(f" β Top {len(top_indices)} results selected") | |
| # PREPARE RESULTS | |
| text_results = [] | |
| relevant_pages = set() | |
| for rank, idx in enumerate(top_indices, 1): | |
| doc = text_docs[idx] | |
| semantic_score = semantic_scores_norm[np.where(semantic_indices == idx)[0][0]] if idx in semantic_indices else 0.0 | |
| bm25_score = bm25_scores_norm[idx] | |
| combined_score = combined_scores[idx] | |
| text_results.append({ | |
| "document": doc, | |
| "score": combined_score, | |
| "semantic_score": float(semantic_score), | |
| "bm25_score": float(bm25_score), | |
| "rank": rank, | |
| "type": doc.metadata.get('type', 'text') | |
| }) | |
| relevant_pages.add((doc.metadata.get('source'), doc.metadata.get('page'))) | |
| # GET IMAGES | |
| relevant_images = [] | |
| try: | |
| image_docs_path = f"{FAISS_INDEX_PATH}/image_documents.pkl" | |
| if os.path.exists(image_docs_path): | |
| with open(image_docs_path, "rb") as f: | |
| image_docs = pickle.load(f) | |
| for img_doc in image_docs: | |
| img_page = (img_doc['source'], img_doc['page']) | |
| if img_page in relevant_pages and len(relevant_images) < k_images: | |
| relevant_images.append({ | |
| "path": img_doc['path'], | |
| "source": img_doc['source'], | |
| "page": img_doc['page'], | |
| "type": "image", | |
| "score": 0.0, | |
| "rank": len(relevant_images) + 1, | |
| "from_page": True | |
| }) | |
| except Exception as e: | |
| pass | |
| return { | |
| "text_results": text_results, | |
| "images": relevant_images, | |
| "query": query, | |
| "processed_query": processed_query, | |
| "method": "hybrid" | |
| } | |
| def display_results_hybrid(results: Dict): | |
| """Display hybrid search results.""" | |
| print("\nπ TOP RESULTS (Hybrid Search: Semantic + BM25):\n") | |
| for result in results['text_results']: | |
| doc = result["document"] | |
| print(f"[{result['rank']}] Combined: {result['score']:.4f} " | |
| f"(Semantic: {result['semantic_score']:.4f}, BM25: {result['bm25_score']:.4f}) " | |
| f"| {doc.metadata.get('type', 'N/A')}") | |
| print(f" π {doc.metadata.get('source')} - Page {doc.metadata.get('page')}") | |
| print(f" π {doc.metadata.get('heading', 'N/A')[:60]}") | |
| if 'total_chunks' in doc.metadata and doc.metadata.get('total_chunks', 1) > 1: | |
| print(f" π Chunk {doc.metadata.get('chunk_index', 0)+1}/{doc.metadata.get('total_chunks')}") | |
| print(f" π {doc.page_content[:200]}...") | |
| print() | |
| print("\nπΌοΈ IMAGES:\n") | |
| if results['images']: | |
| for img in results['images']: | |
| print(f"[{img['rank']}] {img['source']} - Page {img['page']}") | |
| print(f" {img['path']}\n") | |
| else: | |
| print(" No images found\n") | |
| # ======================================== | |
| # π GET CONTEXT WITH PARENTS | |
| # ======================================== | |
| def get_context_with_parents(results: Dict) -> List[Dict]: | |
| """Extract full parent contexts.""" | |
| seen_parents = set() | |
| contexts = [] | |
| for result in results['text_results']: | |
| doc = result['document'] | |
| parent = doc.metadata.get('parent_text') | |
| if parent and parent not in seen_parents: | |
| contexts.append({ | |
| "text": parent, | |
| "source": doc.metadata['source'], | |
| "page": doc.metadata['page'], | |
| "heading": doc.metadata['heading'], | |
| "type": doc.metadata.get('type', 'text'), | |
| "is_parent": True | |
| }) | |
| seen_parents.add(parent) | |
| elif not parent: | |
| contexts.append({ | |
| "text": doc.page_content, | |
| "source": doc.metadata['source'], | |
| "page": doc.metadata['page'], | |
| "heading": doc.metadata['heading'], | |
| "type": doc.metadata.get('type', 'text'), | |
| "is_parent": False | |
| }) | |
| return contexts | |
| # ======================================== | |
| # π MAIN EXECUTION (UPDATED FOR HYBRID) | |
| # ======================================== | |
| if __name__ == "__main__": | |
| print("="*70) | |
| print("π RAG with HYBRID SEARCH (Semantic + BM25)") | |
| print("="*70 + "\n") | |
| pdf_files = glob.glob(f"{PDF_DIR}/*.pdf") | |
| print(f"π Found {len(pdf_files)} PDF files\n") | |
| if len(pdf_files) == 0: | |
| print("β No PDFs found!") | |
| exit(1) | |
| print("\nπ€ Loading VLM2Vec model...") | |
| embedding_model = VLM2VecEmbeddings( | |
| model_name="TIGER-Lab/VLM2Vec-Qwen2VL-2B", | |
| cache_dir=MODEL_CACHE_DIR | |
| ) | |
| # Load or build index | |
| if os.path.exists(f"{FAISS_INDEX_PATH}/text_index.faiss"): | |
| print(f"β Loading existing index\n") | |
| if not os.path.exists(f"{FAISS_INDEX_PATH}/bm25_index.pkl"): | |
| print("β οΈ BM25 index missing, building now...") | |
| with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f: | |
| all_texts = pickle.load(f) | |
| print(" Building BM25 index...") | |
| tokenized_docs = [doc.page_content.lower().split() for doc in all_texts] | |
| bm25_index = BM25Okapi(tokenized_docs, k1=1.3, b=0.65) | |
| with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "wb") as f: | |
| pickle.dump(bm25_index, f) | |
| print(" β BM25 index saved\n") | |
| else: | |
| print("π¨ Building new index...\n") | |
| embedding_model = VLM2VecEmbeddings( | |
| model_name="TIGER-Lab/VLM2Vec-Qwen2VL-2B", | |
| cache_dir=MODEL_CACHE_DIR | |
| ) | |
| index, documents = build_multimodal_faiss_streaming(pdf_files, embedding_model) | |
| if index is None: | |
| exit(0) | |
| # Interactive testing | |
| print("="*70) | |
| print("π§ͺ TESTING MODE - HYBRID SEARCH") | |
| print(f" Weights: Semantic {DENSE_WEIGHT} | BM25 {SPARSE_WEIGHT}") | |
| print("="*70 + "\n") | |
| test_queries = [ | |
| "What is the higher and lower explosive limit of butane?", | |
| "What are the precautions taken while handling H2S?", | |
| "What are the Personal Protection used for Sulfolane?", | |
| "What is the Composition of Platforming Feed and Product?", | |
| "Explain Dual function platforming catalyst chemistry.", | |
| "Steps to be followed in Amine Regeneration Unit for normal shutdown process.", | |
| "Could you tell me what De-greasing of Amine System in pre startup wash", | |
| ] | |
| print("π SUGGESTED QUERIES:") | |
| for i, q in enumerate(test_queries, 1): | |
| print(f" {i}. {q}") | |
| print() | |
| print("π‘ Type 'mode' to switch between hybrid/bm25/semantic") | |
| print() | |
| current_mode = "hybrid" | |
| while True: | |
| user_query = input(f"π¬ Query [{current_mode}] (or 1-5, 'mode', or 'exit'): ").strip() | |
| if user_query.lower() == 'exit': | |
| print("\nβ Done!") | |
| break | |
| if user_query.lower() == 'mode': | |
| print("\nπ Select mode:") | |
| print(" 1. Hybrid (Semantic + BM25)") | |
| print(" 2. BM25 only") | |
| print(" 3. Semantic only") | |
| mode_choice = input(" Choice (1-3): ").strip() | |
| if mode_choice == '1': | |
| current_mode = "hybrid" | |
| elif mode_choice == '2': | |
| current_mode = "bm25" | |
| elif mode_choice == '3': | |
| current_mode = "semantic" | |
| print(f" β Mode set to: {current_mode}\n") | |
| continue | |
| if user_query.isdigit() and 1 <= int(user_query) <= len(test_queries): | |
| user_query = test_queries[int(user_query) - 1] | |
| if not user_query: | |
| continue | |
| print(f"\n{'='*60}") | |
| print(f"π Query: {user_query}") | |
| print(f"π§ Mode: {current_mode.upper()}") | |
| print(f"{'='*60}\n") | |
| try: | |
| if current_mode == "hybrid": | |
| results = query_with_hybrid(user_query, embedding_model, k_text=5, k_images=3) | |
| display_results_hybrid(results) | |
| elif current_mode == "bm25": | |
| results = query_with_bm25(user_query, k_text=5, k_images=3) | |
| display_results_bm25(results) | |
| else: # semantic only | |
| results = query_with_hybrid(user_query, embedding_model, k_text=5, k_images=3, | |
| dense_weight=1.0, sparse_weight=0.0) | |
| display_results_hybrid(results) | |
| print("\nπ FULL CONTEXT:\n") | |
| contexts = get_context_with_parents(results) | |
| for i, ctx in enumerate(contexts[:3], 1): | |
| print(f"[{i}] {ctx['heading'][:50]}") | |
| if ctx['is_parent']: | |
| print(f" β Full section") | |
| print(f" {ctx['text'][:300]}...\n") | |
| print("="*60 + "\n") | |
| except Exception as e: | |
| print(f"\nβ Error: {e}\n") | |
| import traceback | |
| traceback.print_exc() | |