import numpy as np from sentence_transformers import SentenceTransformer from datasets import load_dataset import pandas as pd from typing import List, Union import torch import torch.nn.functional as F def get_embeddings(model_name: str, texts: List[str]) -> np.ndarray: """ Loads the specified model and generates embeddings for the given texts. Handles 'nomic' and 'qwen' specific requirements (trust_remote_code). """ print(f"Loading embedding model: {model_name}...") trust_remote_code = False if "nomic" in model_name or "qwen" in model_name: trust_remote_code = True model = SentenceTransformer(model_name, trust_remote_code=trust_remote_code, device='cpu') # Generate embeddings # Convert to numpy array if it returns a tensor or list embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=True) return embeddings def mrl_slice(vectors: np.ndarray, dims: int) -> np.ndarray: """ Slices the vectors to the specified dimensions AND applies L2 normalization *after* slicing. This is crucial for Matryoshka Representation Learning (MRL). """ # 1. Slice sliced_vectors = vectors[:, :dims] # 2. L2 Normalize # Using sklearn's normalize or manual calculation. # Manual calculation to avoid extra dependency import inside function if possible, # but we have numpy. norms = np.linalg.norm(sliced_vectors, axis=1, keepdims=True) # Avoid division by zero norms[norms == 0] = 1e-10 normalized_sliced_vectors = sliced_vectors / norms return normalized_sliced_vectors def load_ms_marco(n_samples: int = 1000) -> List[str]: """ Loads the MS MARCO dataset from Hugging Face. Streams the dataset to save RAM. Falls back to synthetic data if loading fails. """ try: print(f"Attempting to load {n_samples} samples from MS MARCO...") dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train", streaming=True) texts = [] count = 0 for row in dataset: # MS MARCO has 'query' and 'passages'. We'll use passages for the DB. # The dataset structure can vary, usually 'passages' is a dict. # Let's check the structure or just use a simpler dataset if this is too complex for a quick demo. # Actually, let's use the 'query' for simplicity or 'passages' content. # For a retrieval engine, we usually index documents. # Let's try to get passage text. # Note: ms_marco v1.1 structure: # {'query_id': ..., 'query': ..., 'passages': {'is_selected': [...], 'url': [...], 'passage_text': [...]}} if 'passages' in row: # Take the first passage text passage_list = row['passages']['passage_text'] if passage_list: texts.append(passage_list[0]) count += 1 elif 'query' in row: # Fallback to queries if passages are weird, but we want documents. texts.append(row['query']) count += 1 if count >= n_samples: break if len(texts) < n_samples: print("Warning: Could not fetch enough samples from MS MARCO.") return texts except Exception as e: print(f"Error loading MS MARCO: {e}") print("Falling back to synthetic data.") return generate_synthetic_data(n_samples) def generate_synthetic_data(n_samples: int) -> List[str]: """ Generates synthetic text data for testing. """ base_sentences = [ "The quick brown fox jumps over the lazy dog.", "Artificial intelligence is transforming the world.", "Vector databases enable fast similarity search.", "Machine learning models require data for training.", "Python is a popular programming language for data science.", "Cloud computing provides scalable resources.", "Cybersecurity is essential for protecting digital assets.", "Blockchain technology ensures decentralized transactions.", "Quantum computing will solve complex problems.", "Sustainable energy is the future of the planet." ] data = [] for i in range(n_samples): # Create variations base = base_sentences[i % len(base_sentences)] data.append(f"{base} Variation {i}") return data