Spaces:
Sleeping
Sleeping
| from langchain_chroma import Chroma | |
| from langchain_core.documents import Document | |
| from langchain_core.messages import BaseMessage, HumanMessage | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| import os | |
| import json | |
| from typing import Dict, List, Optional | |
| from .shared_state import SharedState | |
| class Retrieval: | |
| def __init__(self, doc_path: Optional[str] = "./metadata.jsonl"): | |
| self._persist_directory = os.getenv( | |
| "CHROMA_PERSIST_DIRECTORY", "./chroma_db") | |
| print(os.getenv("EMBEDDINGS_PROVIDER", None)) | |
| self._embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") if os.getenv( | |
| "EMBEDDINGS_PROVIDER", None) == "hf" else GoogleGenerativeAIEmbeddings(model="models/embedding-001") | |
| self._vectorstore = Chroma( | |
| persist_directory=self._persist_directory, | |
| embedding_function=self._embeddings, | |
| collection_name="hf_agent_quitz", | |
| ) | |
| if doc_path: | |
| self.load_vectorstore(doc_path=doc_path) | |
| def load_vectorstore(self, doc_path: Optional[str] = None): | |
| print(f"Loading documents from {doc_path}") | |
| self._vectorstore.reset_collection() | |
| documents: List[Document] = [] | |
| with open(file=doc_path, mode="r", encoding="utf-8") as f: | |
| for line in f: | |
| try: | |
| # Assuming each line is a JSON object | |
| data = json.loads(line.strip()) | |
| content = data.get("content") or data.get( | |
| "text") or str(data) | |
| # Store the entire line as a string in metadata | |
| metadata = { | |
| "raw_json": line.strip() # store the raw JSON as string | |
| } | |
| documents.append(Document( | |
| page_content=f"Question: {data['Question']}, file_name:{data['file_name']}, Final answer: {data['Final answer']}", metadata=metadata, id=data["task_id"])) | |
| except json.JSONDecodeError: | |
| print(f"Error decoding JSON: {line.strip()}") | |
| except KeyError as e: | |
| print(f"Missing key in JSON data: {e}") | |
| if documents: | |
| self._vectorstore.add_documents(documents) | |
| else: | |
| print("No documents to add to the vectorstore.") | |
| def __call__(self, state: SharedState) -> Dict[str, List[BaseMessage]]: | |
| similar_docs: List[Document] = self._vectorstore.similarity_search( | |
| state["messages"][0].content, k=1) | |
| response = f"Here is the similar question and answer for your reference:\n {similar_docs[0].page_content}" if similar_docs else "" | |
| return {"messages": state["messages"] + [] if not response else [HumanMessage(content=response)]} | |