Spaces:
Configuration error
Configuration error
File size: 4,873 Bytes
bec06d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import asyncio
import logging
from typing import List, Dict, Any
from document_loader import DocumentLoader
from embedder import Embedder
from vector_store import VectorStore
from preprocessor import TextPreprocessor
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EmbeddingPipeline:
"""
Main class to manage the entire embedding pipeline:
1. Load documents
2. Preprocess text
3. Create embeddings
4. Store in vector database
"""
def __init__(self):
self.document_loader = DocumentLoader()
self.embedder = Embedder()
self.vector_store = VectorStore()
self.preprocessor = TextPreprocessor()
async def process_directory(self, directory_path: str, chunk_size: int = 512, overlap: int = 50) -> int:
"""
Process all documents in a directory: load, embed, and store.
Args:
directory_path: Path to the directory containing documents
chunk_size: Size of text chunks
overlap: Overlap between chunks
Returns:
Number of documents processed
"""
# Create the collection if it doesn't exist
self.vector_store.create_collection()
# Load documents from the directory
logger.info(f"Loading documents from {directory_path}")
documents = self.document_loader.load_documents_from_directory(
directory_path,
chunk_size=chunk_size,
overlap=overlap
)
logger.info(f"Loaded {len(documents)} documents")
if not documents:
logger.warning("No documents found to process")
return 0
# Embed the documents
logger.info("Creating embeddings...")
embedded_documents = await self.embedder.embed_documents(documents)
# Filter out any documents that failed to embed
valid_documents = [
doc for doc in embedded_documents
if doc.get('embedding') and len(doc['embedding']) > 0
]
logger.info(f"Successfully embedded {len(valid_documents)} documents")
# Add documents to vector store
if valid_documents:
self.vector_store.add_documents(valid_documents)
logger.info(f"Added {len(valid_documents)} documents to vector store")
return len(valid_documents)
def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""
Search for documents similar to the query.
Args:
query: The search query
top_k: Number of results to return
Returns:
List of matching documents with scores
"""
# Create embedding for the query
query_embedding = asyncio.run(self.embedder.create_embedding(query))
# Search in the vector store
results = self.vector_store.search_similar(query_embedding, top_k)
return results
def main():
"""
Example usage of the embedding pipeline.
"""
import os
import argparse
parser = argparse.ArgumentParser(description="Physical AI Textbook Embedding Pipeline")
parser.add_argument("--directory", type=str, required=True,
help="Directory containing documents to process")
parser.add_argument("--chunk-size", type=int, default=512,
help="Size of text chunks")
parser.add_argument("--overlap", type=int, default=50,
help="Overlap between chunks")
parser.add_argument("--search", type=str,
help="Search query to test the vector store")
args = parser.parse_args()
pipeline = EmbeddingPipeline()
if args.search:
# Perform a search
logger.info(f"Searching for: {args.search}")
results = pipeline.search(args.search)
for i, result in enumerate(results):
print(f"\nResult {i+1} (Score: {result['score']:.4f}):")
print(f"Source: {result['source']}")
print(f"Content preview: {result['content'][:200]}...")
else:
# Process documents in the directory
logger.info("Starting embedding pipeline...")
processed_count = asyncio.run(
pipeline.process_directory(
args.directory,
chunk_size=args.chunk_size,
overlap=args.overlap
)
)
logger.info(f"Processed {processed_count} documents")
# Show document count in the collection
doc_count = pipeline.vector_store.get_all_documents_count()
logger.info(f"Total documents in vector store: {doc_count}")
if __name__ == "__main__":
main() |