Spaces:
Sleeping
Sleeping
| import hashlib | |
| import chromadb | |
| from sentence_transformers import SentenceTransformer | |
| from src.config import CHROMA_DB_PATH, HF_TOKEN | |
| CHROMA_DB_PATH.mkdir(parents=True, exist_ok=True) | |
| class NewsVectorStore: | |
| _model = None | |
| def __init__(self, collection_name = "news_articles"): | |
| print(f"Initializing ChromaDB at {CHROMA_DB_PATH}...") | |
| self.client = chromadb.PersistentClient(path=str(CHROMA_DB_PATH)) | |
| self.collection = self.client.get_or_create_collection( | |
| name=collection_name, | |
| metadata={"hnsw:space": "cosine"} | |
| ) | |
| if NewsVectorStore._model is None: | |
| print("Loading embedding model (this takes a few seconds)...") | |
| NewsVectorStore._model = SentenceTransformer( | |
| 'all-MiniLM-L6-v2', | |
| token=HF_TOKEN, | |
| ) | |
| self.embedding_model = NewsVectorStore._model | |
| print("ChromaDB initialized and embedding model loaded.") | |
| def store_articles(self, articles_data): | |
| """ | |
| Expects a list of dictionaries from NewsAPI. | |
| """ | |
| if not articles_data: | |
| print("No articles to store.") | |
| return | |
| documents = [] | |
| metadatas = [] | |
| ids = [] | |
| for article in articles_data: | |
| url = article.get('url') | |
| if not url: | |
| continue | |
| title = article.get('title') or "" | |
| desc = article.get('description') or "" | |
| content = article.get("content") or "" | |
| text_to_embed = f"{title}. {desc}. {content}" | |
| if len(text_to_embed.strip()) > 5: | |
| documents.append(text_to_embed) | |
| # Store metadata so we can display it later in the UI | |
| metadatas.append({ | |
| "source": article.get('source', {}).get('name', 'Unknown'), | |
| "url": url, | |
| "publishedAt": article.get('publishedAt', ''), | |
| "title": article.get('title') or "", | |
| "description": article.get('description') or "" | |
| }) | |
| doc_id = hashlib.md5(url.encode()).hexdigest() | |
| ids.append(doc_id) | |
| if not documents: | |
| print("No valid documents to store.") | |
| return | |
| # Generate embeddings | |
| print(f"Generating embeddings for {len(documents)} articles...") | |
| embeddings = self.embedding_model.encode(documents,batch_size=32).tolist() | |
| # Insert into ChromaDB | |
| self.collection.upsert( | |
| embeddings=embeddings, | |
| documents=documents, | |
| metadatas=metadatas, | |
| ids=ids | |
| ) | |
| print(f"Successfully stored {len(documents)} articles in ChromaDB!") | |
| def query(self, topic: str, top_k: int = 10) -> list[dict]: | |
| """ | |
| Embed the query topic and retrieve the top-k most similar articles. | |
| """ | |
| print(f"querying chromaDB for the topic: '{topic}'") | |
| query_embedding = self.embedding_model.encode([topic]).tolist() | |
| results = self.collection.query( | |
| query_embeddings=query_embedding, | |
| n_results=top_k, | |
| include=["documents", "metadatas", "distances"] | |
| ) | |
| articles = [] | |
| for doc, meta, dist in zip( | |
| results["documents"][0], | |
| results["metadatas"][0], | |
| results["distances"][0] | |
| ): | |
| articles.append({ | |
| "text": doc, | |
| "source": meta.get("source", "Unknown"), | |
| "url": meta.get("url", ""), | |
| "publishedAt": meta.get("publishedAt", ""), | |
| "similarity_score": round(1 - dist, 4), | |
| "title": meta.get("title", ""), | |
| "description": meta.get("description", ""), | |
| }) | |
| print(f"Retrieved {len(articles)} articles.") | |
| return articles | |
| if __name__ == "__main__": | |
| db = NewsVectorStore() | |
| print(f"Total documents in collection: {db.collection.count()}") | |
| results = db.collection.get() | |
| urls = [m.get("url") for m in results["metadatas"]] | |
| for url in urls: | |
| print(url) | |