File size: 2,877 Bytes
2f18493
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_huggingface import HuggingFaceEmbeddings
import os
import json
from typing import Dict, List, Optional
from .shared_state import SharedState


class Retrieval:
    def __init__(self, doc_path: Optional[str] = "./metadata.jsonl"):
        self._persist_directory = os.getenv(
            "CHROMA_PERSIST_DIRECTORY", "./chroma_db")
        print(os.getenv("EMBEDDINGS_PROVIDER", None))
        self._embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") if os.getenv(
            "EMBEDDINGS_PROVIDER", None) == "hf" else GoogleGenerativeAIEmbeddings(model="models/embedding-001")
        self._vectorstore = Chroma(
            persist_directory=self._persist_directory,
            embedding_function=self._embeddings,
            collection_name="hf_agent_quitz",
        )
        if doc_path:
            self.load_vectorstore(doc_path=doc_path)

    def load_vectorstore(self, doc_path: Optional[str] = None):
        print(f"Loading documents from {doc_path}")
        self._vectorstore.reset_collection()
        documents: List[Document] = []
        with open(file=doc_path, mode="r", encoding="utf-8") as f:
            for line in f:
                try:
                    # Assuming each line is a JSON object
                    data = json.loads(line.strip())
                    content = data.get("content") or data.get(
                        "text") or str(data)
                    # Store the entire line as a string in metadata
                    metadata = {
                        "raw_json": line.strip()  # store the raw JSON as string
                    }
                    documents.append(Document(
                        page_content=f"Question: {data['Question']}, file_name:{data['file_name']}, Final answer: {data['Final answer']}", metadata=metadata, id=data["task_id"]))
                except json.JSONDecodeError:
                    print(f"Error decoding JSON: {line.strip()}")
                except KeyError as e:
                    print(f"Missing key in JSON data: {e}")
        if documents:
            self._vectorstore.add_documents(documents)
        else:
            print("No documents to add to the vectorstore.")

    def __call__(self, state: SharedState) -> Dict[str, List[BaseMessage]]:
        similar_docs: List[Document] = self._vectorstore.similarity_search(
            state["messages"][0].content, k=1)
        response = f"Here is the similar question and answer for your reference:\n {similar_docs[0].page_content}" if similar_docs else ""
        return {"messages": state["messages"] + [] if not response else [HumanMessage(content=response)]}