Spaces:
Sleeping
Sleeping
Commit ·
6874d8b
1
Parent(s): 61f25c3
phase 1 - data storage in qdrant and retrieval
Browse files- .gitignore +3 -0
- database/README.md +27 -0
- database/ingest.py +115 -0
- database/qdrant_manager.py +137 -0
- database/requirements.txt +15 -0
- database/test_retrieval.py +93 -0
- database/utils.py +117 -0
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
venv/
|
| 3 |
+
__pycache__/
|
database/README.md
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Database Module - Math Agentic RAG
|
| 2 |
+
|
| 3 |
+
This module handles the knowledge base creation and retrieval for the Math Agentic RAG system.
|
| 4 |
+
|
| 5 |
+
## Files Overview
|
| 6 |
+
|
| 7 |
+
### Core Files
|
| 8 |
+
- **`utils.py`** - Utility functions for embedding generation and data processing
|
| 9 |
+
- **`qdrant_manager.py`** - Qdrant vector database client wrapper
|
| 10 |
+
- **`ingest.py`** - Main ingestion script for loading dataset into Qdrant (includes config)
|
| 11 |
+
- **`test_retrieval.py`** - Testing script for validating retrieval functionality (includes config)
|
| 12 |
+
|
| 13 |
+
### Dependencies
|
| 14 |
+
- **`requirements.txt`** - Python package dependencies
|
| 15 |
+
|
| 16 |
+
## Usage
|
| 17 |
+
|
| 18 |
+
1. **Setup Environment Variables**: Ensure `.env` file has Qdrant credentials
|
| 19 |
+
2. **Install Dependencies**: `pip install -r requirements.txt`
|
| 20 |
+
3. **Ingest Data**: `python ingest.py`
|
| 21 |
+
4. **Test Retrieval**: `python test_retrieval.py`
|
| 22 |
+
|
| 23 |
+
## Current Status
|
| 24 |
+
- ✅ Dataset: Nuinamath (5,000 mathematical problems)
|
| 25 |
+
- ✅ Vector DB: Qdrant Cloud
|
| 26 |
+
- ✅ Embedding Model: all-MiniLM-L6-v2 (384 dimensions)
|
| 27 |
+
- ✅ Status: Ready for Phase 2
|
database/ingest.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main ingestion script for loading Nuinamath dataset into Qdrant.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from datasets import load_dataset
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import time
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
# Load environment variables
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
# Configuration settings
|
| 15 |
+
QDRANT_URL = os.getenv("QDRANT_URL")
|
| 16 |
+
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
| 17 |
+
QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "nuinamath")
|
| 18 |
+
DATASET_NAME = "AI-MO/NuminaMath-CoT"
|
| 19 |
+
DATASET_SPLIT = "train"
|
| 20 |
+
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
| 21 |
+
VECTOR_SIZE = 384
|
| 22 |
+
DISTANCE_METRIC = "Cosine"
|
| 23 |
+
BATCH_SIZE = 100
|
| 24 |
+
MAX_SAMPLES = None
|
| 25 |
+
|
| 26 |
+
# Validation
|
| 27 |
+
if not QDRANT_URL or not QDRANT_API_KEY:
|
| 28 |
+
raise ValueError("Please set QDRANT_URL and QDRANT_API_KEY in your .env file")
|
| 29 |
+
|
| 30 |
+
from utils import EmbeddingGenerator, batch_process_dataset
|
| 31 |
+
from qdrant_manager import QdrantManager
|
| 32 |
+
|
| 33 |
+
# Set up logging
|
| 34 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 35 |
+
logger = logging.getLogger(__name__)
|
| 36 |
+
|
| 37 |
+
def main():
|
| 38 |
+
"""Main ingestion pipeline."""
|
| 39 |
+
try:
|
| 40 |
+
# Initialize components
|
| 41 |
+
logger.info("Initializing components...")
|
| 42 |
+
embedding_generator = EmbeddingGenerator(EMBEDDING_MODEL)
|
| 43 |
+
qdrant_manager = QdrantManager(QDRANT_URL, QDRANT_API_KEY)
|
| 44 |
+
|
| 45 |
+
# Load dataset
|
| 46 |
+
logger.info(f"Loading dataset: {DATASET_NAME}")
|
| 47 |
+
if MAX_SAMPLES:
|
| 48 |
+
dataset = load_dataset(DATASET_NAME, split=f"{DATASET_SPLIT}[:{MAX_SAMPLES}]")
|
| 49 |
+
logger.info(f"Loaded {len(dataset)} samples (limited)")
|
| 50 |
+
else:
|
| 51 |
+
dataset = load_dataset(DATASET_NAME, split=DATASET_SPLIT)
|
| 52 |
+
logger.info(f"Loaded full dataset: {len(dataset)} samples")
|
| 53 |
+
|
| 54 |
+
# Create Qdrant collection
|
| 55 |
+
logger.info(f"Creating collection: {QDRANT_COLLECTION}")
|
| 56 |
+
success = qdrant_manager.create_collection(
|
| 57 |
+
collection_name=QDRANT_COLLECTION,
|
| 58 |
+
vector_size=VECTOR_SIZE,
|
| 59 |
+
distance=DISTANCE_METRIC
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
if not success:
|
| 63 |
+
logger.error("Failed to create collection")
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
# Process dataset in batches
|
| 67 |
+
logger.info("Processing dataset in batches...")
|
| 68 |
+
batches = batch_process_dataset(dataset, BATCH_SIZE)
|
| 69 |
+
|
| 70 |
+
total_processed = 0
|
| 71 |
+
total_batches = len(batches)
|
| 72 |
+
|
| 73 |
+
for batch_idx, batch_data in enumerate(tqdm(batches, desc="Processing batches")):
|
| 74 |
+
try:
|
| 75 |
+
# Extract texts for embedding
|
| 76 |
+
texts = [item['text'] for item in batch_data]
|
| 77 |
+
|
| 78 |
+
# Generate embeddings
|
| 79 |
+
logger.info(f"Generating embeddings for batch {batch_idx + 1}/{total_batches}")
|
| 80 |
+
embeddings = embedding_generator.embed_text(texts)
|
| 81 |
+
|
| 82 |
+
# Upsert to Qdrant
|
| 83 |
+
logger.info(f"Uploading batch {batch_idx + 1} to Qdrant...")
|
| 84 |
+
qdrant_manager.upsert_points(
|
| 85 |
+
collection_name=QDRANT_COLLECTION,
|
| 86 |
+
points_data=batch_data,
|
| 87 |
+
embeddings=embeddings
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
total_processed += len(batch_data)
|
| 91 |
+
logger.info(f"Progress: {total_processed}/{len(dataset)} items processed")
|
| 92 |
+
|
| 93 |
+
# Small delay to avoid overwhelming the API
|
| 94 |
+
time.sleep(0.5)
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"Error processing batch {batch_idx + 1}: {e}")
|
| 98 |
+
continue
|
| 99 |
+
|
| 100 |
+
# Final summary
|
| 101 |
+
logger.info("Ingestion completed!")
|
| 102 |
+
logger.info(f"Total items processed: {total_processed}")
|
| 103 |
+
|
| 104 |
+
# Get collection info
|
| 105 |
+
collection_info = qdrant_manager.get_collection_info(QDRANT_COLLECTION)
|
| 106 |
+
if collection_info:
|
| 107 |
+
logger.info(f"Collection status: {collection_info.status}")
|
| 108 |
+
logger.info(f"Vectors count: {collection_info.vectors_count}")
|
| 109 |
+
|
| 110 |
+
except Exception as e:
|
| 111 |
+
logger.error(f"Fatal error in ingestion pipeline: {e}")
|
| 112 |
+
raise
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
main()
|
database/qdrant_manager.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Qdrant client wrapper for vector database operations.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
from typing import List, Dict, Any
|
| 6 |
+
from qdrant_client import QdrantClient
|
| 7 |
+
from qdrant_client.models import Distance, VectorParams, PointStruct
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class QdrantManager:
|
| 13 |
+
"""Manages Qdrant vector database operations."""
|
| 14 |
+
|
| 15 |
+
def __init__(self, url: str, api_key: str):
|
| 16 |
+
"""Initialize Qdrant client."""
|
| 17 |
+
self.client = QdrantClient(url=url, api_key=api_key)
|
| 18 |
+
logger.info(f"Connected to Qdrant at {url}")
|
| 19 |
+
|
| 20 |
+
def create_collection(self, collection_name: str, vector_size: int, distance: str = "Cosine"):
|
| 21 |
+
"""
|
| 22 |
+
Create a new collection in Qdrant.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
collection_name: Name of the collection
|
| 26 |
+
vector_size: Dimension of vectors
|
| 27 |
+
distance: Distance metric (Cosine, Euclidean, Dot)
|
| 28 |
+
"""
|
| 29 |
+
try:
|
| 30 |
+
# Check if collection already exists
|
| 31 |
+
collections = self.client.get_collections().collections
|
| 32 |
+
existing_names = [col.name for col in collections]
|
| 33 |
+
|
| 34 |
+
if collection_name in existing_names:
|
| 35 |
+
logger.info(f"Collection '{collection_name}' already exists")
|
| 36 |
+
return True
|
| 37 |
+
|
| 38 |
+
# Create new collection
|
| 39 |
+
distance_map = {
|
| 40 |
+
"Cosine": Distance.COSINE,
|
| 41 |
+
"Euclidean": Distance.EUCLID,
|
| 42 |
+
"Dot": Distance.DOT
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
self.client.create_collection(
|
| 46 |
+
collection_name=collection_name,
|
| 47 |
+
vectors_config=VectorParams(
|
| 48 |
+
size=vector_size,
|
| 49 |
+
distance=distance_map.get(distance, Distance.COSINE)
|
| 50 |
+
)
|
| 51 |
+
)
|
| 52 |
+
logger.info(f"Created collection '{collection_name}' with vector size {vector_size}")
|
| 53 |
+
return True
|
| 54 |
+
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.error(f"Error creating collection: {e}")
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
def upsert_points(self, collection_name: str, points_data: List[Dict[str, Any]],
|
| 60 |
+
embeddings: List[List[float]], max_retries: int = 3):
|
| 61 |
+
"""
|
| 62 |
+
Upsert points into Qdrant collection with retry logic.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
collection_name: Name of the collection
|
| 66 |
+
points_data: List of point data dictionaries
|
| 67 |
+
embeddings: List of embedding vectors
|
| 68 |
+
max_retries: Maximum number of retry attempts
|
| 69 |
+
"""
|
| 70 |
+
points = []
|
| 71 |
+
for i, (data, embedding) in enumerate(zip(points_data, embeddings)):
|
| 72 |
+
point = PointStruct(
|
| 73 |
+
id=data['id'],
|
| 74 |
+
vector=embedding,
|
| 75 |
+
payload={
|
| 76 |
+
'problem': data['problem'],
|
| 77 |
+
'solution': data['solution'],
|
| 78 |
+
'source': data['source']
|
| 79 |
+
}
|
| 80 |
+
)
|
| 81 |
+
points.append(point)
|
| 82 |
+
|
| 83 |
+
# Retry logic for network issues
|
| 84 |
+
for attempt in range(max_retries):
|
| 85 |
+
try:
|
| 86 |
+
self.client.upsert(
|
| 87 |
+
collection_name=collection_name,
|
| 88 |
+
points=points
|
| 89 |
+
)
|
| 90 |
+
logger.info(f"Successfully upserted {len(points)} points")
|
| 91 |
+
return True
|
| 92 |
+
|
| 93 |
+
except Exception as e:
|
| 94 |
+
logger.warning(f"Attempt {attempt + 1} failed: {e}")
|
| 95 |
+
if attempt < max_retries - 1:
|
| 96 |
+
time.sleep(2 ** attempt) # Exponential backoff
|
| 97 |
+
else:
|
| 98 |
+
logger.error(f"Failed to upsert points after {max_retries} attempts")
|
| 99 |
+
raise e
|
| 100 |
+
|
| 101 |
+
def search_similar(self, collection_name: str, query_vector: List[float],
|
| 102 |
+
limit: int = 3, score_threshold: float = 0.0):
|
| 103 |
+
"""
|
| 104 |
+
Search for similar vectors in the collection.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
collection_name: Name of the collection
|
| 108 |
+
query_vector: Query embedding vector
|
| 109 |
+
limit: Number of results to return
|
| 110 |
+
score_threshold: Minimum similarity score
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
Search results from Qdrant
|
| 114 |
+
"""
|
| 115 |
+
try:
|
| 116 |
+
results = self.client.search(
|
| 117 |
+
collection_name=collection_name,
|
| 118 |
+
query_vector=query_vector,
|
| 119 |
+
limit=limit,
|
| 120 |
+
score_threshold=score_threshold
|
| 121 |
+
)
|
| 122 |
+
logger.info(f"Found {len(results)} similar results")
|
| 123 |
+
return results
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.error(f"Error searching collection: {e}")
|
| 127 |
+
return []
|
| 128 |
+
|
| 129 |
+
def get_collection_info(self, collection_name: str):
|
| 130 |
+
"""Get information about a collection."""
|
| 131 |
+
try:
|
| 132 |
+
info = self.client.get_collection(collection_name)
|
| 133 |
+
logger.info(f"Collection info: {info}")
|
| 134 |
+
return info
|
| 135 |
+
except Exception as e:
|
| 136 |
+
logger.error(f"Error getting collection info: {e}")
|
| 137 |
+
return None
|
database/requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Dataset loading and processing
|
| 2 |
+
datasets==2.18.0
|
| 3 |
+
pandas
|
| 4 |
+
|
| 5 |
+
# For embedding generation
|
| 6 |
+
sentence-transformers==2.2.2
|
| 7 |
+
|
| 8 |
+
# For Qdrant client (VectorDB)
|
| 9 |
+
qdrant-client==1.8.0
|
| 10 |
+
|
| 11 |
+
# For environment variables
|
| 12 |
+
python-dotenv
|
| 13 |
+
|
| 14 |
+
# For progress tracking
|
| 15 |
+
tqdm
|
database/test_retrieval.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test script for retrieving similar math problems from Qdrant.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
import os
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
|
| 8 |
+
# Load environment variables
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
# Configuration settings
|
| 12 |
+
QDRANT_URL = os.getenv("QDRANT_URL")
|
| 13 |
+
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
| 14 |
+
QDRANT_COLLECTION = os.getenv("QDRANT_COLLECTION", "nuinamath")
|
| 15 |
+
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
| 16 |
+
|
| 17 |
+
from utils import EmbeddingGenerator, format_retrieval_results
|
| 18 |
+
from qdrant_manager import QdrantManager
|
| 19 |
+
|
| 20 |
+
# Set up logging
|
| 21 |
+
logging.basicConfig(level=logging.INFO)
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
def test_retrieval():
|
| 25 |
+
"""Test the retrieval system with sample math questions."""
|
| 26 |
+
|
| 27 |
+
# Sample test questions
|
| 28 |
+
test_questions = [
|
| 29 |
+
"What is the value of x in 3x + 5 = 20?",
|
| 30 |
+
"How do you find the area of a triangle given 3 sides?",
|
| 31 |
+
"Solve for y: 2y - 7 = 15",
|
| 32 |
+
"What is the derivative of x^2 + 3x?",
|
| 33 |
+
"Find the arithmetic sequence common difference"
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
# Initialize components
|
| 38 |
+
logger.info("Initializing retrieval system...")
|
| 39 |
+
embedding_generator = EmbeddingGenerator(EMBEDDING_MODEL)
|
| 40 |
+
qdrant_manager = QdrantManager(QDRANT_URL, QDRANT_API_KEY)
|
| 41 |
+
|
| 42 |
+
# Test each question
|
| 43 |
+
for i, question in enumerate(test_questions, 1):
|
| 44 |
+
print(f"\n{'='*60}")
|
| 45 |
+
print(f"TEST QUERY {i}: {question}")
|
| 46 |
+
print('='*60)
|
| 47 |
+
|
| 48 |
+
# Generate embedding for the question
|
| 49 |
+
query_embedding = embedding_generator.embed_single_text(question)
|
| 50 |
+
|
| 51 |
+
# Search for similar problems
|
| 52 |
+
results = qdrant_manager.search_similar(
|
| 53 |
+
collection_name=QDRANT_COLLECTION,
|
| 54 |
+
query_vector=query_embedding,
|
| 55 |
+
limit=3,
|
| 56 |
+
score_threshold=0.1
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Format and display results
|
| 60 |
+
formatted_results = format_retrieval_results(results)
|
| 61 |
+
print(formatted_results)
|
| 62 |
+
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.error(f"Error in retrieval test: {e}")
|
| 65 |
+
|
| 66 |
+
def test_collection_status():
|
| 67 |
+
"""Check the status of the Qdrant collection."""
|
| 68 |
+
try:
|
| 69 |
+
qdrant_manager = QdrantManager(QDRANT_URL, QDRANT_API_KEY)
|
| 70 |
+
|
| 71 |
+
print(f"\n{'='*40}")
|
| 72 |
+
print("COLLECTION STATUS")
|
| 73 |
+
print('='*40)
|
| 74 |
+
|
| 75 |
+
info = qdrant_manager.get_collection_info(QDRANT_COLLECTION)
|
| 76 |
+
if info:
|
| 77 |
+
print(f"Collection Name: {QDRANT_COLLECTION}")
|
| 78 |
+
print(f"Status: {info.status}")
|
| 79 |
+
print(f"Vectors Count: {info.vectors_count}")
|
| 80 |
+
print(f"Vector Size: {info.config.params.vectors.size}")
|
| 81 |
+
print(f"Distance Metric: {info.config.params.vectors.distance}")
|
| 82 |
+
else:
|
| 83 |
+
print("Collection not found or error occurred")
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.error(f"Error checking collection status: {e}")
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
print("Testing Qdrant Collection Status...")
|
| 90 |
+
test_collection_status()
|
| 91 |
+
|
| 92 |
+
print("\n\nTesting Retrieval System...")
|
| 93 |
+
test_retrieval()
|
database/utils.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Utility functions for data processing and embedding generation.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
from typing import List, Dict, Any
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
from datasets import Dataset
|
| 8 |
+
import uuid
|
| 9 |
+
|
| 10 |
+
# Set up logging
|
| 11 |
+
logging.basicConfig(level=logging.INFO)
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
class EmbeddingGenerator:
|
| 15 |
+
"""Handles text embedding generation using sentence transformers."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
| 18 |
+
"""Initialize the embedding model."""
|
| 19 |
+
logger.info(f"Loading embedding model: {model_name}")
|
| 20 |
+
self.model = SentenceTransformer(model_name)
|
| 21 |
+
self.model_name = model_name
|
| 22 |
+
|
| 23 |
+
def embed_text(self, texts: List[str]) -> List[List[float]]:
|
| 24 |
+
"""Generate embeddings for a list of texts."""
|
| 25 |
+
logger.info(f"Generating embeddings for {len(texts)} texts")
|
| 26 |
+
embeddings = self.model.encode(texts, show_progress_bar=True)
|
| 27 |
+
return embeddings.tolist()
|
| 28 |
+
|
| 29 |
+
def embed_single_text(self, text: str) -> List[float]:
|
| 30 |
+
"""Generate embedding for a single text."""
|
| 31 |
+
embedding = self.model.encode([text])
|
| 32 |
+
return embedding[0].tolist()
|
| 33 |
+
|
| 34 |
+
def preprocess_dataset_entry(entry: Dict[str, Any]) -> Dict[str, Any]:
|
| 35 |
+
"""
|
| 36 |
+
Preprocess a single dataset entry to create combined text for embedding.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
entry: Dictionary containing 'problem' and 'solution' keys
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
Processed entry with 'text' field for embedding
|
| 43 |
+
"""
|
| 44 |
+
problem = entry.get('problem', '')
|
| 45 |
+
solution = entry.get('solution', '')
|
| 46 |
+
|
| 47 |
+
# Create combined text for embedding
|
| 48 |
+
combined_text = f"Question: {problem}\nAnswer: {solution}"
|
| 49 |
+
|
| 50 |
+
return {
|
| 51 |
+
'id': str(uuid.uuid4()),
|
| 52 |
+
'text': combined_text,
|
| 53 |
+
'problem': problem,
|
| 54 |
+
'solution': solution,
|
| 55 |
+
'source': entry.get('source', 'unknown')
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
def batch_process_dataset(dataset: Dataset, batch_size: int = 100) -> List[List[Dict[str, Any]]]:
|
| 59 |
+
"""
|
| 60 |
+
Process dataset in batches for memory efficiency.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
dataset: HuggingFace dataset
|
| 64 |
+
batch_size: Number of items per batch
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
List of batches, each containing processed entries
|
| 68 |
+
"""
|
| 69 |
+
batches = []
|
| 70 |
+
total_items = len(dataset)
|
| 71 |
+
|
| 72 |
+
logger.info(f"Processing {total_items} items in batches of {batch_size}")
|
| 73 |
+
|
| 74 |
+
for i in range(0, total_items, batch_size):
|
| 75 |
+
batch_end = min(i + batch_size, total_items)
|
| 76 |
+
batch_data = dataset[i:batch_end]
|
| 77 |
+
|
| 78 |
+
# Process each item in the batch
|
| 79 |
+
processed_batch = []
|
| 80 |
+
for j in range(len(batch_data['problem'])):
|
| 81 |
+
entry = {
|
| 82 |
+
'problem': batch_data['problem'][j],
|
| 83 |
+
'solution': batch_data['solution'][j],
|
| 84 |
+
'source': batch_data['source'][j]
|
| 85 |
+
}
|
| 86 |
+
processed_entry = preprocess_dataset_entry(entry)
|
| 87 |
+
processed_batch.append(processed_entry)
|
| 88 |
+
|
| 89 |
+
batches.append(processed_batch)
|
| 90 |
+
logger.info(f"Processed batch {len(batches)}/{(total_items + batch_size - 1) // batch_size}")
|
| 91 |
+
|
| 92 |
+
return batches
|
| 93 |
+
|
| 94 |
+
def format_retrieval_results(results: List[Dict]) -> str:
|
| 95 |
+
"""
|
| 96 |
+
Format retrieval results for display.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
results: List of search results from Qdrant
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
Formatted string for display
|
| 103 |
+
"""
|
| 104 |
+
if not results:
|
| 105 |
+
return "No results found."
|
| 106 |
+
|
| 107 |
+
output = []
|
| 108 |
+
for i, result in enumerate(results, 1):
|
| 109 |
+
payload = result.payload
|
| 110 |
+
score = result.score
|
| 111 |
+
|
| 112 |
+
output.append(f"\n--- Result {i} (Score: {score:.4f}) ---")
|
| 113 |
+
output.append(f"Question: {payload['problem']}")
|
| 114 |
+
output.append(f"Answer: {payload['solution'][:200]}...") # Truncate long answers
|
| 115 |
+
output.append("-" * 50)
|
| 116 |
+
|
| 117 |
+
return "\n".join(output)
|