Rajan Sharma
Update rag.py
3982b77 verified
raw
history blame
1.23 kB
from typing import List, Tuple
import numpy as np
from llm_router import cohere_embed
class RAGIndex:
def __init__(self):
self.texts: List[str] = []
self.vecs: np.ndarray | None = None
def add(self, chunks: List[str]):
if not chunks: return
new_vecs_list = cohere_embed(chunks)
if not new_vecs_list:
# fallback: random to avoid crash (not ideal for accuracy)
new_vecs = np.random.normal(size=(len(chunks), 384)).astype("float32")
else:
new_vecs = np.array(new_vecs_list, dtype="float32")
if self.vecs is None:
self.vecs = new_vecs
self.texts = list(chunks)
else:
self.vecs = np.vstack([self.vecs, new_vecs])
self.texts.extend(chunks)
def retrieve(self, query: str, k: int = 6) -> List[Tuple[str, float]]:
if not self.texts: return []
qv_list = cohere_embed([query])
if not qv_list: return []
qv = np.array(qv_list[0], dtype="float32")
sims = (self.vecs @ qv) / (np.linalg.norm(self.vecs, axis=1) * (np.linalg.norm(qv) + 1e-9))
idx = np.argsort(-sims)[:k]
return [(self.texts[i], float(sims[i])) for i in idx]