File size: 5,129 Bytes
d1c266e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from langchain_core.messages import AIMessage, HumanMessage
from state import WorkflowState
from tools import tool_analyze_skin_image, tool_fetch_disease_info
from llms import (
    symptom_classifier_chain, 
    question_generation_chain, 
    summary_generation_chain
)


def node_analyze_image(state: WorkflowState) -> WorkflowState:
    """Analyzes the user's image and updates the state."""
    print("--- Node: Analyzing Image ---")
    image = state.get("image")
    if not image:
        state['chat_history'].append(AIMessage(
            content="Please upload an image first."
        ))
        return state

    prediction_result = tool_analyze_skin_image.invoke({"image": image})
    
    if "Error:" in prediction_result:
        state['chat_history'].append(AIMessage(content=prediction_result))
        state['final_diagnosis'] = "Error" 
        return state
    
    state['disease_prediction'] = prediction_result
    return state

def node_fetch_symptoms(state: WorkflowState) -> WorkflowState:
    """Fetches symptoms for the predicted disease."""
    print(f"--- Node: Fetching Symptoms for {state['disease_prediction']} ---")
    disease = state['disease_prediction']
    
    info = tool_fetch_disease_info.invoke({"disease_name": disease})
    
    if "error" in info:
        state['chat_history'].append(AIMessage(content=info['error']))
        state['final_diagnosis'] = "Error" 
        return state
    
    state['symptoms_to_check'] = info.get("symptoms", [])
    state['treatment_info'] = info.get("treatment", "No treatment info available.")
    state['current_symptom_index'] = 0
    state['symptoms_confirmed'] = []
    
    if not state['symptoms_to_check']:
        print("No symptoms found to check. Proceeding to final response.")
    
    return state

def node_ask_symptom_question(state: WorkflowState) -> WorkflowState:
    """Asks the user the next symptom question."""
    print(f"--- Node: Asking Symptom Question {state['current_symptom_index']} ---")
    symptoms = state['symptoms_to_check']
    index = state['current_symptom_index']
    
    symptom = symptoms[index]
    
    question = question_generation_chain.invoke({"symptom": symptom})
    
    state['chat_history'].append(AIMessage(content=question))
    state['current_symptom_index'] = index + 1
    return state

def node_process_user_response(state: WorkflowState) -> WorkflowState:
    """Processes the user's 'yes' or 'no' response to a symptom question."""
    print("--- Node: Processing User Response ---")
    last_human_message = state['chat_history'][-1].content
    
    index = state['current_symptom_index']
    last_asked_symptom = state['symptoms_to_check'][index - 1]
    
    try:
        classification = symptom_classifier_chain.invoke(
            {"last_human_message": last_human_message}
        )
        
        if classification.get("classification") == "yes":
            print(f"User confirmed symptom: {last_asked_symptom}")
            state['symptoms_confirmed'].append(last_asked_symptom)
        else:
            print(f"User denied symptom: {last_asked_symptom}")
            
    except Exception as e:
        print(f"Error classifying user response: {e}. Assuming 'unclear'.")
    
    return state
    
def node_generate_final_response(state: WorkflowState) -> WorkflowState:
    """Generates the final summary and disclaimer for the user."""
    print("--- Node: Generating Final Response ---")
    
    disclaimer = (
        "\n\n**DISCLAIMER:**\n"
        "I am just a dumb agent, not a medical professional. "
        "This is a side project for learning purposes. "
        "Please **DO NOT** take this information for face value. "
        "Consult a real doctor or dermatologist for any medical concerns."
    )
    
    summary = summary_generation_chain.invoke({
        "disease": state['disease_prediction'],
        "symptoms": ", ".join(state['symptoms_confirmed']) or "None confirmed",
        "treatment": state['treatment_info'],
        "disclaimer": disclaimer
    })
    
    state['chat_history'].append(AIMessage(content=summary))
    state['final_diagnosis'] = "Complete" 
    return state


def router_should_ask_symptoms(state: WorkflowState) -> str:
    """
    Checks if there are symptoms to ask about.
    If yes -> ask_symptom_question
    If no -> generate_final_response
    """
    if state.get("symptoms_to_check"):
        return "ask_symptom_question"
    else:
        return "generate_final_response"

def router_should_continue_asking(state: WorkflowState) -> str:
    """
    Checks if we have more symptoms to ask about after a user's response.
    If yes -> ask_symptom_question
    If no -> generate_final_response
    """
    if state['current_symptom_index'] < len(state['symptoms_to_check']):
        return "ask_symptom_question"
    else:
        return "generate_final_response"

def router_check_image_analysis(state: WorkflowState) -> str:
    """
    Checks if the image analysis was successful.
    """
    if state.get("final_diagnosis") == "Error":
        return "end_error" 
    else:
        return "fetch_symptoms"