fido_changes / src /langGraphTests.py
szymskul's picture
update files
00cccb0
from langgraph.graph import StateGraph, START, END
from bielik import llm
from guardian import check_input
from helpful_functions import get_last_user_message, check_situation, beliefs_check_function, introduction_talk, create_interview
from neo4j_driver import driver
from classifier import predict_raw, predict_raw1
from state import ChatState
from prompts import build_system_prompt_introduction_chapter_ellis_distortion
def detect_distortion(state: ChatState):
if not state.get("messages"):
print("Siema")
state["messages"] = [{
"role": "assistant", "content": "Cześć! Cieszę się, że jesteś. Co u ciebie, czy masz jakiś problem? Z checią ci pomogę!"
}]
state["awaitingUser"] = True
state["stage"] = "detect_distortion"
return state
else:
state["first_stage_iterations"] += 1
print(state["first_stage_iterations"])
print("Siema1")
last_message = get_last_user_message(state)
user_text = (last_message["content"] or "").strip()
if state["distortion"] is None:
result = predict_raw(user_text)
if result != "No Distortion":
thought = beliefs_check_function(user_text)
if thought:
distortion = predict_raw1(user_text)
print(distortion)
state["distortion"] = distortion
state["distortion_text"] = user_text
print("Siema2")
system_prompt = build_system_prompt_introduction_chapter_ellis_distortion(state["distortion"], state["situation"], state["think"], state["emotion"])
result = introduction_talk(state["messages"], system_prompt)
if state["situation"] == "":
state["situation"] = result.situation
else:
if result.situation != "":
state["situation"] = create_interview(result.situation, state["situation"])
if state["emotion"] == "":
state["emotion"] = result.emotion
else:
if result.emotion != "":
state["emotion"] = create_interview(result.emotion, state["emotion"])
if state["think"] == "":
state["think"] = result.think
else:
if result.think != "":
state["think"] = create_interview(result.think, state["think"])
state["introduction_end_flag"] = result.chapter_end
if state["distortion"] is not None and state["situation"] != "" and state["think"] != "" and state["emotion"] != "":
print("Next")
state["awaitingUser"] = False
state["messages_detect"] = state["messages"]
state["stage"] = "get_distortion_def"
return state
else:
state["messages"].append({"role":"assistant", "content": result.model_output})
state["awaitingUser"] = True
state["stage"] = "detect_distortion"
return state
def get_distortion_def(state: ChatState):
print("Siema4")
distortion = state["distortion"]
query = """
MATCH (d:Distortion {name: $name})
RETURN d.definicja AS definicja
"""
records, summary, keys = driver.execute_query(
query,
parameters_={"name": distortion},
)
state["distortion_def"] = records[0]["definicja"] if records else None
state["stage"] = "talk_about_distortion"
state["awaitingUser"] = False
return state
def talk_about_distortion(state: ChatState):
distortion = state["distortion"]
distortion_def = state["distortion_def"]
print("Siema5")
if not state.get("distortion_explained"):
print("Siema6")
system_prompt_talk = f"""
Jesteś empatycznym asystentem CBT.
Użytkownikowi wykryto zniekształcenie poznawcze:
Nazwa: {distortion}
Definicja: {distortion_def}
Przedstaw mu, że wykryłeś u niego zniekształcenie i wyjaśnij je w prosty, życzliwy sposób i zapytaj, czy chce, abyś pomógł mu to wspólnie przepracować.
Język: polski, maksymalnie 2–3 zdania.
"""
llm_reply = llm.invoke([
{
"role": "system",
"content": system_prompt_talk,
},
])
follow_text = (
llm_reply if isinstance(llm_reply, str)
else getattr(llm_reply, "content", str(llm_reply))
)
state["messages"].append({"role": "assistant", "content": follow_text})
state["awaitingUser"] = True
state["stage"] = "talk_about_distortion"
state["distortion_explained"] = True
return state
else:
print("Siema7")
last_user_msg = get_last_user_message(state)
if not last_user_msg:
state["awaitingUser"] = True
return state
classify_result = check_situation(last_user_msg["content"])
state["classify_result"] = classify_result
if classify_result == "understand":
print("Siema8")
state["messages"].append({
"role": "assistant",
"content": "Super! To przejdźmy teraz do kolejnego kroku"
})
state["stage"] = "get_intention"
state["awaitingUser"] = False
return state
# elif classify_result == "low_expression":
# system_prompt = f"""
# WEJSCIE
# Historia wiadomości - {state["messages"]}
#
# Użytkownik jest mało wylewny i odpowiada krótko.
# Twoim zadaniem jest napisać 2–3 empatyczne zdania po polsku, które spokojnie i nienachalnie zachęcą go do kontynuowania rozmowy.
# Brzmi naturalnie, bez punktów, presji ani oceniania.
# Na końcu zapytaj czy możemy możemy przejść do działania
# Twoją rolą jest tylko i wyłącznie zachęcenie do działania nie pisz nic innego
# """
# llm_reply = llm.invoke([
# {
# "role": "system",
# "content": system_prompt,
# },
# ])
# follow_text = (
# llm_reply if isinstance(llm_reply, str)
# else getattr(llm_reply, "content", str(llm_reply))
# )
# state["messages"].append({"role": "assistant", "content": follow_text})
# state["awaitingUser"] = True
# state["stage"] = "talk_about_distortion"
else:
print("Siema9")
system_prompt = f"""
WEJSCIE
Historia wiadomości - {state["messages"]}
Użytkownik nie zrozumiał wyjaśnienia zniekształcenia.
Nazwa: {distortion}
Definicja: {distortion_def}
Język tylko polski.
Twoje zadanie:
- Wyjaśnij prostszymi słowami (1–2 zdania).
- Dodaj przykład z życia (1–2 zdania).
- Zapytaj, czy teraz jest to jasne i czy możemy przejść do działania.
Maksymalnie 3-4 zdania
"""
llm_reply = llm.invoke([
{
"role": "system",
"content": system_prompt,
},
])
follow_text = (
llm_reply if isinstance(llm_reply, str)
else getattr(llm_reply, "content", str(llm_reply))
)
state["messages"].append({"role": "assistant", "content": follow_text})
state["awaitingUser"] = True
state["stage"] = "talk_about_distortion"
return state
def validate_input(state: ChatState):
stage = state.get("stage")
if stage == "detect_distortion":
chapter = "ETAP 1"
elif stage == "talk_about_distortion" or stage == "get_distortion_def":
chapter = "ETAP 2"
elif stage == "create_socratic_question" or stage == "get_intention" or stage == "select_intention" or stage == "analyze_output":
chapter = "ETAP 3"
elif stage == "enter_alt_thought" or stage == "enter_alt_thought" or stage == "handle_alt_thought_input" or stage == "handle_alt_thought_input":
chapter = "ETAP 4"
else:
chapter = "None"
last_user_msg = state.get("last_user_msg_content")
result = check_input(state["messages"], chapter, last_user_msg)
state["last_user_msg"] = False
if result.decision:
state["validated"] = True
state["awaitingUser"] = False
else:
state["noValidated"] = f"{chapter} - {last_user_msg}"
state["explanation"] = result.explanation
state["messages"].append({"role": "assistant", "content": result.message_to_user})
state["awaitingUser"] = True
return state
def global_router(state: ChatState) -> str:
if state.get("awaitingUser"):
print("[ROUTER] awaitingUser=True → __end__")
return "__end__"
stage = state.get("stage")
print(f"[ROUTER] stage={stage} (fallback)")
if not state.get("validated") and state.get("last_user_msg"):
return "validate_input"
if stage == "end":
return "__end__"
if stage == "get_distortion_def":
return "get_distortion_def"
if stage == "talk_about_distortion":
return "talk_about_distortion"
print("[ROUTER] default → detect_distortion")
return "detect_distortion"
graph_builder = StateGraph(ChatState)
graph_builder.add_node("detect_distortion", detect_distortion)
graph_builder.add_node("get_distortion_def", get_distortion_def)
graph_builder.add_node("talk_about_distortion", talk_about_distortion)
graph_builder.add_node("validate_input", validate_input)
graph_builder.add_conditional_edges(START, global_router, {
"detect_distortion": "detect_distortion",
"get_distortion_def": "get_distortion_def",
"talk_about_distortion": "talk_about_distortion",
"validate_input": "validate_input",
"__end__": END,
})
edge_map = {
"detect_distortion": "detect_distortion",
"get_distortion_def": "get_distortion_def",
"talk_about_distortion": "talk_about_distortion",
"validate_input": "validate_input",
"__end__": END,
}
for node in ["detect_distortion", "get_distortion_def","talk_about_distortion", "validate_input"]:
graph_builder.add_conditional_edges(node, global_router, edge_map)
graph = graph_builder.compile()