| |
| |
| |
|
|
| import chromadb |
| from chromadb.config import Settings |
| from sentence_transformers import SentenceTransformer |
| import pandas as pd |
| import hashlib |
| from tqdm import tqdm |
| from src.config import CHROMA_DIR, PROCESSED_DATA_DIR, EMBEDDING_MODEL, CHUNK_SIZE, CHUNK_OVERLAP |
| from src.logger import logger |
|
|
|
|
| _client = None |
| _collection = None |
| _embed_model = None |
|
|
| COLLECTION_NAME = "financial_articles" |
|
|
|
|
| def _get_embed_model(): |
| |
| global _embed_model |
| if _embed_model is None: |
| logger.info(f"Loading embedding model: {EMBEDDING_MODEL}") |
| _embed_model = SentenceTransformer(EMBEDDING_MODEL) |
| return _embed_model |
|
|
|
|
| def get_client(): |
| |
| global _client |
| if _client is None: |
| CHROMA_DIR.mkdir(parents=True, exist_ok=True) |
| _client = chromadb.PersistentClient(path=str(CHROMA_DIR)) |
| return _client |
|
|
|
|
| def get_collection(): |
| |
| global _collection |
| if _collection is None: |
| client = get_client() |
| _collection = client.get_or_create_collection( |
| name=COLLECTION_NAME, |
| metadata={"hnsw:space": "cosine"}, |
| ) |
| return _collection |
|
|
|
|
| def chunk_text(text: str, chunk_size: int = CHUNK_SIZE, overlap: int = CHUNK_OVERLAP) -> list[str]: |
| |
| if len(text) <= chunk_size: |
| return [text] |
|
|
| |
| import re |
| sentences = re.split(r'(?<=[.!?])\s+', text) |
|
|
| chunks = [] |
| current_chunk = "" |
|
|
| for sentence in sentences: |
| if len(current_chunk) + len(sentence) > chunk_size and current_chunk: |
| chunks.append(current_chunk.strip()) |
| |
| words = current_chunk.split() |
| overlap_text = " ".join(words[-overlap:]) if len(words) > overlap else current_chunk |
| current_chunk = overlap_text + " " + sentence |
| else: |
| current_chunk += " " + sentence if current_chunk else sentence |
|
|
| if current_chunk.strip(): |
| chunks.append(current_chunk.strip()) |
|
|
| return chunks if chunks else [text] |
|
|
|
|
| def build_index(force_rebuild: bool = False): |
| |
| |
| collection = get_collection() |
|
|
| if collection.count() > 0 and not force_rebuild: |
| logger.info(f"ChromaDB already has {collection.count()} chunks. Skipping index build.") |
| return collection.count() |
|
|
| if force_rebuild: |
| client = get_client() |
| client.delete_collection(COLLECTION_NAME) |
| global _collection |
| _collection = None |
| collection = get_collection() |
|
|
| parquet_path = PROCESSED_DATA_DIR / "financial_news.parquet" |
| if not parquet_path.exists(): |
| raise FileNotFoundError(f"News data not found at {parquet_path}. Run ingestion first.") |
|
|
| df = pd.read_parquet(parquet_path) |
| model = _get_embed_model() |
|
|
| all_chunks = [] |
| all_ids = [] |
| all_metadatas = [] |
|
|
| logger.info(f"Chunking {len(df)} articles...") |
| for idx, row in tqdm(df.iterrows(), total=len(df), desc="Chunking"): |
| chunks = chunk_text(row["text"]) |
| for ci, chunk in enumerate(chunks): |
| |
| chunk_id = hashlib.md5(f"{idx}_{ci}_{chunk[:50]}".encode()).hexdigest() |
| all_chunks.append(chunk) |
| all_ids.append(chunk_id) |
| all_metadatas.append({ |
| "source": str(row.get("source", "unknown")), |
| "date": str(row.get("date", "unknown")), |
| "article_index": int(idx), |
| }) |
|
|
| |
| batch_size = 2000 |
| total_inserted = 0 |
|
|
| logger.info(f"Embedding and indexing {len(all_chunks)} chunks...") |
| for i in tqdm(range(0, len(all_chunks), batch_size), desc="Indexing"): |
| batch_chunks = all_chunks[i:i + batch_size] |
| batch_ids = all_ids[i:i + batch_size] |
| batch_meta = all_metadatas[i:i + batch_size] |
|
|
| embeddings = model.encode(batch_chunks, show_progress_bar=False).tolist() |
|
|
| collection.add( |
| documents=batch_chunks, |
| embeddings=embeddings, |
| ids=batch_ids, |
| metadatas=batch_meta, |
| ) |
| total_inserted += len(batch_chunks) |
|
|
| logger.info(f"ChromaDB: indexed {total_inserted} chunks total") |
| return total_inserted |
|
|
|
|
| def search(query: str, n_results: int = 10) -> list[dict]: |
| |
| collection = get_collection() |
| model = _get_embed_model() |
|
|
| query_embedding = model.encode([query]).tolist() |
|
|
| results = collection.query( |
| query_embeddings=query_embedding, |
| n_results=n_results, |
| include=["documents", "distances", "metadatas"], |
| ) |
|
|
| output = [] |
| for i in range(len(results["documents"][0])): |
| |
| distance = results["distances"][0][i] |
| similarity = 1 - distance |
|
|
| output.append({ |
| "text": results["documents"][0][i], |
| "score": round(similarity, 4), |
| "metadata": results["metadatas"][0][i], |
| }) |
|
|
| return output |
|
|
|
|
| if __name__ == "__main__": |
| count = build_index() |
| print(f"\nIndex built with {count} chunks") |
| print("\nTest search: 'Apple revenue growth'") |
| for r in search("Apple revenue growth", n_results=3): |
| print(f" [{r['score']:.3f}] {r['text'][:100]}...") |
|
|