Spaces:
Paused
Paused
| from langgraph.graph import StateGraph, START, END | |
| from bielik import llm | |
| from guardian import check_input | |
| from helpful_functions import get_last_user_message, check_situation, beliefs_check_function, introduction_talk, create_interview | |
| from neo4j_driver import driver | |
| from classifier import predict_raw, predict_raw1 | |
| from state import ChatState | |
| from prompts import build_system_prompt_introduction_chapter_ellis_distortion | |
| def detect_distortion(state: ChatState): | |
| if not state.get("messages"): | |
| print("Siema") | |
| state["messages"] = [{ | |
| "role": "assistant", "content": "Cześć! Cieszę się, że jesteś. Co u ciebie, czy masz jakiś problem? Z checią ci pomogę!" | |
| }] | |
| state["awaitingUser"] = True | |
| state["stage"] = "detect_distortion" | |
| return state | |
| else: | |
| state["first_stage_iterations"] += 1 | |
| print(state["first_stage_iterations"]) | |
| print("Siema1") | |
| last_message = get_last_user_message(state) | |
| user_text = (last_message["content"] or "").strip() | |
| if state["distortion"] is None: | |
| result = predict_raw(user_text) | |
| if result != "No Distortion": | |
| thought = beliefs_check_function(user_text) | |
| if thought: | |
| distortion = predict_raw1(user_text) | |
| print(distortion) | |
| state["distortion"] = distortion | |
| state["distortion_text"] = user_text | |
| print("Siema2") | |
| system_prompt = build_system_prompt_introduction_chapter_ellis_distortion(state["distortion"], state["situation"], state["think"], state["emotion"]) | |
| result = introduction_talk(state["messages"], system_prompt) | |
| if state["situation"] == "": | |
| state["situation"] = result.situation | |
| else: | |
| if result.situation != "": | |
| state["situation"] = create_interview(result.situation, state["situation"]) | |
| if state["emotion"] == "": | |
| state["emotion"] = result.emotion | |
| else: | |
| if result.emotion != "": | |
| state["emotion"] = create_interview(result.emotion, state["emotion"]) | |
| if state["think"] == "": | |
| state["think"] = result.think | |
| else: | |
| if result.think != "": | |
| state["think"] = create_interview(result.think, state["think"]) | |
| state["introduction_end_flag"] = result.chapter_end | |
| if state["distortion"] is not None and state["situation"] != "" and state["think"] != "" and state["emotion"] != "": | |
| print("Next") | |
| state["awaitingUser"] = False | |
| state["messages_detect"] = state["messages"] | |
| state["stage"] = "get_distortion_def" | |
| return state | |
| else: | |
| state["messages"].append({"role":"assistant", "content": result.model_output}) | |
| state["awaitingUser"] = True | |
| state["stage"] = "detect_distortion" | |
| return state | |
| def get_distortion_def(state: ChatState): | |
| print("Siema4") | |
| distortion = state["distortion"] | |
| query = """ | |
| MATCH (d:Distortion {name: $name}) | |
| RETURN d.definicja AS definicja | |
| """ | |
| records, summary, keys = driver.execute_query( | |
| query, | |
| parameters_={"name": distortion}, | |
| ) | |
| state["distortion_def"] = records[0]["definicja"] if records else None | |
| state["stage"] = "talk_about_distortion" | |
| state["awaitingUser"] = False | |
| return state | |
| def talk_about_distortion(state: ChatState): | |
| distortion = state["distortion"] | |
| distortion_def = state["distortion_def"] | |
| print("Siema5") | |
| if not state.get("distortion_explained"): | |
| print("Siema6") | |
| system_prompt_talk = f""" | |
| Jesteś empatycznym asystentem CBT. | |
| Użytkownikowi wykryto zniekształcenie poznawcze: | |
| Nazwa: {distortion} | |
| Definicja: {distortion_def} | |
| Przedstaw mu, że wykryłeś u niego zniekształcenie i wyjaśnij je w prosty, życzliwy sposób i zapytaj, czy chce, abyś pomógł mu to wspólnie przepracować. | |
| Język: polski, maksymalnie 2–3 zdania. | |
| """ | |
| llm_reply = llm.invoke([ | |
| { | |
| "role": "system", | |
| "content": system_prompt_talk, | |
| }, | |
| ]) | |
| follow_text = ( | |
| llm_reply if isinstance(llm_reply, str) | |
| else getattr(llm_reply, "content", str(llm_reply)) | |
| ) | |
| state["messages"].append({"role": "assistant", "content": follow_text}) | |
| state["awaitingUser"] = True | |
| state["stage"] = "talk_about_distortion" | |
| state["distortion_explained"] = True | |
| return state | |
| else: | |
| print("Siema7") | |
| last_user_msg = get_last_user_message(state) | |
| if not last_user_msg: | |
| state["awaitingUser"] = True | |
| return state | |
| classify_result = check_situation(last_user_msg["content"]) | |
| state["classify_result"] = classify_result | |
| if classify_result == "understand": | |
| print("Siema8") | |
| state["messages"].append({ | |
| "role": "assistant", | |
| "content": "Super! To przejdźmy teraz do kolejnego kroku" | |
| }) | |
| state["stage"] = "get_intention" | |
| state["awaitingUser"] = False | |
| return state | |
| # elif classify_result == "low_expression": | |
| # system_prompt = f""" | |
| # WEJSCIE | |
| # Historia wiadomości - {state["messages"]} | |
| # | |
| # Użytkownik jest mało wylewny i odpowiada krótko. | |
| # Twoim zadaniem jest napisać 2–3 empatyczne zdania po polsku, które spokojnie i nienachalnie zachęcą go do kontynuowania rozmowy. | |
| # Brzmi naturalnie, bez punktów, presji ani oceniania. | |
| # Na końcu zapytaj czy możemy możemy przejść do działania | |
| # Twoją rolą jest tylko i wyłącznie zachęcenie do działania nie pisz nic innego | |
| # """ | |
| # llm_reply = llm.invoke([ | |
| # { | |
| # "role": "system", | |
| # "content": system_prompt, | |
| # }, | |
| # ]) | |
| # follow_text = ( | |
| # llm_reply if isinstance(llm_reply, str) | |
| # else getattr(llm_reply, "content", str(llm_reply)) | |
| # ) | |
| # state["messages"].append({"role": "assistant", "content": follow_text}) | |
| # state["awaitingUser"] = True | |
| # state["stage"] = "talk_about_distortion" | |
| else: | |
| print("Siema9") | |
| system_prompt = f""" | |
| WEJSCIE | |
| Historia wiadomości - {state["messages"]} | |
| Użytkownik nie zrozumiał wyjaśnienia zniekształcenia. | |
| Nazwa: {distortion} | |
| Definicja: {distortion_def} | |
| Język tylko polski. | |
| Twoje zadanie: | |
| - Wyjaśnij prostszymi słowami (1–2 zdania). | |
| - Dodaj przykład z życia (1–2 zdania). | |
| - Zapytaj, czy teraz jest to jasne i czy możemy przejść do działania. | |
| Maksymalnie 3-4 zdania | |
| """ | |
| llm_reply = llm.invoke([ | |
| { | |
| "role": "system", | |
| "content": system_prompt, | |
| }, | |
| ]) | |
| follow_text = ( | |
| llm_reply if isinstance(llm_reply, str) | |
| else getattr(llm_reply, "content", str(llm_reply)) | |
| ) | |
| state["messages"].append({"role": "assistant", "content": follow_text}) | |
| state["awaitingUser"] = True | |
| state["stage"] = "talk_about_distortion" | |
| return state | |
| def validate_input(state: ChatState): | |
| stage = state.get("stage") | |
| if stage == "detect_distortion": | |
| chapter = "ETAP 1" | |
| elif stage == "talk_about_distortion" or stage == "get_distortion_def": | |
| chapter = "ETAP 2" | |
| elif stage == "create_socratic_question" or stage == "get_intention" or stage == "select_intention" or stage == "analyze_output": | |
| chapter = "ETAP 3" | |
| elif stage == "enter_alt_thought" or stage == "enter_alt_thought" or stage == "handle_alt_thought_input" or stage == "handle_alt_thought_input": | |
| chapter = "ETAP 4" | |
| else: | |
| chapter = "None" | |
| last_user_msg = state.get("last_user_msg_content") | |
| result = check_input(state["messages"], chapter, last_user_msg) | |
| state["last_user_msg"] = False | |
| if result.decision: | |
| state["validated"] = True | |
| state["awaitingUser"] = False | |
| else: | |
| state["noValidated"] = f"{chapter} - {last_user_msg}" | |
| state["explanation"] = result.explanation | |
| state["messages"].append({"role": "assistant", "content": result.message_to_user}) | |
| state["awaitingUser"] = True | |
| return state | |
| def global_router(state: ChatState) -> str: | |
| if state.get("awaitingUser"): | |
| print("[ROUTER] awaitingUser=True → __end__") | |
| return "__end__" | |
| stage = state.get("stage") | |
| print(f"[ROUTER] stage={stage} (fallback)") | |
| if not state.get("validated") and state.get("last_user_msg"): | |
| return "validate_input" | |
| if stage == "end": | |
| return "__end__" | |
| if stage == "get_distortion_def": | |
| return "get_distortion_def" | |
| if stage == "talk_about_distortion": | |
| return "talk_about_distortion" | |
| print("[ROUTER] default → detect_distortion") | |
| return "detect_distortion" | |
| graph_builder = StateGraph(ChatState) | |
| graph_builder.add_node("detect_distortion", detect_distortion) | |
| graph_builder.add_node("get_distortion_def", get_distortion_def) | |
| graph_builder.add_node("talk_about_distortion", talk_about_distortion) | |
| graph_builder.add_node("validate_input", validate_input) | |
| graph_builder.add_conditional_edges(START, global_router, { | |
| "detect_distortion": "detect_distortion", | |
| "get_distortion_def": "get_distortion_def", | |
| "talk_about_distortion": "talk_about_distortion", | |
| "validate_input": "validate_input", | |
| "__end__": END, | |
| }) | |
| edge_map = { | |
| "detect_distortion": "detect_distortion", | |
| "get_distortion_def": "get_distortion_def", | |
| "talk_about_distortion": "talk_about_distortion", | |
| "validate_input": "validate_input", | |
| "__end__": END, | |
| } | |
| for node in ["detect_distortion", "get_distortion_def","talk_about_distortion", "validate_input"]: | |
| graph_builder.add_conditional_edges(node, global_router, edge_map) | |
| graph = graph_builder.compile() | |