Spaces:
Runtime error
Runtime error
| """ | |
| Embedding Generation Module | |
| This module generates embeddings for assessments and queries using | |
| Hugging Face sentence transformers and creates a FAISS index for fast retrieval. | |
| """ | |
| import numpy as np | |
| import pandas as pd | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import pickle | |
| import logging | |
| import os | |
| from typing import List, Dict, Tuple | |
| import torch | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class EmbeddingGenerator: | |
| """Generates embeddings and creates FAISS index""" | |
| def __init__(self, model_name: str = 'sentence-transformers/all-MiniLM-L6-v2'): | |
| self.model_name = model_name | |
| self.model = None | |
| self.faiss_index = None | |
| self.embeddings = None | |
| self.catalog_df = None | |
| self.assessment_mapping = {} | |
| # Set device | |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| logger.info(f"Using device: {self.device}") | |
| def load_model(self): | |
| """Load the sentence transformer model""" | |
| try: | |
| logger.info(f"Loading model: {self.model_name}") | |
| self.model = SentenceTransformer(self.model_name) | |
| self.model.to(self.device) | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {e}") | |
| raise | |
| def load_catalog(self, catalog_path: str = 'data/shl_catalog.csv') -> pd.DataFrame: | |
| """Load the SHL catalog""" | |
| try: | |
| self.catalog_df = pd.read_csv(catalog_path) | |
| logger.info(f"Loaded catalog with {len(self.catalog_df)} assessments") | |
| return self.catalog_df | |
| except Exception as e: | |
| logger.error(f"Error loading catalog: {e}") | |
| raise | |
| def create_assessment_texts(self) -> List[str]: | |
| """Create text representations of assessments for embedding""" | |
| texts = [] | |
| for idx, row in self.catalog_df.iterrows(): | |
| # Combine relevant fields for embedding | |
| text_parts = [] | |
| if pd.notna(row['assessment_name']): | |
| text_parts.append(str(row['assessment_name'])) | |
| if pd.notna(row['category']): | |
| text_parts.append(f"Category: {row['category']}") | |
| if pd.notna(row['test_type']): | |
| type_full = 'Knowledge/Skill' if row['test_type'] == 'K' else 'Personality/Behavior' | |
| text_parts.append(f"Type: {type_full}") | |
| if pd.notna(row['description']): | |
| text_parts.append(str(row['description'])) | |
| text = ' | '.join(text_parts) | |
| texts.append(text) | |
| # Create mapping from index to assessment details | |
| self.assessment_mapping[idx] = { | |
| 'assessment_name': row['assessment_name'], | |
| 'assessment_url': row['assessment_url'], | |
| 'category': row['category'], | |
| 'test_type': row['test_type'], | |
| 'description': row['description'] | |
| } | |
| logger.info(f"Created {len(texts)} assessment texts") | |
| return texts | |
| def generate_embeddings(self, texts: List[str], batch_size: int = 32) -> np.ndarray: | |
| """Generate embeddings for a list of texts""" | |
| if self.model is None: | |
| self.load_model() | |
| logger.info(f"Generating embeddings for {len(texts)} texts...") | |
| try: | |
| # Generate embeddings in batches | |
| embeddings = self.model.encode( | |
| texts, | |
| batch_size=batch_size, | |
| show_progress_bar=True, | |
| convert_to_numpy=True, | |
| normalize_embeddings=True # L2 normalization for cosine similarity | |
| ) | |
| logger.info(f"Generated embeddings with shape: {embeddings.shape}") | |
| return embeddings | |
| except Exception as e: | |
| logger.error(f"Error generating embeddings: {e}") | |
| raise | |
| def create_faiss_index(self, embeddings: np.ndarray) -> faiss.Index: | |
| """Create FAISS index for fast similarity search""" | |
| try: | |
| logger.info("Creating FAISS index...") | |
| # Dimensions of embeddings | |
| dimension = embeddings.shape[1] | |
| # Create index - using IndexFlatIP for inner product (cosine similarity with normalized vectors) | |
| index = faiss.IndexFlatIP(dimension) | |
| # Add embeddings to index | |
| index.add(embeddings.astype('float32')) | |
| logger.info(f"FAISS index created with {index.ntotal} vectors") | |
| return index | |
| except Exception as e: | |
| logger.error(f"Error creating FAISS index: {e}") | |
| raise | |
| def save_artifacts(self, | |
| index_path: str = 'models/faiss_index.faiss', | |
| embeddings_path: str = 'models/embeddings.npy', | |
| mapping_path: str = 'models/mapping.pkl'): | |
| """Save FAISS index, embeddings, and mapping""" | |
| try: | |
| # Create models directory if it doesn't exist | |
| os.makedirs(os.path.dirname(index_path), exist_ok=True) | |
| # Save FAISS index | |
| faiss.write_index(self.faiss_index, index_path) | |
| logger.info(f"FAISS index saved to {index_path}") | |
| # Save embeddings | |
| np.save(embeddings_path, self.embeddings) | |
| logger.info(f"Embeddings saved to {embeddings_path}") | |
| # Save mapping | |
| with open(mapping_path, 'wb') as f: | |
| pickle.dump(self.assessment_mapping, f) | |
| logger.info(f"Assessment mapping saved to {mapping_path}") | |
| except Exception as e: | |
| logger.error(f"Error saving artifacts: {e}") | |
| raise | |
| def load_artifacts(self, | |
| index_path: str = 'models/faiss_index.faiss', | |
| embeddings_path: str = 'models/embeddings.npy', | |
| mapping_path: str = 'models/mapping.pkl'): | |
| """Load FAISS index, embeddings, and mapping""" | |
| try: | |
| # Load FAISS index | |
| self.faiss_index = faiss.read_index(index_path) | |
| logger.info(f"FAISS index loaded from {index_path}") | |
| # Load embeddings | |
| self.embeddings = np.load(embeddings_path) | |
| logger.info(f"Embeddings loaded from {embeddings_path}") | |
| # Load mapping | |
| with open(mapping_path, 'rb') as f: | |
| self.assessment_mapping = pickle.load(f) | |
| logger.info(f"Assessment mapping loaded from {mapping_path}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error loading artifacts: {e}") | |
| return False | |
| def build_index(self, catalog_path: str = 'data/shl_catalog.csv'): | |
| """Main method to build the complete index""" | |
| # Load catalog | |
| self.load_catalog(catalog_path) | |
| # Create assessment texts | |
| assessment_texts = self.create_assessment_texts() | |
| # Generate embeddings | |
| self.embeddings = self.generate_embeddings(assessment_texts) | |
| # Create FAISS index | |
| self.faiss_index = self.create_faiss_index(self.embeddings) | |
| # Save artifacts | |
| self.save_artifacts() | |
| logger.info("Index building complete!") | |
| return self.faiss_index, self.embeddings, self.assessment_mapping | |
| def embed_query(self, query: str) -> np.ndarray: | |
| """Generate embedding for a single query""" | |
| if self.model is None: | |
| self.load_model() | |
| embedding = self.model.encode( | |
| [query], | |
| convert_to_numpy=True, | |
| normalize_embeddings=True | |
| ) | |
| return embedding[0] | |
| def embed_queries(self, queries: List[str], batch_size: int = 32) -> np.ndarray: | |
| """Generate embeddings for multiple queries""" | |
| return self.generate_embeddings(queries, batch_size) | |
| def main(): | |
| """Main execution function""" | |
| # Initialize embedder | |
| embedder = EmbeddingGenerator() | |
| # Build index | |
| index, embeddings, mapping = embedder.build_index() | |
| print("\n=== Embedding Generation Summary ===") | |
| print(f"Total assessments indexed: {index.ntotal}") | |
| print(f"Embedding dimension: {embeddings.shape[1]}") | |
| print(f"Assessment mapping entries: {len(mapping)}") | |
| # Test with a sample query | |
| test_query = "Looking for a Java developer with strong programming skills" | |
| query_embedding = embedder.embed_query(test_query) | |
| print(f"\nTest query embedding shape: {query_embedding.shape}") | |
| # Search test | |
| k = 5 | |
| distances, indices = index.search(query_embedding.reshape(1, -1).astype('float32'), k) | |
| print(f"\nTop {k} matches for test query:") | |
| for i, (idx, dist) in enumerate(zip(indices[0], distances[0])): | |
| assessment = mapping[idx] | |
| print(f"\n{i+1}. {assessment['assessment_name']}") | |
| print(f" Score: {dist:.4f}") | |
| print(f" Type: {assessment['test_type']}") | |
| return embedder | |
| if __name__ == "__main__": | |
| main() | |