anup220799 commited on
Commit
fa13c61
·
verified ·
1 Parent(s): 3fc8191

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +77 -0
agent.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import csv
4
+ import json
5
+ from langchain_core.documents import Document
6
+ from langchain_core.messages import AIMessage, HumanMessage
7
+ from langchain_huggingface import HuggingFaceEmbeddings
8
+ from langchain_community.vectorstores import Chroma
9
+ from langchain_core.tools import tool
10
+ from langgraph.graph import StateGraph, MessagesState
11
+
12
+ INPUT_CSV = "data_clean.csv"
13
+
14
+ def load_docs(csv_path):
15
+ docs = []
16
+ with open(csv_path, newline="", encoding="utf-8") as f:
17
+ reader = csv.DictReader(f)
18
+ for row in reader:
19
+ content = row["content"]
20
+
21
+ try:
22
+ metadata = json.loads(row.get("metadata", "{}"))
23
+ except json.JSONDecodeError:
24
+ metadata = {}
25
+
26
+ docs.append(Document(page_content=content, metadata=metadata))
27
+ return docs
28
+
29
+
30
+ docs = load_docs(INPUT_CSV)
31
+
32
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
33
+
34
+ vector_store = Chroma.from_documents(
35
+ docs,
36
+ embeddings,
37
+ persist_directory="chroma_db"
38
+ )
39
+ vector_store.persist()
40
+ print("vector store created and stored in 'chroma_db'")
41
+
42
+
43
+ def find_answer(query, k=1) -> str:
44
+ """
45
+ Searches for an answer in the vector database based on the user's query.
46
+ Returns a string with the final answer or the last text of the document.
47
+ :param query: User query
48
+ :param k: number of possible answers
49
+ :return: User's answer
50
+ """
51
+ results = vector_store.similarity_search(query, k=k)
52
+ if not results:
53
+ return "Ответ не найден"
54
+
55
+ content = results[0].page_content
56
+
57
+ if "Final answer :" in content:
58
+ return content.split("Final answer :", 1)[1].strip()
59
+ elif "Answer:" in content:
60
+ return content.split("Answer:", 1)[1].strip()
61
+ else:
62
+ return content.strip().splitlines()[-1]
63
+
64
+
65
+ def build_graph():
66
+ def retriever_node(state: MessagesState):
67
+ user_query = state["messages"][-1].content
68
+ answer_text = find_answer(user_query)
69
+ return {"messages": state["messages"] + [AIMessage(content=answer_text)]}
70
+
71
+ builder = StateGraph(MessagesState)
72
+ builder.add_node("retriever", retriever_node)
73
+ builder.set_entry_point("retriever")
74
+ builder.set_finish_point("retriever")
75
+ return builder.compile()
76
+
77
+ graph = build_graph()