priyansh-saxena1 commited on
Commit
44d41e8
Β·
1 Parent(s): daf4268

fix: logic for ros array mapping and loop guard progress evaluation

Browse files
Files changed (2) hide show
  1. app/graph.py +28 -27
  2. app/llm.py +5 -0
app/graph.py CHANGED
@@ -134,37 +134,38 @@ def agent_node(state: IntakeState) -> dict:
134
 
135
  # ── Loop Guard: if LLM returned same reply as last turn, force-fill stuck field ──
136
  if _detect_repeat({"messages": msgs + [{"role": "assistant", "content": result.reply}]}):
 
 
 
 
 
 
 
 
 
137
  hpi_filled = all(getattr(result, f, None) for f in HPI_FIELDS)
138
 
139
  if not hpi_filled:
140
- # Still in HPI β€” force-fill the first empty HPI field
141
- for stuck_field in HPI_FIELDS:
142
- if getattr(result, stuck_field, None) is None:
143
- object.__setattr__(result, stuck_field, "not specified")
144
- print(f"[LoopGuard] Force-filled HPI '{stuck_field}' = 'not specified' to break repeat loop")
145
- new_missing = missing_from(result)
146
- if new_missing:
147
- object.__setattr__(result, "reply", f"Thank you. Now, could you tell me about {new_missing[0].replace('HPI:', '')}?")
148
- else:
149
- object.__setattr__(result, "reply", "Thank you β€” I have everything I need.")
150
- break
 
 
 
151
  else:
152
- # In ROS stage β€” force-fill the current ROS system with patient's last answer
153
- patient_answer = ""
154
- for m in reversed(msgs):
155
- if m.get("role") == "user":
156
- patient_answer = m.get("content", "denied")
157
- break
158
- patient_answer = patient_answer or "denied"
159
-
160
- # Find which ROS system the LLM was asking about (from its repeated reply)
161
- ros = dict(result.ros) if result.ros else {}
162
- ros_label = f"patient_reported_{len(ros) + 1}"
163
- ros[ros_label] = [patient_answer]
164
- object.__setattr__(result, "ros", ros)
165
- print(f"[LoopGuard] Force-filled ROS '{ros_label}' = ['{patient_answer}'] to break ROS repeat loop")
166
-
167
- if len(ros) < ROS_REQUIRED:
168
  object.__setattr__(result, "reply", f"Thank you. Are there any other symptoms you've been experiencing?")
169
  else:
170
  object.__setattr__(result, "reply", "Thank you β€” I have everything I need.")
 
134
 
135
  # ── Loop Guard: if LLM returned same reply as last turn, force-fill stuck field ──
136
  if _detect_repeat({"messages": msgs + [{"role": "assistant", "content": result.reply}]}):
137
+ # Check if the LLM made progress extracting data despite repeating the reply
138
+ try:
139
+ prev_state = CombinedOutput.model_validate_json(current_json)
140
+ prev_filled = sum(1 for f in HPI_FIELDS if getattr(prev_state, f, None)) + len(prev_state.ros)
141
+ new_filled = sum(1 for f in HPI_FIELDS if getattr(result, f, None)) + len(result.ros)
142
+ made_progress = new_filled > prev_filled
143
+ except Exception:
144
+ made_progress = False
145
+
146
  hpi_filled = all(getattr(result, f, None) for f in HPI_FIELDS)
147
 
148
  if not hpi_filled:
149
+ if not made_progress:
150
+ # Still in HPI and stuck β€” force-fill the first empty HPI field
151
+ for stuck_field in HPI_FIELDS:
152
+ if getattr(result, stuck_field, None) is None:
153
+ object.__setattr__(result, stuck_field, "not specified")
154
+ print(f"[LoopGuard] Force-filled HPI '{stuck_field}' = 'not specified' to break repeat loop")
155
+ break
156
+
157
+ # Ensure we ask a new question to break the loop
158
+ new_missing = missing_from(result)
159
+ if new_missing:
160
+ object.__setattr__(result, "reply", f"Thank you. Now, could you tell me about {new_missing[0].replace('HPI:', '')}?")
161
+ else:
162
+ object.__setattr__(result, "reply", "Thank you β€” I have everything I need.")
163
  else:
164
+ # In ROS stage
165
+ if not made_progress:
166
+ print("[LoopGuard] LLM stuck in ROS without extracting data. Skipping system.")
167
+
168
+ if len(result.ros) < ROS_REQUIRED:
 
 
 
 
 
 
 
 
 
 
 
169
  object.__setattr__(result, "reply", f"Thank you. Are there any other symptoms you've been experiencing?")
170
  else:
171
  object.__setattr__(result, "reply", "Thank you β€” I have everything I need.")
app/llm.py CHANGED
@@ -43,6 +43,7 @@ OUTPUT FORMAT:
43
  "aggravating": "..." or null,
44
  "relieving": "..." or null,
45
  "ros": {"system_name": ["finding1", "finding2"], ...},
 
46
  "emergency": false,
47
  "reply": "Your single question"
48
  }"""
@@ -77,12 +78,15 @@ def build_state_context(current_json: str) -> str:
77
 
78
  # ROS
79
  ros = state.get("ros", {})
 
80
  if ros:
81
  for sys_name, findings in ros.items():
82
  lines.append(f' βœ… ros.{sys_name}: {findings}')
83
  ros_remaining = ROS_REQUIRED - len(ros)
84
  if ros_remaining > 0:
85
  lines.append(f" ❌ ros: {ros_remaining} more system(s) needed")
 
 
86
  else:
87
  lines.append(f" βœ… ros: all {ROS_REQUIRED} systems collected")
88
 
@@ -116,6 +120,7 @@ class CombinedOutput(BaseModel):
116
  aggravating: str | None = None
117
  relieving: str | None = None
118
  ros: dict[str, list[str]] = {}
 
119
  emergency: bool = False
120
  reply: str = ""
121
 
 
43
  "aggravating": "..." or null,
44
  "relieving": "..." or null,
45
  "ros": {"system_name": ["finding1", "finding2"], ...},
46
+ "ros_asked": ["system_name1"] (append any new system you ask about here to prevent repeating),
47
  "emergency": false,
48
  "reply": "Your single question"
49
  }"""
 
78
 
79
  # ROS
80
  ros = state.get("ros", {})
81
+ ros_asked = state.get("ros_asked", [])
82
  if ros:
83
  for sys_name, findings in ros.items():
84
  lines.append(f' βœ… ros.{sys_name}: {findings}')
85
  ros_remaining = ROS_REQUIRED - len(ros)
86
  if ros_remaining > 0:
87
  lines.append(f" ❌ ros: {ros_remaining} more system(s) needed")
88
+ if ros_asked:
89
+ lines.append(f" ℹ️ Already asked about: {', '.join(ros_asked)} β€” DO NOT ask about these again")
90
  else:
91
  lines.append(f" βœ… ros: all {ROS_REQUIRED} systems collected")
92
 
 
120
  aggravating: str | None = None
121
  relieving: str | None = None
122
  ros: dict[str, list[str]] = {}
123
+ ros_asked: list[str] = []
124
  emergency: bool = False
125
  reply: str = ""
126