| from typing import Dict, List, Optional
|
| from pathlib import Path
|
| import os
|
|
|
| from llama_index import VectorStoreIndex, StorageContext
|
| from llama_index.vector_stores import ChromaVectorStore
|
| from llama_index.embeddings import HuggingFaceEmbedding
|
| import chromadb
|
|
|
| from indexes.csv_index_builder import EnhancedCSVReader
|
|
|
| class CSVIndexManager:
|
| """Manages creation and retrieval of indexes for CSV files."""
|
|
|
| def __init__(self, embedding_model_name: str = "all-MiniLM-L6-v2"):
|
| self.csv_reader = EnhancedCSVReader()
|
| self.embed_model = HuggingFaceEmbedding(model_name=embedding_model_name)
|
| self.chroma_client = chromadb.Client()
|
| self.indexes = {}
|
|
|
| def create_index(self, file_path: str) -> VectorStoreIndex:
|
| """Create vector index for a CSV file."""
|
|
|
| file_id = Path(file_path).stem
|
|
|
|
|
| documents = self.csv_reader.load_data(file_path)
|
|
|
|
|
| chroma_collection = self.chroma_client.create_collection(file_id)
|
| vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
| storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
|
|
|
|
| index = VectorStoreIndex.from_documents(
|
| documents,
|
| storage_context=storage_context,
|
| embed_model=self.embed_model
|
| )
|
|
|
|
|
| self.indexes[file_id] = {
|
| "index": index,
|
| "metadata": documents[0].metadata if documents else {}
|
| }
|
|
|
| return index
|
|
|
| def index_directory(self, directory_path: str) -> Dict[str, VectorStoreIndex]:
|
| """Index all CSV files in a directory."""
|
| indexed_files = {}
|
|
|
|
|
| csv_files = [f for f in os.listdir(directory_path)
|
| if f.lower().endswith('.csv')]
|
|
|
|
|
| for csv_file in csv_files:
|
| file_path = os.path.join(directory_path, csv_file)
|
| file_id = Path(file_path).stem
|
| index = self.create_index(file_path)
|
| indexed_files[file_id] = index
|
|
|
| return indexed_files
|
|
|
| def find_relevant_csvs(self, query: str, top_k: int = 3) -> List[str]:
|
| """Find most relevant CSV files for a given query."""
|
| if not self.indexes:
|
| return []
|
|
|
|
|
| query_embedding = self.embed_model.get_text_embedding(query)
|
|
|
|
|
| similarities = {}
|
| for file_id, index_info in self.indexes.items():
|
|
|
| metadata = index_info["metadata"]
|
|
|
|
|
| csv_description = f"CSV file {metadata['filename']} with columns: {', '.join(metadata['columns'])}. "
|
| csv_description += f"Contains {metadata['row_count']} rows. "
|
| csv_description += "Sample data: "
|
| for col, samples in metadata['samples'].items():
|
| if samples:
|
| csv_description += f"{col}: {', '.join(str(s) for s in samples[:2])}; "
|
|
|
|
|
| csv_embedding = self.embed_model.get_text_embedding(csv_description)
|
|
|
|
|
| similarity = self._cosine_similarity(query_embedding, csv_embedding)
|
| similarities[file_id] = similarity
|
|
|
|
|
| sorted_files = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
|
| return [file_id for file_id, _ in sorted_files[:top_k]]
|
|
|
| def _cosine_similarity(self, vec1, vec2):
|
| """Calculate cosine similarity between two vectors."""
|
| dot_product = sum(a * b for a, b in zip(vec1, vec2))
|
| norm_a = sum(a * a for a in vec1) ** 0.5
|
| norm_b = sum(b * b for b in vec2) ** 0.5
|
| return dot_product / (norm_a * norm_b) if norm_a * norm_b != 0 else 0
|
|
|