Pavol Liška commited on
Commit
869eb7d
·
1 Parent(s): 15d4751
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY . .
10
+
11
+ CMD ["uvicorn", "api:api", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,11 +1,9 @@
1
  ---
2
- title: Trykopy
3
- emoji: 📈
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: docker
7
  pinned: false
8
  license: agpl-3.0
9
  ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Try kopy
3
+ emoji: 💻📚🤖
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: docker
7
  pinned: false
8
  license: agpl-3.0
9
  ---
 
 
agent/Agent.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.embeddings import CacheBackedEmbeddings
2
+ from langchain.storage import LocalFileStore
3
+ from langchain_core.language_models import BaseChatModel
4
+
5
+ from emdedd.Embedding import Embedding
6
+
7
+
8
+ class Agent:
9
+ embedding: Embedding
10
+ llm: BaseChatModel
11
+
12
+ def __init__(self, embedding, llm):
13
+ self.embedding = embedding
14
+ self.llm = llm
agent/__init__.py ADDED
File without changes
agent/agents.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from dotenv import load_dotenv
4
+ from langchain_cohere.llms import Cohere
5
+ from langchain_community.chat_models import ChatDeepInfra
6
+ from langchain_groq import ChatGroq
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_together import ChatTogether
9
+
10
+ load_dotenv()
11
+
12
+
13
+ def chat_openai_llm():
14
+ return ChatOpenAI(
15
+ model_name="gpt-4o",
16
+ temperature=os.environ["temperature"],
17
+ openai_api_key=os.environ["OPENAI_API_KEY"],
18
+ openai_organization=os.environ["OPENAI_ORGANIZATION_ID"]
19
+ )
20
+
21
+
22
+ def groq_chat(model: str):
23
+ return ChatGroq(
24
+ model_name=model,
25
+ temperature=os.environ["temperature"],
26
+ groq_api_key=os.environ["GROQ_API_KEY"],
27
+ max_tokens=2000
28
+ )
29
+
30
+
31
+ def cohere_llm():
32
+ return Cohere(
33
+ model="command-r-plus",
34
+ max_tokens=2048,
35
+ temperature=os.environ["temperature"],
36
+ # p=os.environ["top_p"],
37
+ # frequency_penalty=os.environ["frequency_penalty"],
38
+ )
39
+
40
+
41
+ def together_ai_chat(model, temperature):
42
+ return ChatTogether(
43
+ model_name=model,
44
+ together_api_key=os.environ["TOGETHER_AI_API_KEY"],
45
+ temperature=temperature,
46
+ top_p=os.environ["top_p"],
47
+ max_tokens=4096,
48
+ model_kwargs={"stop": ["%%%%"]},
49
+ )
50
+
51
+
52
+ def deepinfra_chat(model, temperature):
53
+ return ChatDeepInfra(
54
+ model=model,
55
+ deepinfra_api_token=os.environ["DEEPINFRA_API_KEY"],
56
+ temperature=temperature,
57
+ top_p=os.environ["top_p"],
58
+ max_tokens=4096,
59
+ model_kwargs={"stop": ["%%%%"]},
60
+ )
api.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Response, Body, Security
2
+ from fastapi.security import APIKeyHeader
3
+
4
+ from conversation.conversation_store import ConversationStore
5
+ from rag_langchain import LangChainRAG
6
+
7
+ api = FastAPI()
8
+
9
+ conversation_store = ConversationStore()
10
+
11
+ api_key_header = APIKeyHeader(name="Authorization", auto_error=True)
12
+ prompt_id = "summarize_rag_1"
13
+ check_prompt_id = "check_control_challenge_step_back"
14
+ rewrite_prompt_id = "first"
15
+ default_llm = "gpt-4o 128k"
16
+
17
+
18
+ @api.get("/")
19
+ def read_root():
20
+ return "Empty"
21
+
22
+
23
+ @api.post("/q")
24
+ async def q(api_key: str = Security(api_key_header), json_body: dict = Body(...)):
25
+ # Verify the API key
26
+ if not valid_api_key(api_key):
27
+ return Response(status_code=401)
28
+
29
+ # Process the JSON body
30
+ data = json_body
31
+
32
+ rag = LangChainRAG(
33
+ config={
34
+ "retrieve_documents": data["retrieval_count"],
35
+ "temperature": data["temperature"],
36
+ "prompt_id": prompt_id,
37
+ "check_prompt_id": check_prompt_id,
38
+ "rewrite_prompt_id": rewrite_prompt_id
39
+ }
40
+ )
41
+
42
+ answer, check_result, sources = rag.rag_chain(data["q"], default_llm)
43
+
44
+ oid = conversation_store.save_content(
45
+ q=q,
46
+ a=answer,
47
+ sources=list(map(lambda doc: doc.page_content, sources)),
48
+ params=
49
+ {
50
+ "prompt_id": prompt_id,
51
+ "check_prompt_id": check_prompt_id,
52
+ "rewrite_prompt_id": rewrite_prompt_id,
53
+ "check_result": check_result,
54
+ "temperature": data["temperature"],
55
+ "retrieve_document_count": data["retrieval_count"],
56
+ }
57
+ )
58
+
59
+ return Response(
60
+ status_code=200,
61
+ content={
62
+ "response": answer,
63
+ "sources": list(map(lambda doc: doc.page_content, sources)),
64
+ "qid": oid
65
+ }
66
+ )
67
+
68
+
69
+ @api.post("/emo")
70
+ async def emo(api_key: str = Security(api_key_header), json_body: dict = Body(...)):
71
+ # Verify the API key
72
+ if not valid_api_key(api_key):
73
+ return Response(status_code=401)
74
+
75
+ qa = conversation_store.get(json_body["qid"])
76
+ new_params = qa.params
77
+ new_params["user_grading"] = str(json_body["helpfulness"])
78
+ conversation_store.update(
79
+ oid=json_body["qid"],
80
+ q=qa.conversation[0].q,
81
+ a=qa.conversation[0].a,
82
+ sources=qa.conversation[0].sources,
83
+ params=new_params
84
+ )
85
+
86
+
87
+ def valid_api_key(api_key: str):
88
+ return api_key == "your_secret_api_key"
conversation/__init__.py ADDED
File without changes
conversation/conversation_store.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from dotenv import load_dotenv
4
+ from pymongo import MongoClient
5
+
6
+ from dto.conversation import Conversation, OneShotConversation
7
+
8
+ load_dotenv()
9
+
10
+
11
+ class ConversationStorage:
12
+
13
+ def __init__(self, db_uri, db_name, collection_name):
14
+ self.client = MongoClient(db_uri)
15
+ self.collection = self.client[db_name][collection_name]
16
+
17
+ def store_document_mongodb(self, conversation):
18
+ return self.collection.insert_one(conversation.model_dump()).inserted_id
19
+
20
+ def get_all(self):
21
+ docs = []
22
+ for q in self.collection.find():
23
+ docs.append(Conversation(**q))
24
+ return docs
25
+
26
+ def get_next(self, offset):
27
+ q = self.collection.find().limit(1).skip(offset)[0]
28
+ document = Conversation(**q)
29
+ print("Got {} stored at ".format(document.name, document.created))
30
+ return document
31
+
32
+ def get_one(self, oid) -> Conversation:
33
+ q = self.collection.find({"_id": oid})[0]
34
+ document = Conversation(**q)
35
+ print("Got {} stored at ".format(document.name, document.created))
36
+ return document
37
+
38
+ def count(self):
39
+ return self.collection.estimated_document_count()
40
+
41
+ def update(self, oid, conversation):
42
+ self.collection.update_one(
43
+ filter={"_id": oid},
44
+ update=conversation.model_dump()
45
+ )
46
+
47
+
48
+ class ConversationStore:
49
+ URI = os.environ["DB_CONN_LOG"]
50
+ DB_NAME = os.environ["MONGODB_DB_NAME_MSG_LOG"]
51
+ COLLECTION_NAME = os.environ["MONGODB_COL_NAME_MSG_LOG"]
52
+
53
+ storage = ConversationStorage(URI, DB_NAME, COLLECTION_NAME)
54
+
55
+ def save_content(self, q, a=None, sources: list[str] = None, params: dict[str, str] = None):
56
+ return self.storage.store_document_mongodb(
57
+ Conversation(
58
+ conversation=[
59
+ OneShotConversation(
60
+ q=q,
61
+ a=a,
62
+ sources=sources
63
+ )
64
+ ],
65
+ params=params
66
+ )
67
+ )
68
+
69
+ def get_all(self):
70
+ return self.storage.get_all()
71
+
72
+ def get_one(self, offset):
73
+ return self.storage.get_next(offset)
74
+
75
+ def get(self, oid: str) -> Conversation:
76
+ return self.storage.get_one(oid)
77
+
78
+ def count(self):
79
+ return self.storage.count()
80
+
81
+ def update(self, oid, q, a, sources, params):
82
+ self.storage.update(
83
+ oid,
84
+ Conversation(
85
+ conversation=[
86
+ OneShotConversation(
87
+ q=q,
88
+ a=a,
89
+ sources=sources
90
+ )
91
+ ],
92
+ params=params
93
+ )
94
+ )
dto/__init__.py ADDED
File without changes
dto/conversation.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from datetime import datetime
3
+ from typing import Optional, Annotated
4
+
5
+ from pydantic import BaseModel, Field, BeforeValidator
6
+
7
+
8
+ class OneShotConversation(BaseModel):
9
+ q: str = Field()
10
+ a: str = Field()
11
+ sources: list[str] = Field()
12
+
13
+
14
+ class Conversation(BaseModel):
15
+ id: Optional[Annotated[str, BeforeValidator(str)]] = Field(alias="_id", default=None)
16
+ created: datetime = Field(default_factory=datetime.now)
17
+ conversation: list[OneShotConversation] = Field()
18
+ params: dict[str, str] = Field()
dto/document.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from datetime import datetime
3
+
4
+ from pydantic import BaseModel, Field
5
+
6
+
7
+ class Document(BaseModel):
8
+ id: str = Field(default_factory=uuid.uuid4, alias="_id")
9
+ name: str = Field()
10
+ text: str = Field()
11
+ created: str = Field(default_factory=datetime.now)
dto/prompt.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uuid
2
+ from typing import Optional, Annotated
3
+
4
+ from bson import ObjectId
5
+ from pydantic import BaseModel, Field, BeforeValidator
6
+
7
+
8
+ class Prompt(BaseModel):
9
+ id: Optional[Annotated[str, BeforeValidator(str)]] = Field(alias="_id", default=None)
10
+ name: str = Field()
11
+ text: str = Field()
12
+ prompt_type: str = Field()
emdedd/ChromaEmbedding.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores.chroma import Chroma
2
+
3
+ from emdedd.Embedding import Embedding
4
+
5
+
6
+ class ChromaEmbedding(Embedding):
7
+ db: Chroma
8
+
9
+ def __init__(self, embedding, path, collection, collection_metadata = None):
10
+ self.db = Chroma(
11
+ embedding_function=embedding,
12
+ persist_directory=path,
13
+ collection_name=collection,
14
+ collection_metadata=collection_metadata
15
+ )
16
+
17
+ def embedd(self, chunks, metadata: list[dict] = None):
18
+ self.__store_embeddings(chunks, metadata)
19
+
20
+ def __store_embeddings(self, chunks, metadata: list[dict] = None):
21
+ self.db.add_texts(
22
+ texts=chunks,
23
+ metadatas=metadata
24
+ )
25
+ self.db.persist()
26
+
27
+ def get_vector_store(self):
28
+ return self.db
emdedd/Embedding.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class EmbeddingDbConnection:
6
+ connection: str
7
+ database: str
8
+ collection: str
9
+ index: str
10
+
11
+
12
+ class Embedding:
13
+ def embedd(self, chunks, metadata: list[dict] = None):
14
+ pass
15
+
16
+ def get_vector_store(self):
17
+ pass
18
+
19
+ def search(self, query, search_type, doc_count):
20
+ pass
emdedd/MongoEmbedding.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from langchain.embeddings import CacheBackedEmbeddings
4
+ from langchain.storage import LocalFileStore
5
+ from langchain_community.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch
6
+ from langchain_core.embeddings import Embeddings
7
+ from langchain_core.stores import InMemoryStore
8
+ from pymongo import MongoClient
9
+ from bson.objectid import ObjectId
10
+
11
+ from emdedd.Embedding import Embedding
12
+
13
+
14
+ @dataclass
15
+ class EmbeddingDbConnection:
16
+ connection: str
17
+ database: str
18
+ collection: str
19
+ index: str
20
+
21
+
22
+ class MongoEmbedding(Embedding):
23
+ db: EmbeddingDbConnection
24
+ embedding: Embeddings
25
+
26
+ def __init__(self, db, embedding, cache: bool = True):
27
+ self.db = db
28
+ if cache:
29
+ self.embedding = CacheBackedEmbeddings.from_bytes_store(
30
+ underlying_embeddings=embedding,
31
+ document_embedding_cache=InMemoryStore(),
32
+ namespace="mongo-embedding-cache"
33
+ )
34
+ else:
35
+ self.embedding = embedding
36
+
37
+ def embedd(self, chunks, metadata: list[dict] = None):
38
+ self.__store_embeddings(chunks, metadata)
39
+
40
+ def __store_embeddings(self, chunks, metadata: list[dict] = None):
41
+ client = MongoClient(self.db.connection)
42
+ collection = client[self.db.database][self.db.collection]
43
+
44
+ # collection.create_search_index(
45
+ # {"definition":
46
+ # {"mappings": {"dynamic": True, "fields": {
47
+ # "embedding": {
48
+ # "dimensions": 1536,
49
+ # "similarity": "cosine",
50
+ # "type": "knnVector"
51
+ # }}}},
52
+ # "name": self.MONGODB_INDEX_NAME
53
+ # }
54
+ # )
55
+
56
+ MongoDBAtlasVectorSearch.from_texts(
57
+ texts=chunks,
58
+ metadatas=metadata,
59
+ embedding=self.embedding,
60
+ collection=collection,
61
+ index_name=self.db.index
62
+ )
63
+
64
+ self.__add_id_to_metadata(collection)
65
+
66
+ def __add_id_to_metadata(self, collection):
67
+ for document in collection.find({"metadata.id": {"$exists": "false"}}):
68
+ metadata: dict = document["metadata"]
69
+ if metadata is None: metadata = {}
70
+ object_id: ObjectId = document["_id"]
71
+ metadata["id"] = object_id.__str__()
72
+ collection.update_one(
73
+ filter={"_id": object_id},
74
+ update={"metadata": metadata}
75
+ )
76
+
77
+ def get_vector_store(self):
78
+ return MongoDBAtlasVectorSearch.from_connection_string(
79
+ self.db.connection,
80
+ self.db.database + "." + self.db.collection,
81
+ embedding=self.embedding,
82
+ index_name=self.db.index
83
+ )
84
+
85
+ def search(self, query, search_type, doc_count):
86
+ vector_store = self.get_vector_store()
87
+ retriever = vector_store.as_retriever(
88
+ search_type="similarity",
89
+ search_kwargs={"k": doc_count}
90
+ )
91
+ return retriever.get_relevant_documents(query=query)
emdedd/Splitter.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
2
+
3
+
4
+ class Splitter:
5
+ separators = []
6
+ chunk_overlap: int
7
+ chunk_size: int
8
+
9
+ def __init__(self, separators, chunk_overlap, chunk_size):
10
+ self.separators = separators
11
+ self.chunk_overlap = chunk_overlap
12
+ self.chunk_size = chunk_size
13
+
14
+ def split(self, text):
15
+ text_splitter = RecursiveCharacterTextSplitter(
16
+ separators=self.separators,
17
+ is_separator_regex=True,
18
+ chunk_overlap=self.chunk_overlap,
19
+ chunk_size=self.chunk_size
20
+ )
21
+ return text_splitter.split_text(text)
emdedd/__init__.py ADDED
File without changes
emdedd/embeddings.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from emdedd.ChromaEmbedding import ChromaEmbedding
2
+
3
+
4
+ def chroma_embedding(name: str, embedding, metadata=None) -> ChromaEmbedding:
5
+ return ChromaEmbedding(
6
+ path="./chromadb/zpl",
7
+ embedding=embedding,
8
+ collection=name,
9
+ collection_metadata=metadata
10
+ )
prompt/__init__.py ADDED
File without changes
prompt/prompt_store.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from dotenv import load_dotenv
4
+ from pymongo import MongoClient
5
+
6
+ from dto.prompt import Prompt
7
+
8
+ load_dotenv()
9
+
10
+
11
+ class PromptStore:
12
+ URI = os.environ["DB_CONN_LOG"]
13
+ DB_NAME = os.environ["MONGODB_DB_NAME_MSG_LOG"]
14
+ COLLECTION_NAME = os.environ["MONGODB_COL_NAME_PROMPT"]
15
+ client = MongoClient(URI)
16
+ collection = client[DB_NAME][COLLECTION_NAME]
17
+
18
+ def save_content(self, name, text, prompt_type):
19
+ self.collection.insert_one(
20
+ Prompt(
21
+ name=name,
22
+ text=text,
23
+ prompt_type=prompt_type
24
+ ).model_dump()
25
+ )
26
+
27
+ def get_all(self):
28
+ docs = []
29
+ for q in self.collection.find():
30
+ docs.append(Prompt(**q))
31
+ return docs
32
+
33
+ def get_one(self, offset) -> Prompt:
34
+ q = self.collection.find().limit(1).skip(offset)[0]
35
+ document = Prompt(**q)
36
+ print("Got prompt {}".format(document.name))
37
+ return document
38
+
39
+ def get(self, oid: str) -> Prompt:
40
+ q = self.collection.find({"_id": oid})[0]
41
+ document = Prompt(**q)
42
+ print("Got prompt {}".format(document.name))
43
+ return document
44
+
45
+ def count(self):
46
+ return self.collection.estimated_document_count()
47
+
48
+ def get_by_name(self, name) -> Prompt:
49
+ q = self.collection.find({"name": name})[0]
50
+ document = Prompt(**q)
51
+ print("Got prompt {}".format(document.name))
52
+ return document
rag.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os
3
+ import traceback
4
+ from typing import Any
5
+
6
+ from dotenv import load_dotenv
7
+ from langchain.chains import LLMChain
8
+ from langchain.chains.combine_documents import create_stuff_documents_chain
9
+ from langchain.chains.retrieval import create_retrieval_chain
10
+ from langchain.retrievers import MultiQueryRetriever, MergerRetriever, ContextualCompressionRetriever, EnsembleRetriever
11
+ from langchain_cohere import CohereRerank
12
+ from langchain_core.documents import Document
13
+ from langchain_core.prompts import PromptTemplate
14
+
15
+ from agent.Agent import Agent
16
+ from agent.agents import chat_openai_llm, deepinfra_chat
17
+ from conversation.conversation_store import ConversationStore
18
+ from prompt.prompt_store import PromptStore
19
+ from retrieval import retrieve, retrieve_with_rerank
20
+
21
+ load_dotenv()
22
+
23
+ conversation_store = ConversationStore()
24
+ prompt_store = PromptStore()
25
+
26
+ grammar_check_1 = prompt_store.get_by_name("gramar_check_1")
27
+ rewrite_hyde_1 = prompt_store.get_by_name("rewrite_hyde_1")
28
+ rewrite_hyde_2 = prompt_store.get_by_name("rewrite_hyde_2")
29
+ rewrite_1 = prompt_store.get_by_name("rewrite_1")
30
+ rewrite_2 = prompt_store.get_by_name("rewrite_2")
31
+ rewrite_hyde = prompt_store.get_by_name("rewrite_hyde")
32
+
33
+
34
+ def replace_nl(input: str) -> str:
35
+ return input.replace('\r\n', '<br>').replace('\n', '<br>').replace('\r', '<br>')
36
+
37
+
38
+ def rag(agent: Agent, q: str, retrieve_document_count: int):
39
+ k = retrieve_document_count
40
+
41
+ context_doc = retrieve(agent.embedding, q, k)
42
+
43
+ prompt_template = PromptTemplate(
44
+ input_variables=["context", "question"],
45
+ template=os.environ["RAG_TEMPLATE"]
46
+ )
47
+
48
+ llm_chain = LLMChain(
49
+ llm=agent.llm,
50
+ prompt=prompt_template,
51
+ verbose=False
52
+ )
53
+
54
+ # llm_chain = prompt_template | agent.llm
55
+
56
+ result: dict[str, Any] = llm_chain.invoke(
57
+ input={
58
+ "question": q,
59
+ "context": context_doc
60
+ }
61
+ )
62
+
63
+ return result["text"]
64
+
65
+
66
+ def rewrite(agent: Agent, q: str, prompt: str) -> list[str]:
67
+ prompt_template = PromptTemplate(
68
+ input_variables=["question"],
69
+ template=prompt
70
+ )
71
+ llm_chain = LLMChain(
72
+ llm=agent.llm,
73
+ prompt=prompt_template,
74
+ verbose=False
75
+ )
76
+ questions = llm_chain.invoke(
77
+ input={"question": q}
78
+ )["text"].splitlines()
79
+
80
+ return [x for x in questions if ("##" not in x and len(str(x).strip()) > 0)]
81
+
82
+
83
+ def rag_with_rerank_check_rewrite(agent: Agent, q: str, retrieve_document_count: int, prompt: str, check_prompt: str,
84
+ rewrite_prompt: str):
85
+ rewritten_list: list[str] = rewrite(agent, q, rewrite_prompt)
86
+
87
+ if len(rewritten_list) == 0:
88
+ return "Neviem, nemám podklady!", "", ""
89
+
90
+ context_doc = retrieve_subqueries(agent, retrieve_document_count, rewritten_list)
91
+
92
+ if len(context_doc) == 0:
93
+ return "Neviem, nemám kontext!", "", ""
94
+
95
+ result = answer_pipeline(agent, context_doc, prompt, q)
96
+ answer = result["text"]
97
+
98
+ check_result = check_pipeline(answer, check_prompt, context_doc, q)
99
+
100
+ return answer, check_result, context_doc
101
+
102
+
103
+ def rag_with_rerank_check_rewrite_hyde(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
104
+ check_prompt: str,
105
+ rewrite_prompt: str):
106
+ rewritten_list: list[str] = rewrite(agent, q, rewrite_prompt)
107
+
108
+ if len(rewritten_list) == 0:
109
+ return "Neviem, nemám podklady!", "", ""
110
+
111
+ context_doc = retrieve_subqueries_hyde(agent, retrieve_document_count, rewritten_list)
112
+
113
+ if len(context_doc) == 0:
114
+ return "Neviem, nemám kontext!", "", ""
115
+
116
+ result = answer_pipeline(agent, context_doc, prompt, q)
117
+ answer = result["text"]
118
+
119
+ check_result = check_pipeline(answer, check_prompt, context_doc, q)
120
+
121
+ return answer, check_result, context_doc
122
+
123
+
124
+ def rag_with_rerank_check_multi_query_retriever(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
125
+ check_prompt: str):
126
+ context_doc = hyde_retrieval(agent, retrieve_document_count).invoke(
127
+ input=q,
128
+ kwargs={"k": retrieve_document_count}
129
+ )
130
+
131
+ if len(context_doc) == 0:
132
+ return "Neviem, nemám kontext!", "", ""
133
+
134
+ result = answer_pipeline(agent, context_doc, prompt, q)
135
+ answer = result["text"]
136
+
137
+ check_result = check_pipeline(answer, check_prompt, context_doc, q)
138
+
139
+ return answer, check_result, context_doc
140
+
141
+
142
+ def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
143
+ check_prompt: str):
144
+ result = create_retrieval_chain(
145
+ retriever=hyde_2_retrieval(agent, retrieve_document_count),
146
+ combine_docs_chain=create_stuff_documents_chain(
147
+ llm=agent.llm,
148
+ prompt=PromptTemplate(
149
+ input_variables=["context", "question", "actual_date"],
150
+ template=prompt
151
+ )
152
+ )
153
+ ).invoke(
154
+ input={
155
+ "question": q,
156
+ "input": q,
157
+ "actual_date": datetime.date.today().isoformat()
158
+ }
159
+ )
160
+
161
+ check_result = check_pipeline(result["answer"], check_prompt, result["context"], q)
162
+
163
+ return result["answer"], check_result, result["context"]
164
+
165
+
166
+ def hyde_retrieval(agent, retrieve_document_count):
167
+ retriever_1 = MultiQueryRetriever.from_llm(
168
+ llm=agent.llm,
169
+ retriever=agent.embedding.get_vector_store().as_retriever(
170
+ search_type="similarity",
171
+ search_kwargs={"k": retrieve_document_count}
172
+ ),
173
+ prompt=PromptTemplate(
174
+ input_variables=["question"],
175
+ template=rewrite_hyde_1
176
+ )
177
+ )
178
+ retriever_2 = MultiQueryRetriever.from_llm(
179
+ llm=agent.llm,
180
+ retriever=agent.embedding.get_vector_store().as_retriever(
181
+ search_type="similarity",
182
+ search_kwargs={"k": retrieve_document_count}
183
+ ),
184
+ prompt=PromptTemplate(
185
+ input_variables=["question"],
186
+ template=rewrite_hyde_2
187
+ )
188
+ )
189
+ merge_retriever = MergerRetriever(
190
+ retrievers=[retriever_1, retriever_2],
191
+ )
192
+ compressor = CohereRerank(
193
+ model="rerank-multilingual-v3.0",
194
+ top_n=retrieve_document_count
195
+ )
196
+ compression_retriever = ContextualCompressionRetriever(
197
+ base_compressor=compressor,
198
+ base_retriever=merge_retriever,
199
+ search_kwargs={"k": retrieve_document_count},
200
+ )
201
+
202
+ return compression_retriever
203
+
204
+
205
+ def hyde_2_retrieval(agent, retrieve_document_count):
206
+ compressor = CohereRerank(
207
+ model="rerank-multilingual-v3.0",
208
+ top_n=retrieve_document_count / 2
209
+ )
210
+ retriever_1 = MultiQueryRetriever.from_llm(
211
+ llm=agent.llm,
212
+ retriever=agent.embedding.get_vector_store().as_retriever(
213
+ search_type="similarity",
214
+ search_kwargs={"k": min(retrieve_document_count * 10, 500)}
215
+ ),
216
+ prompt=PromptTemplate(
217
+ input_variables=["question"],
218
+ template=rewrite_1
219
+ )
220
+ )
221
+ compression_retriever_1 = ContextualCompressionRetriever(
222
+ base_compressor=compressor,
223
+ base_retriever=retriever_1
224
+ )
225
+
226
+ retriever_2 = MultiQueryRetriever.from_llm(
227
+ llm=agent.llm,
228
+ retriever=agent.embedding.get_vector_store().as_retriever(
229
+ search_type="similarity",
230
+ search_kwargs={"k": min(retrieve_document_count * 10, 500)}
231
+ ),
232
+ prompt=PromptTemplate(
233
+ input_variables=["question"],
234
+ template=rewrite_2
235
+ )
236
+ )
237
+ compression_retriever_2 = ContextualCompressionRetriever(
238
+ base_compressor=compressor,
239
+ base_retriever=retriever_2
240
+ )
241
+
242
+ retriever_3 = MultiQueryRetriever.from_llm(
243
+ llm=agent.llm,
244
+ retriever=agent.embedding.get_vector_store().as_retriever(
245
+ search_type="similarity",
246
+ search_kwargs={"k": min(retrieve_document_count * 10, 500)}
247
+ ),
248
+ prompt=PromptTemplate(
249
+ input_variables=["question"],
250
+ template=rewrite_hyde
251
+ )
252
+ )
253
+ compression_retriever_3 = ContextualCompressionRetriever(
254
+ base_compressor=compressor,
255
+ base_retriever=retriever_3
256
+ )
257
+
258
+ merge_retriever = EnsembleRetriever(
259
+ retrievers=[compression_retriever_1, compression_retriever_2, compression_retriever_3],
260
+ weights=[1.0, 1.0, 1.0]
261
+ )
262
+
263
+ return merge_retriever
264
+
265
+
266
+ def retrieve_subqueries(agent, retrieve_document_count, rewritten_list) -> list[Document]:
267
+ contexts: list[Document] = []
268
+ for rewritten in rewritten_list:
269
+ contexts.extend(retrieve_with_rerank(agent.embedding, rewritten, retrieve_document_count))
270
+
271
+ contexts.sort(key=lambda x: -x.metadata["relevance_score"])
272
+
273
+ deduplicated: list[Document] = []
274
+ for doc in contexts:
275
+ already_in = False
276
+ for de_doc in deduplicated:
277
+ if doc.page_content == de_doc.page_content:
278
+ already_in = True
279
+ if not already_in:
280
+ deduplicated.append(doc)
281
+
282
+ return deduplicated[:retrieve_document_count]
283
+
284
+
285
+ def retrieve_subqueries_hyde(agent, retrieve_document_count, rewritten_list) -> list[Document]:
286
+ contexts: list[Document] = []
287
+ for rewritten in rewritten_list:
288
+ answer = agent.llm.invoke(rewritten).content
289
+ contexts.extend(retrieve_with_rerank(agent.embedding, rewritten + "\n" + answer, retrieve_document_count))
290
+
291
+ contexts.sort(key=lambda x: -x.metadata["relevance_score"])
292
+
293
+ deduplicated: list[Document] = []
294
+ for doc in contexts:
295
+ already_in = False
296
+ for de_doc in deduplicated:
297
+ if doc.page_content == de_doc.page_content:
298
+ already_in = True
299
+ if not already_in:
300
+ deduplicated.append(doc)
301
+
302
+ return deduplicated[:retrieve_document_count]
303
+
304
+
305
+ def answer_pipeline(agent, context_doc, prompt, q):
306
+ prompt_template = PromptTemplate(
307
+ input_variables=["context", "question"],
308
+ template=prompt
309
+ )
310
+ llm_chain = LLMChain(
311
+ llm=agent.llm,
312
+ prompt=prompt_template,
313
+ verbose=False
314
+ )
315
+
316
+ result: dict[str, Any] = llm_chain.invoke(
317
+ input={
318
+ "question": q,
319
+ "context": context_doc,
320
+ "actual_date": datetime.date.today().isoformat()
321
+ }
322
+ )
323
+ return result
324
+
325
+
326
+ def check_pipeline(answer, check_prompt, context_doc, q):
327
+ prompt_template = PromptTemplate(
328
+ input_variables=["context", "question", "answer"],
329
+ template=check_prompt
330
+ )
331
+ llm_chain = LLMChain(
332
+ llm=deepinfra_chat("meta-llama/Meta-Llama-3-70B-Instruct", "0.4"),
333
+ prompt=prompt_template,
334
+ verbose=False
335
+ )
336
+ try:
337
+ check_result = llm_chain.invoke(
338
+ input={
339
+ "question": q[:2000],
340
+ "context": context_doc,
341
+ "answer": answer
342
+ }
343
+ )["text"]
344
+ except Exception as e:
345
+ check_result = traceback.format_exc()
346
+
347
+ return check_result
348
+
349
+
350
+ def rag_with_rerank(agent: Agent, q: str, retrieve_document_count: int, prompt: str = None, check_prompt: str = None):
351
+ context_doc: list[Document] = retrieve_with_rerank(agent.embedding, q, retrieve_document_count)
352
+
353
+ try:
354
+ result: dict[str, Any] = answer_pipeline(agent, context_doc, prompt, q)
355
+
356
+ answer = result["text"]
357
+ check_result = ""
358
+
359
+ if check_prompt is not None:
360
+ check_result = check_pipeline(answer, check_prompt, context_doc, q)
361
+
362
+ return answer, check_result, context_doc
363
+ except Exception as e:
364
+ return "", traceback.format_exc(), ""
365
+
366
+
367
+ def save_conversation(answer: str, check_result: str, context_doc: list[Document], gramatika: str, question: str,
368
+ prompt_id: str, check_prompt_id: str, grammar_prompt_id: str):
369
+ if len(answer) > 0:
370
+ conversation_store.save_content(
371
+ q=question,
372
+ a=answer,
373
+ sources=list(map(lambda doc: doc.page_content, context_doc)),
374
+ params=
375
+ {
376
+ "prompt_id": prompt_id,
377
+ "check_prompt_id": check_prompt_id,
378
+ "grammar_prompt_id": grammar_prompt_id,
379
+ "check_result": check_result,
380
+ "grammar_result": gramatika,
381
+ "temperature": os.environ["temperature"],
382
+ }
383
+ )
384
+
385
+
386
+ def check_slovak_agent(text: str) -> str:
387
+ prompt_template = PromptTemplate(
388
+ input_variables=["text"],
389
+ template=grammar_check_1
390
+ )
391
+
392
+ llm_chain = LLMChain(
393
+ llm=chat_openai_llm(),
394
+ prompt=prompt_template,
395
+ verbose=False
396
+ )
397
+
398
+ result: dict[str, Any] = llm_chain.invoke(input={"text": text})
399
+
400
+ return result["text"]
rag_langchain.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from dotenv import load_dotenv
4
+ from gptcache import Cache
5
+ from gptcache.manager.factory import manager_factory
6
+ from gptcache.processor.pre import get_prompt
7
+ from langchain.retrievers import ContextualCompressionRetriever
8
+ from langchain_cohere import CohereRerank
9
+ from langchain_community.cache import GPTCache
10
+ from langchain_core.language_models import BaseChatModel
11
+ from langchain_core.prompts import PromptTemplate
12
+ from langchain_core.retrievers import BaseRetriever
13
+ from langchain_google_genai import ChatGoogleGenerativeAI, HarmCategory, HarmBlockThreshold
14
+ from langchain_openai import ChatOpenAI
15
+
16
+ from agent.Agent import Agent
17
+ from agent.agents import deepinfra_chat, \
18
+ together_ai_chat, groq_chat, cohere_llm
19
+ from emdedd.Embedding import Embedding
20
+ from emdedd.embeddings import chroma_embedding, cohere_embeddings
21
+ from prompt.prompt_store import PromptStore
22
+ from rag import rag_chain
23
+
24
+ load_dotenv()
25
+
26
+
27
+ class LangChainRAG:
28
+ embedding: tuple[Embedding]
29
+ llms: dict[str, BaseChatModel]
30
+ retriever: BaseRetriever
31
+ prompt_template: PromptTemplate
32
+ config: dict
33
+ semantic_cache: GPTCache
34
+ prompt_store = PromptStore()
35
+
36
+ def __init__(self, config):
37
+ self.config = config
38
+ self.semantic_cache = GPTCache(_init_gptcache)
39
+ self.embedding = MongoEmbedding(
40
+ db=EmbeddingDbConnection(
41
+ connection=os.environ["DB_CONN_EMBED"],
42
+ database=os.environ["MONGODB_DB_NAME_ZPL_EMBED"],
43
+ collection="zpl-2402-cohere",
44
+ index="knnVector-cosine-index"
45
+ ),
46
+ embedding=CohereEmbeddings(model="embed-multilingual-v3.0")
47
+ )
48
+
49
+ self.llms = {
50
+ "gpt-4o 128k": ChatOpenAI(
51
+ model_name="gpt-4o",
52
+ temperature=config["temperature"],
53
+ openai_api_key=os.environ["OPENAI_API_KEY"],
54
+ openai_organization=os.environ["OPENAI_ORGANIZATION_ID"]
55
+ ),
56
+ "llama-3 70B deepinfra 8k": deepinfra_chat("meta-llama/Meta-Llama-3-70B-Instruct",
57
+ self.config["temperature"]),
58
+ "llama-3 8B deepinfra 8k": deepinfra_chat("meta-llama/Meta-Llama-3-8B-Instruct",
59
+ self.config["temperature"]),
60
+ "Mixtral-8x22B-Instruct-v0.1 deepinfra 32k": deepinfra_chat("mistralai/Mixtral-8x22B-Instruct-v0.1",
61
+ self.config["temperature"]),
62
+ "gemini-pro 128k": ChatGoogleGenerativeAI(
63
+ model="gemini-pro",
64
+ convert_system_message_to_human=True,
65
+ safety_settings={
66
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
67
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
68
+ HarmCategory.HARM_CATEGORY_DEROGATORY: HarmBlockThreshold.BLOCK_NONE,
69
+ HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
70
+ },
71
+ transport="rest",
72
+ stopSequence=["%%%%"],
73
+ temperature=config["temperature"],
74
+ cache=self.semantic_cache
75
+ ),
76
+ "Mistral (7B) Instruct v0.3 together.ai 32k": together_ai_chat(
77
+ model="mistralai/Mistral-7B-Instruct-v0.3",
78
+ temperature=config["temperature"]
79
+ ),
80
+ "OpenHermes-2.5 Mistral 7B together.ai 32k": together_ai_chat(
81
+ model="teknium/OpenHermes-2p5-Mistral-7B",
82
+ temperature=config["temperature"]
83
+ ),
84
+ "chat_groq_llm": groq_chat("mixtral-8x7b-32768"),
85
+ "chat_groq_llama3_70": groq_chat("llama3-70b-8192"),
86
+ "command_r_plus": cohere_llm(),
87
+
88
+ }
89
+
90
+ self.retriever = ContextualCompressionRetriever(
91
+ base_compressor=CohereRerank(model="rerank-multilingual-v3.0", top_n=os.getenv("retrieve_documents")),
92
+ base_retriever=self.get_vector_store_mongodb().as_retriever(
93
+ search_type="similarity",
94
+ search_kwargs={"k": config["retrieve_documents"] * 10}
95
+ )
96
+ )
97
+
98
+ def get_vector_store_mongodb(self):
99
+ return self.embedding[0].get_vector_store()
100
+
101
+ def get_llms(self):
102
+ return self.llms.keys()
103
+
104
+ def rag_chain(self, query, choice):
105
+ # answer, check_result, context_doc = rag_with_rerank_check_rewrite(
106
+ # answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde(
107
+ # answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever(
108
+ answer, check_result, context_doc = rag_chain(
109
+ Agent(embedding=self.embedding[0], llm=self.llms[choice]),
110
+ query,
111
+ self.config["retrieve_documents"],
112
+ self.prompt_store.get_by_name(self.config["prompt_id"]).text,
113
+ self.prompt_store.get_by_name(self.config["check_prompt_id"]).text
114
+ )
115
+
116
+ return answer, check_result, context_doc
117
+
118
+
119
+ def _init_gptcache(cache_obj: Cache, llm: str):
120
+ cache_obj.init(
121
+ pre_embedding_func=get_prompt,
122
+ data_manager=manager_factory(data_dir=f"map_cache"),
123
+ # data_manager=get_data_manager(
124
+ # cache_base=CacheBase("mongo", url="mongodb://localhost:27017/"),
125
+ # vector_base=Chromadb(
126
+ # persist_directory="./chromadb/cache",
127
+ # ),
128
+ # )
129
+ )
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain-community
3
+ langchain-openai
4
+ langchain-groq
5
+ langchain-mistralai
6
+ langchain-cohere
7
+ langchain-google-genai
8
+ langchain-together
9
+
10
+ fitz
11
+ pypdf
12
+ tools
13
+ python-dotenv
14
+ pymongo
15
+ pydantic
16
+ chromadb
17
+ bs4
18
+ fastapi
19
+ gptcache
20
+
21
+ fastapi
22
+ requests
23
+ uvicorn[standard]
24
+
25
+ sentence-transformers
26
+ text_generation
retrieval.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+
3
+ from langchain.retrievers import ContextualCompressionRetriever
4
+ from langchain_cohere.rerank import CohereRerank
5
+ from langchain_core.vectorstores import VectorStoreRetriever
6
+
7
+ from emdedd.Embedding import Embedding
8
+ from emdedd.embeddings import embed_zakonnik_prace
9
+ from questions import questions
10
+
11
+
12
+ def retrieve(embedding, q, retrieve_document_count):
13
+ retriever: VectorStoreRetriever = embedding.get_vector_store().as_retriever(
14
+ search_type="similarity",
15
+ search_kwargs={"k": retrieve_document_count}
16
+ )
17
+
18
+ context_doc = retriever.get_relevant_documents(
19
+ query=q,
20
+ kwargs={"k": retrieve_document_count}
21
+ )
22
+
23
+ return context_doc
24
+
25
+
26
+ def retrieve_with_rerank(embedding, q, retrieve_document_count):
27
+ compression_retriever = reranking_retriever(embedding, retrieve_document_count)
28
+
29
+ context_doc = compression_retriever.invoke(
30
+ input=q,
31
+ kwargs={"k": retrieve_document_count}
32
+ )
33
+
34
+ # for doc in context_doc:
35
+ # text = doc.page_content
36
+ # print(" kontext: " + text.replace('\n', ' ').replace('\r', ' '))
37
+
38
+ return context_doc
39
+
40
+
41
+ def reranking_retriever(embedding, retrieve_document_count):
42
+ retriever: VectorStoreRetriever = embedding.get_vector_store().as_retriever(
43
+ search_type="similarity",
44
+ search_kwargs={"k": retrieve_document_count * 10}
45
+ )
46
+ compressor = CohereRerank(model="rerank-multilingual-v3.0")
47
+ compression_retriever = ContextualCompressionRetriever(
48
+ base_compressor=compressor, base_retriever=retriever
49
+ )
50
+ return compression_retriever
51
+
52
+
53
+ # todo
54
+ # def hyde(agent: Agent, q, retrieve_document_count):
55
+ # retriever: VectorStoreRetriever = agent.embedding.get_vector_store().as_retriever(
56
+ # search_type="similarity",
57
+ # search_kwargs={"k": retrieve_document_count * 10}
58
+ # )
59
+
60
+ #
61
+ # context_doc = compression_retriever.get_relevant_documents(
62
+ # query=q,
63
+ # kwargs={"k": retrieve_document_count}
64
+ # )
65
+ #
66
+ # for doc in context_doc:
67
+ # text = doc.page_content
68
+ # print(" kontext: " + text.replace('\n', ' ').replace('\r', ' '))
69
+ #
70
+ # return context_doc
71
+
72
+
73
+ def retrieve_test(name: str, embed_dict: dict[str, Embedding], emded: bool = False):
74
+ try:
75
+ result_file = open(name + "_retrieve_test.md", "a")
76
+ for embed_key, embedding in embed_dict.items():
77
+ if emded:
78
+ embed_zakonnik_prace(embedding)
79
+ print("--- Running on " + embed_key)
80
+ result_file.write("\n\n| " + embed_key + " | " + str(datetime.datetime.now()) + " |")
81
+ result_file.write("\n|-------|-----------|")
82
+ dobre: int = 0
83
+ for q in questions:
84
+ print(q)
85
+ context_doc = retrieve(embedding, q, 5)
86
+ for doc in context_doc:
87
+ text = doc.page_content
88
+ print(" kontext: " + text.replace('\n', ' ').replace('\r', ' '))
89
+ result_file.write("\n| " + q + " | " + text.replace('\n', ' ').replace('\r', ' ') + " |")
90
+ dobre = dobre + 1 if "§ 100" in text else dobre
91
+ dobre = dobre + 1 if "§ 101" in text else dobre
92
+ dobre = dobre + 1 if "§ 103" in text else dobre
93
+ dobre = dobre + 1 if "§ 104" in text else dobre
94
+ dobre = dobre + 1 if "§ 105" in text else dobre
95
+ dobre = dobre + 1 if "§ 106" in text else dobre
96
+ dobre = dobre + 1 if "§ 107" in text else dobre
97
+ dobre = dobre + 1 if "§ 109" in text else dobre
98
+ dobre = dobre + 1 if "§ 110" in text else dobre
99
+ dobre = dobre + 1 if "§ 111" in text else dobre
100
+ dobre = dobre + 1 if "§ 112" in text else dobre
101
+ dobre = dobre + 1 if "§ 113" in text else dobre
102
+ dobre = dobre + 1 if "§ 114" in text else dobre
103
+ dobre = dobre + 1 if "§ 115" in text else dobre
104
+ dobre = dobre + 1 if "§ 116" in text else dobre
105
+ dobre = dobre + 1 if "§ 117" in text else dobre
106
+ result_file.write("\n| Dobre: | " + str(dobre) + " |")
107
+ finally:
108
+ result_file.write("\n\n")
109
+ result_file.flush()
110
+ result_file.close()
task_splitting.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ from time import sleep
3
+
4
+ from langchain.chains import LLMChain
5
+ from langchain_core.prompts import PromptTemplate
6
+
7
+ from agent.Agent import Agent
8
+ from agent.agents import chat_groq_llama3_70
9
+ from emdedd.embeddings import cohere_embeddings, chroma_embedding, embed_zakonnik_prace
10
+ from promts import for_tree_llama3_rag_sub, for_tree_llama3_rag_tree, for_tree_llama3_rag_group
11
+ from retrieval import retrieve_with_rerank
12
+ from questions import questions
13
+
14
+
15
+ def rag_tree(agent: Agent, q: str, retrieve_document_count: int) -> str:
16
+ tree_template = PromptTemplate(
17
+ input_variables=["context", "question"],
18
+ template=for_tree_llama3_rag_tree
19
+ )
20
+
21
+ context_doc = retrieve_with_rerank(agent.embedding, q, retrieve_document_count * 2)
22
+
23
+ sub_qs = LLMChain(
24
+ llm=agent.llm,
25
+ prompt=tree_template,
26
+ verbose=False
27
+ ).invoke(
28
+ input={
29
+ "question": q,
30
+ "context": context_doc
31
+ }
32
+ )["text"]
33
+
34
+ print(sub_qs)
35
+ sleep(60)
36
+
37
+ print("_________")
38
+
39
+ sub_template = PromptTemplate(
40
+ input_variables=["context", "question"],
41
+ template=for_tree_llama3_rag_sub
42
+ )
43
+
44
+ sub_answers: dict[str, str] = {}
45
+
46
+ for sub_q in sub_qs.splitlines():
47
+ if "?" not in sub_q: continue
48
+ print(sub_q)
49
+ sub_answers[sub_q] = LLMChain(
50
+ llm=agent.llm,
51
+ prompt=sub_template,
52
+ verbose=False
53
+ ).invoke(
54
+ input={
55
+ "question": sub_q,
56
+ "context": retrieve_with_rerank(agent.embedding, sub_q, retrieve_document_count)
57
+ }
58
+ )["text"]
59
+ print(sub_answers[sub_q])
60
+ sleep(60)
61
+
62
+
63
+ final_template = PromptTemplate(
64
+ input_variables=["context", "question", "subs"],
65
+ template=for_tree_llama3_rag_group
66
+ )
67
+
68
+ result = LLMChain(
69
+ llm=agent.llm,
70
+ prompt=final_template,
71
+ verbose=True
72
+ ).invoke(
73
+ input={
74
+ "question": q,
75
+ "context": context_doc,
76
+ "subs": sub_answers.items()
77
+ }
78
+ )
79
+
80
+ return result["text"]
81
+
82
+
83
+ def tree_of_thought(name: str, agent: Agent, emded: bool = False, retrieve_document_count=5):
84
+ try:
85
+ result_file = open(name + "_test.md", "a")
86
+ if emded:
87
+ embed_zakonnik_prace(agent.embedding)
88
+ for q in questions:
89
+ print("--- Q: " + q)
90
+ result_file.write("\n\n| " + name + str(datetime.datetime.now()) + " | " + q + " |")
91
+ result_file.write("\n|-------|-----------|")
92
+
93
+ answer = rag_tree(agent, q, retrieve_document_count)
94
+ print(answer)
95
+ result_file.write(
96
+ "\n| tree | " + answer.replace('\r\n', '<br>').replace('\n', '<br>').replace('\r', '<br>') + " |")
97
+ sleep(60)
98
+ finally:
99
+ result_file.write("\n\n")
100
+ result_file.flush()
101
+ result_file.close()