Shubham170793's picture
Update src/embeddings.py
12d1bb1 verified
# ==========================================================
# πŸ“˜ embeddings.py β€” optimized for Hugging Face + FAISS + E5
# ==========================================================
import os
import numpy as np
from sentence_transformers import SentenceTransformer
# ----------------------------
# Hugging Face Cache Bootstrap
# ----------------------------
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
os.environ["HF_MODULES_CACHE"] = CACHE_DIR
print(f"βœ… Using Hugging Face cache at {CACHE_DIR}")
# ----------------------------
# Load Embedding Model (E5 with fallback)
# ----------------------------
try:
_model = SentenceTransformer(
"intfloat/e5-small-v2", # βœ… Trained for retrieval-augmented QA
cache_folder=CACHE_DIR
)
print("βœ… Loaded model: intfloat/e5-small-v2")
except Exception as e:
print(f"⚠️ Model load failed ({e}), falling back to MiniLM.")
_model = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2",
cache_folder=CACHE_DIR
)
print("βœ… Loaded fallback model: all-MiniLM-L6-v2")
# ----------------------------
# Function: Generate Embeddings
# ----------------------------
def generate_embeddings(chunks: list) -> list:
"""
πŸ“Œ Generate normalized embeddings for a list of text chunks.
Args:
chunks (list): List of text chunks.
Returns:
list: List of normalized embedding vectors (Python lists).
Notes:
- Prefixing chunks with 'passage:' improves retrieval accuracy for E5.
- normalize_embeddings=True ensures cosine-similarity consistency.
- Works efficiently even for large PDFs.
"""
if not chunks:
print("⚠️ No chunks provided for embedding generation.")
return []
# Step 1: Prefix each chunk for semantic clarity (per E5 training)
prepared_chunks = [f"passage: {chunk.strip()}" for chunk in chunks]
# Step 2: Encode with normalization for cosine similarity
vectors = _model.encode(
prepared_chunks,
convert_to_numpy=True,
normalize_embeddings=True # βœ… Makes FAISS IndexFlatIP accurate
)
# Step 3: Convert to Python list for FAISS / JSON compatibility
embeddings = vectors.tolist()
print(f"βœ… Generated {len(embeddings)} embeddings.")
return embeddings