Spaces:
Sleeping
Sleeping
fixes and changes, new faiss uploaded
Browse files- build_faiss_index.py +36 -34
- old_build_faiss_index.py +339 -0
build_faiss_index.py
CHANGED
|
@@ -1,24 +1,21 @@
|
|
| 1 |
"""
|
| 2 |
-
Build FAISS Index from Scratch
|
| 3 |
-
Creates faiss_index.pkl
|
| 4 |
|
| 5 |
Run this ONCE before starting the backend:
|
| 6 |
python build_faiss_index.py
|
| 7 |
|
| 8 |
Author: Banking RAG Chatbot
|
| 9 |
-
Date:
|
| 10 |
"""
|
| 11 |
|
| 12 |
-
|
| 13 |
-
# Add these lines at the very top (after docstring)
|
| 14 |
import os
|
| 15 |
-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
| 16 |
-
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
| 17 |
-
|
| 18 |
import warnings
|
| 19 |
-
warnings.filterwarnings('ignore')
|
| 20 |
|
| 21 |
-
import os
|
| 22 |
import pickle
|
| 23 |
import json
|
| 24 |
import torch
|
|
@@ -30,7 +27,6 @@ from pathlib import Path
|
|
| 30 |
from transformers import AutoTokenizer, AutoModel
|
| 31 |
from typing import List
|
| 32 |
|
| 33 |
-
|
| 34 |
# ============================================================================
|
| 35 |
# CONFIGURATION - UPDATE THESE PATHS!
|
| 36 |
# ============================================================================
|
|
@@ -50,7 +46,6 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
| 50 |
# Batch size for encoding (reduce if you get OOM errors)
|
| 51 |
BATCH_SIZE = 32
|
| 52 |
|
| 53 |
-
|
| 54 |
# ============================================================================
|
| 55 |
# CUSTOM SENTENCE TRANSFORMER (Same as retriever.py)
|
| 56 |
# ============================================================================
|
|
@@ -87,7 +82,6 @@ class CustomSentenceTransformer(nn.Module):
|
|
| 87 |
def encode(self, sentences: List[str], batch_size: int = 32) -> np.ndarray:
|
| 88 |
"""Encode sentences - same as training"""
|
| 89 |
self.eval()
|
| 90 |
-
|
| 91 |
if isinstance(sentences, str):
|
| 92 |
sentences = [sentences]
|
| 93 |
|
|
@@ -95,7 +89,6 @@ class CustomSentenceTransformer(nn.Module):
|
|
| 95 |
processed_sentences = [f"query: {s.strip()}" for s in sentences]
|
| 96 |
|
| 97 |
all_embeddings = []
|
| 98 |
-
|
| 99 |
with torch.no_grad():
|
| 100 |
for i in range(0, len(processed_sentences), batch_size):
|
| 101 |
batch_sentences = processed_sentences[i:i + batch_size]
|
|
@@ -115,7 +108,6 @@ class CustomSentenceTransformer(nn.Module):
|
|
| 115 |
|
| 116 |
return np.vstack(all_embeddings)
|
| 117 |
|
| 118 |
-
|
| 119 |
# ============================================================================
|
| 120 |
# RETRIEVER MODEL (Wrapper)
|
| 121 |
# ============================================================================
|
|
@@ -126,7 +118,6 @@ class RetrieverModel:
|
|
| 126 |
def __init__(self, model_path: str, device: str = "cpu"):
|
| 127 |
print(f"\n๐ค Loading retriever model...")
|
| 128 |
print(f" Device: {device}")
|
| 129 |
-
|
| 130 |
self.device = device
|
| 131 |
self.model = CustomSentenceTransformer("intfloat/e5-base-v2").to(device)
|
| 132 |
|
|
@@ -137,7 +128,7 @@ class RetrieverModel:
|
|
| 137 |
self.model.load_state_dict(state_dict)
|
| 138 |
print(f" โ
Trained weights loaded")
|
| 139 |
except Exception as e:
|
| 140 |
-
print(f" โ ๏ธ
|
| 141 |
print(f" Using base e5-base-v2 model instead")
|
| 142 |
|
| 143 |
self.model.eval()
|
|
@@ -146,7 +137,6 @@ class RetrieverModel:
|
|
| 146 |
"""Encode documents"""
|
| 147 |
return self.model.encode(documents, batch_size=batch_size)
|
| 148 |
|
| 149 |
-
|
| 150 |
# ============================================================================
|
| 151 |
# MAIN: BUILD FAISS INDEX
|
| 152 |
# ============================================================================
|
|
@@ -175,7 +165,7 @@ def build_faiss_index():
|
|
| 175 |
try:
|
| 176 |
kb_data.append(json.loads(line))
|
| 177 |
except json.JSONDecodeError as e:
|
| 178 |
-
print(f" โ ๏ธ
|
| 179 |
|
| 180 |
print(f" โ
Loaded {len(kb_data)} documents")
|
| 181 |
|
|
@@ -202,7 +192,7 @@ def build_faiss_index():
|
|
| 202 |
elif response:
|
| 203 |
text = response
|
| 204 |
else:
|
| 205 |
-
print(f" โ ๏ธ
|
| 206 |
text = "empty document"
|
| 207 |
|
| 208 |
documents.append(text)
|
|
@@ -239,7 +229,7 @@ def build_faiss_index():
|
|
| 239 |
return False
|
| 240 |
|
| 241 |
# ========================================================================
|
| 242 |
-
# STEP 4: BUILD FAISS INDEX
|
| 243 |
# ========================================================================
|
| 244 |
print(f"\n๐ STEP 4: Building FAISS index...")
|
| 245 |
|
|
@@ -247,6 +237,7 @@ def build_faiss_index():
|
|
| 247 |
print(f" Dimension: {dimension}")
|
| 248 |
|
| 249 |
# Create FAISS index (Inner Product = Cosine similarity after normalization)
|
|
|
|
| 250 |
index = faiss.IndexFlatIP(dimension)
|
| 251 |
|
| 252 |
# Normalize embeddings for cosine similarity
|
|
@@ -261,24 +252,31 @@ def build_faiss_index():
|
|
| 261 |
print(f" Total vectors: {index.ntotal}")
|
| 262 |
|
| 263 |
# ========================================================================
|
| 264 |
-
# STEP 5: SAVE
|
| 265 |
# ========================================================================
|
| 266 |
-
print(f"\n๐พ STEP 5: Saving
|
| 267 |
|
| 268 |
# Create models directory if it doesn't exist
|
| 269 |
os.makedirs(os.path.dirname(OUTPUT_PKL_FILE), exist_ok=True)
|
| 270 |
|
| 271 |
-
#
|
| 272 |
-
print(f"
|
| 273 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
with open(OUTPUT_PKL_FILE, 'wb') as f:
|
| 275 |
-
pickle.dump((
|
| 276 |
|
| 277 |
file_size_mb = Path(OUTPUT_PKL_FILE).stat().st_size / (1024 * 1024)
|
| 278 |
print(f" โ
Saved: {OUTPUT_PKL_FILE}")
|
| 279 |
print(f" File size: {file_size_mb:.2f} MB")
|
| 280 |
except Exception as e:
|
| 281 |
print(f" โ ERROR saving pickle: {e}")
|
|
|
|
|
|
|
| 282 |
return False
|
| 283 |
|
| 284 |
# ========================================================================
|
|
@@ -288,17 +286,21 @@ def build_faiss_index():
|
|
| 288 |
|
| 289 |
try:
|
| 290 |
with open(OUTPUT_PKL_FILE, 'rb') as f:
|
| 291 |
-
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
print(f" โ
Verification successful")
|
| 294 |
print(f" Index vectors: {loaded_index.ntotal}")
|
| 295 |
print(f" KB documents: {len(loaded_kb)}")
|
| 296 |
|
| 297 |
if loaded_index.ntotal != len(loaded_kb):
|
| 298 |
-
print(f" โ ๏ธ
|
| 299 |
-
|
| 300 |
except Exception as e:
|
| 301 |
print(f" โ ERROR verifying file: {e}")
|
|
|
|
|
|
|
| 302 |
return False
|
| 303 |
|
| 304 |
# ========================================================================
|
|
@@ -312,14 +314,14 @@ def build_faiss_index():
|
|
| 312 |
print(f" Vectors: {index.ntotal}")
|
| 313 |
print(f" Dimension: {dimension}")
|
| 314 |
print(f" File: {OUTPUT_PKL_FILE} ({file_size_mb:.2f} MB)")
|
| 315 |
-
print(f"\n๐
|
| 316 |
-
print(f"
|
| 317 |
-
print(f"
|
|
|
|
| 318 |
print("=" * 80 + "\n")
|
| 319 |
|
| 320 |
return True
|
| 321 |
|
| 322 |
-
|
| 323 |
# ============================================================================
|
| 324 |
# RUN SCRIPT
|
| 325 |
# ============================================================================
|
|
@@ -333,7 +335,7 @@ if __name__ == "__main__":
|
|
| 333 |
print("=" * 80)
|
| 334 |
print("\nPlease check:")
|
| 335 |
print("1. Knowledge base file exists: data/final_knowledge_base.jsonl")
|
| 336 |
-
print("2. Retriever model exists: models/best_retriever_model.pth")
|
| 337 |
print("3. You have enough RAM (embeddings need ~1GB for 10k docs)")
|
| 338 |
print("=" * 80 + "\n")
|
| 339 |
exit(1)
|
|
|
|
| 1 |
"""
|
| 2 |
+
Build FAISS Index from Scratch - COMPATIBLE VERSION
|
| 3 |
+
Creates faiss_index.pkl with proper serialization for version compatibility
|
| 4 |
|
| 5 |
Run this ONCE before starting the backend:
|
| 6 |
python build_faiss_index.py
|
| 7 |
|
| 8 |
Author: Banking RAG Chatbot
|
| 9 |
+
Date: November 2025
|
| 10 |
"""
|
| 11 |
|
| 12 |
+
# Suppress warnings
|
|
|
|
| 13 |
import os
|
| 14 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
| 15 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
|
|
|
|
| 16 |
import warnings
|
| 17 |
+
warnings.filterwarnings('ignore')
|
| 18 |
|
|
|
|
| 19 |
import pickle
|
| 20 |
import json
|
| 21 |
import torch
|
|
|
|
| 27 |
from transformers import AutoTokenizer, AutoModel
|
| 28 |
from typing import List
|
| 29 |
|
|
|
|
| 30 |
# ============================================================================
|
| 31 |
# CONFIGURATION - UPDATE THESE PATHS!
|
| 32 |
# ============================================================================
|
|
|
|
| 46 |
# Batch size for encoding (reduce if you get OOM errors)
|
| 47 |
BATCH_SIZE = 32
|
| 48 |
|
|
|
|
| 49 |
# ============================================================================
|
| 50 |
# CUSTOM SENTENCE TRANSFORMER (Same as retriever.py)
|
| 51 |
# ============================================================================
|
|
|
|
| 82 |
def encode(self, sentences: List[str], batch_size: int = 32) -> np.ndarray:
|
| 83 |
"""Encode sentences - same as training"""
|
| 84 |
self.eval()
|
|
|
|
| 85 |
if isinstance(sentences, str):
|
| 86 |
sentences = [sentences]
|
| 87 |
|
|
|
|
| 89 |
processed_sentences = [f"query: {s.strip()}" for s in sentences]
|
| 90 |
|
| 91 |
all_embeddings = []
|
|
|
|
| 92 |
with torch.no_grad():
|
| 93 |
for i in range(0, len(processed_sentences), batch_size):
|
| 94 |
batch_sentences = processed_sentences[i:i + batch_size]
|
|
|
|
| 108 |
|
| 109 |
return np.vstack(all_embeddings)
|
| 110 |
|
|
|
|
| 111 |
# ============================================================================
|
| 112 |
# RETRIEVER MODEL (Wrapper)
|
| 113 |
# ============================================================================
|
|
|
|
| 118 |
def __init__(self, model_path: str, device: str = "cpu"):
|
| 119 |
print(f"\n๐ค Loading retriever model...")
|
| 120 |
print(f" Device: {device}")
|
|
|
|
| 121 |
self.device = device
|
| 122 |
self.model = CustomSentenceTransformer("intfloat/e5-base-v2").to(device)
|
| 123 |
|
|
|
|
| 128 |
self.model.load_state_dict(state_dict)
|
| 129 |
print(f" โ
Trained weights loaded")
|
| 130 |
except Exception as e:
|
| 131 |
+
print(f" โ ๏ธ Warning: Could not load trained weights: {e}")
|
| 132 |
print(f" Using base e5-base-v2 model instead")
|
| 133 |
|
| 134 |
self.model.eval()
|
|
|
|
| 137 |
"""Encode documents"""
|
| 138 |
return self.model.encode(documents, batch_size=batch_size)
|
| 139 |
|
|
|
|
| 140 |
# ============================================================================
|
| 141 |
# MAIN: BUILD FAISS INDEX
|
| 142 |
# ============================================================================
|
|
|
|
| 165 |
try:
|
| 166 |
kb_data.append(json.loads(line))
|
| 167 |
except json.JSONDecodeError as e:
|
| 168 |
+
print(f" โ ๏ธ Warning: Skipping invalid JSON on line {line_num}: {e}")
|
| 169 |
|
| 170 |
print(f" โ
Loaded {len(kb_data)} documents")
|
| 171 |
|
|
|
|
| 192 |
elif response:
|
| 193 |
text = response
|
| 194 |
else:
|
| 195 |
+
print(f" โ ๏ธ Warning: Document {i} has no content, using placeholder")
|
| 196 |
text = "empty document"
|
| 197 |
|
| 198 |
documents.append(text)
|
|
|
|
| 229 |
return False
|
| 230 |
|
| 231 |
# ========================================================================
|
| 232 |
+
# STEP 4: BUILD FAISS INDEX WITH PROPER SERIALIZATION
|
| 233 |
# ========================================================================
|
| 234 |
print(f"\n๐ STEP 4: Building FAISS index...")
|
| 235 |
|
|
|
|
| 237 |
print(f" Dimension: {dimension}")
|
| 238 |
|
| 239 |
# Create FAISS index (Inner Product = Cosine similarity after normalization)
|
| 240 |
+
print(f" Creating IndexFlatIP...")
|
| 241 |
index = faiss.IndexFlatIP(dimension)
|
| 242 |
|
| 243 |
# Normalize embeddings for cosine similarity
|
|
|
|
| 252 |
print(f" Total vectors: {index.ntotal}")
|
| 253 |
|
| 254 |
# ========================================================================
|
| 255 |
+
# STEP 5: SAVE WITH PROPER FAISS SERIALIZATION (VERSION COMPATIBLE!)
|
| 256 |
# ========================================================================
|
| 257 |
+
print(f"\n๐พ STEP 5: Saving with FAISS serialization (version-compatible)...")
|
| 258 |
|
| 259 |
# Create models directory if it doesn't exist
|
| 260 |
os.makedirs(os.path.dirname(OUTPUT_PKL_FILE), exist_ok=True)
|
| 261 |
|
| 262 |
+
# โ
PROPER WAY: Serialize FAISS index to bytes first
|
| 263 |
+
print(f" Serializing FAISS index to bytes...")
|
| 264 |
try:
|
| 265 |
+
# Write FAISS index to bytes (works across FAISS versions!)
|
| 266 |
+
index_bytes = faiss.serialize_index(index)
|
| 267 |
+
|
| 268 |
+
# Now pickle the bytes + kb_data
|
| 269 |
+
print(f" Pickling (index_bytes, kb_data) tuple...")
|
| 270 |
with open(OUTPUT_PKL_FILE, 'wb') as f:
|
| 271 |
+
pickle.dump((index_bytes, kb_data), f, protocol=pickle.HIGHEST_PROTOCOL)
|
| 272 |
|
| 273 |
file_size_mb = Path(OUTPUT_PKL_FILE).stat().st_size / (1024 * 1024)
|
| 274 |
print(f" โ
Saved: {OUTPUT_PKL_FILE}")
|
| 275 |
print(f" File size: {file_size_mb:.2f} MB")
|
| 276 |
except Exception as e:
|
| 277 |
print(f" โ ERROR saving pickle: {e}")
|
| 278 |
+
import traceback
|
| 279 |
+
traceback.print_exc()
|
| 280 |
return False
|
| 281 |
|
| 282 |
# ========================================================================
|
|
|
|
| 286 |
|
| 287 |
try:
|
| 288 |
with open(OUTPUT_PKL_FILE, 'rb') as f:
|
| 289 |
+
loaded_index_bytes, loaded_kb = pickle.load(f)
|
| 290 |
+
|
| 291 |
+
# Deserialize FAISS index from bytes
|
| 292 |
+
loaded_index = faiss.deserialize_index(loaded_index_bytes)
|
| 293 |
|
| 294 |
print(f" โ
Verification successful")
|
| 295 |
print(f" Index vectors: {loaded_index.ntotal}")
|
| 296 |
print(f" KB documents: {len(loaded_kb)}")
|
| 297 |
|
| 298 |
if loaded_index.ntotal != len(loaded_kb):
|
| 299 |
+
print(f" โ ๏ธ WARNING: Size mismatch detected!")
|
|
|
|
| 300 |
except Exception as e:
|
| 301 |
print(f" โ ERROR verifying file: {e}")
|
| 302 |
+
import traceback
|
| 303 |
+
traceback.print_exc()
|
| 304 |
return False
|
| 305 |
|
| 306 |
# ========================================================================
|
|
|
|
| 314 |
print(f" Vectors: {index.ntotal}")
|
| 315 |
print(f" Dimension: {dimension}")
|
| 316 |
print(f" File: {OUTPUT_PKL_FILE} ({file_size_mb:.2f} MB)")
|
| 317 |
+
print(f"\n๐ Next steps:")
|
| 318 |
+
print(f" 1. Upload {OUTPUT_PKL_FILE} to HuggingFace Hub")
|
| 319 |
+
print(f" 2. Restart your backend")
|
| 320 |
+
print(f" 3. Test retrieval - should work now!")
|
| 321 |
print("=" * 80 + "\n")
|
| 322 |
|
| 323 |
return True
|
| 324 |
|
|
|
|
| 325 |
# ============================================================================
|
| 326 |
# RUN SCRIPT
|
| 327 |
# ============================================================================
|
|
|
|
| 335 |
print("=" * 80)
|
| 336 |
print("\nPlease check:")
|
| 337 |
print("1. Knowledge base file exists: data/final_knowledge_base.jsonl")
|
| 338 |
+
print("2. Retriever model exists: app/models/best_retriever_model.pth")
|
| 339 |
print("3. You have enough RAM (embeddings need ~1GB for 10k docs)")
|
| 340 |
print("=" * 80 + "\n")
|
| 341 |
exit(1)
|
old_build_faiss_index.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Build FAISS Index from Scratch
|
| 3 |
+
Creates faiss_index.pkl from your knowledge base and trained retriever model
|
| 4 |
+
|
| 5 |
+
Run this ONCE before starting the backend:
|
| 6 |
+
python build_faiss_index.py
|
| 7 |
+
|
| 8 |
+
Author: Banking RAG Chatbot
|
| 9 |
+
Date: October 2025
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Add these lines at the very top (after docstring)
|
| 14 |
+
import os
|
| 15 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppress TensorFlow info/warnings
|
| 16 |
+
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # Disable oneDNN messages
|
| 17 |
+
|
| 18 |
+
import warnings
|
| 19 |
+
warnings.filterwarnings('ignore') # Suppress all warnings
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import pickle
|
| 23 |
+
import json
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
import faiss
|
| 28 |
+
import numpy as np
|
| 29 |
+
from pathlib import Path
|
| 30 |
+
from transformers import AutoTokenizer, AutoModel
|
| 31 |
+
from typing import List
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ============================================================================
|
| 35 |
+
# CONFIGURATION - UPDATE THESE PATHS!
|
| 36 |
+
# ============================================================================
|
| 37 |
+
|
| 38 |
+
# Where is your knowledge base JSONL file?
|
| 39 |
+
KB_JSONL_FILE = "data/final_knowledge_base.jsonl"
|
| 40 |
+
|
| 41 |
+
# Where is your trained retriever model?
|
| 42 |
+
RETRIEVER_MODEL_PATH = "app/models/best_retriever_model.pth"
|
| 43 |
+
|
| 44 |
+
# Where to save the output FAISS pickle?
|
| 45 |
+
OUTPUT_PKL_FILE = "app/models/faiss_index.pkl"
|
| 46 |
+
|
| 47 |
+
# Device (auto-detect GPU/CPU)
|
| 48 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 49 |
+
|
| 50 |
+
# Batch size for encoding (reduce if you get OOM errors)
|
| 51 |
+
BATCH_SIZE = 32
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ============================================================================
|
| 55 |
+
# CUSTOM SENTENCE TRANSFORMER (Same as retriever.py)
|
| 56 |
+
# ============================================================================
|
| 57 |
+
|
| 58 |
+
class CustomSentenceTransformer(nn.Module):
|
| 59 |
+
"""
|
| 60 |
+
Custom SentenceTransformer - exact copy from retriever.py
|
| 61 |
+
Uses e5-base-v2 with mean pooling and L2 normalization
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def __init__(self, model_name: str = "intfloat/e5-base-v2"):
|
| 65 |
+
super().__init__()
|
| 66 |
+
print(f" Loading base model: {model_name}...")
|
| 67 |
+
self.encoder = AutoModel.from_pretrained(model_name)
|
| 68 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 69 |
+
self.config = self.encoder.config
|
| 70 |
+
print(f" โ
Base model loaded")
|
| 71 |
+
|
| 72 |
+
def forward(self, input_ids, attention_mask):
|
| 73 |
+
"""Forward pass through BERT encoder"""
|
| 74 |
+
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
| 75 |
+
|
| 76 |
+
# Mean pooling
|
| 77 |
+
token_embeddings = outputs.last_hidden_state
|
| 78 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
| 79 |
+
embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
| 80 |
+
input_mask_expanded.sum(1), min=1e-9
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# L2 normalize
|
| 84 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
| 85 |
+
return embeddings
|
| 86 |
+
|
| 87 |
+
def encode(self, sentences: List[str], batch_size: int = 32) -> np.ndarray:
|
| 88 |
+
"""Encode sentences - same as training"""
|
| 89 |
+
self.eval()
|
| 90 |
+
|
| 91 |
+
if isinstance(sentences, str):
|
| 92 |
+
sentences = [sentences]
|
| 93 |
+
|
| 94 |
+
# Add 'query: ' prefix for e5-base-v2
|
| 95 |
+
processed_sentences = [f"query: {s.strip()}" for s in sentences]
|
| 96 |
+
|
| 97 |
+
all_embeddings = []
|
| 98 |
+
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
for i in range(0, len(processed_sentences), batch_size):
|
| 101 |
+
batch_sentences = processed_sentences[i:i + batch_size]
|
| 102 |
+
|
| 103 |
+
# Tokenize
|
| 104 |
+
tokens = self.tokenizer(
|
| 105 |
+
batch_sentences,
|
| 106 |
+
truncation=True,
|
| 107 |
+
padding=True,
|
| 108 |
+
max_length=128,
|
| 109 |
+
return_tensors='pt'
|
| 110 |
+
).to(self.encoder.device)
|
| 111 |
+
|
| 112 |
+
# Get embeddings
|
| 113 |
+
embeddings = self.forward(tokens['input_ids'], tokens['attention_mask'])
|
| 114 |
+
all_embeddings.append(embeddings.cpu().numpy())
|
| 115 |
+
|
| 116 |
+
return np.vstack(all_embeddings)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# ============================================================================
|
| 120 |
+
# RETRIEVER MODEL (Wrapper)
|
| 121 |
+
# ============================================================================
|
| 122 |
+
|
| 123 |
+
class RetrieverModel:
|
| 124 |
+
"""Wrapper for trained retriever model"""
|
| 125 |
+
|
| 126 |
+
def __init__(self, model_path: str, device: str = "cpu"):
|
| 127 |
+
print(f"\n๐ค Loading retriever model...")
|
| 128 |
+
print(f" Device: {device}")
|
| 129 |
+
|
| 130 |
+
self.device = device
|
| 131 |
+
self.model = CustomSentenceTransformer("intfloat/e5-base-v2").to(device)
|
| 132 |
+
|
| 133 |
+
# Load trained weights
|
| 134 |
+
print(f" Loading weights from: {model_path}")
|
| 135 |
+
try:
|
| 136 |
+
state_dict = torch.load(model_path, map_location=device)
|
| 137 |
+
self.model.load_state_dict(state_dict)
|
| 138 |
+
print(f" โ
Trained weights loaded")
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f" โ ๏ธ Warning: Could not load trained weights: {e}")
|
| 141 |
+
print(f" Using base e5-base-v2 model instead")
|
| 142 |
+
|
| 143 |
+
self.model.eval()
|
| 144 |
+
|
| 145 |
+
def encode_documents(self, documents: List[str], batch_size: int = 32) -> np.ndarray:
|
| 146 |
+
"""Encode documents"""
|
| 147 |
+
return self.model.encode(documents, batch_size=batch_size)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# ============================================================================
|
| 151 |
+
# MAIN: BUILD FAISS INDEX
|
| 152 |
+
# ============================================================================
|
| 153 |
+
|
| 154 |
+
def build_faiss_index():
|
| 155 |
+
"""Main function to build FAISS index from scratch"""
|
| 156 |
+
|
| 157 |
+
print("=" * 80)
|
| 158 |
+
print("๐๏ธ BUILDING FAISS INDEX FROM SCRATCH")
|
| 159 |
+
print("=" * 80)
|
| 160 |
+
|
| 161 |
+
# ========================================================================
|
| 162 |
+
# STEP 1: LOAD KNOWLEDGE BASE
|
| 163 |
+
# ========================================================================
|
| 164 |
+
print(f"\n๐ STEP 1: Loading knowledge base...")
|
| 165 |
+
print(f" File: {KB_JSONL_FILE}")
|
| 166 |
+
|
| 167 |
+
if not os.path.exists(KB_JSONL_FILE):
|
| 168 |
+
print(f" โ ERROR: File not found!")
|
| 169 |
+
print(f" Please copy your knowledge base to: {KB_JSONL_FILE}")
|
| 170 |
+
return False
|
| 171 |
+
|
| 172 |
+
kb_data = []
|
| 173 |
+
with open(KB_JSONL_FILE, 'r', encoding='utf-8') as f:
|
| 174 |
+
for line_num, line in enumerate(f, 1):
|
| 175 |
+
try:
|
| 176 |
+
kb_data.append(json.loads(line))
|
| 177 |
+
except json.JSONDecodeError as e:
|
| 178 |
+
print(f" โ ๏ธ Warning: Skipping invalid JSON on line {line_num}: {e}")
|
| 179 |
+
|
| 180 |
+
print(f" โ
Loaded {len(kb_data)} documents")
|
| 181 |
+
|
| 182 |
+
if len(kb_data) == 0:
|
| 183 |
+
print(f" โ ERROR: Knowledge base is empty!")
|
| 184 |
+
return False
|
| 185 |
+
|
| 186 |
+
# ========================================================================
|
| 187 |
+
# STEP 2: PREPARE DOCUMENTS FOR ENCODING
|
| 188 |
+
# ========================================================================
|
| 189 |
+
print(f"\n๐ STEP 2: Preparing documents for encoding...")
|
| 190 |
+
|
| 191 |
+
documents = []
|
| 192 |
+
for i, item in enumerate(kb_data):
|
| 193 |
+
# Combine instruction + response for embedding (same as training)
|
| 194 |
+
instruction = item.get('instruction', '')
|
| 195 |
+
response = item.get('response', '')
|
| 196 |
+
|
| 197 |
+
# Create combined text
|
| 198 |
+
if instruction and response:
|
| 199 |
+
text = f"{instruction} {response}"
|
| 200 |
+
elif instruction:
|
| 201 |
+
text = instruction
|
| 202 |
+
elif response:
|
| 203 |
+
text = response
|
| 204 |
+
else:
|
| 205 |
+
print(f" โ ๏ธ Warning: Document {i} has no content, using placeholder")
|
| 206 |
+
text = "empty document"
|
| 207 |
+
|
| 208 |
+
documents.append(text)
|
| 209 |
+
|
| 210 |
+
print(f" โ
Prepared {len(documents)} documents for encoding")
|
| 211 |
+
print(f" Average length: {sum(len(d) for d in documents) / len(documents):.1f} chars")
|
| 212 |
+
|
| 213 |
+
# ========================================================================
|
| 214 |
+
# STEP 3: LOAD RETRIEVER AND ENCODE DOCUMENTS
|
| 215 |
+
# ========================================================================
|
| 216 |
+
print(f"\n๐ฎ STEP 3: Encoding documents with trained retriever...")
|
| 217 |
+
|
| 218 |
+
if not os.path.exists(RETRIEVER_MODEL_PATH):
|
| 219 |
+
print(f" โ ERROR: Retriever model not found!")
|
| 220 |
+
print(f" Please copy your trained model to: {RETRIEVER_MODEL_PATH}")
|
| 221 |
+
return False
|
| 222 |
+
|
| 223 |
+
# Load retriever
|
| 224 |
+
retriever = RetrieverModel(RETRIEVER_MODEL_PATH, device=DEVICE)
|
| 225 |
+
|
| 226 |
+
# Encode all documents
|
| 227 |
+
print(f" Encoding {len(documents)} documents...")
|
| 228 |
+
print(f" Batch size: {BATCH_SIZE}")
|
| 229 |
+
print(f" This may take a few minutes... โ")
|
| 230 |
+
|
| 231 |
+
try:
|
| 232 |
+
embeddings = retriever.encode_documents(documents, batch_size=BATCH_SIZE)
|
| 233 |
+
print(f" โ
Encoded {embeddings.shape[0]} documents")
|
| 234 |
+
print(f" Embedding dimension: {embeddings.shape[1]}")
|
| 235 |
+
except Exception as e:
|
| 236 |
+
print(f" โ ERROR during encoding: {e}")
|
| 237 |
+
import traceback
|
| 238 |
+
traceback.print_exc()
|
| 239 |
+
return False
|
| 240 |
+
|
| 241 |
+
# ========================================================================
|
| 242 |
+
# STEP 4: BUILD FAISS INDEX
|
| 243 |
+
# ========================================================================
|
| 244 |
+
print(f"\n๐ STEP 4: Building FAISS index...")
|
| 245 |
+
|
| 246 |
+
dimension = embeddings.shape[1]
|
| 247 |
+
print(f" Dimension: {dimension}")
|
| 248 |
+
|
| 249 |
+
# Create FAISS index (Inner Product = Cosine similarity after normalization)
|
| 250 |
+
index = faiss.IndexFlatIP(dimension)
|
| 251 |
+
|
| 252 |
+
# Normalize embeddings for cosine similarity
|
| 253 |
+
print(f" Normalizing embeddings...")
|
| 254 |
+
faiss.normalize_L2(embeddings)
|
| 255 |
+
|
| 256 |
+
# Add to index
|
| 257 |
+
print(f" Adding {embeddings.shape[0]} vectors to FAISS index...")
|
| 258 |
+
index.add(embeddings.astype('float32'))
|
| 259 |
+
|
| 260 |
+
print(f" โ
FAISS index built successfully")
|
| 261 |
+
print(f" Total vectors: {index.ntotal}")
|
| 262 |
+
|
| 263 |
+
# ========================================================================
|
| 264 |
+
# STEP 5: SAVE AS PICKLE FILE
|
| 265 |
+
# ========================================================================
|
| 266 |
+
print(f"\n๐พ STEP 5: Saving as pickle file...")
|
| 267 |
+
|
| 268 |
+
# Create models directory if it doesn't exist
|
| 269 |
+
os.makedirs(os.path.dirname(OUTPUT_PKL_FILE), exist_ok=True)
|
| 270 |
+
|
| 271 |
+
# Save tuple of (index, kb_data)
|
| 272 |
+
print(f" Pickling (index, kb_data) tuple...")
|
| 273 |
+
try:
|
| 274 |
+
with open(OUTPUT_PKL_FILE, 'wb') as f:
|
| 275 |
+
pickle.dump((index, kb_data), f)
|
| 276 |
+
|
| 277 |
+
file_size_mb = Path(OUTPUT_PKL_FILE).stat().st_size / (1024 * 1024)
|
| 278 |
+
print(f" โ
Saved: {OUTPUT_PKL_FILE}")
|
| 279 |
+
print(f" File size: {file_size_mb:.2f} MB")
|
| 280 |
+
except Exception as e:
|
| 281 |
+
print(f" โ ERROR saving pickle: {e}")
|
| 282 |
+
return False
|
| 283 |
+
|
| 284 |
+
# ========================================================================
|
| 285 |
+
# STEP 6: VERIFY SAVED FILE
|
| 286 |
+
# ========================================================================
|
| 287 |
+
print(f"\nโ
STEP 6: Verifying saved file...")
|
| 288 |
+
|
| 289 |
+
try:
|
| 290 |
+
with open(OUTPUT_PKL_FILE, 'rb') as f:
|
| 291 |
+
loaded_index, loaded_kb = pickle.load(f)
|
| 292 |
+
|
| 293 |
+
print(f" โ
Verification successful")
|
| 294 |
+
print(f" Index vectors: {loaded_index.ntotal}")
|
| 295 |
+
print(f" KB documents: {len(loaded_kb)}")
|
| 296 |
+
|
| 297 |
+
if loaded_index.ntotal != len(loaded_kb):
|
| 298 |
+
print(f" โ ๏ธ WARNING: Size mismatch detected!")
|
| 299 |
+
|
| 300 |
+
except Exception as e:
|
| 301 |
+
print(f" โ ERROR verifying file: {e}")
|
| 302 |
+
return False
|
| 303 |
+
|
| 304 |
+
# ========================================================================
|
| 305 |
+
# SUCCESS!
|
| 306 |
+
# ========================================================================
|
| 307 |
+
print("\n" + "=" * 80)
|
| 308 |
+
print("๐ SUCCESS! FAISS INDEX BUILT AND SAVED")
|
| 309 |
+
print("=" * 80)
|
| 310 |
+
print(f"\n๐ Summary:")
|
| 311 |
+
print(f" Documents: {len(kb_data)}")
|
| 312 |
+
print(f" Vectors: {index.ntotal}")
|
| 313 |
+
print(f" Dimension: {dimension}")
|
| 314 |
+
print(f" File: {OUTPUT_PKL_FILE} ({file_size_mb:.2f} MB)")
|
| 315 |
+
print(f"\n๐ You can now start the backend:")
|
| 316 |
+
print(f" cd backend")
|
| 317 |
+
print(f" uvicorn app.main:app --reload")
|
| 318 |
+
print("=" * 80 + "\n")
|
| 319 |
+
|
| 320 |
+
return True
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
# ============================================================================
|
| 324 |
+
# RUN SCRIPT
|
| 325 |
+
# ============================================================================
|
| 326 |
+
|
| 327 |
+
if __name__ == "__main__":
|
| 328 |
+
success = build_faiss_index()
|
| 329 |
+
|
| 330 |
+
if not success:
|
| 331 |
+
print("\n" + "=" * 80)
|
| 332 |
+
print("โ FAILED TO BUILD FAISS INDEX")
|
| 333 |
+
print("=" * 80)
|
| 334 |
+
print("\nPlease check:")
|
| 335 |
+
print("1. Knowledge base file exists: data/final_knowledge_base.jsonl")
|
| 336 |
+
print("2. Retriever model exists: models/best_retriever_model.pth")
|
| 337 |
+
print("3. You have enough RAM (embeddings need ~1GB for 10k docs)")
|
| 338 |
+
print("=" * 80 + "\n")
|
| 339 |
+
exit(1)
|