Spaces:
Sleeping
Sleeping
| from langchain_chroma import Chroma | |
| from langchain_core.vectorstores import VectorStore | |
| #from task1 import LangchainGeminiWrapper #This is from your old task1 file | |
| import chromadb | |
| from llama_index.embeddings.gemini import GeminiEmbedding | |
| from typing import List, Dict | |
| import chromadb | |
| import os | |
| import pickle | |
| # Retrieve API keys from environment variables | |
| userdata = { | |
| "GEMINI_API_KEY":os.getenv("GEMINI_API_KEY"), | |
| } | |
| gemini_key = userdata.get("GEMINI_API_KEY") | |
| parent_dir = os.path.dirname(os.path.abspath(__file__)) | |
| pkl_path = os.path.join(parent_dir, 'split_docs.pkl') | |
| # sync | |
| # Load docs later | |
| with open(pkl_path, 'rb') as f: | |
| docs = pickle.load(f) | |
| client = chromadb.PersistentClient(path=parent_dir) | |
| # For all subsequent usage: | |
| class LangchainGeminiWrapper: | |
| """ | |
| Wrapper class to make GeminiEmbedding compatible with Langchain Chroma's interface | |
| """ | |
| def __init__(self, api_key: str, model_name: str = "models/embedding-001"): | |
| self.model = GeminiEmbedding( | |
| api_key=api_key, | |
| model_name=model_name | |
| ) | |
| def embed_documents(self, texts: List[str]) -> List[List[float]]: | |
| """ | |
| Embed multiple documents | |
| """ | |
| return [self.model.get_text_embedding(text) for text in texts] | |
| def embed_query(self, text: str) -> List[float]: | |
| """ | |
| Embed a single query | |
| """ | |
| return self.model.get_text_embedding(text) | |
| def load_vector_store(gemini_key: str, persist_directory: str) -> VectorStore: | |
| gemini_embedder = LangchainGeminiWrapper(api_key=gemini_key) | |
| return Chroma( | |
| collection_name="example_collection", | |
| embedding_function=gemini_embedder, | |
| persist_directory=persist_directory | |
| ) | |
| class Retriever: | |
| def __init__(self, vectordb: VectorStore): | |
| self.vectordb = vectordb | |
| def retrieve_documents(self, query: str, k: int = 7) -> str: | |
| docs = self.vectordb.similarity_search(query, k=k) | |
| return "\nRetrieved documents:\n" + "".join( | |
| [f"===== Document {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)] | |
| ) |