Spaces:
Running
Running
| import numpy as np | |
| from sentence_transformers import SentenceTransformer | |
| from datasets import load_dataset | |
| import pandas as pd | |
| from typing import List, Union | |
| import torch | |
| import torch.nn.functional as F | |
| def get_embeddings(model_name: str, texts: List[str]) -> np.ndarray: | |
| """ | |
| Loads the specified model and generates embeddings for the given texts. | |
| Handles 'nomic' and 'qwen' specific requirements (trust_remote_code). | |
| """ | |
| print(f"Loading embedding model: {model_name}...") | |
| trust_remote_code = False | |
| if "nomic" in model_name or "qwen" in model_name: | |
| trust_remote_code = True | |
| model = SentenceTransformer(model_name, trust_remote_code=trust_remote_code, device='cpu') | |
| # Generate embeddings | |
| # Convert to numpy array if it returns a tensor or list | |
| embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=True) | |
| return embeddings | |
| def mrl_slice(vectors: np.ndarray, dims: int) -> np.ndarray: | |
| """ | |
| Slices the vectors to the specified dimensions AND applies L2 normalization *after* slicing. | |
| This is crucial for Matryoshka Representation Learning (MRL). | |
| """ | |
| # 1. Slice | |
| sliced_vectors = vectors[:, :dims] | |
| # 2. L2 Normalize | |
| # Using sklearn's normalize or manual calculation. | |
| # Manual calculation to avoid extra dependency import inside function if possible, | |
| # but we have numpy. | |
| norms = np.linalg.norm(sliced_vectors, axis=1, keepdims=True) | |
| # Avoid division by zero | |
| norms[norms == 0] = 1e-10 | |
| normalized_sliced_vectors = sliced_vectors / norms | |
| return normalized_sliced_vectors | |
| def load_ms_marco(n_samples: int = 1000) -> List[str]: | |
| """ | |
| Loads the MS MARCO dataset from Hugging Face. | |
| Streams the dataset to save RAM. | |
| Falls back to synthetic data if loading fails. | |
| """ | |
| try: | |
| print(f"Attempting to load {n_samples} samples from MS MARCO...") | |
| dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train", streaming=True) | |
| texts = [] | |
| count = 0 | |
| for row in dataset: | |
| # MS MARCO has 'query' and 'passages'. We'll use passages for the DB. | |
| # The dataset structure can vary, usually 'passages' is a dict. | |
| # Let's check the structure or just use a simpler dataset if this is too complex for a quick demo. | |
| # Actually, let's use the 'query' for simplicity or 'passages' content. | |
| # For a retrieval engine, we usually index documents. | |
| # Let's try to get passage text. | |
| # Note: ms_marco v1.1 structure: | |
| # {'query_id': ..., 'query': ..., 'passages': {'is_selected': [...], 'url': [...], 'passage_text': [...]}} | |
| if 'passages' in row: | |
| # Take the first passage text | |
| passage_list = row['passages']['passage_text'] | |
| if passage_list: | |
| texts.append(passage_list[0]) | |
| count += 1 | |
| elif 'query' in row: | |
| # Fallback to queries if passages are weird, but we want documents. | |
| texts.append(row['query']) | |
| count += 1 | |
| if count >= n_samples: | |
| break | |
| if len(texts) < n_samples: | |
| print("Warning: Could not fetch enough samples from MS MARCO.") | |
| return texts | |
| except Exception as e: | |
| print(f"Error loading MS MARCO: {e}") | |
| print("Falling back to synthetic data.") | |
| return generate_synthetic_data(n_samples) | |
| def generate_synthetic_data(n_samples: int) -> List[str]: | |
| """ | |
| Generates synthetic text data for testing. | |
| """ | |
| base_sentences = [ | |
| "The quick brown fox jumps over the lazy dog.", | |
| "Artificial intelligence is transforming the world.", | |
| "Vector databases enable fast similarity search.", | |
| "Machine learning models require data for training.", | |
| "Python is a popular programming language for data science.", | |
| "Cloud computing provides scalable resources.", | |
| "Cybersecurity is essential for protecting digital assets.", | |
| "Blockchain technology ensures decentralized transactions.", | |
| "Quantum computing will solve complex problems.", | |
| "Sustainable energy is the future of the planet." | |
| ] | |
| data = [] | |
| for i in range(n_samples): | |
| # Create variations | |
| base = base_sentences[i % len(base_sentences)] | |
| data.append(f"{base} Variation {i}") | |
| return data | |