Spaces:
Sleeping
Sleeping
File size: 5,129 Bytes
d1c266e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from langchain_core.messages import AIMessage, HumanMessage
from state import WorkflowState
from tools import tool_analyze_skin_image, tool_fetch_disease_info
from llms import (
symptom_classifier_chain,
question_generation_chain,
summary_generation_chain
)
def node_analyze_image(state: WorkflowState) -> WorkflowState:
"""Analyzes the user's image and updates the state."""
print("--- Node: Analyzing Image ---")
image = state.get("image")
if not image:
state['chat_history'].append(AIMessage(
content="Please upload an image first."
))
return state
prediction_result = tool_analyze_skin_image.invoke({"image": image})
if "Error:" in prediction_result:
state['chat_history'].append(AIMessage(content=prediction_result))
state['final_diagnosis'] = "Error"
return state
state['disease_prediction'] = prediction_result
return state
def node_fetch_symptoms(state: WorkflowState) -> WorkflowState:
"""Fetches symptoms for the predicted disease."""
print(f"--- Node: Fetching Symptoms for {state['disease_prediction']} ---")
disease = state['disease_prediction']
info = tool_fetch_disease_info.invoke({"disease_name": disease})
if "error" in info:
state['chat_history'].append(AIMessage(content=info['error']))
state['final_diagnosis'] = "Error"
return state
state['symptoms_to_check'] = info.get("symptoms", [])
state['treatment_info'] = info.get("treatment", "No treatment info available.")
state['current_symptom_index'] = 0
state['symptoms_confirmed'] = []
if not state['symptoms_to_check']:
print("No symptoms found to check. Proceeding to final response.")
return state
def node_ask_symptom_question(state: WorkflowState) -> WorkflowState:
"""Asks the user the next symptom question."""
print(f"--- Node: Asking Symptom Question {state['current_symptom_index']} ---")
symptoms = state['symptoms_to_check']
index = state['current_symptom_index']
symptom = symptoms[index]
question = question_generation_chain.invoke({"symptom": symptom})
state['chat_history'].append(AIMessage(content=question))
state['current_symptom_index'] = index + 1
return state
def node_process_user_response(state: WorkflowState) -> WorkflowState:
"""Processes the user's 'yes' or 'no' response to a symptom question."""
print("--- Node: Processing User Response ---")
last_human_message = state['chat_history'][-1].content
index = state['current_symptom_index']
last_asked_symptom = state['symptoms_to_check'][index - 1]
try:
classification = symptom_classifier_chain.invoke(
{"last_human_message": last_human_message}
)
if classification.get("classification") == "yes":
print(f"User confirmed symptom: {last_asked_symptom}")
state['symptoms_confirmed'].append(last_asked_symptom)
else:
print(f"User denied symptom: {last_asked_symptom}")
except Exception as e:
print(f"Error classifying user response: {e}. Assuming 'unclear'.")
return state
def node_generate_final_response(state: WorkflowState) -> WorkflowState:
"""Generates the final summary and disclaimer for the user."""
print("--- Node: Generating Final Response ---")
disclaimer = (
"\n\n**DISCLAIMER:**\n"
"I am just a dumb agent, not a medical professional. "
"This is a side project for learning purposes. "
"Please **DO NOT** take this information for face value. "
"Consult a real doctor or dermatologist for any medical concerns."
)
summary = summary_generation_chain.invoke({
"disease": state['disease_prediction'],
"symptoms": ", ".join(state['symptoms_confirmed']) or "None confirmed",
"treatment": state['treatment_info'],
"disclaimer": disclaimer
})
state['chat_history'].append(AIMessage(content=summary))
state['final_diagnosis'] = "Complete"
return state
def router_should_ask_symptoms(state: WorkflowState) -> str:
"""
Checks if there are symptoms to ask about.
If yes -> ask_symptom_question
If no -> generate_final_response
"""
if state.get("symptoms_to_check"):
return "ask_symptom_question"
else:
return "generate_final_response"
def router_should_continue_asking(state: WorkflowState) -> str:
"""
Checks if we have more symptoms to ask about after a user's response.
If yes -> ask_symptom_question
If no -> generate_final_response
"""
if state['current_symptom_index'] < len(state['symptoms_to_check']):
return "ask_symptom_question"
else:
return "generate_final_response"
def router_check_image_analysis(state: WorkflowState) -> str:
"""
Checks if the image analysis was successful.
"""
if state.get("final_diagnosis") == "Error":
return "end_error"
else:
return "fetch_symptoms"
|