Pavol Liška commited on
Commit
3c35194
·
1 Parent(s): ae95c3d
Files changed (3) hide show
  1. api.py +22 -12
  2. rag.py +1 -1
  3. retrieval.py +0 -46
api.py CHANGED
@@ -1,5 +1,6 @@
1
  from fastapi import FastAPI, Response, Body, Security
2
  from fastapi.security import APIKeyHeader
 
3
 
4
  from conversation.conversation_store import ConversationStore
5
  from rag_langchain import LangChainRAG
@@ -15,31 +16,40 @@ rewrite_prompt_id = "first"
15
  default_llm = "gpt-4o 128k"
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @api.get("/")
19
  def read_root():
20
  return "Empty"
21
 
22
 
23
  @api.post("/q")
24
- async def q(api_key: str = Security(api_key_header), json_body: dict = Body(...)):
25
  # Verify the API key
26
  if not valid_api_key(api_key):
27
  return Response(status_code=401)
28
 
29
- # Process the JSON body
30
- data = json_body
31
-
32
  rag = LangChainRAG(
33
  config={
34
- "retrieve_documents": data["retrieval_count"],
35
- "temperature": data["temperature"],
36
  "prompt_id": prompt_id,
37
  "check_prompt_id": check_prompt_id,
38
  "rewrite_prompt_id": rewrite_prompt_id
39
  }
40
  )
41
 
42
- answer, check_result, sources = rag.rag_chain(data["q"], default_llm)
43
 
44
  oid = conversation_store.save_content(
45
  q=q,
@@ -51,8 +61,8 @@ async def q(api_key: str = Security(api_key_header), json_body: dict = Body(...)
51
  "check_prompt_id": check_prompt_id,
52
  "rewrite_prompt_id": rewrite_prompt_id,
53
  "check_result": check_result,
54
- "temperature": data["temperature"],
55
- "retrieve_document_count": data["retrieval_count"],
56
  }
57
  )
58
 
@@ -67,14 +77,14 @@ async def q(api_key: str = Security(api_key_header), json_body: dict = Body(...)
67
 
68
 
69
  @api.post("/emo")
70
- async def emo(api_key: str = Security(api_key_header), json_body: dict = Body(...)):
71
  # Verify the API key
72
  if not valid_api_key(api_key):
73
  return Response(status_code=401)
74
 
75
- qa = conversation_store.get(json_body["qid"])
76
  new_params = qa.params
77
- new_params["user_grading"] = str(json_body["helpfulness"])
78
  conversation_store.update(
79
  oid=json_body["qid"],
80
  q=qa.conversation[0].q,
 
1
  from fastapi import FastAPI, Response, Body, Security
2
  from fastapi.security import APIKeyHeader
3
+ from pydantic import BaseModel
4
 
5
  from conversation.conversation_store import ConversationStore
6
  from rag_langchain import LangChainRAG
 
16
  default_llm = "gpt-4o 128k"
17
 
18
 
19
+ class QModel(BaseModel):
20
+ q: str
21
+ retrieval_count: int = 10
22
+ temperature: str = "0.2"
23
+ llm: str = default_llm
24
+
25
+
26
+ class EmoModel(BaseModel):
27
+ qid: str
28
+ helpfulness: str
29
+
30
+
31
  @api.get("/")
32
  def read_root():
33
  return "Empty"
34
 
35
 
36
  @api.post("/q")
37
+ async def q(api_key: str = Security(api_key_header), json_body: QModel = Body(...)):
38
  # Verify the API key
39
  if not valid_api_key(api_key):
40
  return Response(status_code=401)
41
 
 
 
 
42
  rag = LangChainRAG(
43
  config={
44
+ "retrieve_documents": json_body.retrieval_count,
45
+ "temperature": json_body.temperature,
46
  "prompt_id": prompt_id,
47
  "check_prompt_id": check_prompt_id,
48
  "rewrite_prompt_id": rewrite_prompt_id
49
  }
50
  )
51
 
52
+ answer, check_result, sources = rag.rag_chain(json_body.q, json_body.llm)
53
 
54
  oid = conversation_store.save_content(
55
  q=q,
 
61
  "check_prompt_id": check_prompt_id,
62
  "rewrite_prompt_id": rewrite_prompt_id,
63
  "check_result": check_result,
64
+ "temperature": json_body.temperature,
65
+ "retrieve_document_count": json_body.retrieval_count,
66
  }
67
  )
68
 
 
77
 
78
 
79
  @api.post("/emo")
80
+ async def emo(api_key: str = Security(api_key_header), json_body: EmoModel = Body(...)):
81
  # Verify the API key
82
  if not valid_api_key(api_key):
83
  return Response(status_code=401)
84
 
85
+ qa = conversation_store.get(json_body.qid)
86
  new_params = qa.params
87
+ new_params["user_grading"] = str(json_body.helpfulness)
88
  conversation_store.update(
89
  oid=json_body["qid"],
90
  q=qa.conversation[0].q,
rag.py CHANGED
@@ -16,7 +16,7 @@ from agent.Agent import Agent
16
  from agent.agents import chat_openai_llm, deepinfra_chat
17
  from conversation.conversation_store import ConversationStore
18
  from prompt.prompt_store import PromptStore
19
- from retrieval import retrieve, retrieve_with_rerank
20
 
21
  load_dotenv()
22
 
 
16
  from agent.agents import chat_openai_llm, deepinfra_chat
17
  from conversation.conversation_store import ConversationStore
18
  from prompt.prompt_store import PromptStore
19
+ from retrieval import retrieve_with_rerank
20
 
21
  load_dotenv()
22
 
retrieval.py CHANGED
@@ -1,13 +1,7 @@
1
- import datetime
2
-
3
  from langchain.retrievers import ContextualCompressionRetriever
4
  from langchain_cohere.rerank import CohereRerank
5
  from langchain_core.vectorstores import VectorStoreRetriever
6
 
7
- from emdedd.Embedding import Embedding
8
- from emdedd.embeddings import embed_zakonnik_prace
9
- from questions import questions
10
-
11
 
12
  def retrieve(embedding, q, retrieve_document_count):
13
  retriever: VectorStoreRetriever = embedding.get_vector_store().as_retriever(
@@ -68,43 +62,3 @@ def reranking_retriever(embedding, retrieve_document_count):
68
  # print(" kontext: " + text.replace('\n', ' ').replace('\r', ' '))
69
  #
70
  # return context_doc
71
-
72
-
73
- def retrieve_test(name: str, embed_dict: dict[str, Embedding], emded: bool = False):
74
- try:
75
- result_file = open(name + "_retrieve_test.md", "a")
76
- for embed_key, embedding in embed_dict.items():
77
- if emded:
78
- embed_zakonnik_prace(embedding)
79
- print("--- Running on " + embed_key)
80
- result_file.write("\n\n| " + embed_key + " | " + str(datetime.datetime.now()) + " |")
81
- result_file.write("\n|-------|-----------|")
82
- dobre: int = 0
83
- for q in questions:
84
- print(q)
85
- context_doc = retrieve(embedding, q, 5)
86
- for doc in context_doc:
87
- text = doc.page_content
88
- print(" kontext: " + text.replace('\n', ' ').replace('\r', ' '))
89
- result_file.write("\n| " + q + " | " + text.replace('\n', ' ').replace('\r', ' ') + " |")
90
- dobre = dobre + 1 if "§ 100" in text else dobre
91
- dobre = dobre + 1 if "§ 101" in text else dobre
92
- dobre = dobre + 1 if "§ 103" in text else dobre
93
- dobre = dobre + 1 if "§ 104" in text else dobre
94
- dobre = dobre + 1 if "§ 105" in text else dobre
95
- dobre = dobre + 1 if "§ 106" in text else dobre
96
- dobre = dobre + 1 if "§ 107" in text else dobre
97
- dobre = dobre + 1 if "§ 109" in text else dobre
98
- dobre = dobre + 1 if "§ 110" in text else dobre
99
- dobre = dobre + 1 if "§ 111" in text else dobre
100
- dobre = dobre + 1 if "§ 112" in text else dobre
101
- dobre = dobre + 1 if "§ 113" in text else dobre
102
- dobre = dobre + 1 if "§ 114" in text else dobre
103
- dobre = dobre + 1 if "§ 115" in text else dobre
104
- dobre = dobre + 1 if "§ 116" in text else dobre
105
- dobre = dobre + 1 if "§ 117" in text else dobre
106
- result_file.write("\n| Dobre: | " + str(dobre) + " |")
107
- finally:
108
- result_file.write("\n\n")
109
- result_file.flush()
110
- result_file.close()
 
 
 
1
  from langchain.retrievers import ContextualCompressionRetriever
2
  from langchain_cohere.rerank import CohereRerank
3
  from langchain_core.vectorstores import VectorStoreRetriever
4
 
 
 
 
 
5
 
6
  def retrieve(embedding, q, retrieve_document_count):
7
  retriever: VectorStoreRetriever = embedding.get_vector_store().as_retriever(
 
62
  # print(" kontext: " + text.replace('\n', ' ').replace('\r', ' '))
63
  #
64
  # return context_doc