Spaces:
Sleeping
Sleeping
Rajan Sharma
commited on
Update rag.py
Browse files
rag.py
CHANGED
|
@@ -1,9 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
class RAGIndex:
|
| 2 |
def __init__(self):
|
| 3 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
-
def
|
| 6 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
-
def retrieve(self,q,k=5):
|
| 9 |
-
return [(d,1.0) for d in self.docs[:k]]
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cohere
|
| 4 |
+
from settings import COHERE_API_KEY, COHERE_EMBED_MODEL
|
| 5 |
+
|
| 6 |
class RAGIndex:
|
| 7 |
def __init__(self):
|
| 8 |
+
self.client = cohere.Client(api_key=COHERE_API_KEY) if COHERE_API_KEY else None
|
| 9 |
+
self.texts: List[str] = []
|
| 10 |
+
self.vecs: np.ndarray | None = None
|
| 11 |
+
|
| 12 |
+
def _embed(self, texts: List[str]) -> np.ndarray:
|
| 13 |
+
if not texts: return np.zeros((0, 384), dtype="float32")
|
| 14 |
+
if not self.client:
|
| 15 |
+
# Fallback: random embeddings (avoid crash; not ideal)
|
| 16 |
+
return np.random.normal(size=(len(texts), 384)).astype("float32")
|
| 17 |
+
resp = self.client.embed(texts=texts, model=COHERE_EMBED_MODEL)
|
| 18 |
+
vecs = np.array(getattr(resp, "embeddings", []) or getattr(resp, "data", []), dtype="float32")
|
| 19 |
+
return vecs
|
| 20 |
+
|
| 21 |
+
def add(self, chunks: List[str]):
|
| 22 |
+
if not chunks: return
|
| 23 |
+
new_vecs = self._embed(chunks)
|
| 24 |
+
if self.vecs is None:
|
| 25 |
+
self.vecs = new_vecs
|
| 26 |
+
self.texts = list(chunks)
|
| 27 |
+
else:
|
| 28 |
+
self.vecs = np.vstack([self.vecs, new_vecs])
|
| 29 |
+
self.texts.extend(chunks)
|
| 30 |
|
| 31 |
+
def retrieve(self, query: str, k: int = 6) -> List[Tuple[str, float]]:
|
| 32 |
+
if not self.texts: return []
|
| 33 |
+
qv = self._embed([query])[0]
|
| 34 |
+
sims = (self.vecs @ qv) / (np.linalg.norm(self.vecs, axis=1) * (np.linalg.norm(qv) + 1e-9))
|
| 35 |
+
idx = np.argsort(-sims)[:k]
|
| 36 |
+
return [(self.texts[i], float(sims[i])) for i in idx]
|
| 37 |
|
|
|
|
|
|