Rajan Sharma commited on
Commit
7a11028
·
verified ·
1 Parent(s): e9fc2e1

Update rag.py

Browse files
Files changed (1) hide show
  1. rag.py +33 -5
rag.py CHANGED
@@ -1,9 +1,37 @@
 
 
 
 
 
1
  class RAGIndex:
2
  def __init__(self):
3
- self.docs=[]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- def add(self,chunks):
6
- self.docs.extend(chunks)
 
 
 
 
7
 
8
- def retrieve(self,q,k=5):
9
- return [(d,1.0) for d in self.docs[:k]]
 
1
+ from typing import List, Tuple
2
+ import numpy as np
3
+ import cohere
4
+ from settings import COHERE_API_KEY, COHERE_EMBED_MODEL
5
+
6
  class RAGIndex:
7
  def __init__(self):
8
+ self.client = cohere.Client(api_key=COHERE_API_KEY) if COHERE_API_KEY else None
9
+ self.texts: List[str] = []
10
+ self.vecs: np.ndarray | None = None
11
+
12
+ def _embed(self, texts: List[str]) -> np.ndarray:
13
+ if not texts: return np.zeros((0, 384), dtype="float32")
14
+ if not self.client:
15
+ # Fallback: random embeddings (avoid crash; not ideal)
16
+ return np.random.normal(size=(len(texts), 384)).astype("float32")
17
+ resp = self.client.embed(texts=texts, model=COHERE_EMBED_MODEL)
18
+ vecs = np.array(getattr(resp, "embeddings", []) or getattr(resp, "data", []), dtype="float32")
19
+ return vecs
20
+
21
+ def add(self, chunks: List[str]):
22
+ if not chunks: return
23
+ new_vecs = self._embed(chunks)
24
+ if self.vecs is None:
25
+ self.vecs = new_vecs
26
+ self.texts = list(chunks)
27
+ else:
28
+ self.vecs = np.vstack([self.vecs, new_vecs])
29
+ self.texts.extend(chunks)
30
 
31
+ def retrieve(self, query: str, k: int = 6) -> List[Tuple[str, float]]:
32
+ if not self.texts: return []
33
+ qv = self._embed([query])[0]
34
+ sims = (self.vecs @ qv) / (np.linalg.norm(self.vecs, axis=1) * (np.linalg.norm(qv) + 1e-9))
35
+ idx = np.argsort(-sims)[:k]
36
+ return [(self.texts[i], float(sims[i])) for i in idx]
37