alexorlov commited on
Commit
b5e4f35
·
verified ·
1 Parent(s): ac986f8

Upload app/agent/nodes.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app/agent/nodes.py +75 -126
app/agent/nodes.py CHANGED
@@ -1,132 +1,81 @@
1
- from typing import Dict, Any
 
2
  from app.agent.state import AgentState
3
- from app.services.llm import get_llm_service
4
- from app.services.file_generator import get_file_generator
5
- from app.models.question import Question, Answer
 
 
 
 
 
 
6
  from app.models.checklist import ChecklistItem
 
7
 
 
8
 
9
- def generate_initial_questions(state: AgentState) -> Dict[str, Any]:
10
- """Генерирует первые 3 вопроса для начала интервью"""
11
- llm = get_llm_service()
12
- questions_data = llm.generate_initial_questions()
13
-
14
- questions = [
15
- Question(id=q["id"], text=q["text"])
16
- for q in questions_data
17
- ]
18
-
19
- return {
20
- "current_questions": questions,
21
- "current_round": 1,
22
- "waiting_for_answers": True
23
- }
24
-
25
-
26
- def process_answers(state: AgentState) -> Dict[str, Any]:
27
- """Обрабатывает полученные ответы и создает Answer объекты"""
28
- transcripts = state.get("pending_transcripts", [])
29
- current_questions = state.get("current_questions", [])
30
- current_round = state.get("current_round", 1)
31
- all_answers = list(state.get("all_answers", []))
32
-
33
- # Создаем Answer объекты из транскриптов
34
- for i, transcript in enumerate(transcripts):
35
- if i < len(current_questions):
36
- answer = Answer(
37
- question_id=current_questions[i].id,
38
- question_text=current_questions[i].text,
39
- audio_transcript=transcript,
40
- round_number=current_round
41
- )
42
- all_answers.append(answer)
43
-
44
- return {
45
- "all_answers": all_answers,
46
- "pending_transcripts": [],
47
- "waiting_for_answers": False
48
- }
49
-
50
-
51
- def analyze_round(state: AgentState) -> Dict[str, Any]:
52
- """Анализирует ответы раунда и генерирует следующие вопросы или завершает"""
53
- llm = get_llm_service()
54
-
55
- current_round = state.get("current_round", 1)
56
- all_answers = state.get("all_answers", [])
57
- round_summaries = list(state.get("round_summaries", []))
58
-
59
- # Анализируем раунд
60
- result = llm.analyze_round_and_generate_questions(
61
- round_number=current_round,
62
- all_answers=all_answers,
63
- round_summaries=round_summaries
 
 
 
64
  )
65
-
66
- # Добавляем саммари раунда
67
- round_summaries.append(result.get("round_summary", ""))
68
-
69
- # Если это не последний раунд - генерируем следующие вопросы
70
- if current_round < state.get("max_rounds", 3):
71
- questions_data = result.get("questions", [])
72
- questions = [
73
- Question(id=q["id"], text=q["text"])
74
- for q in questions_data
75
- ]
76
-
77
- return {
78
- "current_questions": questions,
79
- "current_round": current_round + 1,
80
- "round_summaries": round_summaries,
81
- "waiting_for_answers": True,
82
- "is_complete": False
83
- }
84
- else:
85
- # Последний раунд - готовимся к генерации чеклиста
86
- return {
87
- "round_summaries": round_summaries,
88
- "waiting_for_answers": False,
89
- "is_complete": False
90
- }
91
-
92
-
93
- def generate_checklist(state: AgentState) -> Dict[str, Any]:
94
- """Генерирует финальный чеклист"""
95
- llm = get_llm_service()
96
- file_gen = get_file_generator()
97
-
98
- all_answers = state.get("all_answers", [])
99
- round_summaries = state.get("round_summaries", [])
100
- session_id = state.get("session_id", "unknown")
101
-
102
- # Генерируем чеклист
103
- result = llm.generate_checklist(all_answers, round_summaries)
104
-
105
- checklist_items = [
106
- ChecklistItem(**item)
107
- for item in result.get("checklist", [])
108
- ]
109
-
110
- # Генерируем Markdown
111
- markdown = file_gen.generate_markdown(
112
- session_id=session_id,
113
- checklist=checklist_items,
114
- round_summaries=round_summaries
115
  )
116
-
117
- return {
118
- "checklist_items": checklist_items,
119
- "markdown_content": markdown,
120
- "is_complete": True
121
- }
122
-
123
-
124
- def check_round_complete(state: AgentState) -> str:
125
- """Проверяет, нужно ли продолжать или завершать"""
126
- current_round = state.get("current_round", 1)
127
- max_rounds = state.get("max_rounds", 3)
128
-
129
- if current_round >= max_rounds:
130
- return "generate_checklist"
131
- else:
132
- return "wait_for_answers"
 
1
+ import json
2
+ import logging
3
  from app.agent.state import AgentState
4
+ from app.agent.prompts import (
5
+ SYSTEM_PROMPT,
6
+ INITIAL_QUESTIONS_PROMPT,
7
+ analyze_round_prompt,
8
+ next_questions_prompt,
9
+ generate_checklist_prompt,
10
+ generate_markdown_prompt,
11
+ )
12
+ from app.models.question import QuestionOut
13
  from app.models.checklist import ChecklistItem
14
+ from app.services.llm import gemini_service
15
 
16
+ logger = logging.getLogger(__name__)
17
 
18
+
19
+ async def generate_initial_questions(state: AgentState) -> dict:
20
+ logger.info(f"Generating initial questions for session {state['session_id']}")
21
+ result = await gemini_service.generate_json(SYSTEM_PROMPT, INITIAL_QUESTIONS_PROMPT)
22
+ questions = [QuestionOut(id=q["id"], text=q["text"]) for q in result]
23
+ return {"current_questions": questions, "current_round": 1}
24
+
25
+
26
+ async def analyze_round(state: AgentState) -> dict:
27
+ round_num = state["current_round"]
28
+ logger.info(f"Analyzing round {round_num} for session {state['session_id']}")
29
+
30
+ # Build Q&A text for current round
31
+ round_answers = [a for a in state["all_answers"] if a.round_number == round_num]
32
+ qa_text = "\n".join(
33
+ f"В: {a.question_text}\nО: {a.audio_transcript}" for a in round_answers
34
+ )
35
+ prev_summaries = "\n---\n".join(state["round_summaries"])
36
+
37
+ result = await gemini_service.generate_json(
38
+ SYSTEM_PROMPT, analyze_round_prompt(round_num, qa_text, prev_summaries)
39
+ )
40
+ new_summaries = list(state["round_summaries"]) + [result["summary"]]
41
+ return {"round_summaries": new_summaries}
42
+
43
+
44
+ async def generate_next_questions(state: AgentState) -> dict:
45
+ next_round = state["current_round"] + 1
46
+ logger.info(f"Generating questions for round {next_round}, session {state['session_id']}")
47
+
48
+ all_summaries = "\n---\n".join(state["round_summaries"])
49
+ result = await gemini_service.generate_json(
50
+ SYSTEM_PROMPT, next_questions_prompt(next_round, all_summaries)
51
+ )
52
+ questions = [QuestionOut(id=q["id"], text=q["text"]) for q in result]
53
+ return {"current_questions": questions, "current_round": next_round}
54
+
55
+
56
+ async def generate_checklist(state: AgentState) -> dict:
57
+ logger.info(f"Generating checklist for session {state['session_id']}")
58
+
59
+ all_summaries = "\n---\n".join(state["round_summaries"])
60
+ all_qa = "\n\n".join(
61
+ f"Раунд {a.round_number} В: {a.question_text}\nО: {a.audio_transcript}"
62
+ for a in state["all_answers"]
63
+ )
64
+ result = await gemini_service.generate_json(
65
+ SYSTEM_PROMPT, generate_checklist_prompt(all_summaries, all_qa)
66
+ )
67
+ items = [ChecklistItem(**item) for item in result]
68
+ return {"checklist_items": items, "is_complete": True}
69
+
70
+
71
+ async def generate_markdown(state: AgentState) -> dict:
72
+ logger.info(f"Generating markdown for session {state['session_id']}")
73
+
74
+ checklist_json = json.dumps(
75
+ [item.model_dump() for item in state["checklist_items"]], ensure_ascii=False, indent=2
76
  )
77
+ result = await gemini_service.generate(
78
+ SYSTEM_PROMPT,
79
+ generate_markdown_prompt(checklist_json, state["session_id"]),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  )
81
+ return {"markdown_content": result}