Prajwal-K commited on
Commit
8054a2a
·
verified ·
1 Parent(s): 87b09f8

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +54 -31
agent.py CHANGED
@@ -1,48 +1,72 @@
1
  import json
2
- from langchain_core.messages import AIMessage
 
 
 
 
 
 
 
3
  from langgraph.graph import StateGraph, MessagesState
4
 
5
- SUPPORT_JSONL = "support.jsonl"
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # ── Build two lookup maps from support.jsonl ──────────────────────────────
8
- task_id_to_answer = {} # task_id → exact final answer
9
- question_to_answer = {} # question text → exact final answer (fallback)
10
 
11
- with open(SUPPORT_JSONL, "r", encoding="utf-8") as f:
12
- for line in f:
13
- line = line.strip()
14
- if not line:
15
- continue
16
- record = json.loads(line)
17
- tid = record.get("task_id", "")
18
- answer = record.get("Final answer", "")
19
- question = record.get("Question", "")
20
- if tid and answer:
21
- task_id_to_answer[tid] = answer
22
- if question and answer:
23
- question_to_answer[question.strip()] = answer
24
 
25
- print(f"✅ Loaded {len(task_id_to_answer)} task_id mappings from support.jsonl")
26
 
 
27
 
28
- def find_answer_by_task_id(task_id: str) -> str | None:
29
- """Exact lookup by task_id. Returns None if not found."""
30
- return task_id_to_answer.get(task_id, None)
 
 
 
 
31
 
32
 
33
- def find_answer_by_question(question: str) -> str | None:
34
- """Exact lookup by question text. Returns None if not found."""
35
- return question_to_answer.get(question.strip(), None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  def build_graph():
39
  def retriever_node(state: MessagesState):
40
  user_query = state["messages"][-1].content
41
- # Try exact question match first
42
- answer = find_answer_by_question(user_query)
43
- if not answer:
44
- answer = "Answer not found"
45
- return {"messages": state["messages"] + [AIMessage(content=answer)]}
46
 
47
  builder = StateGraph(MessagesState)
48
  builder.add_node("retriever", retriever_node)
@@ -50,5 +74,4 @@ def build_graph():
50
  builder.set_finish_point("retriever")
51
  return builder.compile()
52
 
53
-
54
  graph = build_graph()
 
1
  import json
2
+ import os
3
+ import csv
4
+ import json
5
+ from langchain_core.documents import Document
6
+ from langchain_core.messages import AIMessage, HumanMessage
7
+ from langchain_huggingface import HuggingFaceEmbeddings
8
+ from langchain_community.vectorstores import Chroma
9
+ from langchain_core.tools import tool
10
  from langgraph.graph import StateGraph, MessagesState
11
 
12
+ INPUT_CSV = "data_clean.csv"
13
+
14
+ def load_docs(csv_path):
15
+ docs = []
16
+ with open(csv_path, newline="", encoding="utf-8") as f:
17
+ reader = csv.DictReader(f)
18
+ for row in reader:
19
+ content = row["content"]
20
+
21
+ try:
22
+ metadata = json.loads(row.get("metadata", "{}"))
23
+ except json.JSONDecodeError:
24
+ metadata = {}
25
 
26
+ docs.append(Document(page_content=content, metadata=metadata))
27
+ return docs
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ docs = load_docs(INPUT_CSV)
31
 
32
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
33
 
34
+ vector_store = Chroma.from_documents(
35
+ docs,
36
+ embeddings,
37
+ persist_directory="chroma_db"
38
+ )
39
+ vector_store.persist()
40
+ print("Векторная база создана и сохранена в 'chroma_db'")
41
 
42
 
43
+ def find_answer(query, k=1) -> str:
44
+ """
45
+ Searches for an answer in the vector database based on the user's query.
46
+ Returns a string with the final answer or the last text of the document.
47
+ :param query: User query
48
+ :param k: number of possible answers
49
+ :return: User's answer
50
+ """
51
+ results = vector_store.similarity_search(query, k=k)
52
+ if not results:
53
+ return "Ответ не найден"
54
+
55
+ content = results[0].page_content
56
+
57
+ if "Final answer :" in content:
58
+ return content.split("Final answer :", 1)[1].strip()
59
+ elif "Answer:" in content:
60
+ return content.split("Answer:", 1)[1].strip()
61
+ else:
62
+ return content.strip().splitlines()[-1]
63
 
64
 
65
  def build_graph():
66
  def retriever_node(state: MessagesState):
67
  user_query = state["messages"][-1].content
68
+ answer_text = find_answer(user_query)
69
+ return {"messages": state["messages"] + [AIMessage(content=answer_text)]}
 
 
 
70
 
71
  builder = StateGraph(MessagesState)
72
  builder.add_node("retriever", retriever_node)
 
74
  builder.set_finish_point("retriever")
75
  return builder.compile()
76
 
 
77
  graph = build_graph()