import numpy as np from sentence_transformers import SentenceTransformer from datasets import load_dataset import pandas as pd from typing import List, Union import torch import torch.nn.functional as F _MODEL_CACHE = {} def get_model(model_name: str): if model_name not in _MODEL_CACHE: print(f"Loading embedding model: {model_name}...") trust_remote_code = "nomic" in model_name or "qwen" in model_name or "gemma" in model_name _MODEL_CACHE[model_name] = SentenceTransformer(model_name, trust_remote_code=trust_remote_code, device='cpu') return _MODEL_CACHE[model_name] def get_embeddings(model_name: str, texts: List[str]) -> np.ndarray: """ Loads the specified model and generates embeddings for the given texts. Handles 'nomic' and 'qwen' specific requirements (trust_remote_code). """ model = get_model(model_name) # Generate embeddings # Convert to numpy array if it returns a tensor or list embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=True) return embeddings def get_embedding(text: str, model_name: str = "nomic-ai/nomic-embed-text-v1.5") -> np.ndarray: """ Generates a single embedding for a query string. """ embeddings = get_embeddings(model_name, [text]) return embeddings[0] def mrl_slice(vectors: np.ndarray, dims: int) -> np.ndarray: """ Slices the vectors to the specified dimensions AND applies L2 normalization *after* slicing. This is crucial for Matryoshka Representation Learning (MRL). """ # 1. Slice sliced_vectors = vectors[:, :dims] # 2. L2 Normalize # Using sklearn's normalize or manual calculation. # Manual calculation to avoid extra dependency import inside function if possible, # but we have numpy. norms = np.linalg.norm(sliced_vectors, axis=1, keepdims=True) # Avoid division by zero norms[norms == 0] = 1e-10 normalized_sliced_vectors = sliced_vectors / norms return normalized_sliced_vectors def load_ms_marco(n_samples: int = 1000) -> List[str]: """ Loads the MS MARCO dataset from Hugging Face. Streams the dataset to save RAM. Falls back to synthetic data if loading fails. """ try: print(f"Attempting to load {n_samples} samples from MS MARCO...") dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train", streaming=True) texts = [] count = 0 for row in dataset: # MS MARCO has 'query' and 'passages'. We'll use passages for the DB. # The dataset structure can vary, usually 'passages' is a dict. # Let's check the structure or just use a simpler dataset if this is too complex for a quick demo. # Actually, let's use the 'query' for simplicity or 'passages' content. # For a retrieval engine, we usually index documents. # Let's try to get passage text. # Note: ms_marco v1.1 structure: # {'query_id': ..., 'query': ..., 'passages': {'is_selected': [...], 'url': [...], 'passage_text': [...]}} if 'passages' in row: # Take the first passage text passage_list = row['passages']['passage_text'] if passage_list: texts.append(passage_list[0]) count += 1 elif 'query' in row: # Fallback to queries if passages are weird, but we want documents. texts.append(row['query']) count += 1 if count >= n_samples: break if len(texts) < n_samples: print("Warning: Could not fetch enough samples from MS MARCO.") return texts except Exception as e: print(f"Error loading MS MARCO: {e}") print("Falling back to synthetic data.") return generate_synthetic_data(n_samples) def generate_synthetic_data(n_samples: int) -> List[str]: """ Generates synthetic text data for testing. """ base_sentences = [ "The quick brown fox jumps over the lazy dog.", "Artificial intelligence is transforming the world.", "Vector databases enable fast similarity search.", "Machine learning models require data for training.", "Python is a popular programming language for data science.", "Cloud computing provides scalable resources.", "Cybersecurity is essential for protecting digital assets.", "Blockchain technology ensures decentralized transactions.", "Quantum computing will solve complex problems.", "Sustainable energy is the future of the planet." ] data = [] for i in range(n_samples): # Create variations base = base_sentences[i % len(base_sentences)] data.append(f"{base} Variation {i}") return data