Spaces:
Sleeping
Sleeping
Upload app/agent/nodes.py with huggingface_hub
Browse files- app/agent/nodes.py +75 -126
app/agent/nodes.py
CHANGED
|
@@ -1,132 +1,81 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
from app.agent.state import AgentState
|
| 3 |
-
from app.
|
| 4 |
-
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from app.models.checklist import ChecklistItem
|
|
|
|
| 7 |
|
|
|
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
questions
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
return {
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
}
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
| 64 |
)
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
# Если это не последний раунд - генерируем следующие вопросы
|
| 70 |
-
if current_round < state.get("max_rounds", 3):
|
| 71 |
-
questions_data = result.get("questions", [])
|
| 72 |
-
questions = [
|
| 73 |
-
Question(id=q["id"], text=q["text"])
|
| 74 |
-
for q in questions_data
|
| 75 |
-
]
|
| 76 |
-
|
| 77 |
-
return {
|
| 78 |
-
"current_questions": questions,
|
| 79 |
-
"current_round": current_round + 1,
|
| 80 |
-
"round_summaries": round_summaries,
|
| 81 |
-
"waiting_for_answers": True,
|
| 82 |
-
"is_complete": False
|
| 83 |
-
}
|
| 84 |
-
else:
|
| 85 |
-
# Последний раунд - готовимся к генерации чеклиста
|
| 86 |
-
return {
|
| 87 |
-
"round_summaries": round_summaries,
|
| 88 |
-
"waiting_for_answers": False,
|
| 89 |
-
"is_complete": False
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def generate_checklist(state: AgentState) -> Dict[str, Any]:
|
| 94 |
-
"""Генерирует финальный чеклист"""
|
| 95 |
-
llm = get_llm_service()
|
| 96 |
-
file_gen = get_file_generator()
|
| 97 |
-
|
| 98 |
-
all_answers = state.get("all_answers", [])
|
| 99 |
-
round_summaries = state.get("round_summaries", [])
|
| 100 |
-
session_id = state.get("session_id", "unknown")
|
| 101 |
-
|
| 102 |
-
# Генерируем чеклист
|
| 103 |
-
result = llm.generate_checklist(all_answers, round_summaries)
|
| 104 |
-
|
| 105 |
-
checklist_items = [
|
| 106 |
-
ChecklistItem(**item)
|
| 107 |
-
for item in result.get("checklist", [])
|
| 108 |
-
]
|
| 109 |
-
|
| 110 |
-
# Генерируем Markdown
|
| 111 |
-
markdown = file_gen.generate_markdown(
|
| 112 |
-
session_id=session_id,
|
| 113 |
-
checklist=checklist_items,
|
| 114 |
-
round_summaries=round_summaries
|
| 115 |
)
|
| 116 |
-
|
| 117 |
-
return {
|
| 118 |
-
"checklist_items": checklist_items,
|
| 119 |
-
"markdown_content": markdown,
|
| 120 |
-
"is_complete": True
|
| 121 |
-
}
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
def check_round_complete(state: AgentState) -> str:
|
| 125 |
-
"""Проверяет, нужно ли продолжать или завершать"""
|
| 126 |
-
current_round = state.get("current_round", 1)
|
| 127 |
-
max_rounds = state.get("max_rounds", 3)
|
| 128 |
-
|
| 129 |
-
if current_round >= max_rounds:
|
| 130 |
-
return "generate_checklist"
|
| 131 |
-
else:
|
| 132 |
-
return "wait_for_answers"
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
from app.agent.state import AgentState
|
| 4 |
+
from app.agent.prompts import (
|
| 5 |
+
SYSTEM_PROMPT,
|
| 6 |
+
INITIAL_QUESTIONS_PROMPT,
|
| 7 |
+
analyze_round_prompt,
|
| 8 |
+
next_questions_prompt,
|
| 9 |
+
generate_checklist_prompt,
|
| 10 |
+
generate_markdown_prompt,
|
| 11 |
+
)
|
| 12 |
+
from app.models.question import QuestionOut
|
| 13 |
from app.models.checklist import ChecklistItem
|
| 14 |
+
from app.services.llm import gemini_service
|
| 15 |
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
+
|
| 19 |
+
async def generate_initial_questions(state: AgentState) -> dict:
|
| 20 |
+
logger.info(f"Generating initial questions for session {state['session_id']}")
|
| 21 |
+
result = await gemini_service.generate_json(SYSTEM_PROMPT, INITIAL_QUESTIONS_PROMPT)
|
| 22 |
+
questions = [QuestionOut(id=q["id"], text=q["text"]) for q in result]
|
| 23 |
+
return {"current_questions": questions, "current_round": 1}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
async def analyze_round(state: AgentState) -> dict:
|
| 27 |
+
round_num = state["current_round"]
|
| 28 |
+
logger.info(f"Analyzing round {round_num} for session {state['session_id']}")
|
| 29 |
+
|
| 30 |
+
# Build Q&A text for current round
|
| 31 |
+
round_answers = [a for a in state["all_answers"] if a.round_number == round_num]
|
| 32 |
+
qa_text = "\n".join(
|
| 33 |
+
f"В: {a.question_text}\nО: {a.audio_transcript}" for a in round_answers
|
| 34 |
+
)
|
| 35 |
+
prev_summaries = "\n---\n".join(state["round_summaries"])
|
| 36 |
+
|
| 37 |
+
result = await gemini_service.generate_json(
|
| 38 |
+
SYSTEM_PROMPT, analyze_round_prompt(round_num, qa_text, prev_summaries)
|
| 39 |
+
)
|
| 40 |
+
new_summaries = list(state["round_summaries"]) + [result["summary"]]
|
| 41 |
+
return {"round_summaries": new_summaries}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
async def generate_next_questions(state: AgentState) -> dict:
|
| 45 |
+
next_round = state["current_round"] + 1
|
| 46 |
+
logger.info(f"Generating questions for round {next_round}, session {state['session_id']}")
|
| 47 |
+
|
| 48 |
+
all_summaries = "\n---\n".join(state["round_summaries"])
|
| 49 |
+
result = await gemini_service.generate_json(
|
| 50 |
+
SYSTEM_PROMPT, next_questions_prompt(next_round, all_summaries)
|
| 51 |
+
)
|
| 52 |
+
questions = [QuestionOut(id=q["id"], text=q["text"]) for q in result]
|
| 53 |
+
return {"current_questions": questions, "current_round": next_round}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
async def generate_checklist(state: AgentState) -> dict:
|
| 57 |
+
logger.info(f"Generating checklist for session {state['session_id']}")
|
| 58 |
+
|
| 59 |
+
all_summaries = "\n---\n".join(state["round_summaries"])
|
| 60 |
+
all_qa = "\n\n".join(
|
| 61 |
+
f"Раунд {a.round_number} — В: {a.question_text}\nО: {a.audio_transcript}"
|
| 62 |
+
for a in state["all_answers"]
|
| 63 |
+
)
|
| 64 |
+
result = await gemini_service.generate_json(
|
| 65 |
+
SYSTEM_PROMPT, generate_checklist_prompt(all_summaries, all_qa)
|
| 66 |
+
)
|
| 67 |
+
items = [ChecklistItem(**item) for item in result]
|
| 68 |
+
return {"checklist_items": items, "is_complete": True}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
async def generate_markdown(state: AgentState) -> dict:
|
| 72 |
+
logger.info(f"Generating markdown for session {state['session_id']}")
|
| 73 |
+
|
| 74 |
+
checklist_json = json.dumps(
|
| 75 |
+
[item.model_dump() for item in state["checklist_items"]], ensure_ascii=False, indent=2
|
| 76 |
)
|
| 77 |
+
result = await gemini_service.generate(
|
| 78 |
+
SYSTEM_PROMPT,
|
| 79 |
+
generate_markdown_prompt(checklist_json, state["session_id"]),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
)
|
| 81 |
+
return {"markdown_content": result}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|