File size: 4,170 Bytes
208266a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import hashlib 
import chromadb
from sentence_transformers import SentenceTransformer
from src.config import CHROMA_DB_PATH, HF_TOKEN


CHROMA_DB_PATH.mkdir(parents=True, exist_ok=True)


class NewsVectorStore:
    _model = None
    
    def __init__(self, collection_name = "news_articles"):
        print(f"Initializing ChromaDB at {CHROMA_DB_PATH}...")
        self.client = chromadb.PersistentClient(path=str(CHROMA_DB_PATH))
        self.collection = self.client.get_or_create_collection(
            name=collection_name,
            metadata={"hnsw:space": "cosine"}
        )

        if NewsVectorStore._model is None:
            print("Loading embedding model (this takes a few seconds)...")
            NewsVectorStore._model = SentenceTransformer(
                'all-MiniLM-L6-v2',
                token=HF_TOKEN,
            )

        self.embedding_model = NewsVectorStore._model
        print("ChromaDB initialized and embedding model loaded.")

    def store_articles(self, articles_data):
        """
        Expects a list of dictionaries from NewsAPI.
        """
        if not articles_data:
            print("No articles to store.")
            return
        documents = []
        metadatas = []
        ids = []

        for article in articles_data:
            url = article.get('url')
            if not url:
                continue

            title = article.get('title') or ""
            desc = article.get('description') or ""
            content = article.get("content") or ""
            text_to_embed = f"{title}. {desc}. {content}"

            if len(text_to_embed.strip()) > 5:
                documents.append(text_to_embed)
                
                # Store metadata so we can display it later in the UI
                metadatas.append({
                    "source": article.get('source', {}).get('name', 'Unknown'),
                    "url": url,
                    "publishedAt": article.get('publishedAt', ''),
                    "title": article.get('title') or "",
                    "description": article.get('description') or ""
                })
                doc_id = hashlib.md5(url.encode()).hexdigest()
                ids.append(doc_id)

        if not documents:
            print("No valid documents to store.")
            return

        # Generate embeddings
        print(f"Generating embeddings for {len(documents)} articles...")
        embeddings = self.embedding_model.encode(documents,batch_size=32).tolist()

        # Insert into ChromaDB
        self.collection.upsert(
            embeddings=embeddings,
            documents=documents,
            metadatas=metadatas,
            ids=ids
        )
        print(f"Successfully stored {len(documents)} articles in ChromaDB!")
    
    def query(self, topic: str, top_k: int = 10) -> list[dict]:
        """
        Embed the query topic and retrieve the top-k most similar articles.
        """

        print(f"querying chromaDB for the topic: '{topic}'")
        query_embedding = self.embedding_model.encode([topic]).tolist()
        results = self.collection.query(
            query_embeddings=query_embedding,
            n_results=top_k,
            include=["documents", "metadatas", "distances"]
        )
        articles = []
        for doc, meta, dist in zip(
            results["documents"][0],
            results["metadatas"][0],
            results["distances"][0]
        ):
            articles.append({
                "text": doc,
                "source": meta.get("source", "Unknown"),
                "url": meta.get("url", ""),
                "publishedAt": meta.get("publishedAt", ""),
                "similarity_score": round(1 - dist, 4),
                "title": meta.get("title", ""),
                "description": meta.get("description", ""),
            })

        print(f"Retrieved {len(articles)} articles.")
        return articles

if __name__ == "__main__":
    db = NewsVectorStore()
    print(f"Total documents in collection: {db.collection.count()}")
    results = db.collection.get()
    urls = [m.get("url") for m in results["metadatas"]]
    for url in urls:
        print(url)