SHL / src /embedder.py
Harsh-1132's picture
Clean deployment
d18c374
"""
Embedding Generation Module
This module generates embeddings for assessments and queries using
Hugging Face sentence transformers and creates a FAISS index for fast retrieval.
"""
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer
import faiss
import pickle
import logging
import os
from typing import List, Dict, Tuple
import torch
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class EmbeddingGenerator:
"""Generates embeddings and creates FAISS index"""
def __init__(self, model_name: str = 'sentence-transformers/all-MiniLM-L6-v2'):
self.model_name = model_name
self.model = None
self.faiss_index = None
self.embeddings = None
self.catalog_df = None
self.assessment_mapping = {}
# Set device
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info(f"Using device: {self.device}")
def load_model(self):
"""Load the sentence transformer model"""
try:
logger.info(f"Loading model: {self.model_name}")
self.model = SentenceTransformer(self.model_name)
self.model.to(self.device)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
def load_catalog(self, catalog_path: str = 'data/shl_catalog.csv') -> pd.DataFrame:
"""Load the SHL catalog"""
try:
self.catalog_df = pd.read_csv(catalog_path)
logger.info(f"Loaded catalog with {len(self.catalog_df)} assessments")
return self.catalog_df
except Exception as e:
logger.error(f"Error loading catalog: {e}")
raise
def create_assessment_texts(self) -> List[str]:
"""Create text representations of assessments for embedding"""
texts = []
for idx, row in self.catalog_df.iterrows():
# Combine relevant fields for embedding
text_parts = []
if pd.notna(row['assessment_name']):
text_parts.append(str(row['assessment_name']))
if pd.notna(row['category']):
text_parts.append(f"Category: {row['category']}")
if pd.notna(row['test_type']):
type_full = 'Knowledge/Skill' if row['test_type'] == 'K' else 'Personality/Behavior'
text_parts.append(f"Type: {type_full}")
if pd.notna(row['description']):
text_parts.append(str(row['description']))
text = ' | '.join(text_parts)
texts.append(text)
# Create mapping from index to assessment details
self.assessment_mapping[idx] = {
'assessment_name': row['assessment_name'],
'assessment_url': row['assessment_url'],
'category': row['category'],
'test_type': row['test_type'],
'description': row['description']
}
logger.info(f"Created {len(texts)} assessment texts")
return texts
def generate_embeddings(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Generate embeddings for a list of texts"""
if self.model is None:
self.load_model()
logger.info(f"Generating embeddings for {len(texts)} texts...")
try:
# Generate embeddings in batches
embeddings = self.model.encode(
texts,
batch_size=batch_size,
show_progress_bar=True,
convert_to_numpy=True,
normalize_embeddings=True # L2 normalization for cosine similarity
)
logger.info(f"Generated embeddings with shape: {embeddings.shape}")
return embeddings
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
raise
def create_faiss_index(self, embeddings: np.ndarray) -> faiss.Index:
"""Create FAISS index for fast similarity search"""
try:
logger.info("Creating FAISS index...")
# Dimensions of embeddings
dimension = embeddings.shape[1]
# Create index - using IndexFlatIP for inner product (cosine similarity with normalized vectors)
index = faiss.IndexFlatIP(dimension)
# Add embeddings to index
index.add(embeddings.astype('float32'))
logger.info(f"FAISS index created with {index.ntotal} vectors")
return index
except Exception as e:
logger.error(f"Error creating FAISS index: {e}")
raise
def save_artifacts(self,
index_path: str = 'models/faiss_index.faiss',
embeddings_path: str = 'models/embeddings.npy',
mapping_path: str = 'models/mapping.pkl'):
"""Save FAISS index, embeddings, and mapping"""
try:
# Create models directory if it doesn't exist
os.makedirs(os.path.dirname(index_path), exist_ok=True)
# Save FAISS index
faiss.write_index(self.faiss_index, index_path)
logger.info(f"FAISS index saved to {index_path}")
# Save embeddings
np.save(embeddings_path, self.embeddings)
logger.info(f"Embeddings saved to {embeddings_path}")
# Save mapping
with open(mapping_path, 'wb') as f:
pickle.dump(self.assessment_mapping, f)
logger.info(f"Assessment mapping saved to {mapping_path}")
except Exception as e:
logger.error(f"Error saving artifacts: {e}")
raise
def load_artifacts(self,
index_path: str = 'models/faiss_index.faiss',
embeddings_path: str = 'models/embeddings.npy',
mapping_path: str = 'models/mapping.pkl'):
"""Load FAISS index, embeddings, and mapping"""
try:
# Load FAISS index
self.faiss_index = faiss.read_index(index_path)
logger.info(f"FAISS index loaded from {index_path}")
# Load embeddings
self.embeddings = np.load(embeddings_path)
logger.info(f"Embeddings loaded from {embeddings_path}")
# Load mapping
with open(mapping_path, 'rb') as f:
self.assessment_mapping = pickle.load(f)
logger.info(f"Assessment mapping loaded from {mapping_path}")
return True
except Exception as e:
logger.error(f"Error loading artifacts: {e}")
return False
def build_index(self, catalog_path: str = 'data/shl_catalog.csv'):
"""Main method to build the complete index"""
# Load catalog
self.load_catalog(catalog_path)
# Create assessment texts
assessment_texts = self.create_assessment_texts()
# Generate embeddings
self.embeddings = self.generate_embeddings(assessment_texts)
# Create FAISS index
self.faiss_index = self.create_faiss_index(self.embeddings)
# Save artifacts
self.save_artifacts()
logger.info("Index building complete!")
return self.faiss_index, self.embeddings, self.assessment_mapping
def embed_query(self, query: str) -> np.ndarray:
"""Generate embedding for a single query"""
if self.model is None:
self.load_model()
embedding = self.model.encode(
[query],
convert_to_numpy=True,
normalize_embeddings=True
)
return embedding[0]
def embed_queries(self, queries: List[str], batch_size: int = 32) -> np.ndarray:
"""Generate embeddings for multiple queries"""
return self.generate_embeddings(queries, batch_size)
def main():
"""Main execution function"""
# Initialize embedder
embedder = EmbeddingGenerator()
# Build index
index, embeddings, mapping = embedder.build_index()
print("\n=== Embedding Generation Summary ===")
print(f"Total assessments indexed: {index.ntotal}")
print(f"Embedding dimension: {embeddings.shape[1]}")
print(f"Assessment mapping entries: {len(mapping)}")
# Test with a sample query
test_query = "Looking for a Java developer with strong programming skills"
query_embedding = embedder.embed_query(test_query)
print(f"\nTest query embedding shape: {query_embedding.shape}")
# Search test
k = 5
distances, indices = index.search(query_embedding.reshape(1, -1).astype('float32'), k)
print(f"\nTop {k} matches for test query:")
for i, (idx, dist) in enumerate(zip(indices[0], distances[0])):
assessment = mapping[idx]
print(f"\n{i+1}. {assessment['assessment_name']}")
print(f" Score: {dist:.4f}")
print(f" Type: {assessment['test_type']}")
return embedder
if __name__ == "__main__":
main()