File size: 4,972 Bytes
b92d96d
 
 
 
 
 
 
 
bf79e3e
 
 
 
 
51fc709
bf79e3e
 
 
b92d96d
 
 
 
 
bf79e3e
b92d96d
 
 
 
 
 
 
bf79e3e
 
 
 
 
 
 
b92d96d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import numpy as np
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import pandas as pd
from typing import List, Union
import torch
import torch.nn.functional as F

_MODEL_CACHE = {}

def get_model(model_name: str):
    if model_name not in _MODEL_CACHE:
        print(f"Loading embedding model: {model_name}...")
        trust_remote_code = "nomic" in model_name or "qwen" in model_name or "gemma" in model_name
        _MODEL_CACHE[model_name] = SentenceTransformer(model_name, trust_remote_code=trust_remote_code, device='cpu')
    return _MODEL_CACHE[model_name]

def get_embeddings(model_name: str, texts: List[str]) -> np.ndarray:
    """
    Loads the specified model and generates embeddings for the given texts.
    Handles 'nomic' and 'qwen' specific requirements (trust_remote_code).
    """
    model = get_model(model_name)
    
    # Generate embeddings
    # Convert to numpy array if it returns a tensor or list
    embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=True)
    
    return embeddings

def get_embedding(text: str, model_name: str = "nomic-ai/nomic-embed-text-v1.5") -> np.ndarray:
    """
    Generates a single embedding for a query string.
    """
    embeddings = get_embeddings(model_name, [text])
    return embeddings[0]

def mrl_slice(vectors: np.ndarray, dims: int) -> np.ndarray:
    """
    Slices the vectors to the specified dimensions AND applies L2 normalization *after* slicing.
    This is crucial for Matryoshka Representation Learning (MRL).
    """
    # 1. Slice
    sliced_vectors = vectors[:, :dims]
    
    # 2. L2 Normalize
    # Using sklearn's normalize or manual calculation. 
    # Manual calculation to avoid extra dependency import inside function if possible, 
    # but we have numpy.
    norms = np.linalg.norm(sliced_vectors, axis=1, keepdims=True)
    # Avoid division by zero
    norms[norms == 0] = 1e-10
    normalized_sliced_vectors = sliced_vectors / norms
    
    return normalized_sliced_vectors

def load_ms_marco(n_samples: int = 1000) -> List[str]:
    """
    Loads the MS MARCO dataset from Hugging Face.
    Streams the dataset to save RAM.
    Falls back to synthetic data if loading fails.
    """
    try:
        print(f"Attempting to load {n_samples} samples from MS MARCO...")
        dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train", streaming=True)
        
        texts = []
        count = 0
        for row in dataset:
            # MS MARCO has 'query' and 'passages'. We'll use passages for the DB.
            # The dataset structure can vary, usually 'passages' is a dict.
            # Let's check the structure or just use a simpler dataset if this is too complex for a quick demo.
            # Actually, let's use the 'query' for simplicity or 'passages' content.
            # For a retrieval engine, we usually index documents.
            # Let's try to get passage text.
            
            # Note: ms_marco v1.1 structure:
            # {'query_id': ..., 'query': ..., 'passages': {'is_selected': [...], 'url': [...], 'passage_text': [...]}}
            
            if 'passages' in row:
                # Take the first passage text
                passage_list = row['passages']['passage_text']
                if passage_list:
                    texts.append(passage_list[0])
                    count += 1
            elif 'query' in row:
                 # Fallback to queries if passages are weird, but we want documents.
                 texts.append(row['query'])
                 count += 1
            
            if count >= n_samples:
                break
                
        if len(texts) < n_samples:
            print("Warning: Could not fetch enough samples from MS MARCO.")
            
        return texts

    except Exception as e:
        print(f"Error loading MS MARCO: {e}")
        print("Falling back to synthetic data.")
        return generate_synthetic_data(n_samples)

def generate_synthetic_data(n_samples: int) -> List[str]:
    """
    Generates synthetic text data for testing.
    """
    base_sentences = [
        "The quick brown fox jumps over the lazy dog.",
        "Artificial intelligence is transforming the world.",
        "Vector databases enable fast similarity search.",
        "Machine learning models require data for training.",
        "Python is a popular programming language for data science.",
        "Cloud computing provides scalable resources.",
        "Cybersecurity is essential for protecting digital assets.",
        "Blockchain technology ensures decentralized transactions.",
        "Quantum computing will solve complex problems.",
        "Sustainable energy is the future of the planet."
    ]
    
    data = []
    for i in range(n_samples):
        # Create variations
        base = base_sentences[i % len(base_sentences)]
        data.append(f"{base} Variation {i}")
        
    return data