Spaces:
Paused
Paused
Pavol Liška
commited on
Commit
·
869eb7d
1
Parent(s):
15d4751
v1
Browse files- Dockerfile +11 -0
- README.md +4 -6
- agent/Agent.py +14 -0
- agent/__init__.py +0 -0
- agent/agents.py +60 -0
- api.py +88 -0
- conversation/__init__.py +0 -0
- conversation/conversation_store.py +94 -0
- dto/__init__.py +0 -0
- dto/conversation.py +18 -0
- dto/document.py +11 -0
- dto/prompt.py +12 -0
- emdedd/ChromaEmbedding.py +28 -0
- emdedd/Embedding.py +20 -0
- emdedd/MongoEmbedding.py +91 -0
- emdedd/Splitter.py +21 -0
- emdedd/__init__.py +0 -0
- emdedd/embeddings.py +10 -0
- prompt/__init__.py +0 -0
- prompt/prompt_store.py +52 -0
- rag.py +400 -0
- rag_langchain.py +129 -0
- requirements.txt +26 -0
- retrieval.py +110 -0
- task_splitting.py +101 -0
Dockerfile
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.9
|
| 2 |
+
|
| 3 |
+
WORKDIR /code
|
| 4 |
+
|
| 5 |
+
COPY ./requirements.txt /code/requirements.txt
|
| 6 |
+
|
| 7 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
| 8 |
+
|
| 9 |
+
COPY . .
|
| 10 |
+
|
| 11 |
+
CMD ["uvicorn", "api:api", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,11 +1,9 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: agpl-3.0
|
| 9 |
---
|
| 10 |
-
|
| 11 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Try kopy
|
| 3 |
+
emoji: 💻📚🤖
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: gray
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
license: agpl-3.0
|
| 9 |
---
|
|
|
|
|
|
agent/Agent.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.embeddings import CacheBackedEmbeddings
|
| 2 |
+
from langchain.storage import LocalFileStore
|
| 3 |
+
from langchain_core.language_models import BaseChatModel
|
| 4 |
+
|
| 5 |
+
from emdedd.Embedding import Embedding
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Agent:
|
| 9 |
+
embedding: Embedding
|
| 10 |
+
llm: BaseChatModel
|
| 11 |
+
|
| 12 |
+
def __init__(self, embedding, llm):
|
| 13 |
+
self.embedding = embedding
|
| 14 |
+
self.llm = llm
|
agent/__init__.py
ADDED
|
File without changes
|
agent/agents.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from langchain_cohere.llms import Cohere
|
| 5 |
+
from langchain_community.chat_models import ChatDeepInfra
|
| 6 |
+
from langchain_groq import ChatGroq
|
| 7 |
+
from langchain_openai import ChatOpenAI
|
| 8 |
+
from langchain_together import ChatTogether
|
| 9 |
+
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def chat_openai_llm():
|
| 14 |
+
return ChatOpenAI(
|
| 15 |
+
model_name="gpt-4o",
|
| 16 |
+
temperature=os.environ["temperature"],
|
| 17 |
+
openai_api_key=os.environ["OPENAI_API_KEY"],
|
| 18 |
+
openai_organization=os.environ["OPENAI_ORGANIZATION_ID"]
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def groq_chat(model: str):
|
| 23 |
+
return ChatGroq(
|
| 24 |
+
model_name=model,
|
| 25 |
+
temperature=os.environ["temperature"],
|
| 26 |
+
groq_api_key=os.environ["GROQ_API_KEY"],
|
| 27 |
+
max_tokens=2000
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def cohere_llm():
|
| 32 |
+
return Cohere(
|
| 33 |
+
model="command-r-plus",
|
| 34 |
+
max_tokens=2048,
|
| 35 |
+
temperature=os.environ["temperature"],
|
| 36 |
+
# p=os.environ["top_p"],
|
| 37 |
+
# frequency_penalty=os.environ["frequency_penalty"],
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def together_ai_chat(model, temperature):
|
| 42 |
+
return ChatTogether(
|
| 43 |
+
model_name=model,
|
| 44 |
+
together_api_key=os.environ["TOGETHER_AI_API_KEY"],
|
| 45 |
+
temperature=temperature,
|
| 46 |
+
top_p=os.environ["top_p"],
|
| 47 |
+
max_tokens=4096,
|
| 48 |
+
model_kwargs={"stop": ["%%%%"]},
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def deepinfra_chat(model, temperature):
|
| 53 |
+
return ChatDeepInfra(
|
| 54 |
+
model=model,
|
| 55 |
+
deepinfra_api_token=os.environ["DEEPINFRA_API_KEY"],
|
| 56 |
+
temperature=temperature,
|
| 57 |
+
top_p=os.environ["top_p"],
|
| 58 |
+
max_tokens=4096,
|
| 59 |
+
model_kwargs={"stop": ["%%%%"]},
|
| 60 |
+
)
|
api.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, Response, Body, Security
|
| 2 |
+
from fastapi.security import APIKeyHeader
|
| 3 |
+
|
| 4 |
+
from conversation.conversation_store import ConversationStore
|
| 5 |
+
from rag_langchain import LangChainRAG
|
| 6 |
+
|
| 7 |
+
api = FastAPI()
|
| 8 |
+
|
| 9 |
+
conversation_store = ConversationStore()
|
| 10 |
+
|
| 11 |
+
api_key_header = APIKeyHeader(name="Authorization", auto_error=True)
|
| 12 |
+
prompt_id = "summarize_rag_1"
|
| 13 |
+
check_prompt_id = "check_control_challenge_step_back"
|
| 14 |
+
rewrite_prompt_id = "first"
|
| 15 |
+
default_llm = "gpt-4o 128k"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@api.get("/")
|
| 19 |
+
def read_root():
|
| 20 |
+
return "Empty"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@api.post("/q")
|
| 24 |
+
async def q(api_key: str = Security(api_key_header), json_body: dict = Body(...)):
|
| 25 |
+
# Verify the API key
|
| 26 |
+
if not valid_api_key(api_key):
|
| 27 |
+
return Response(status_code=401)
|
| 28 |
+
|
| 29 |
+
# Process the JSON body
|
| 30 |
+
data = json_body
|
| 31 |
+
|
| 32 |
+
rag = LangChainRAG(
|
| 33 |
+
config={
|
| 34 |
+
"retrieve_documents": data["retrieval_count"],
|
| 35 |
+
"temperature": data["temperature"],
|
| 36 |
+
"prompt_id": prompt_id,
|
| 37 |
+
"check_prompt_id": check_prompt_id,
|
| 38 |
+
"rewrite_prompt_id": rewrite_prompt_id
|
| 39 |
+
}
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
answer, check_result, sources = rag.rag_chain(data["q"], default_llm)
|
| 43 |
+
|
| 44 |
+
oid = conversation_store.save_content(
|
| 45 |
+
q=q,
|
| 46 |
+
a=answer,
|
| 47 |
+
sources=list(map(lambda doc: doc.page_content, sources)),
|
| 48 |
+
params=
|
| 49 |
+
{
|
| 50 |
+
"prompt_id": prompt_id,
|
| 51 |
+
"check_prompt_id": check_prompt_id,
|
| 52 |
+
"rewrite_prompt_id": rewrite_prompt_id,
|
| 53 |
+
"check_result": check_result,
|
| 54 |
+
"temperature": data["temperature"],
|
| 55 |
+
"retrieve_document_count": data["retrieval_count"],
|
| 56 |
+
}
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return Response(
|
| 60 |
+
status_code=200,
|
| 61 |
+
content={
|
| 62 |
+
"response": answer,
|
| 63 |
+
"sources": list(map(lambda doc: doc.page_content, sources)),
|
| 64 |
+
"qid": oid
|
| 65 |
+
}
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@api.post("/emo")
|
| 70 |
+
async def emo(api_key: str = Security(api_key_header), json_body: dict = Body(...)):
|
| 71 |
+
# Verify the API key
|
| 72 |
+
if not valid_api_key(api_key):
|
| 73 |
+
return Response(status_code=401)
|
| 74 |
+
|
| 75 |
+
qa = conversation_store.get(json_body["qid"])
|
| 76 |
+
new_params = qa.params
|
| 77 |
+
new_params["user_grading"] = str(json_body["helpfulness"])
|
| 78 |
+
conversation_store.update(
|
| 79 |
+
oid=json_body["qid"],
|
| 80 |
+
q=qa.conversation[0].q,
|
| 81 |
+
a=qa.conversation[0].a,
|
| 82 |
+
sources=qa.conversation[0].sources,
|
| 83 |
+
params=new_params
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def valid_api_key(api_key: str):
|
| 88 |
+
return api_key == "your_secret_api_key"
|
conversation/__init__.py
ADDED
|
File without changes
|
conversation/conversation_store.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from pymongo import MongoClient
|
| 5 |
+
|
| 6 |
+
from dto.conversation import Conversation, OneShotConversation
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ConversationStorage:
|
| 12 |
+
|
| 13 |
+
def __init__(self, db_uri, db_name, collection_name):
|
| 14 |
+
self.client = MongoClient(db_uri)
|
| 15 |
+
self.collection = self.client[db_name][collection_name]
|
| 16 |
+
|
| 17 |
+
def store_document_mongodb(self, conversation):
|
| 18 |
+
return self.collection.insert_one(conversation.model_dump()).inserted_id
|
| 19 |
+
|
| 20 |
+
def get_all(self):
|
| 21 |
+
docs = []
|
| 22 |
+
for q in self.collection.find():
|
| 23 |
+
docs.append(Conversation(**q))
|
| 24 |
+
return docs
|
| 25 |
+
|
| 26 |
+
def get_next(self, offset):
|
| 27 |
+
q = self.collection.find().limit(1).skip(offset)[0]
|
| 28 |
+
document = Conversation(**q)
|
| 29 |
+
print("Got {} stored at ".format(document.name, document.created))
|
| 30 |
+
return document
|
| 31 |
+
|
| 32 |
+
def get_one(self, oid) -> Conversation:
|
| 33 |
+
q = self.collection.find({"_id": oid})[0]
|
| 34 |
+
document = Conversation(**q)
|
| 35 |
+
print("Got {} stored at ".format(document.name, document.created))
|
| 36 |
+
return document
|
| 37 |
+
|
| 38 |
+
def count(self):
|
| 39 |
+
return self.collection.estimated_document_count()
|
| 40 |
+
|
| 41 |
+
def update(self, oid, conversation):
|
| 42 |
+
self.collection.update_one(
|
| 43 |
+
filter={"_id": oid},
|
| 44 |
+
update=conversation.model_dump()
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ConversationStore:
|
| 49 |
+
URI = os.environ["DB_CONN_LOG"]
|
| 50 |
+
DB_NAME = os.environ["MONGODB_DB_NAME_MSG_LOG"]
|
| 51 |
+
COLLECTION_NAME = os.environ["MONGODB_COL_NAME_MSG_LOG"]
|
| 52 |
+
|
| 53 |
+
storage = ConversationStorage(URI, DB_NAME, COLLECTION_NAME)
|
| 54 |
+
|
| 55 |
+
def save_content(self, q, a=None, sources: list[str] = None, params: dict[str, str] = None):
|
| 56 |
+
return self.storage.store_document_mongodb(
|
| 57 |
+
Conversation(
|
| 58 |
+
conversation=[
|
| 59 |
+
OneShotConversation(
|
| 60 |
+
q=q,
|
| 61 |
+
a=a,
|
| 62 |
+
sources=sources
|
| 63 |
+
)
|
| 64 |
+
],
|
| 65 |
+
params=params
|
| 66 |
+
)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def get_all(self):
|
| 70 |
+
return self.storage.get_all()
|
| 71 |
+
|
| 72 |
+
def get_one(self, offset):
|
| 73 |
+
return self.storage.get_next(offset)
|
| 74 |
+
|
| 75 |
+
def get(self, oid: str) -> Conversation:
|
| 76 |
+
return self.storage.get_one(oid)
|
| 77 |
+
|
| 78 |
+
def count(self):
|
| 79 |
+
return self.storage.count()
|
| 80 |
+
|
| 81 |
+
def update(self, oid, q, a, sources, params):
|
| 82 |
+
self.storage.update(
|
| 83 |
+
oid,
|
| 84 |
+
Conversation(
|
| 85 |
+
conversation=[
|
| 86 |
+
OneShotConversation(
|
| 87 |
+
q=q,
|
| 88 |
+
a=a,
|
| 89 |
+
sources=sources
|
| 90 |
+
)
|
| 91 |
+
],
|
| 92 |
+
params=params
|
| 93 |
+
)
|
| 94 |
+
)
|
dto/__init__.py
ADDED
|
File without changes
|
dto/conversation.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
from typing import Optional, Annotated
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel, Field, BeforeValidator
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class OneShotConversation(BaseModel):
|
| 9 |
+
q: str = Field()
|
| 10 |
+
a: str = Field()
|
| 11 |
+
sources: list[str] = Field()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Conversation(BaseModel):
|
| 15 |
+
id: Optional[Annotated[str, BeforeValidator(str)]] = Field(alias="_id", default=None)
|
| 16 |
+
created: datetime = Field(default_factory=datetime.now)
|
| 17 |
+
conversation: list[OneShotConversation] = Field()
|
| 18 |
+
params: dict[str, str] = Field()
|
dto/document.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
|
| 4 |
+
from pydantic import BaseModel, Field
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Document(BaseModel):
|
| 8 |
+
id: str = Field(default_factory=uuid.uuid4, alias="_id")
|
| 9 |
+
name: str = Field()
|
| 10 |
+
text: str = Field()
|
| 11 |
+
created: str = Field(default_factory=datetime.now)
|
dto/prompt.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
from typing import Optional, Annotated
|
| 3 |
+
|
| 4 |
+
from bson import ObjectId
|
| 5 |
+
from pydantic import BaseModel, Field, BeforeValidator
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Prompt(BaseModel):
|
| 9 |
+
id: Optional[Annotated[str, BeforeValidator(str)]] = Field(alias="_id", default=None)
|
| 10 |
+
name: str = Field()
|
| 11 |
+
text: str = Field()
|
| 12 |
+
prompt_type: str = Field()
|
emdedd/ChromaEmbedding.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.vectorstores.chroma import Chroma
|
| 2 |
+
|
| 3 |
+
from emdedd.Embedding import Embedding
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ChromaEmbedding(Embedding):
|
| 7 |
+
db: Chroma
|
| 8 |
+
|
| 9 |
+
def __init__(self, embedding, path, collection, collection_metadata = None):
|
| 10 |
+
self.db = Chroma(
|
| 11 |
+
embedding_function=embedding,
|
| 12 |
+
persist_directory=path,
|
| 13 |
+
collection_name=collection,
|
| 14 |
+
collection_metadata=collection_metadata
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def embedd(self, chunks, metadata: list[dict] = None):
|
| 18 |
+
self.__store_embeddings(chunks, metadata)
|
| 19 |
+
|
| 20 |
+
def __store_embeddings(self, chunks, metadata: list[dict] = None):
|
| 21 |
+
self.db.add_texts(
|
| 22 |
+
texts=chunks,
|
| 23 |
+
metadatas=metadata
|
| 24 |
+
)
|
| 25 |
+
self.db.persist()
|
| 26 |
+
|
| 27 |
+
def get_vector_store(self):
|
| 28 |
+
return self.db
|
emdedd/Embedding.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class EmbeddingDbConnection:
|
| 6 |
+
connection: str
|
| 7 |
+
database: str
|
| 8 |
+
collection: str
|
| 9 |
+
index: str
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Embedding:
|
| 13 |
+
def embedd(self, chunks, metadata: list[dict] = None):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
def get_vector_store(self):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
def search(self, query, search_type, doc_count):
|
| 20 |
+
pass
|
emdedd/MongoEmbedding.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
from langchain.embeddings import CacheBackedEmbeddings
|
| 4 |
+
from langchain.storage import LocalFileStore
|
| 5 |
+
from langchain_community.vectorstores.mongodb_atlas import MongoDBAtlasVectorSearch
|
| 6 |
+
from langchain_core.embeddings import Embeddings
|
| 7 |
+
from langchain_core.stores import InMemoryStore
|
| 8 |
+
from pymongo import MongoClient
|
| 9 |
+
from bson.objectid import ObjectId
|
| 10 |
+
|
| 11 |
+
from emdedd.Embedding import Embedding
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class EmbeddingDbConnection:
|
| 16 |
+
connection: str
|
| 17 |
+
database: str
|
| 18 |
+
collection: str
|
| 19 |
+
index: str
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class MongoEmbedding(Embedding):
|
| 23 |
+
db: EmbeddingDbConnection
|
| 24 |
+
embedding: Embeddings
|
| 25 |
+
|
| 26 |
+
def __init__(self, db, embedding, cache: bool = True):
|
| 27 |
+
self.db = db
|
| 28 |
+
if cache:
|
| 29 |
+
self.embedding = CacheBackedEmbeddings.from_bytes_store(
|
| 30 |
+
underlying_embeddings=embedding,
|
| 31 |
+
document_embedding_cache=InMemoryStore(),
|
| 32 |
+
namespace="mongo-embedding-cache"
|
| 33 |
+
)
|
| 34 |
+
else:
|
| 35 |
+
self.embedding = embedding
|
| 36 |
+
|
| 37 |
+
def embedd(self, chunks, metadata: list[dict] = None):
|
| 38 |
+
self.__store_embeddings(chunks, metadata)
|
| 39 |
+
|
| 40 |
+
def __store_embeddings(self, chunks, metadata: list[dict] = None):
|
| 41 |
+
client = MongoClient(self.db.connection)
|
| 42 |
+
collection = client[self.db.database][self.db.collection]
|
| 43 |
+
|
| 44 |
+
# collection.create_search_index(
|
| 45 |
+
# {"definition":
|
| 46 |
+
# {"mappings": {"dynamic": True, "fields": {
|
| 47 |
+
# "embedding": {
|
| 48 |
+
# "dimensions": 1536,
|
| 49 |
+
# "similarity": "cosine",
|
| 50 |
+
# "type": "knnVector"
|
| 51 |
+
# }}}},
|
| 52 |
+
# "name": self.MONGODB_INDEX_NAME
|
| 53 |
+
# }
|
| 54 |
+
# )
|
| 55 |
+
|
| 56 |
+
MongoDBAtlasVectorSearch.from_texts(
|
| 57 |
+
texts=chunks,
|
| 58 |
+
metadatas=metadata,
|
| 59 |
+
embedding=self.embedding,
|
| 60 |
+
collection=collection,
|
| 61 |
+
index_name=self.db.index
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
self.__add_id_to_metadata(collection)
|
| 65 |
+
|
| 66 |
+
def __add_id_to_metadata(self, collection):
|
| 67 |
+
for document in collection.find({"metadata.id": {"$exists": "false"}}):
|
| 68 |
+
metadata: dict = document["metadata"]
|
| 69 |
+
if metadata is None: metadata = {}
|
| 70 |
+
object_id: ObjectId = document["_id"]
|
| 71 |
+
metadata["id"] = object_id.__str__()
|
| 72 |
+
collection.update_one(
|
| 73 |
+
filter={"_id": object_id},
|
| 74 |
+
update={"metadata": metadata}
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def get_vector_store(self):
|
| 78 |
+
return MongoDBAtlasVectorSearch.from_connection_string(
|
| 79 |
+
self.db.connection,
|
| 80 |
+
self.db.database + "." + self.db.collection,
|
| 81 |
+
embedding=self.embedding,
|
| 82 |
+
index_name=self.db.index
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def search(self, query, search_type, doc_count):
|
| 86 |
+
vector_store = self.get_vector_store()
|
| 87 |
+
retriever = vector_store.as_retriever(
|
| 88 |
+
search_type="similarity",
|
| 89 |
+
search_kwargs={"k": doc_count}
|
| 90 |
+
)
|
| 91 |
+
return retriever.get_relevant_documents(query=query)
|
emdedd/Splitter.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Splitter:
|
| 5 |
+
separators = []
|
| 6 |
+
chunk_overlap: int
|
| 7 |
+
chunk_size: int
|
| 8 |
+
|
| 9 |
+
def __init__(self, separators, chunk_overlap, chunk_size):
|
| 10 |
+
self.separators = separators
|
| 11 |
+
self.chunk_overlap = chunk_overlap
|
| 12 |
+
self.chunk_size = chunk_size
|
| 13 |
+
|
| 14 |
+
def split(self, text):
|
| 15 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
| 16 |
+
separators=self.separators,
|
| 17 |
+
is_separator_regex=True,
|
| 18 |
+
chunk_overlap=self.chunk_overlap,
|
| 19 |
+
chunk_size=self.chunk_size
|
| 20 |
+
)
|
| 21 |
+
return text_splitter.split_text(text)
|
emdedd/__init__.py
ADDED
|
File without changes
|
emdedd/embeddings.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from emdedd.ChromaEmbedding import ChromaEmbedding
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def chroma_embedding(name: str, embedding, metadata=None) -> ChromaEmbedding:
|
| 5 |
+
return ChromaEmbedding(
|
| 6 |
+
path="./chromadb/zpl",
|
| 7 |
+
embedding=embedding,
|
| 8 |
+
collection=name,
|
| 9 |
+
collection_metadata=metadata
|
| 10 |
+
)
|
prompt/__init__.py
ADDED
|
File without changes
|
prompt/prompt_store.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from pymongo import MongoClient
|
| 5 |
+
|
| 6 |
+
from dto.prompt import Prompt
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class PromptStore:
|
| 12 |
+
URI = os.environ["DB_CONN_LOG"]
|
| 13 |
+
DB_NAME = os.environ["MONGODB_DB_NAME_MSG_LOG"]
|
| 14 |
+
COLLECTION_NAME = os.environ["MONGODB_COL_NAME_PROMPT"]
|
| 15 |
+
client = MongoClient(URI)
|
| 16 |
+
collection = client[DB_NAME][COLLECTION_NAME]
|
| 17 |
+
|
| 18 |
+
def save_content(self, name, text, prompt_type):
|
| 19 |
+
self.collection.insert_one(
|
| 20 |
+
Prompt(
|
| 21 |
+
name=name,
|
| 22 |
+
text=text,
|
| 23 |
+
prompt_type=prompt_type
|
| 24 |
+
).model_dump()
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def get_all(self):
|
| 28 |
+
docs = []
|
| 29 |
+
for q in self.collection.find():
|
| 30 |
+
docs.append(Prompt(**q))
|
| 31 |
+
return docs
|
| 32 |
+
|
| 33 |
+
def get_one(self, offset) -> Prompt:
|
| 34 |
+
q = self.collection.find().limit(1).skip(offset)[0]
|
| 35 |
+
document = Prompt(**q)
|
| 36 |
+
print("Got prompt {}".format(document.name))
|
| 37 |
+
return document
|
| 38 |
+
|
| 39 |
+
def get(self, oid: str) -> Prompt:
|
| 40 |
+
q = self.collection.find({"_id": oid})[0]
|
| 41 |
+
document = Prompt(**q)
|
| 42 |
+
print("Got prompt {}".format(document.name))
|
| 43 |
+
return document
|
| 44 |
+
|
| 45 |
+
def count(self):
|
| 46 |
+
return self.collection.estimated_document_count()
|
| 47 |
+
|
| 48 |
+
def get_by_name(self, name) -> Prompt:
|
| 49 |
+
q = self.collection.find({"name": name})[0]
|
| 50 |
+
document = Prompt(**q)
|
| 51 |
+
print("Got prompt {}".format(document.name))
|
| 52 |
+
return document
|
rag.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import os
|
| 3 |
+
import traceback
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
from langchain.chains import LLMChain
|
| 8 |
+
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 9 |
+
from langchain.chains.retrieval import create_retrieval_chain
|
| 10 |
+
from langchain.retrievers import MultiQueryRetriever, MergerRetriever, ContextualCompressionRetriever, EnsembleRetriever
|
| 11 |
+
from langchain_cohere import CohereRerank
|
| 12 |
+
from langchain_core.documents import Document
|
| 13 |
+
from langchain_core.prompts import PromptTemplate
|
| 14 |
+
|
| 15 |
+
from agent.Agent import Agent
|
| 16 |
+
from agent.agents import chat_openai_llm, deepinfra_chat
|
| 17 |
+
from conversation.conversation_store import ConversationStore
|
| 18 |
+
from prompt.prompt_store import PromptStore
|
| 19 |
+
from retrieval import retrieve, retrieve_with_rerank
|
| 20 |
+
|
| 21 |
+
load_dotenv()
|
| 22 |
+
|
| 23 |
+
conversation_store = ConversationStore()
|
| 24 |
+
prompt_store = PromptStore()
|
| 25 |
+
|
| 26 |
+
grammar_check_1 = prompt_store.get_by_name("gramar_check_1")
|
| 27 |
+
rewrite_hyde_1 = prompt_store.get_by_name("rewrite_hyde_1")
|
| 28 |
+
rewrite_hyde_2 = prompt_store.get_by_name("rewrite_hyde_2")
|
| 29 |
+
rewrite_1 = prompt_store.get_by_name("rewrite_1")
|
| 30 |
+
rewrite_2 = prompt_store.get_by_name("rewrite_2")
|
| 31 |
+
rewrite_hyde = prompt_store.get_by_name("rewrite_hyde")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def replace_nl(input: str) -> str:
|
| 35 |
+
return input.replace('\r\n', '<br>').replace('\n', '<br>').replace('\r', '<br>')
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def rag(agent: Agent, q: str, retrieve_document_count: int):
|
| 39 |
+
k = retrieve_document_count
|
| 40 |
+
|
| 41 |
+
context_doc = retrieve(agent.embedding, q, k)
|
| 42 |
+
|
| 43 |
+
prompt_template = PromptTemplate(
|
| 44 |
+
input_variables=["context", "question"],
|
| 45 |
+
template=os.environ["RAG_TEMPLATE"]
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
llm_chain = LLMChain(
|
| 49 |
+
llm=agent.llm,
|
| 50 |
+
prompt=prompt_template,
|
| 51 |
+
verbose=False
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
# llm_chain = prompt_template | agent.llm
|
| 55 |
+
|
| 56 |
+
result: dict[str, Any] = llm_chain.invoke(
|
| 57 |
+
input={
|
| 58 |
+
"question": q,
|
| 59 |
+
"context": context_doc
|
| 60 |
+
}
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
return result["text"]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def rewrite(agent: Agent, q: str, prompt: str) -> list[str]:
|
| 67 |
+
prompt_template = PromptTemplate(
|
| 68 |
+
input_variables=["question"],
|
| 69 |
+
template=prompt
|
| 70 |
+
)
|
| 71 |
+
llm_chain = LLMChain(
|
| 72 |
+
llm=agent.llm,
|
| 73 |
+
prompt=prompt_template,
|
| 74 |
+
verbose=False
|
| 75 |
+
)
|
| 76 |
+
questions = llm_chain.invoke(
|
| 77 |
+
input={"question": q}
|
| 78 |
+
)["text"].splitlines()
|
| 79 |
+
|
| 80 |
+
return [x for x in questions if ("##" not in x and len(str(x).strip()) > 0)]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def rag_with_rerank_check_rewrite(agent: Agent, q: str, retrieve_document_count: int, prompt: str, check_prompt: str,
|
| 84 |
+
rewrite_prompt: str):
|
| 85 |
+
rewritten_list: list[str] = rewrite(agent, q, rewrite_prompt)
|
| 86 |
+
|
| 87 |
+
if len(rewritten_list) == 0:
|
| 88 |
+
return "Neviem, nemám podklady!", "", ""
|
| 89 |
+
|
| 90 |
+
context_doc = retrieve_subqueries(agent, retrieve_document_count, rewritten_list)
|
| 91 |
+
|
| 92 |
+
if len(context_doc) == 0:
|
| 93 |
+
return "Neviem, nemám kontext!", "", ""
|
| 94 |
+
|
| 95 |
+
result = answer_pipeline(agent, context_doc, prompt, q)
|
| 96 |
+
answer = result["text"]
|
| 97 |
+
|
| 98 |
+
check_result = check_pipeline(answer, check_prompt, context_doc, q)
|
| 99 |
+
|
| 100 |
+
return answer, check_result, context_doc
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def rag_with_rerank_check_rewrite_hyde(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
|
| 104 |
+
check_prompt: str,
|
| 105 |
+
rewrite_prompt: str):
|
| 106 |
+
rewritten_list: list[str] = rewrite(agent, q, rewrite_prompt)
|
| 107 |
+
|
| 108 |
+
if len(rewritten_list) == 0:
|
| 109 |
+
return "Neviem, nemám podklady!", "", ""
|
| 110 |
+
|
| 111 |
+
context_doc = retrieve_subqueries_hyde(agent, retrieve_document_count, rewritten_list)
|
| 112 |
+
|
| 113 |
+
if len(context_doc) == 0:
|
| 114 |
+
return "Neviem, nemám kontext!", "", ""
|
| 115 |
+
|
| 116 |
+
result = answer_pipeline(agent, context_doc, prompt, q)
|
| 117 |
+
answer = result["text"]
|
| 118 |
+
|
| 119 |
+
check_result = check_pipeline(answer, check_prompt, context_doc, q)
|
| 120 |
+
|
| 121 |
+
return answer, check_result, context_doc
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def rag_with_rerank_check_multi_query_retriever(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
|
| 125 |
+
check_prompt: str):
|
| 126 |
+
context_doc = hyde_retrieval(agent, retrieve_document_count).invoke(
|
| 127 |
+
input=q,
|
| 128 |
+
kwargs={"k": retrieve_document_count}
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if len(context_doc) == 0:
|
| 132 |
+
return "Neviem, nemám kontext!", "", ""
|
| 133 |
+
|
| 134 |
+
result = answer_pipeline(agent, context_doc, prompt, q)
|
| 135 |
+
answer = result["text"]
|
| 136 |
+
|
| 137 |
+
check_result = check_pipeline(answer, check_prompt, context_doc, q)
|
| 138 |
+
|
| 139 |
+
return answer, check_result, context_doc
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def rag_chain(agent: Agent, q: str, retrieve_document_count: int, prompt: str,
|
| 143 |
+
check_prompt: str):
|
| 144 |
+
result = create_retrieval_chain(
|
| 145 |
+
retriever=hyde_2_retrieval(agent, retrieve_document_count),
|
| 146 |
+
combine_docs_chain=create_stuff_documents_chain(
|
| 147 |
+
llm=agent.llm,
|
| 148 |
+
prompt=PromptTemplate(
|
| 149 |
+
input_variables=["context", "question", "actual_date"],
|
| 150 |
+
template=prompt
|
| 151 |
+
)
|
| 152 |
+
)
|
| 153 |
+
).invoke(
|
| 154 |
+
input={
|
| 155 |
+
"question": q,
|
| 156 |
+
"input": q,
|
| 157 |
+
"actual_date": datetime.date.today().isoformat()
|
| 158 |
+
}
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
check_result = check_pipeline(result["answer"], check_prompt, result["context"], q)
|
| 162 |
+
|
| 163 |
+
return result["answer"], check_result, result["context"]
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def hyde_retrieval(agent, retrieve_document_count):
|
| 167 |
+
retriever_1 = MultiQueryRetriever.from_llm(
|
| 168 |
+
llm=agent.llm,
|
| 169 |
+
retriever=agent.embedding.get_vector_store().as_retriever(
|
| 170 |
+
search_type="similarity",
|
| 171 |
+
search_kwargs={"k": retrieve_document_count}
|
| 172 |
+
),
|
| 173 |
+
prompt=PromptTemplate(
|
| 174 |
+
input_variables=["question"],
|
| 175 |
+
template=rewrite_hyde_1
|
| 176 |
+
)
|
| 177 |
+
)
|
| 178 |
+
retriever_2 = MultiQueryRetriever.from_llm(
|
| 179 |
+
llm=agent.llm,
|
| 180 |
+
retriever=agent.embedding.get_vector_store().as_retriever(
|
| 181 |
+
search_type="similarity",
|
| 182 |
+
search_kwargs={"k": retrieve_document_count}
|
| 183 |
+
),
|
| 184 |
+
prompt=PromptTemplate(
|
| 185 |
+
input_variables=["question"],
|
| 186 |
+
template=rewrite_hyde_2
|
| 187 |
+
)
|
| 188 |
+
)
|
| 189 |
+
merge_retriever = MergerRetriever(
|
| 190 |
+
retrievers=[retriever_1, retriever_2],
|
| 191 |
+
)
|
| 192 |
+
compressor = CohereRerank(
|
| 193 |
+
model="rerank-multilingual-v3.0",
|
| 194 |
+
top_n=retrieve_document_count
|
| 195 |
+
)
|
| 196 |
+
compression_retriever = ContextualCompressionRetriever(
|
| 197 |
+
base_compressor=compressor,
|
| 198 |
+
base_retriever=merge_retriever,
|
| 199 |
+
search_kwargs={"k": retrieve_document_count},
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
return compression_retriever
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def hyde_2_retrieval(agent, retrieve_document_count):
|
| 206 |
+
compressor = CohereRerank(
|
| 207 |
+
model="rerank-multilingual-v3.0",
|
| 208 |
+
top_n=retrieve_document_count / 2
|
| 209 |
+
)
|
| 210 |
+
retriever_1 = MultiQueryRetriever.from_llm(
|
| 211 |
+
llm=agent.llm,
|
| 212 |
+
retriever=agent.embedding.get_vector_store().as_retriever(
|
| 213 |
+
search_type="similarity",
|
| 214 |
+
search_kwargs={"k": min(retrieve_document_count * 10, 500)}
|
| 215 |
+
),
|
| 216 |
+
prompt=PromptTemplate(
|
| 217 |
+
input_variables=["question"],
|
| 218 |
+
template=rewrite_1
|
| 219 |
+
)
|
| 220 |
+
)
|
| 221 |
+
compression_retriever_1 = ContextualCompressionRetriever(
|
| 222 |
+
base_compressor=compressor,
|
| 223 |
+
base_retriever=retriever_1
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
retriever_2 = MultiQueryRetriever.from_llm(
|
| 227 |
+
llm=agent.llm,
|
| 228 |
+
retriever=agent.embedding.get_vector_store().as_retriever(
|
| 229 |
+
search_type="similarity",
|
| 230 |
+
search_kwargs={"k": min(retrieve_document_count * 10, 500)}
|
| 231 |
+
),
|
| 232 |
+
prompt=PromptTemplate(
|
| 233 |
+
input_variables=["question"],
|
| 234 |
+
template=rewrite_2
|
| 235 |
+
)
|
| 236 |
+
)
|
| 237 |
+
compression_retriever_2 = ContextualCompressionRetriever(
|
| 238 |
+
base_compressor=compressor,
|
| 239 |
+
base_retriever=retriever_2
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
retriever_3 = MultiQueryRetriever.from_llm(
|
| 243 |
+
llm=agent.llm,
|
| 244 |
+
retriever=agent.embedding.get_vector_store().as_retriever(
|
| 245 |
+
search_type="similarity",
|
| 246 |
+
search_kwargs={"k": min(retrieve_document_count * 10, 500)}
|
| 247 |
+
),
|
| 248 |
+
prompt=PromptTemplate(
|
| 249 |
+
input_variables=["question"],
|
| 250 |
+
template=rewrite_hyde
|
| 251 |
+
)
|
| 252 |
+
)
|
| 253 |
+
compression_retriever_3 = ContextualCompressionRetriever(
|
| 254 |
+
base_compressor=compressor,
|
| 255 |
+
base_retriever=retriever_3
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
merge_retriever = EnsembleRetriever(
|
| 259 |
+
retrievers=[compression_retriever_1, compression_retriever_2, compression_retriever_3],
|
| 260 |
+
weights=[1.0, 1.0, 1.0]
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
return merge_retriever
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def retrieve_subqueries(agent, retrieve_document_count, rewritten_list) -> list[Document]:
|
| 267 |
+
contexts: list[Document] = []
|
| 268 |
+
for rewritten in rewritten_list:
|
| 269 |
+
contexts.extend(retrieve_with_rerank(agent.embedding, rewritten, retrieve_document_count))
|
| 270 |
+
|
| 271 |
+
contexts.sort(key=lambda x: -x.metadata["relevance_score"])
|
| 272 |
+
|
| 273 |
+
deduplicated: list[Document] = []
|
| 274 |
+
for doc in contexts:
|
| 275 |
+
already_in = False
|
| 276 |
+
for de_doc in deduplicated:
|
| 277 |
+
if doc.page_content == de_doc.page_content:
|
| 278 |
+
already_in = True
|
| 279 |
+
if not already_in:
|
| 280 |
+
deduplicated.append(doc)
|
| 281 |
+
|
| 282 |
+
return deduplicated[:retrieve_document_count]
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def retrieve_subqueries_hyde(agent, retrieve_document_count, rewritten_list) -> list[Document]:
|
| 286 |
+
contexts: list[Document] = []
|
| 287 |
+
for rewritten in rewritten_list:
|
| 288 |
+
answer = agent.llm.invoke(rewritten).content
|
| 289 |
+
contexts.extend(retrieve_with_rerank(agent.embedding, rewritten + "\n" + answer, retrieve_document_count))
|
| 290 |
+
|
| 291 |
+
contexts.sort(key=lambda x: -x.metadata["relevance_score"])
|
| 292 |
+
|
| 293 |
+
deduplicated: list[Document] = []
|
| 294 |
+
for doc in contexts:
|
| 295 |
+
already_in = False
|
| 296 |
+
for de_doc in deduplicated:
|
| 297 |
+
if doc.page_content == de_doc.page_content:
|
| 298 |
+
already_in = True
|
| 299 |
+
if not already_in:
|
| 300 |
+
deduplicated.append(doc)
|
| 301 |
+
|
| 302 |
+
return deduplicated[:retrieve_document_count]
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def answer_pipeline(agent, context_doc, prompt, q):
|
| 306 |
+
prompt_template = PromptTemplate(
|
| 307 |
+
input_variables=["context", "question"],
|
| 308 |
+
template=prompt
|
| 309 |
+
)
|
| 310 |
+
llm_chain = LLMChain(
|
| 311 |
+
llm=agent.llm,
|
| 312 |
+
prompt=prompt_template,
|
| 313 |
+
verbose=False
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
result: dict[str, Any] = llm_chain.invoke(
|
| 317 |
+
input={
|
| 318 |
+
"question": q,
|
| 319 |
+
"context": context_doc,
|
| 320 |
+
"actual_date": datetime.date.today().isoformat()
|
| 321 |
+
}
|
| 322 |
+
)
|
| 323 |
+
return result
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def check_pipeline(answer, check_prompt, context_doc, q):
|
| 327 |
+
prompt_template = PromptTemplate(
|
| 328 |
+
input_variables=["context", "question", "answer"],
|
| 329 |
+
template=check_prompt
|
| 330 |
+
)
|
| 331 |
+
llm_chain = LLMChain(
|
| 332 |
+
llm=deepinfra_chat("meta-llama/Meta-Llama-3-70B-Instruct", "0.4"),
|
| 333 |
+
prompt=prompt_template,
|
| 334 |
+
verbose=False
|
| 335 |
+
)
|
| 336 |
+
try:
|
| 337 |
+
check_result = llm_chain.invoke(
|
| 338 |
+
input={
|
| 339 |
+
"question": q[:2000],
|
| 340 |
+
"context": context_doc,
|
| 341 |
+
"answer": answer
|
| 342 |
+
}
|
| 343 |
+
)["text"]
|
| 344 |
+
except Exception as e:
|
| 345 |
+
check_result = traceback.format_exc()
|
| 346 |
+
|
| 347 |
+
return check_result
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def rag_with_rerank(agent: Agent, q: str, retrieve_document_count: int, prompt: str = None, check_prompt: str = None):
|
| 351 |
+
context_doc: list[Document] = retrieve_with_rerank(agent.embedding, q, retrieve_document_count)
|
| 352 |
+
|
| 353 |
+
try:
|
| 354 |
+
result: dict[str, Any] = answer_pipeline(agent, context_doc, prompt, q)
|
| 355 |
+
|
| 356 |
+
answer = result["text"]
|
| 357 |
+
check_result = ""
|
| 358 |
+
|
| 359 |
+
if check_prompt is not None:
|
| 360 |
+
check_result = check_pipeline(answer, check_prompt, context_doc, q)
|
| 361 |
+
|
| 362 |
+
return answer, check_result, context_doc
|
| 363 |
+
except Exception as e:
|
| 364 |
+
return "", traceback.format_exc(), ""
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def save_conversation(answer: str, check_result: str, context_doc: list[Document], gramatika: str, question: str,
|
| 368 |
+
prompt_id: str, check_prompt_id: str, grammar_prompt_id: str):
|
| 369 |
+
if len(answer) > 0:
|
| 370 |
+
conversation_store.save_content(
|
| 371 |
+
q=question,
|
| 372 |
+
a=answer,
|
| 373 |
+
sources=list(map(lambda doc: doc.page_content, context_doc)),
|
| 374 |
+
params=
|
| 375 |
+
{
|
| 376 |
+
"prompt_id": prompt_id,
|
| 377 |
+
"check_prompt_id": check_prompt_id,
|
| 378 |
+
"grammar_prompt_id": grammar_prompt_id,
|
| 379 |
+
"check_result": check_result,
|
| 380 |
+
"grammar_result": gramatika,
|
| 381 |
+
"temperature": os.environ["temperature"],
|
| 382 |
+
}
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def check_slovak_agent(text: str) -> str:
|
| 387 |
+
prompt_template = PromptTemplate(
|
| 388 |
+
input_variables=["text"],
|
| 389 |
+
template=grammar_check_1
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
llm_chain = LLMChain(
|
| 393 |
+
llm=chat_openai_llm(),
|
| 394 |
+
prompt=prompt_template,
|
| 395 |
+
verbose=False
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
result: dict[str, Any] = llm_chain.invoke(input={"text": text})
|
| 399 |
+
|
| 400 |
+
return result["text"]
|
rag_langchain.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from gptcache import Cache
|
| 5 |
+
from gptcache.manager.factory import manager_factory
|
| 6 |
+
from gptcache.processor.pre import get_prompt
|
| 7 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
| 8 |
+
from langchain_cohere import CohereRerank
|
| 9 |
+
from langchain_community.cache import GPTCache
|
| 10 |
+
from langchain_core.language_models import BaseChatModel
|
| 11 |
+
from langchain_core.prompts import PromptTemplate
|
| 12 |
+
from langchain_core.retrievers import BaseRetriever
|
| 13 |
+
from langchain_google_genai import ChatGoogleGenerativeAI, HarmCategory, HarmBlockThreshold
|
| 14 |
+
from langchain_openai import ChatOpenAI
|
| 15 |
+
|
| 16 |
+
from agent.Agent import Agent
|
| 17 |
+
from agent.agents import deepinfra_chat, \
|
| 18 |
+
together_ai_chat, groq_chat, cohere_llm
|
| 19 |
+
from emdedd.Embedding import Embedding
|
| 20 |
+
from emdedd.embeddings import chroma_embedding, cohere_embeddings
|
| 21 |
+
from prompt.prompt_store import PromptStore
|
| 22 |
+
from rag import rag_chain
|
| 23 |
+
|
| 24 |
+
load_dotenv()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LangChainRAG:
|
| 28 |
+
embedding: tuple[Embedding]
|
| 29 |
+
llms: dict[str, BaseChatModel]
|
| 30 |
+
retriever: BaseRetriever
|
| 31 |
+
prompt_template: PromptTemplate
|
| 32 |
+
config: dict
|
| 33 |
+
semantic_cache: GPTCache
|
| 34 |
+
prompt_store = PromptStore()
|
| 35 |
+
|
| 36 |
+
def __init__(self, config):
|
| 37 |
+
self.config = config
|
| 38 |
+
self.semantic_cache = GPTCache(_init_gptcache)
|
| 39 |
+
self.embedding = MongoEmbedding(
|
| 40 |
+
db=EmbeddingDbConnection(
|
| 41 |
+
connection=os.environ["DB_CONN_EMBED"],
|
| 42 |
+
database=os.environ["MONGODB_DB_NAME_ZPL_EMBED"],
|
| 43 |
+
collection="zpl-2402-cohere",
|
| 44 |
+
index="knnVector-cosine-index"
|
| 45 |
+
),
|
| 46 |
+
embedding=CohereEmbeddings(model="embed-multilingual-v3.0")
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
self.llms = {
|
| 50 |
+
"gpt-4o 128k": ChatOpenAI(
|
| 51 |
+
model_name="gpt-4o",
|
| 52 |
+
temperature=config["temperature"],
|
| 53 |
+
openai_api_key=os.environ["OPENAI_API_KEY"],
|
| 54 |
+
openai_organization=os.environ["OPENAI_ORGANIZATION_ID"]
|
| 55 |
+
),
|
| 56 |
+
"llama-3 70B deepinfra 8k": deepinfra_chat("meta-llama/Meta-Llama-3-70B-Instruct",
|
| 57 |
+
self.config["temperature"]),
|
| 58 |
+
"llama-3 8B deepinfra 8k": deepinfra_chat("meta-llama/Meta-Llama-3-8B-Instruct",
|
| 59 |
+
self.config["temperature"]),
|
| 60 |
+
"Mixtral-8x22B-Instruct-v0.1 deepinfra 32k": deepinfra_chat("mistralai/Mixtral-8x22B-Instruct-v0.1",
|
| 61 |
+
self.config["temperature"]),
|
| 62 |
+
"gemini-pro 128k": ChatGoogleGenerativeAI(
|
| 63 |
+
model="gemini-pro",
|
| 64 |
+
convert_system_message_to_human=True,
|
| 65 |
+
safety_settings={
|
| 66 |
+
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
| 67 |
+
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
| 68 |
+
HarmCategory.HARM_CATEGORY_DEROGATORY: HarmBlockThreshold.BLOCK_NONE,
|
| 69 |
+
HarmCategory.HARM_CATEGORY_UNSPECIFIED: HarmBlockThreshold.BLOCK_NONE,
|
| 70 |
+
},
|
| 71 |
+
transport="rest",
|
| 72 |
+
stopSequence=["%%%%"],
|
| 73 |
+
temperature=config["temperature"],
|
| 74 |
+
cache=self.semantic_cache
|
| 75 |
+
),
|
| 76 |
+
"Mistral (7B) Instruct v0.3 together.ai 32k": together_ai_chat(
|
| 77 |
+
model="mistralai/Mistral-7B-Instruct-v0.3",
|
| 78 |
+
temperature=config["temperature"]
|
| 79 |
+
),
|
| 80 |
+
"OpenHermes-2.5 Mistral 7B together.ai 32k": together_ai_chat(
|
| 81 |
+
model="teknium/OpenHermes-2p5-Mistral-7B",
|
| 82 |
+
temperature=config["temperature"]
|
| 83 |
+
),
|
| 84 |
+
"chat_groq_llm": groq_chat("mixtral-8x7b-32768"),
|
| 85 |
+
"chat_groq_llama3_70": groq_chat("llama3-70b-8192"),
|
| 86 |
+
"command_r_plus": cohere_llm(),
|
| 87 |
+
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
self.retriever = ContextualCompressionRetriever(
|
| 91 |
+
base_compressor=CohereRerank(model="rerank-multilingual-v3.0", top_n=os.getenv("retrieve_documents")),
|
| 92 |
+
base_retriever=self.get_vector_store_mongodb().as_retriever(
|
| 93 |
+
search_type="similarity",
|
| 94 |
+
search_kwargs={"k": config["retrieve_documents"] * 10}
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def get_vector_store_mongodb(self):
|
| 99 |
+
return self.embedding[0].get_vector_store()
|
| 100 |
+
|
| 101 |
+
def get_llms(self):
|
| 102 |
+
return self.llms.keys()
|
| 103 |
+
|
| 104 |
+
def rag_chain(self, query, choice):
|
| 105 |
+
# answer, check_result, context_doc = rag_with_rerank_check_rewrite(
|
| 106 |
+
# answer, check_result, context_doc = rag_with_rerank_check_rewrite_hyde(
|
| 107 |
+
# answer, check_result, context_doc = rag_with_rerank_check_multi_query_retriever(
|
| 108 |
+
answer, check_result, context_doc = rag_chain(
|
| 109 |
+
Agent(embedding=self.embedding[0], llm=self.llms[choice]),
|
| 110 |
+
query,
|
| 111 |
+
self.config["retrieve_documents"],
|
| 112 |
+
self.prompt_store.get_by_name(self.config["prompt_id"]).text,
|
| 113 |
+
self.prompt_store.get_by_name(self.config["check_prompt_id"]).text
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
return answer, check_result, context_doc
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _init_gptcache(cache_obj: Cache, llm: str):
|
| 120 |
+
cache_obj.init(
|
| 121 |
+
pre_embedding_func=get_prompt,
|
| 122 |
+
data_manager=manager_factory(data_dir=f"map_cache"),
|
| 123 |
+
# data_manager=get_data_manager(
|
| 124 |
+
# cache_base=CacheBase("mongo", url="mongodb://localhost:27017/"),
|
| 125 |
+
# vector_base=Chromadb(
|
| 126 |
+
# persist_directory="./chromadb/cache",
|
| 127 |
+
# ),
|
| 128 |
+
# )
|
| 129 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
langchain
|
| 2 |
+
langchain-community
|
| 3 |
+
langchain-openai
|
| 4 |
+
langchain-groq
|
| 5 |
+
langchain-mistralai
|
| 6 |
+
langchain-cohere
|
| 7 |
+
langchain-google-genai
|
| 8 |
+
langchain-together
|
| 9 |
+
|
| 10 |
+
fitz
|
| 11 |
+
pypdf
|
| 12 |
+
tools
|
| 13 |
+
python-dotenv
|
| 14 |
+
pymongo
|
| 15 |
+
pydantic
|
| 16 |
+
chromadb
|
| 17 |
+
bs4
|
| 18 |
+
fastapi
|
| 19 |
+
gptcache
|
| 20 |
+
|
| 21 |
+
fastapi
|
| 22 |
+
requests
|
| 23 |
+
uvicorn[standard]
|
| 24 |
+
|
| 25 |
+
sentence-transformers
|
| 26 |
+
text_generation
|
retrieval.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
|
| 3 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
| 4 |
+
from langchain_cohere.rerank import CohereRerank
|
| 5 |
+
from langchain_core.vectorstores import VectorStoreRetriever
|
| 6 |
+
|
| 7 |
+
from emdedd.Embedding import Embedding
|
| 8 |
+
from emdedd.embeddings import embed_zakonnik_prace
|
| 9 |
+
from questions import questions
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def retrieve(embedding, q, retrieve_document_count):
|
| 13 |
+
retriever: VectorStoreRetriever = embedding.get_vector_store().as_retriever(
|
| 14 |
+
search_type="similarity",
|
| 15 |
+
search_kwargs={"k": retrieve_document_count}
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
context_doc = retriever.get_relevant_documents(
|
| 19 |
+
query=q,
|
| 20 |
+
kwargs={"k": retrieve_document_count}
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
return context_doc
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def retrieve_with_rerank(embedding, q, retrieve_document_count):
|
| 27 |
+
compression_retriever = reranking_retriever(embedding, retrieve_document_count)
|
| 28 |
+
|
| 29 |
+
context_doc = compression_retriever.invoke(
|
| 30 |
+
input=q,
|
| 31 |
+
kwargs={"k": retrieve_document_count}
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# for doc in context_doc:
|
| 35 |
+
# text = doc.page_content
|
| 36 |
+
# print(" kontext: " + text.replace('\n', ' ').replace('\r', ' '))
|
| 37 |
+
|
| 38 |
+
return context_doc
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def reranking_retriever(embedding, retrieve_document_count):
|
| 42 |
+
retriever: VectorStoreRetriever = embedding.get_vector_store().as_retriever(
|
| 43 |
+
search_type="similarity",
|
| 44 |
+
search_kwargs={"k": retrieve_document_count * 10}
|
| 45 |
+
)
|
| 46 |
+
compressor = CohereRerank(model="rerank-multilingual-v3.0")
|
| 47 |
+
compression_retriever = ContextualCompressionRetriever(
|
| 48 |
+
base_compressor=compressor, base_retriever=retriever
|
| 49 |
+
)
|
| 50 |
+
return compression_retriever
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# todo
|
| 54 |
+
# def hyde(agent: Agent, q, retrieve_document_count):
|
| 55 |
+
# retriever: VectorStoreRetriever = agent.embedding.get_vector_store().as_retriever(
|
| 56 |
+
# search_type="similarity",
|
| 57 |
+
# search_kwargs={"k": retrieve_document_count * 10}
|
| 58 |
+
# )
|
| 59 |
+
|
| 60 |
+
#
|
| 61 |
+
# context_doc = compression_retriever.get_relevant_documents(
|
| 62 |
+
# query=q,
|
| 63 |
+
# kwargs={"k": retrieve_document_count}
|
| 64 |
+
# )
|
| 65 |
+
#
|
| 66 |
+
# for doc in context_doc:
|
| 67 |
+
# text = doc.page_content
|
| 68 |
+
# print(" kontext: " + text.replace('\n', ' ').replace('\r', ' '))
|
| 69 |
+
#
|
| 70 |
+
# return context_doc
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def retrieve_test(name: str, embed_dict: dict[str, Embedding], emded: bool = False):
|
| 74 |
+
try:
|
| 75 |
+
result_file = open(name + "_retrieve_test.md", "a")
|
| 76 |
+
for embed_key, embedding in embed_dict.items():
|
| 77 |
+
if emded:
|
| 78 |
+
embed_zakonnik_prace(embedding)
|
| 79 |
+
print("--- Running on " + embed_key)
|
| 80 |
+
result_file.write("\n\n| " + embed_key + " | " + str(datetime.datetime.now()) + " |")
|
| 81 |
+
result_file.write("\n|-------|-----------|")
|
| 82 |
+
dobre: int = 0
|
| 83 |
+
for q in questions:
|
| 84 |
+
print(q)
|
| 85 |
+
context_doc = retrieve(embedding, q, 5)
|
| 86 |
+
for doc in context_doc:
|
| 87 |
+
text = doc.page_content
|
| 88 |
+
print(" kontext: " + text.replace('\n', ' ').replace('\r', ' '))
|
| 89 |
+
result_file.write("\n| " + q + " | " + text.replace('\n', ' ').replace('\r', ' ') + " |")
|
| 90 |
+
dobre = dobre + 1 if "§ 100" in text else dobre
|
| 91 |
+
dobre = dobre + 1 if "§ 101" in text else dobre
|
| 92 |
+
dobre = dobre + 1 if "§ 103" in text else dobre
|
| 93 |
+
dobre = dobre + 1 if "§ 104" in text else dobre
|
| 94 |
+
dobre = dobre + 1 if "§ 105" in text else dobre
|
| 95 |
+
dobre = dobre + 1 if "§ 106" in text else dobre
|
| 96 |
+
dobre = dobre + 1 if "§ 107" in text else dobre
|
| 97 |
+
dobre = dobre + 1 if "§ 109" in text else dobre
|
| 98 |
+
dobre = dobre + 1 if "§ 110" in text else dobre
|
| 99 |
+
dobre = dobre + 1 if "§ 111" in text else dobre
|
| 100 |
+
dobre = dobre + 1 if "§ 112" in text else dobre
|
| 101 |
+
dobre = dobre + 1 if "§ 113" in text else dobre
|
| 102 |
+
dobre = dobre + 1 if "§ 114" in text else dobre
|
| 103 |
+
dobre = dobre + 1 if "§ 115" in text else dobre
|
| 104 |
+
dobre = dobre + 1 if "§ 116" in text else dobre
|
| 105 |
+
dobre = dobre + 1 if "§ 117" in text else dobre
|
| 106 |
+
result_file.write("\n| Dobre: | " + str(dobre) + " |")
|
| 107 |
+
finally:
|
| 108 |
+
result_file.write("\n\n")
|
| 109 |
+
result_file.flush()
|
| 110 |
+
result_file.close()
|
task_splitting.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
from time import sleep
|
| 3 |
+
|
| 4 |
+
from langchain.chains import LLMChain
|
| 5 |
+
from langchain_core.prompts import PromptTemplate
|
| 6 |
+
|
| 7 |
+
from agent.Agent import Agent
|
| 8 |
+
from agent.agents import chat_groq_llama3_70
|
| 9 |
+
from emdedd.embeddings import cohere_embeddings, chroma_embedding, embed_zakonnik_prace
|
| 10 |
+
from promts import for_tree_llama3_rag_sub, for_tree_llama3_rag_tree, for_tree_llama3_rag_group
|
| 11 |
+
from retrieval import retrieve_with_rerank
|
| 12 |
+
from questions import questions
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def rag_tree(agent: Agent, q: str, retrieve_document_count: int) -> str:
|
| 16 |
+
tree_template = PromptTemplate(
|
| 17 |
+
input_variables=["context", "question"],
|
| 18 |
+
template=for_tree_llama3_rag_tree
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
context_doc = retrieve_with_rerank(agent.embedding, q, retrieve_document_count * 2)
|
| 22 |
+
|
| 23 |
+
sub_qs = LLMChain(
|
| 24 |
+
llm=agent.llm,
|
| 25 |
+
prompt=tree_template,
|
| 26 |
+
verbose=False
|
| 27 |
+
).invoke(
|
| 28 |
+
input={
|
| 29 |
+
"question": q,
|
| 30 |
+
"context": context_doc
|
| 31 |
+
}
|
| 32 |
+
)["text"]
|
| 33 |
+
|
| 34 |
+
print(sub_qs)
|
| 35 |
+
sleep(60)
|
| 36 |
+
|
| 37 |
+
print("_________")
|
| 38 |
+
|
| 39 |
+
sub_template = PromptTemplate(
|
| 40 |
+
input_variables=["context", "question"],
|
| 41 |
+
template=for_tree_llama3_rag_sub
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
sub_answers: dict[str, str] = {}
|
| 45 |
+
|
| 46 |
+
for sub_q in sub_qs.splitlines():
|
| 47 |
+
if "?" not in sub_q: continue
|
| 48 |
+
print(sub_q)
|
| 49 |
+
sub_answers[sub_q] = LLMChain(
|
| 50 |
+
llm=agent.llm,
|
| 51 |
+
prompt=sub_template,
|
| 52 |
+
verbose=False
|
| 53 |
+
).invoke(
|
| 54 |
+
input={
|
| 55 |
+
"question": sub_q,
|
| 56 |
+
"context": retrieve_with_rerank(agent.embedding, sub_q, retrieve_document_count)
|
| 57 |
+
}
|
| 58 |
+
)["text"]
|
| 59 |
+
print(sub_answers[sub_q])
|
| 60 |
+
sleep(60)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
final_template = PromptTemplate(
|
| 64 |
+
input_variables=["context", "question", "subs"],
|
| 65 |
+
template=for_tree_llama3_rag_group
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
result = LLMChain(
|
| 69 |
+
llm=agent.llm,
|
| 70 |
+
prompt=final_template,
|
| 71 |
+
verbose=True
|
| 72 |
+
).invoke(
|
| 73 |
+
input={
|
| 74 |
+
"question": q,
|
| 75 |
+
"context": context_doc,
|
| 76 |
+
"subs": sub_answers.items()
|
| 77 |
+
}
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
return result["text"]
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def tree_of_thought(name: str, agent: Agent, emded: bool = False, retrieve_document_count=5):
|
| 84 |
+
try:
|
| 85 |
+
result_file = open(name + "_test.md", "a")
|
| 86 |
+
if emded:
|
| 87 |
+
embed_zakonnik_prace(agent.embedding)
|
| 88 |
+
for q in questions:
|
| 89 |
+
print("--- Q: " + q)
|
| 90 |
+
result_file.write("\n\n| " + name + str(datetime.datetime.now()) + " | " + q + " |")
|
| 91 |
+
result_file.write("\n|-------|-----------|")
|
| 92 |
+
|
| 93 |
+
answer = rag_tree(agent, q, retrieve_document_count)
|
| 94 |
+
print(answer)
|
| 95 |
+
result_file.write(
|
| 96 |
+
"\n| tree | " + answer.replace('\r\n', '<br>').replace('\n', '<br>').replace('\r', '<br>') + " |")
|
| 97 |
+
sleep(60)
|
| 98 |
+
finally:
|
| 99 |
+
result_file.write("\n\n")
|
| 100 |
+
result_file.flush()
|
| 101 |
+
result_file.close()
|