BeRu / vlm2rag2.py
BeRU Deployer
Deploy BeRU Streamlit RAG System - Add app, models logic, configs, and optimizations for HF Spaces
dec533d
import glob
import os
import gc
import time
import re
import hashlib
from pathlib import Path
from typing import List, Dict, Tuple, Optional
import fitz # PyMuPDF
import torch
import numpy as np
from PIL import Image
from transformers import AutoModel, AutoProcessor, AutoTokenizer # Changed from AutoModelForCausalLM
from langchain_core.documents import Document
import pickle
from numpy.linalg import norm
import camelot
import base64
import pytesseract
from pdf2image import convert_from_path
import faiss
from rank_bm25 import BM25Okapi
# ========================================
# πŸ“‚ CONFIGURATION
# ========================================
PDF_DIR = r"D:\BeRU\testing"
FAISS_INDEX_PATH = "VLM2Vec-V2rag2"
MODEL_CACHE_DIR = ".cache"
IMAGE_OUTPUT_DIR = "extracted_images2"
# Chunking configuration
CHUNK_SIZE = 450 # words
OVERLAP = 100 # words
MIN_CHUNK_SIZE = 50
MAX_CHUNK_SIZE = 800
# Instruction prefixes for better embeddings
DOCUMENT_INSTRUCTION = "Represent this technical document for semantic search: "
QUERY_INSTRUCTION = "Represent this question for finding relevant technical information: "
# Hybrid search weights
DENSE_WEIGHT = 0.4 # Weight for semantic search
SPARSE_WEIGHT = 0.6 # Weight for keyword search
# Create directories
os.makedirs(PDF_DIR, exist_ok=True)
os.makedirs(FAISS_INDEX_PATH, exist_ok=True)
os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
os.makedirs(IMAGE_OUTPUT_DIR, exist_ok=True)
# ========================================
# πŸ€– VLM2Vec-V2 WRAPPER (ENHANCED)
# ========================================
class VLM2VecEmbeddings:
"""VLM2Vec-V2 embedding class with instruction prefixes."""
def __init__(self, model_name: str = "TIGER-Lab/VLM2Vec-Qwen2VL-2B", cache_dir: str = None):
print(f"πŸ€– Loading VLM2Vec-V2 model: {model_name}")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f" Device: {self.device}")
try:
self.model = AutoModel.from_pretrained(
model_name,
cache_dir=cache_dir,
trust_remote_code=True,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
).to(self.device)
self.processor = AutoProcessor.from_pretrained(
model_name,
cache_dir=cache_dir,
trust_remote_code=True
)
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
cache_dir=cache_dir,
trust_remote_code=True
)
self.model.eval()
# Get actual embedding dimension
test_input = self.tokenizer("test", return_tensors="pt").to(self.device)
with torch.no_grad():
test_output = self.model(**test_input, output_hidden_states=True)
self.embedding_dim = test_output.hidden_states[-1].shape[-1]
print(f" Embedding dimension: {self.embedding_dim}")
print("βœ… VLM2Vec-V2 loaded successfully\n")
except Exception as e:
print(f"❌ Error loading VLM2Vec-V2: {e}")
raise
def normalize_text(self, text: str) -> str:
"""Normalize text for better embeddings."""
# Remove excessive whitespace
text = re.sub(r'\s+', ' ', text)
# Remove page numbers
text = re.sub(r'Page \d+', '', text, flags=re.IGNORECASE)
# Normalize unicode
text = text.strip()
return text
def embed_documents(self, texts: List[str], add_instruction: bool = True) -> List[List[float]]:
"""Embed documents with instruction prefix and weighted mean pooling."""
embeddings = []
with torch.no_grad():
for text in texts:
try:
# βœ… NORMALIZE TEXT
clean_text = self.normalize_text(text)
# βœ… ADD INSTRUCTION PREFIX
if add_instruction:
prefixed_text = DOCUMENT_INSTRUCTION + clean_text
else:
prefixed_text = clean_text
inputs = self.tokenizer(
prefixed_text,
return_tensors="pt",
padding=True,
truncation=True,
max_length=min(self.tokenizer.model_max_length or 512, 2048)
).to(self.device)
outputs = self.model(**inputs, output_hidden_states=True)
if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
# βœ… WEIGHTED MEAN POOLING (ignores padding)
hidden_states = outputs.hidden_states[-1]
attention_mask = inputs['attention_mask'].unsqueeze(-1).float()
# Apply attention mask as weights
weighted_hidden_states = hidden_states * attention_mask
sum_embeddings = weighted_hidden_states.sum(dim=1)
sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
# Weighted mean
embedding = (sum_embeddings / sum_mask).squeeze()
else:
# Fallback to logits
attention_mask = inputs['attention_mask'].unsqueeze(-1).float()
weighted_logits = outputs.logits * attention_mask
sum_embeddings = weighted_logits.sum(dim=1)
sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
embedding = (sum_embeddings / sum_mask).squeeze()
embeddings.append(embedding.cpu().numpy().tolist())
except Exception as e:
print(f" ❌ CRITICAL: Failed to embed text: {e}")
print(f" Text preview: {text[:100]}")
raise RuntimeError(f"Embedding failed for text: {text[:50]}...") from e
return embeddings
def embed_query(self, text: str) -> List[float]:
"""Embed query with query-specific instruction."""
# βœ… DIFFERENT INSTRUCTION FOR QUERIES
clean_text = self.normalize_text(text)
prefixed_text = QUERY_INSTRUCTION + clean_text
# Don't add document instruction again
return self.embed_documents([prefixed_text], add_instruction=False)[0]
def embed_image(self, image_path: str, prompt: str = "Technical diagram") -> Optional[List[float]]:
"""Embed image with Qwen2-VL proper format."""
try:
with torch.no_grad():
image = Image.open(image_path).convert('RGB')
# βœ… QWEN2-VL CORRECT FORMAT
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt}
]
}
]
# Apply chat template
text = self.processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Process with both text and images
inputs = self.processor(
text=[text],
images=[image],
return_tensors="pt",
padding=True
).to(self.device)
outputs = self.model(**inputs, output_hidden_states=True)
if hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
hidden_states = outputs.hidden_states[-1]
# Use weighted mean pooling
if 'attention_mask' in inputs:
attention_mask = inputs['attention_mask'].unsqueeze(-1).float()
weighted_hidden_states = hidden_states * attention_mask
sum_embeddings = weighted_hidden_states.sum(dim=1)
sum_mask = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
embedding = (sum_embeddings / sum_mask).squeeze()
else:
embedding = hidden_states.mean(dim=1).squeeze()
else:
# Fallback to pooler output if available
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
embedding = outputs.pooler_output.squeeze()
else:
return None
return embedding.cpu().numpy().tolist()
except Exception as e:
print(f" ⚠️ Failed to embed image {Path(image_path).name}: {str(e)[:100]}")
return None
# ========================================
# πŸ” QUERY PREPROCESSING
# ========================================
def preprocess_query(query: str) -> str:
"""Preprocess query by expanding abbreviations."""
abbreviations = {
r'\bh2s\b': 'hydrogen sulfide',
r'\bppm\b': 'parts per million',
r'\bppe\b': 'personal protective equipment',
r'\bscba\b': 'self contained breathing apparatus',
r'\blel\b': 'lower explosive limit',
r'\bhel\b': 'higher explosive limit',
r'\buel\b': 'upper explosive limit'
}
query_lower = query.lower()
for abbr, full in abbreviations.items():
query_lower = re.sub(abbr, full, query_lower)
# Remove excessive punctuation
query_lower = re.sub(r'[?!]+$', '', query_lower)
# Clean extra spaces
query_lower = re.sub(r'\s+', ' ', query_lower).strip()
return query_lower
# ========================================
# πŸ“Š TABLE EXTRACTION
# ========================================
def is_table_of_contents_header(df, page_num):
"""Detect TOC by checking first row for keywords."""
if len(df) == 0 or page_num > 15:
return False
# Check first row (headers)
first_row = ' '.join(df.iloc[0].astype(str)).lower()
# TOC keywords in your images
toc_keywords = ['section', 'subsection', 'description', 'page no', 'page number', 'contents']
# If at least 2 keywords match, it's TOC
keyword_count = sum(1 for keyword in toc_keywords if keyword in first_row)
return keyword_count >= 2
def looks_like_toc_data(df):
"""Check if table data looks like TOC (section numbers + page numbers)."""
if len(df) < 2 or len(df.columns) < 2:
return False
# Check last column: should be mostly page numbers (182-246 range in your case)
last_col = df.iloc[1:, -1].astype(str) # Skip header row
numeric_count = sum(val.strip().isdigit() and 50 < int(val.strip()) < 300
for val in last_col if val.strip().isdigit())
if len(last_col) > 0 and numeric_count / len(last_col) > 0.7:
# Check first column: should have section numbers like "10.1", "10.2"
first_col = df.iloc[1:, 0].astype(str)
section_pattern = sum(1 for val in first_col
if re.match(r'^\d+\.?\d*$', val.strip()))
if section_pattern / len(first_col) > 0.5:
return True
return False
def extract_tables_from_pdf(pdf_path: str) -> List[Document]:
"""Extract bordered tables with smart TOC detection."""
chunks = []
try:
lattice_tables = camelot.read_pdf(
pdf_path,
pages='all',
flavor='lattice', # Only bordered tables
suppress_stdout=True
)
all_tables = list(lattice_tables)
seen_tables = set()
# Track TOC state
in_toc_section = False
toc_start_page = None
print(f" πŸ“Š Found {len(all_tables)} bordered tables")
for table in all_tables:
df = table.df
current_page = table.page
# Unique ID
table_id = (current_page, tuple(df.iloc[0].tolist()) if len(df) > 0 else ())
if table_id in seen_tables:
continue
seen_tables.add(table_id)
# Skip first 5 pages (title pages)
if current_page <= 5:
continue
# Basic validation
if len(df.columns) < 2 or len(df) < 3 or table.accuracy < 80:
continue
# βœ… Detect TOC start (page with header row)
if not in_toc_section and is_table_of_contents_header(df, current_page):
in_toc_section = True
toc_start_page = current_page
print(f" πŸ” TOC detected at page {current_page}")
continue
# βœ… If we're in TOC section, check if this continues the pattern
if in_toc_section:
if looks_like_toc_data(df):
print(f" ⏭️ Skipping TOC continuation on page {current_page}")
continue
else:
# TOC ended, resume normal extraction
print(f" βœ… TOC ended, found real table on page {current_page}")
in_toc_section = False
# Extract valid table
table_text = table_to_natural_language_enhanced(table)
if table_text.strip():
chunks.append(Document(
page_content=table_text,
metadata={
"source": os.path.basename(pdf_path),
"page": current_page,
"heading": "Table Data",
"type": "table",
"table_accuracy": table.accuracy
}
))
print(f" βœ… Extracted {len(chunks)} valid tables (after TOC filtering)")
except Exception as e:
print(f"⚠️ Table extraction failed: {e}")
finally:
try:
del lattice_tables
del all_tables
gc.collect()
time.sleep(0.1)
except:
pass
return chunks
def table_to_natural_language_enhanced(table) -> str:
"""Enhanced table-to-natural-language conversion."""
df = table.df
if len(df) < 2:
return ""
headers = [str(h).strip() for h in df.iloc[0].astype(str).tolist()]
headers = [h if h and h.lower() not in ['', 'nan', 'none'] else f"Column_{i}"
for i, h in enumerate(headers)]
descriptions = []
for idx in range(1, len(df)):
row = [str(cell).strip() for cell in df.iloc[idx].astype(str).tolist()]
if not any(cell and cell.lower() not in ['', 'nan', 'none'] for cell in row):
continue
if len(row) > 0 and row[0] and row[0].lower() not in ['', 'nan', 'none']:
sentence_parts = []
for i in range(1, min(len(row), len(headers))):
if row[i] and row[i].lower() not in ['', 'nan', 'none']:
sentence_parts.append(f"{headers[i]}: {row[i]}")
if sentence_parts:
descriptions.append(f"{row[0]} has {', '.join(sentence_parts)}.")
else:
descriptions.append(f"{row[0]}.")
return "\n".join(descriptions)
def extract_tables_with_ocr(pdf_path: str, page_num: int) -> List[Dict]:
"""OCR fallback for image-based PDFs."""
try:
images = convert_from_path(pdf_path, first_page=page_num, last_page=page_num)
if not images:
return []
ocr_text = pytesseract.image_to_string(images[0])
lines = ocr_text.split('\n')
table_lines = []
for line in lines:
if re.search(r'\s{2,}', line) or '\t' in line:
table_lines.append(line)
if len(table_lines) > 2:
return [{
"text": "\n".join(table_lines),
"page": page_num,
"method": "ocr"
}]
return []
except Exception as e:
return []
def get_table_regions(pdf_path: str) -> Dict[int, List[tuple]]:
"""Get bounding boxes using BOTH lattice and stream methods."""
table_regions = {}
try:
lattice_tables = camelot.read_pdf(pdf_path, pages='all', flavor='lattice', suppress_stdout=True)
stream_tables = camelot.read_pdf(pdf_path, pages='all', flavor='stream', suppress_stdout=True)
all_tables = list(lattice_tables) + list(stream_tables)
for table in all_tables:
page = table.page
if is_table_of_contents_header(table.df, page):
continue
bbox = table._bbox
if page not in table_regions:
table_regions[page] = []
if bbox not in table_regions[page]:
table_regions[page].append(bbox)
except Exception as e:
pass
return table_regions
# ========================================
# πŸ–ΌοΈ IMAGE EXTRACTION
# ========================================
def extract_images_from_pdf(pdf_path: str, output_dir: str) -> List[Dict]:
"""Extract images from PDF."""
doc = fitz.open(pdf_path)
image_data = []
for page_num in range(len(doc)):
page = doc[page_num]
images = page.get_images()
for img_index, img in enumerate(images):
try:
xref = img[0]
base_image = doc.extract_image(xref)
image_bytes = base_image["image"]
if len(image_bytes) < 10000:
continue
image_filename = f"{Path(pdf_path).stem}_p{page_num+1}_img{img_index+1}.png"
image_path = os.path.join(output_dir, image_filename)
with open(image_path, "wb") as img_file:
img_file.write(image_bytes)
image_data.append({
"path": image_path,
"page": page_num + 1,
"source": os.path.basename(pdf_path),
"type": "image"
})
except Exception as e:
continue
doc.close()
return image_data
# ========================================
# πŸ“„ TEXT EXTRACTION WITH OVERLAPPING CHUNKS
# ========================================
def is_bold_text(span):
return "bold" in span['font'].lower() or (span['flags'] & 2**4)
def is_likely_heading(text, font_size, is_bold, avg_font_size):
if not is_bold:
return False
text = text.strip()
if len(text) > 100 or len(text) < 3:
return False
if font_size > avg_font_size * 1.1:
return True
if text.isupper() or re.match(r'^\d+\.?\d*\s+[A-Z]', text):
return True
return False
def is_inside_table(block_bbox, table_bboxes):
"""Check if text block overlaps with table region."""
bx1, by1, bx2, by2 = block_bbox
for table_bbox in table_bboxes:
tx1, ty1, tx2, ty2 = table_bbox
if not (bx2 < tx1 or bx1 > tx2 or by2 < ty1 or by1 > ty2):
return True
return False
def split_text_with_overlap(text: str, heading: str, source: str, page: int,
chunk_size: int = CHUNK_SIZE, overlap: int = OVERLAP) -> List[Document]:
"""Split text with overlap and heading context."""
words = text.split()
if len(words) <= chunk_size:
# βœ… ADD HEADING CONTEXT
content_with_context = f"Section: {heading}\n\n{text}"
return [Document(
page_content=content_with_context,
metadata={
"source": source,
"page": page,
"heading": heading,
"type": "text",
"parent_text": text,
"chunk_index": 0,
"total_chunks": 1
}
)]
chunks = []
chunk_index = 0
for i in range(0, len(words), chunk_size - overlap):
chunk_words = words[i:i + chunk_size]
if len(chunk_words) < MIN_CHUNK_SIZE and len(chunks) > 0:
break
chunk_text = " ".join(chunk_words)
# βœ… ADD HEADING CONTEXT TO EACH CHUNK
content_with_context = f"Section: {heading}\n\n{chunk_text}"
chunks.append(Document(
page_content=content_with_context,
metadata={
"source": source,
"page": page,
"heading": heading,
"type": "text",
"parent_text": text,
"chunk_index": chunk_index,
"start_word": i,
"end_word": i + len(chunk_words)
}
))
chunk_index += 1
for chunk in chunks:
chunk.metadata["total_chunks"] = len(chunks)
return chunks
def extract_text_chunks_with_overlap(pdf_path: str, table_regions: Dict[int, List[tuple]]) -> List[Document]:
"""Extract text with overlapping chunks."""
doc = fitz.open(pdf_path)
all_font_sizes = []
for page_num in range(len(doc)):
page = doc[page_num]
blocks = page.get_text("dict")["blocks"]
for block in blocks:
if "lines" in block:
for line in block["lines"]:
for span in line["spans"]:
all_font_sizes.append(span["size"])
avg_font_size = sum(all_font_sizes) / len(all_font_sizes) if all_font_sizes else 12
sections = []
current_section = ""
current_heading = "Introduction"
current_page = 1
for page_num in range(len(doc)):
page = doc[page_num]
blocks = page.get_text("dict")["blocks"]
page_tables = table_regions.get(page_num + 1, [])
for block in blocks:
if "lines" not in block:
continue
block_bbox = block.get("bbox", (0, 0, 0, 0))
if is_inside_table(block_bbox, page_tables):
continue
for line in block["lines"]:
line_text = ""
line_is_bold = False
line_font_size = 0
for span in line["spans"]:
line_text += span["text"]
if is_bold_text(span):
line_is_bold = True
line_font_size = max(line_font_size, span["size"])
line_text = line_text.strip()
if not line_text:
continue
if is_likely_heading(line_text, line_font_size, line_is_bold, avg_font_size):
if current_section.strip():
sections.append({
"text": current_section.strip(),
"heading": current_heading,
"page": current_page,
"source": os.path.basename(pdf_path)
})
current_heading = line_text
current_section = ""
current_page = page_num + 1
else:
current_section += line_text + " "
if current_section.strip():
sections.append({
"text": current_section.strip(),
"heading": current_heading,
"page": current_page,
"source": os.path.basename(pdf_path)
})
doc.close()
all_chunks = []
for section in sections:
chunks = split_text_with_overlap(
text=section['text'],
heading=section['heading'],
source=section['source'],
page=section['page'],
chunk_size=CHUNK_SIZE,
overlap=OVERLAP
)
all_chunks.extend(chunks)
return all_chunks
# ========================================
# πŸ”„ COMBINED EXTRACTION
# ========================================
def extract_all_content_from_pdf(pdf_path: str) -> Tuple[List[Document], List[Dict]]:
"""Extract text, tables, and images."""
print(f" πŸ“Š Extracting tables...")
table_regions = get_table_regions(pdf_path)
table_chunks = extract_tables_from_pdf(pdf_path)
print(f" βœ… {len(table_chunks)} table chunks")
print(f" πŸ“„ Extracting text...")
text_chunks = extract_text_chunks_with_overlap(pdf_path, table_regions)
print(f" βœ… {len(text_chunks)} text chunks")
print(f" πŸ–ΌοΈ Extracting images...")
images = extract_images_from_pdf(pdf_path, IMAGE_OUTPUT_DIR)
print(f" βœ… {len(images)} images")
all_chunks = text_chunks + table_chunks
return all_chunks, images
# ========================================
# πŸ—οΈ BUILD FAISS INDEX WITH STREAMING
# ========================================
# Replace the HybridRetriever class and related functions with this optimized version:
# ========================================
# πŸ—οΈ BUILD FAISS INDEX WITH BM25
# ========================================
def build_multimodal_faiss_streaming(pdf_files: List[str], embedding_model: VLM2VecEmbeddings):
"""Build FAISS index with streaming and BM25."""
index_hash_file = f"{FAISS_INDEX_PATH}/index_hash.txt"
current_hash = hashlib.md5("".join(sorted(pdf_files)).encode()).hexdigest()
if os.path.exists(index_hash_file):
with open(index_hash_file, 'r') as f:
existing_hash = f.read().strip()
if existing_hash == current_hash:
print("⚠️ Index already exists for these PDFs!")
response = input(" Rebuild anyway? (yes/no): ").strip().lower()
if response != 'yes':
return None, []
all_texts = []
all_image_paths = []
print("\nπŸ“„ Processing PDFs...\n")
for pdf_file in pdf_files:
print(f"πŸ“– Processing: {Path(pdf_file).name}")
try:
text_chunks, images = extract_all_content_from_pdf(pdf_file)
all_texts.extend(text_chunks)
all_image_paths.extend(images)
except Exception as e:
print(f" ❌ Error: {e}")
continue
print()
print(f"βœ… Total chunks: {len(all_texts)}")
print(f"βœ… Total images: {len(all_image_paths)}\n")
if len(all_texts) == 0:
print("❌ No content extracted!")
return None, []
# Build text index
print("πŸ”— Generating text embeddings...\n")
text_index = None
batch_size = 10
for i in range(0, len(all_texts), batch_size):
batch = all_texts[i:i+batch_size]
batch_contents = [doc.page_content for doc in batch]
try:
batch_embeddings = embedding_model.embed_documents(batch_contents, add_instruction=True)
batch_embeddings_np = np.array(batch_embeddings).astype('float32')
if text_index is None:
dimension = batch_embeddings_np.shape[1]
text_index = faiss.IndexFlatIP(dimension)
print(f" Text embedding dimension: {dimension}")
faiss.normalize_L2(batch_embeddings_np)
text_index.add(batch_embeddings_np)
if (i // batch_size + 1) % 5 == 0:
print(f" Progress: {i + len(batch)}/{len(all_texts)}")
except Exception as e:
print(f" ❌ Error: {e}")
raise
print(f" βœ… Complete")
# Save FAISS index
faiss.write_index(text_index, f"{FAISS_INDEX_PATH}/text_index.faiss")
# Save documents
with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "wb") as f:
pickle.dump(all_texts, f)
# βœ… BUILD AND SAVE BM25 INDEX
print("\nπŸ” Building BM25 index for keyword search...")
tokenized_docs = [doc.page_content.lower().split() for doc in all_texts]
bm25_index = BM25Okapi(tokenized_docs,k1=1.3, b=0.65)
with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "wb") as f:
pickle.dump(bm25_index, f)
print(" βœ… BM25 index saved")
# Build image index
if len(all_image_paths) > 0:
print(f"\nπŸ–ΌοΈ Embedding images...")
image_index = None
successful_images = []
for idx, img_data in enumerate(all_image_paths):
img_embedding = embedding_model.embed_image(img_data["path"])
if img_embedding is None:
continue
img_embedding_np = np.array([img_embedding]).astype('float32')
if image_index is None:
dimension = img_embedding_np.shape[1]
image_index = faiss.IndexFlatIP(dimension)
print(f" Image dimension: {dimension}")
faiss.normalize_L2(img_embedding_np)
image_index.add(img_embedding_np)
successful_images.append(img_data)
if (len(successful_images)) % 10 == 0:
print(f" Progress: {len(successful_images)}/{len(all_image_paths)}")
print(f" βœ… {len(successful_images)} images embedded")
if image_index is not None and len(successful_images) > 0:
faiss.write_index(image_index, f"{FAISS_INDEX_PATH}/image_index.faiss")
with open(f"{FAISS_INDEX_PATH}/image_documents.pkl", "wb") as f:
pickle.dump(successful_images, f)
# Save hash
with open(index_hash_file, 'w') as f:
f.write(current_hash)
print(f"\nβœ… Index saved: {FAISS_INDEX_PATH}\n")
return text_index, all_texts
# ========================================
# πŸ” OPTIMIZED HYBRID SEARCH
# ========================================
# ========================================
# πŸ“Š QUERY WITH BM25 ONLY
# ========================================
def query_with_bm25(query: str, k_text: int = 5, k_images: int = 3):
"""Query using BM25 keyword search only."""
# βœ… PREPROCESS QUERY
processed_query = preprocess_query(query)
print(f" πŸ” Processed: {processed_query}")
# Load documents
with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f:
text_docs = pickle.load(f)
# βœ… LOAD BM25 INDEX
try:
with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "rb") as f:
bm25_index = pickle.load(f)
except FileNotFoundError:
print(" ⚠️ BM25 index not found, building on-the-fly...")
tokenized_docs = [doc.page_content.lower().split() for doc in text_docs]
bm25_index = BM25Okapi(tokenized_docs)
# BM25 SEARCH ONLY
tokenized_query = processed_query.lower().split()
bm25_scores = bm25_index.get_scores(tokenized_query)
# Get top k results
top_indices = np.argsort(bm25_scores)[::-1][:k_text]
text_results = []
relevant_pages = set()
for rank, idx in enumerate(top_indices, 1):
doc = text_docs[idx]
score = float(bm25_scores[idx])
text_results.append({
"document": doc,
"score": score,
"rank": rank,
"type": doc.metadata.get('type', 'text')
})
relevant_pages.add((doc.metadata.get('source'), doc.metadata.get('page')))
# Get images from relevant pages (not semantic search)
relevant_images = []
try:
image_docs_path = f"{FAISS_INDEX_PATH}/image_documents.pkl"
if os.path.exists(image_docs_path):
with open(image_docs_path, "rb") as f:
image_docs = pickle.load(f)
# Get images from same pages as top text results
for img_doc in image_docs:
img_page = (img_doc['source'], img_doc['page'])
if img_page in relevant_pages and len(relevant_images) < k_images:
relevant_images.append({
"path": img_doc['path'],
"source": img_doc['source'],
"page": img_doc['page'],
"type": "image",
"score": 0.0,
"rank": len(relevant_images) + 1,
"from_page": True
})
except Exception as e:
pass
return {
"text_results": text_results,
"images": relevant_images,
"query": query,
"processed_query": processed_query
}
# ========================================
# πŸ“Š DISPLAY RESULTS (BM25 ONLY)
# ========================================
def display_results_bm25(results: Dict):
"""Display BM25 results."""
print("\nπŸ“š TOP RESULTS (BM25 Keyword Search):\n")
for result in results['text_results']:
doc = result["document"]
print(f"[{result['rank']}] BM25 Score: {result['score']:.4f} | {doc.metadata.get('type', 'N/A')}")
print(f" πŸ“„ {doc.metadata.get('source')} - Page {doc.metadata.get('page')}")
print(f" πŸ“Œ {doc.metadata.get('heading', 'N/A')[:60]}")
if 'total_chunks' in doc.metadata and doc.metadata.get('total_chunks', 1) > 1:
print(f" πŸ”— Chunk {doc.metadata.get('chunk_index', 0)+1}/{doc.metadata.get('total_chunks')}")
print(f" πŸ“ {doc.page_content[:200]}...")
print()
print("\nπŸ–ΌοΈ IMAGES:\n")
if results['images']:
for img in results['images']:
print(f"[{img['rank']}] {img['source']} - Page {img['page']}")
print(f" {img['path']}\n")
else:
print(" No images found\n")
# ========================================
# πŸ” HYBRID SEARCH IMPLEMENTATION
# ========================================
def normalize_scores(scores: np.ndarray) -> np.ndarray:
"""Min-max normalization to 0-1 range."""
if len(scores) == 0:
return scores
min_score = np.min(scores)
max_score = np.max(scores)
if max_score == min_score:
return np.ones_like(scores)
return (scores - min_score) / (max_score - min_score)
def query_with_hybrid(query: str, embedding_model: VLM2VecEmbeddings,
k_text: int = 5, k_images: int = 3,
dense_weight: float = DENSE_WEIGHT,
sparse_weight: float = SPARSE_WEIGHT):
"""
Hybrid search combining semantic (FAISS) and keyword (BM25) retrieval.
"""
processed_query = preprocess_query(query)
print(f" πŸ” Processed: {processed_query}")
with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f:
text_docs = pickle.load(f)
# SEMANTIC SEARCH
print(f" 🧠 Running semantic search...")
try:
text_index = faiss.read_index(f"{FAISS_INDEX_PATH}/text_index.faiss")
query_embedding = embedding_model.embed_query(processed_query)
query_np = np.array([query_embedding]).astype('float32')
faiss.normalize_L2(query_np)
k_retrieve = min(k_text * 3, len(text_docs))
distances, indices = text_index.search(query_np, k_retrieve)
semantic_scores = distances[0]
semantic_indices = indices[0]
print(f" βœ… Retrieved {len(semantic_indices)} semantic results")
except Exception as e:
print(f" ⚠️ Semantic search failed: {e}")
semantic_scores = np.array([])
semantic_indices = np.array([])
# BM25 SEARCH
print(f" πŸ”€ Running BM25 keyword search...")
try:
with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "rb") as f:
bm25_index = pickle.load(f)
except FileNotFoundError:
tokenized_docs = [doc.page_content.lower().split() for doc in text_docs]
bm25_index = BM25Okapi(tokenized_docs, k1=1.3, b=0.65)
tokenized_query = processed_query.lower().split()
bm25_scores_all = bm25_index.get_scores(tokenized_query)
print(f" βœ… Scored {len(bm25_scores_all)} documents")
# SCORE FUSION
print(f" βš–οΈ Fusing scores (semantic: {dense_weight}, BM25: {sparse_weight})...")
combined_scores = {}
if len(semantic_scores) > 0:
semantic_scores_norm = normalize_scores(semantic_scores)
for idx, score in zip(semantic_indices, semantic_scores_norm):
if idx < len(text_docs):
combined_scores[idx] = dense_weight * score
bm25_scores_norm = normalize_scores(bm25_scores_all)
for idx, score in enumerate(bm25_scores_norm):
if idx in combined_scores:
combined_scores[idx] += sparse_weight * score
else:
combined_scores[idx] = sparse_weight * score
sorted_indices = sorted(combined_scores.keys(),
key=lambda x: combined_scores[x],
reverse=True)
top_indices = sorted_indices[:k_text]
print(f" βœ… Top {len(top_indices)} results selected")
# PREPARE RESULTS
text_results = []
relevant_pages = set()
for rank, idx in enumerate(top_indices, 1):
doc = text_docs[idx]
semantic_score = semantic_scores_norm[np.where(semantic_indices == idx)[0][0]] if idx in semantic_indices else 0.0
bm25_score = bm25_scores_norm[idx]
combined_score = combined_scores[idx]
text_results.append({
"document": doc,
"score": combined_score,
"semantic_score": float(semantic_score),
"bm25_score": float(bm25_score),
"rank": rank,
"type": doc.metadata.get('type', 'text')
})
relevant_pages.add((doc.metadata.get('source'), doc.metadata.get('page')))
# GET IMAGES
relevant_images = []
try:
image_docs_path = f"{FAISS_INDEX_PATH}/image_documents.pkl"
if os.path.exists(image_docs_path):
with open(image_docs_path, "rb") as f:
image_docs = pickle.load(f)
for img_doc in image_docs:
img_page = (img_doc['source'], img_doc['page'])
if img_page in relevant_pages and len(relevant_images) < k_images:
relevant_images.append({
"path": img_doc['path'],
"source": img_doc['source'],
"page": img_doc['page'],
"type": "image",
"score": 0.0,
"rank": len(relevant_images) + 1,
"from_page": True
})
except Exception as e:
pass
return {
"text_results": text_results,
"images": relevant_images,
"query": query,
"processed_query": processed_query,
"method": "hybrid"
}
def display_results_hybrid(results: Dict):
"""Display hybrid search results."""
print("\nπŸ“š TOP RESULTS (Hybrid Search: Semantic + BM25):\n")
for result in results['text_results']:
doc = result["document"]
print(f"[{result['rank']}] Combined: {result['score']:.4f} "
f"(Semantic: {result['semantic_score']:.4f}, BM25: {result['bm25_score']:.4f}) "
f"| {doc.metadata.get('type', 'N/A')}")
print(f" πŸ“„ {doc.metadata.get('source')} - Page {doc.metadata.get('page')}")
print(f" πŸ“Œ {doc.metadata.get('heading', 'N/A')[:60]}")
if 'total_chunks' in doc.metadata and doc.metadata.get('total_chunks', 1) > 1:
print(f" πŸ”— Chunk {doc.metadata.get('chunk_index', 0)+1}/{doc.metadata.get('total_chunks')}")
print(f" πŸ“ {doc.page_content[:200]}...")
print()
print("\nπŸ–ΌοΈ IMAGES:\n")
if results['images']:
for img in results['images']:
print(f"[{img['rank']}] {img['source']} - Page {img['page']}")
print(f" {img['path']}\n")
else:
print(" No images found\n")
# ========================================
# πŸ“– GET CONTEXT WITH PARENTS
# ========================================
def get_context_with_parents(results: Dict) -> List[Dict]:
"""Extract full parent contexts."""
seen_parents = set()
contexts = []
for result in results['text_results']:
doc = result['document']
parent = doc.metadata.get('parent_text')
if parent and parent not in seen_parents:
contexts.append({
"text": parent,
"source": doc.metadata['source'],
"page": doc.metadata['page'],
"heading": doc.metadata['heading'],
"type": doc.metadata.get('type', 'text'),
"is_parent": True
})
seen_parents.add(parent)
elif not parent:
contexts.append({
"text": doc.page_content,
"source": doc.metadata['source'],
"page": doc.metadata['page'],
"heading": doc.metadata['heading'],
"type": doc.metadata.get('type', 'text'),
"is_parent": False
})
return contexts
# ========================================
# πŸš€ MAIN EXECUTION (UPDATED FOR HYBRID)
# ========================================
if __name__ == "__main__":
print("="*70)
print("πŸš€ RAG with HYBRID SEARCH (Semantic + BM25)")
print("="*70 + "\n")
pdf_files = glob.glob(f"{PDF_DIR}/*.pdf")
print(f"πŸ“‚ Found {len(pdf_files)} PDF files\n")
if len(pdf_files) == 0:
print("❌ No PDFs found!")
exit(1)
print("\nπŸ€– Loading VLM2Vec model...")
embedding_model = VLM2VecEmbeddings(
model_name="TIGER-Lab/VLM2Vec-Qwen2VL-2B",
cache_dir=MODEL_CACHE_DIR
)
# Load or build index
if os.path.exists(f"{FAISS_INDEX_PATH}/text_index.faiss"):
print(f"βœ… Loading existing index\n")
if not os.path.exists(f"{FAISS_INDEX_PATH}/bm25_index.pkl"):
print("⚠️ BM25 index missing, building now...")
with open(f"{FAISS_INDEX_PATH}/text_documents.pkl", "rb") as f:
all_texts = pickle.load(f)
print(" Building BM25 index...")
tokenized_docs = [doc.page_content.lower().split() for doc in all_texts]
bm25_index = BM25Okapi(tokenized_docs, k1=1.3, b=0.65)
with open(f"{FAISS_INDEX_PATH}/bm25_index.pkl", "wb") as f:
pickle.dump(bm25_index, f)
print(" βœ… BM25 index saved\n")
else:
print("πŸ”¨ Building new index...\n")
embedding_model = VLM2VecEmbeddings(
model_name="TIGER-Lab/VLM2Vec-Qwen2VL-2B",
cache_dir=MODEL_CACHE_DIR
)
index, documents = build_multimodal_faiss_streaming(pdf_files, embedding_model)
if index is None:
exit(0)
# Interactive testing
print("="*70)
print("πŸ§ͺ TESTING MODE - HYBRID SEARCH")
print(f" Weights: Semantic {DENSE_WEIGHT} | BM25 {SPARSE_WEIGHT}")
print("="*70 + "\n")
test_queries = [
"What is the higher and lower explosive limit of butane?",
"What are the precautions taken while handling H2S?",
"What are the Personal Protection used for Sulfolane?",
"What is the Composition of Platforming Feed and Product?",
"Explain Dual function platforming catalyst chemistry.",
"Steps to be followed in Amine Regeneration Unit for normal shutdown process.",
"Could you tell me what De-greasing of Amine System in pre startup wash",
]
print("πŸ“‹ SUGGESTED QUERIES:")
for i, q in enumerate(test_queries, 1):
print(f" {i}. {q}")
print()
print("πŸ’‘ Type 'mode' to switch between hybrid/bm25/semantic")
print()
current_mode = "hybrid"
while True:
user_query = input(f"πŸ’¬ Query [{current_mode}] (or 1-5, 'mode', or 'exit'): ").strip()
if user_query.lower() == 'exit':
print("\nβœ… Done!")
break
if user_query.lower() == 'mode':
print("\nπŸ”„ Select mode:")
print(" 1. Hybrid (Semantic + BM25)")
print(" 2. BM25 only")
print(" 3. Semantic only")
mode_choice = input(" Choice (1-3): ").strip()
if mode_choice == '1':
current_mode = "hybrid"
elif mode_choice == '2':
current_mode = "bm25"
elif mode_choice == '3':
current_mode = "semantic"
print(f" βœ… Mode set to: {current_mode}\n")
continue
if user_query.isdigit() and 1 <= int(user_query) <= len(test_queries):
user_query = test_queries[int(user_query) - 1]
if not user_query:
continue
print(f"\n{'='*60}")
print(f"πŸ” Query: {user_query}")
print(f"πŸ”§ Mode: {current_mode.upper()}")
print(f"{'='*60}\n")
try:
if current_mode == "hybrid":
results = query_with_hybrid(user_query, embedding_model, k_text=5, k_images=3)
display_results_hybrid(results)
elif current_mode == "bm25":
results = query_with_bm25(user_query, k_text=5, k_images=3)
display_results_bm25(results)
else: # semantic only
results = query_with_hybrid(user_query, embedding_model, k_text=5, k_images=3,
dense_weight=1.0, sparse_weight=0.0)
display_results_hybrid(results)
print("\nπŸ“– FULL CONTEXT:\n")
contexts = get_context_with_parents(results)
for i, ctx in enumerate(contexts[:3], 1):
print(f"[{i}] {ctx['heading'][:50]}")
if ctx['is_parent']:
print(f" βœ… Full section")
print(f" {ctx['text'][:300]}...\n")
print("="*60 + "\n")
except Exception as e:
print(f"\n❌ Error: {e}\n")
import traceback
traceback.print_exc()