dashVectorSpace / src /data_pipeline.py
justmotes's picture
Deploy dashVectorspace v1 (Full)
b92d96d
raw
history blame
4.57 kB
import numpy as np
from sentence_transformers import SentenceTransformer
from datasets import load_dataset
import pandas as pd
from typing import List, Union
import torch
import torch.nn.functional as F
def get_embeddings(model_name: str, texts: List[str]) -> np.ndarray:
"""
Loads the specified model and generates embeddings for the given texts.
Handles 'nomic' and 'qwen' specific requirements (trust_remote_code).
"""
print(f"Loading embedding model: {model_name}...")
trust_remote_code = False
if "nomic" in model_name or "qwen" in model_name:
trust_remote_code = True
model = SentenceTransformer(model_name, trust_remote_code=trust_remote_code, device='cpu')
# Generate embeddings
# Convert to numpy array if it returns a tensor or list
embeddings = model.encode(texts, convert_to_numpy=True, show_progress_bar=True)
return embeddings
def mrl_slice(vectors: np.ndarray, dims: int) -> np.ndarray:
"""
Slices the vectors to the specified dimensions AND applies L2 normalization *after* slicing.
This is crucial for Matryoshka Representation Learning (MRL).
"""
# 1. Slice
sliced_vectors = vectors[:, :dims]
# 2. L2 Normalize
# Using sklearn's normalize or manual calculation.
# Manual calculation to avoid extra dependency import inside function if possible,
# but we have numpy.
norms = np.linalg.norm(sliced_vectors, axis=1, keepdims=True)
# Avoid division by zero
norms[norms == 0] = 1e-10
normalized_sliced_vectors = sliced_vectors / norms
return normalized_sliced_vectors
def load_ms_marco(n_samples: int = 1000) -> List[str]:
"""
Loads the MS MARCO dataset from Hugging Face.
Streams the dataset to save RAM.
Falls back to synthetic data if loading fails.
"""
try:
print(f"Attempting to load {n_samples} samples from MS MARCO...")
dataset = load_dataset("microsoft/ms_marco", "v1.1", split="train", streaming=True)
texts = []
count = 0
for row in dataset:
# MS MARCO has 'query' and 'passages'. We'll use passages for the DB.
# The dataset structure can vary, usually 'passages' is a dict.
# Let's check the structure or just use a simpler dataset if this is too complex for a quick demo.
# Actually, let's use the 'query' for simplicity or 'passages' content.
# For a retrieval engine, we usually index documents.
# Let's try to get passage text.
# Note: ms_marco v1.1 structure:
# {'query_id': ..., 'query': ..., 'passages': {'is_selected': [...], 'url': [...], 'passage_text': [...]}}
if 'passages' in row:
# Take the first passage text
passage_list = row['passages']['passage_text']
if passage_list:
texts.append(passage_list[0])
count += 1
elif 'query' in row:
# Fallback to queries if passages are weird, but we want documents.
texts.append(row['query'])
count += 1
if count >= n_samples:
break
if len(texts) < n_samples:
print("Warning: Could not fetch enough samples from MS MARCO.")
return texts
except Exception as e:
print(f"Error loading MS MARCO: {e}")
print("Falling back to synthetic data.")
return generate_synthetic_data(n_samples)
def generate_synthetic_data(n_samples: int) -> List[str]:
"""
Generates synthetic text data for testing.
"""
base_sentences = [
"The quick brown fox jumps over the lazy dog.",
"Artificial intelligence is transforming the world.",
"Vector databases enable fast similarity search.",
"Machine learning models require data for training.",
"Python is a popular programming language for data science.",
"Cloud computing provides scalable resources.",
"Cybersecurity is essential for protecting digital assets.",
"Blockchain technology ensures decentralized transactions.",
"Quantum computing will solve complex problems.",
"Sustainable energy is the future of the planet."
]
data = []
for i in range(n_samples):
# Create variations
base = base_sentences[i % len(base_sentences)]
data.append(f"{base} Variation {i}")
return data