Amna2024 commited on
Commit
b09d763
·
verified ·
1 Parent(s): 4ca67f9

Update RAG/Retriever.py

Browse files
Files changed (1) hide show
  1. RAG/Retriever.py +50 -1
RAG/Retriever.py CHANGED
@@ -1,8 +1,57 @@
1
  from langchain_chroma import Chroma
2
  from langchain_core.vectorstores import VectorStore
3
- from task1 import LangchainGeminiWrapper #This is from your old task1 file
4
  import chromadb
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def load_vector_store(gemini_key: str, persist_directory: str) -> VectorStore:
7
  gemini_embedder = LangchainGeminiWrapper(api_key=gemini_key)
8
  return Chroma(
 
1
  from langchain_chroma import Chroma
2
  from langchain_core.vectorstores import VectorStore
3
+ #from task1 import LangchainGeminiWrapper #This is from your old task1 file
4
  import chromadb
5
 
6
+ from llama_index.embeddings.gemini import GeminiEmbedding
7
+ from typing import List, Dict
8
+
9
+ import chromadb
10
+ import os
11
+ import pickle
12
+
13
+
14
+ # Retrieve API keys from environment variables
15
+ userdata = {
16
+ "GEMINI_API_KEY":os.getenv("GEMINI_API_KEY"),
17
+ }
18
+ gemini_key = userdata.get("GEMINI_API_KEY")
19
+ parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
20
+
21
+ # sync
22
+ # Load docs later
23
+ with open('split_docs.pkl', 'rb') as f:
24
+ docs = pickle.load(f)
25
+ client = chromadb.PersistentClient(path=parent_dir)
26
+
27
+ # For all subsequent usage:
28
+ class LangchainGeminiWrapper:
29
+ """
30
+ Wrapper class to make GeminiEmbedding compatible with Langchain Chroma's interface
31
+ """
32
+ def __init__(self, api_key: str, model_name: str = "models/embedding-001"):
33
+ self.model = GeminiEmbedding(
34
+ api_key=api_key,
35
+ model_name=model_name
36
+ )
37
+
38
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
39
+ """
40
+ Embed multiple documents
41
+ """
42
+ return [self.model.get_text_embedding(text) for text in texts]
43
+
44
+ def embed_query(self, text: str) -> List[float]:
45
+ """
46
+ Embed a single query
47
+ """
48
+ return self.model.get_text_embedding(text)
49
+
50
+
51
+
52
+
53
+
54
+
55
  def load_vector_store(gemini_key: str, persist_directory: str) -> VectorStore:
56
  gemini_embedder = LangchainGeminiWrapper(api_key=gemini_key)
57
  return Chroma(