CatoG commited on
Commit
bec7c31
·
1 Parent(s): 61b5f58

revision 3

Browse files
Files changed (3) hide show
  1. app.py +104 -55
  2. test_workflow.py +196 -0
  3. workflow_helpers.py +125 -0
app.py CHANGED
@@ -20,6 +20,8 @@ from workflow_helpers import (
20
  select_relevant_roles, identify_revision_targets,
21
  compress_final_answer, strip_internal_noise,
22
  get_synthesizer_format_instruction, get_qa_format_instruction,
 
 
23
  ROLE_RELEVANCE,
24
  )
25
  from evidence import (
@@ -601,6 +603,8 @@ class WorkflowState(TypedDict):
601
  output_format: str # detected output format (single_choice, short_answer, etc.)
602
  brevity_requirement: str # minimal, short, normal, verbose
603
  qa_structured: Optional[dict] # serialised QAResult for structured QA
 
 
604
 
605
 
606
  # --- Role system prompts ---
@@ -628,6 +632,8 @@ _PLANNER_SYSTEM = (
628
  "- QA results are BINDING — if QA says FAIL, you MUST revise, never approve.\n\n"
629
  "Respond in this exact format:\n"
630
  "TASK BREAKDOWN:\n<subtask list>\n\n"
 
 
631
  "ROLE TO CALL: <specialist name>\n\n"
632
  "SUCCESS CRITERIA:\n<what a correct, complete answer looks like>\n\n"
633
  "GUIDANCE FOR SPECIALIST:\n<any constraints or focus areas>"
@@ -1655,6 +1661,7 @@ _EMPTY_STATE_BASE: WorkflowState = {
1655
  "draft_output": "", "qa_report": "", "qa_role_feedback": {}, "qa_passed": False,
1656
  "revision_count": 0, "final_answer": "",
1657
  "output_format": "other", "brevity_requirement": "normal", "qa_structured": None,
 
1658
  }
1659
 
1660
 
@@ -1921,6 +1928,8 @@ def run_multi_role_workflow(
1921
  "output_format": output_format,
1922
  "brevity_requirement": brevity,
1923
  "qa_structured": None,
 
 
1924
  }
1925
 
1926
  trace: List[str] = [
@@ -1958,6 +1967,14 @@ def run_multi_role_workflow(
1958
  try:
1959
  if planner_active:
1960
  state = _step_plan(chat_model, state, trace)
 
 
 
 
 
 
 
 
1961
  else:
1962
  state["current_role"] = active_specialist_keys[0]
1963
  state["plan"] = message
@@ -1993,39 +2010,62 @@ def run_multi_role_workflow(
1993
  + ", ".join(AGENT_ROLES.get(k, k) for k in selected_roles)
1994
  )
1995
 
1996
- # Main orchestration loop
1997
- while True:
1998
- # Step 4: Run selected specialists
1999
- if primary_role not in selected_roles:
2000
- primary_role = selected_roles[0]
2001
- state["current_role"] = primary_role
2002
-
2003
- # Run primary specialist (research gets evidence injected)
2004
- primary_fn = _SPECIALIST_STEPS.get(primary_role, _step_technical)
2005
- if primary_role == "research" and evidence:
2006
- state = _step_research(chat_model, state, trace, evidence=evidence)
2007
- else:
2008
- state = primary_fn(chat_model, state, trace)
2009
- primary_output = state["draft_output"]
2010
- planner_state.specialist_outputs[primary_role] = primary_output[:500]
2011
-
2012
- all_outputs: List[Tuple[str, str]] = [(primary_role, primary_output)]
2013
- for specialist_role in selected_roles:
2014
- if specialist_role == primary_role:
2015
- continue
2016
- if specialist_role == "research" and evidence:
2017
- state = _step_research(chat_model, state, trace, evidence=evidence)
2018
- else:
2019
- step_fn = _SPECIALIST_STEPS[specialist_role]
2020
- state = step_fn(chat_model, state, trace)
2021
- output = state["draft_output"]
2022
- all_outputs.append((specialist_role, output))
2023
- planner_state.specialist_outputs[specialist_role] = output[:500]
2024
-
2025
- # Step 5: Synthesize — format-aware, evidence-grounded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2026
  state = _step_synthesize(chat_model, state, trace, all_outputs,
2027
  evidence=evidence)
 
 
2028
 
 
 
 
2029
  # Step 6: QA validation (with evidence context)
2030
  if qa_active:
2031
  state = _step_qa(chat_model, state, trace, all_outputs,
@@ -2056,7 +2096,18 @@ def run_multi_role_workflow(
2056
  trace.append("\n═══ WORKFLOW COMPLETE — APPROVED ═══")
2057
  break
2058
 
2059
- # QA failed and planner was forced to revise
 
 
 
 
 
 
 
 
 
 
 
2060
  state["revision_count"] += 1
2061
  planner_state.revision_count = state["revision_count"]
2062
 
@@ -2079,7 +2130,6 @@ def run_multi_role_workflow(
2079
  planner_state.record_event("escalation", escalation)
2080
 
2081
  if escalation == "suppress_role":
2082
- # Suppress roles that keep producing unsupported content
2083
  suppress = planner_state.get_roles_to_suppress()
2084
  for role_label in suppress:
2085
  role_key = _ROLE_LABEL_TO_KEY.get(role_label)
@@ -2090,17 +2140,15 @@ def run_multi_role_workflow(
2090
  selected_roles = [primary_role]
2091
 
2092
  elif escalation == "rewrite_from_state":
2093
- # Synthesizer should rewrite from state, not reuse bloated draft
2094
  trace.append(" ⚠ Synthesizer will rewrite from state instead of reusing draft")
2095
- state["draft_output"] = "" # Force synthesizer to rebuild
2096
 
2097
  elif escalation == "narrow_scope":
2098
- # Reduce to a single specialist
2099
  if len(selected_roles) > 1:
2100
  selected_roles = [selected_roles[0]]
2101
  trace.append(f" ⚠ Narrowed to single specialist: {selected_roles[0]}")
2102
 
2103
- # Step 9: TARGETED REVISIONS — only rerun failing role(s)
2104
  revision_targets = identify_revision_targets(qa_result, _ROLE_LABEL_TO_KEY)
2105
  trace.append(
2106
  f"\n═���═ REVISION {state['revision_count']} / {MAX_REVISIONS} ═══\n"
@@ -2108,45 +2156,46 @@ def run_multi_role_workflow(
2108
  )
2109
  planner_state.record_event("revision", f"targets={revision_targets}")
2110
 
2111
- # Determine what to rerun
2112
  rerun_specialists = [
2113
  t for t in revision_targets
2114
  if t in _SPECIALIST_STEPS and t in selected_roles
2115
  ]
2116
- rerun_synthesizer = "synthesizer" in revision_targets or rerun_specialists
2117
 
2118
  if rerun_specialists:
2119
- # Only rerun the targeted specialists
2120
  new_outputs = []
2121
  for rk in rerun_specialists:
2122
- if rk == "research" and evidence:
2123
- state = _step_research(chat_model, state, trace, evidence=evidence)
2124
- else:
2125
- step_fn = _SPECIALIST_STEPS[rk]
2126
- state = step_fn(chat_model, state, trace)
2127
  new_outputs.append((rk, state["draft_output"]))
2128
  planner_state.specialist_outputs[rk] = state["draft_output"][:500]
2129
 
2130
- # Merge with previous outputs (replace updated roles)
2131
  updated_keys = {rk for rk, _ in new_outputs}
2132
- merged_outputs = [
2133
  (rk, out) for rk, out in all_outputs if rk not in updated_keys
2134
  ] + new_outputs
2135
- all_outputs = merged_outputs
2136
 
2137
  if rerun_synthesizer or rerun_specialists:
2138
  state = _step_synthesize(chat_model, state, trace, all_outputs,
2139
  evidence=evidence)
2140
 
2141
- # Update selected_roles based on planner's new routing
2142
- primary_role = state["current_role"]
2143
- if primary_role in selected_roles:
2144
- pass # keep existing selection
2145
- elif primary_role in active_specialist_keys:
2146
- selected_roles = [primary_role] + [r for r in selected_roles if r != primary_role]
2147
- selected_roles = selected_roles[:config.max_specialists_per_task]
 
 
 
 
 
 
2148
 
2149
- continue # Loop back to QA
 
2150
 
2151
  else:
2152
  # No Planner review loop — accept the draft
 
20
  select_relevant_roles, identify_revision_targets,
21
  compress_final_answer, strip_internal_noise,
22
  get_synthesizer_format_instruction, get_qa_format_instruction,
23
+ validate_output_format, format_violations_instruction,
24
+ parse_task_assumptions, format_assumptions_for_prompt,
25
  ROLE_RELEVANCE,
26
  )
27
  from evidence import (
 
603
  output_format: str # detected output format (single_choice, short_answer, etc.)
604
  brevity_requirement: str # minimal, short, normal, verbose
605
  qa_structured: Optional[dict] # serialised QAResult for structured QA
606
+ task_assumptions: Dict[str, str] # shared assumptions all specialists must use
607
+ revision_instruction: str # latest revision instruction from planner
608
 
609
 
610
  # --- Role system prompts ---
 
632
  "- QA results are BINDING — if QA says FAIL, you MUST revise, never approve.\n\n"
633
  "Respond in this exact format:\n"
634
  "TASK BREAKDOWN:\n<subtask list>\n\n"
635
+ "TASK ASSUMPTIONS:\n<shared assumptions all specialists must use, e.g. cost model, "
636
+ "coverage rate, units, scope, time frame — one per line as 'key: value'>\n\n"
637
  "ROLE TO CALL: <specialist name>\n\n"
638
  "SUCCESS CRITERIA:\n<what a correct, complete answer looks like>\n\n"
639
  "GUIDANCE FOR SPECIALIST:\n<any constraints or focus areas>"
 
1661
  "draft_output": "", "qa_report": "", "qa_role_feedback": {}, "qa_passed": False,
1662
  "revision_count": 0, "final_answer": "",
1663
  "output_format": "other", "brevity_requirement": "normal", "qa_structured": None,
1664
+ "task_assumptions": {}, "revision_instruction": "",
1665
  }
1666
 
1667
 
 
1928
  "output_format": output_format,
1929
  "brevity_requirement": brevity,
1930
  "qa_structured": None,
1931
+ "task_assumptions": {},
1932
+ "revision_instruction": "",
1933
  }
1934
 
1935
  trace: List[str] = [
 
1967
  try:
1968
  if planner_active:
1969
  state = _step_plan(chat_model, state, trace)
1970
+
1971
+ # Parse shared task assumptions from planner output
1972
+ assumptions = parse_task_assumptions(state["plan"])
1973
+ if assumptions:
1974
+ state["task_assumptions"] = assumptions
1975
+ planner_state.task_assumptions = assumptions
1976
+ trace.append(f"[ASSUMPTIONS] {len(assumptions)} shared assumption(s) set: "
1977
+ + ", ".join(f"{k}={v}" for k, v in assumptions.items()))
1978
  else:
1979
  state["current_role"] = active_specialist_keys[0]
1980
  state["plan"] = message
 
2010
  + ", ".join(AGENT_ROLES.get(k, k) for k in selected_roles)
2011
  )
2012
 
2013
+ # Step 4: Run ALL selected specialists (initial run only)
2014
+ if primary_role not in selected_roles:
2015
+ primary_role = selected_roles[0]
2016
+ state["current_role"] = primary_role
2017
+
2018
+ # Build assumptions context for specialist prompts
2019
+ assumptions_ctx = format_assumptions_for_prompt(state.get("task_assumptions", {}))
2020
+
2021
+ def _run_specialist(role_key):
2022
+ """Run a single specialist, injecting evidence and assumptions as needed."""
2023
+ if role_key == "research" and evidence:
2024
+ return _step_research(chat_model, state, trace, evidence=evidence)
2025
+ step_fn = _SPECIALIST_STEPS.get(role_key, _step_technical)
2026
+ # Inject shared assumptions into plan context for specialist
2027
+ if assumptions_ctx and assumptions_ctx not in state["plan"]:
2028
+ state["plan"] = state["plan"] + "\n\n" + assumptions_ctx
2029
+ return step_fn(chat_model, state, trace)
2030
+
2031
+ # Run primary specialist
2032
+ state = _run_specialist(primary_role)
2033
+ primary_output = state["draft_output"]
2034
+ planner_state.specialist_outputs[primary_role] = primary_output[:500]
2035
+
2036
+ all_outputs: List[Tuple[str, str]] = [(primary_role, primary_output)]
2037
+ for specialist_role in selected_roles:
2038
+ if specialist_role == primary_role:
2039
+ continue
2040
+ state = _run_specialist(specialist_role)
2041
+ output = state["draft_output"]
2042
+ all_outputs.append((specialist_role, output))
2043
+ planner_state.specialist_outputs[specialist_role] = output[:500]
2044
+
2045
+ # Step 5: Synthesize — format-aware, evidence-grounded
2046
+ state = _step_synthesize(chat_model, state, trace, all_outputs,
2047
+ evidence=evidence)
2048
+
2049
+ # Step 5b: Pre-QA format validation — catch structural violations early
2050
+ fmt_violations = validate_output_format(
2051
+ state["draft_output"], output_format, brevity
2052
+ )
2053
+ if fmt_violations:
2054
+ trace.append(
2055
+ "\n[FORMAT VALIDATION] Violations detected before QA:\n"
2056
+ + "\n".join(f" - {v}" for v in fmt_violations)
2057
+ )
2058
+ # Re-synthesize with explicit violation feedback
2059
+ violation_instr = format_violations_instruction(fmt_violations)
2060
+ state["plan"] = state["plan"] + "\n\n" + violation_instr
2061
  state = _step_synthesize(chat_model, state, trace, all_outputs,
2062
  evidence=evidence)
2063
+ planner_state.record_event("format_rewrite", "; ".join(fmt_violations))
2064
+ trace.append("[FORMAT VALIDATION] Re-synthesized to fix format violations.")
2065
 
2066
+ # === QA-REVISION LOOP ===
2067
+ # From here, only QA + planner review + targeted revision (no full specialist rerun)
2068
+ while True:
2069
  # Step 6: QA validation (with evidence context)
2070
  if qa_active:
2071
  state = _step_qa(chat_model, state, trace, all_outputs,
 
2096
  trace.append("\n═══ WORKFLOW COMPLETE — APPROVED ═══")
2097
  break
2098
 
2099
+ # QA failed and planner was forced to revise
2100
+ # store revision instruction reliably
2101
+ revision_instr = ""
2102
+ if "REVISED INSTRUCTIONS:" in state.get("plan", ""):
2103
+ revision_instr = state["plan"]
2104
+ elif qa_result.correction_instruction:
2105
+ revision_instr = qa_result.correction_instruction
2106
+ state["revision_instruction"] = revision_instr
2107
+ planner_state.revision_instruction = revision_instr
2108
+ planner_state.record_event("revision_instruction_stored",
2109
+ revision_instr[:200] if revision_instr else "MISSING")
2110
+
2111
  state["revision_count"] += 1
2112
  planner_state.revision_count = state["revision_count"]
2113
 
 
2130
  planner_state.record_event("escalation", escalation)
2131
 
2132
  if escalation == "suppress_role":
 
2133
  suppress = planner_state.get_roles_to_suppress()
2134
  for role_label in suppress:
2135
  role_key = _ROLE_LABEL_TO_KEY.get(role_label)
 
2140
  selected_roles = [primary_role]
2141
 
2142
  elif escalation == "rewrite_from_state":
 
2143
  trace.append(" ⚠ Synthesizer will rewrite from state instead of reusing draft")
2144
+ state["draft_output"] = ""
2145
 
2146
  elif escalation == "narrow_scope":
 
2147
  if len(selected_roles) > 1:
2148
  selected_roles = [selected_roles[0]]
2149
  trace.append(f" ⚠ Narrowed to single specialist: {selected_roles[0]}")
2150
 
2151
+ # Step 9: TARGETED REVISIONS — only rerun the failing role(s)
2152
  revision_targets = identify_revision_targets(qa_result, _ROLE_LABEL_TO_KEY)
2153
  trace.append(
2154
  f"\n═���═ REVISION {state['revision_count']} / {MAX_REVISIONS} ═══\n"
 
2156
  )
2157
  planner_state.record_event("revision", f"targets={revision_targets}")
2158
 
2159
+ # Only rerun the targeted specialists — NOT all specialists
2160
  rerun_specialists = [
2161
  t for t in revision_targets
2162
  if t in _SPECIALIST_STEPS and t in selected_roles
2163
  ]
2164
+ rerun_synthesizer = "synthesizer" in revision_targets or bool(rerun_specialists)
2165
 
2166
  if rerun_specialists:
 
2167
  new_outputs = []
2168
  for rk in rerun_specialists:
2169
+ state = _run_specialist(rk)
 
 
 
 
2170
  new_outputs.append((rk, state["draft_output"]))
2171
  planner_state.specialist_outputs[rk] = state["draft_output"][:500]
2172
 
2173
+ # Merge: replace updated roles, keep others unchanged
2174
  updated_keys = {rk for rk, _ in new_outputs}
2175
+ all_outputs = [
2176
  (rk, out) for rk, out in all_outputs if rk not in updated_keys
2177
  ] + new_outputs
 
2178
 
2179
  if rerun_synthesizer or rerun_specialists:
2180
  state = _step_synthesize(chat_model, state, trace, all_outputs,
2181
  evidence=evidence)
2182
 
2183
+ # Post-revision format validation
2184
+ fmt_violations = validate_output_format(
2185
+ state["draft_output"], output_format, brevity
2186
+ )
2187
+ if fmt_violations:
2188
+ trace.append(
2189
+ "\n[FORMAT VALIDATION] Post-revision violations:\n"
2190
+ + "\n".join(f" - {v}" for v in fmt_violations)
2191
+ )
2192
+ violation_instr = format_violations_instruction(fmt_violations)
2193
+ state["plan"] = state["plan"] + "\n\n" + violation_instr
2194
+ state = _step_synthesize(chat_model, state, trace, all_outputs,
2195
+ evidence=evidence)
2196
 
2197
+ # Loop back to QA — NOT back to specialists
2198
+ continue
2199
 
2200
  else:
2201
  # No Planner review loop — accept the draft
test_workflow.py CHANGED
@@ -35,6 +35,10 @@ from workflow_helpers import (
35
  FailureRecord,
36
  get_synthesizer_format_instruction,
37
  get_qa_format_instruction,
 
 
 
 
38
  )
39
  from evidence import (
40
  EvidenceItem,
@@ -1023,6 +1027,198 @@ class TestPlannerStateExtended(unittest.TestCase):
1023
  # Test: Scenario - Role Selection with Task Categories
1024
  # ============================================================
1025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1026
  class TestTaskAwareScenarios(unittest.TestCase):
1027
  """End-to-end scenario tests validating the 4 user-specified cases."""
1028
 
 
35
  FailureRecord,
36
  get_synthesizer_format_instruction,
37
  get_qa_format_instruction,
38
+ validate_output_format,
39
+ format_violations_instruction,
40
+ parse_task_assumptions,
41
+ format_assumptions_for_prompt,
42
  )
43
  from evidence import (
44
  EvidenceItem,
 
1027
  # Test: Scenario - Role Selection with Task Categories
1028
  # ============================================================
1029
 
1030
+ # ============================================================
1031
+ # Test: Output Format Validation
1032
+ # ============================================================
1033
+
1034
+ class TestFormatValidation(unittest.TestCase):
1035
+
1036
+ def test_paragraph_with_bullets_fails(self):
1037
+ text = "This is a paragraph.\n- bullet one\n- bullet two"
1038
+ violations = validate_output_format(text, "paragraph", "normal")
1039
+ self.assertTrue(any("bullet" in v.lower() for v in violations))
1040
+
1041
+ def test_paragraph_with_headings_fails(self):
1042
+ text = "## Heading\nSome paragraph text."
1043
+ violations = validate_output_format(text, "paragraph", "normal")
1044
+ self.assertTrue(any("heading" in v.lower() for v in violations))
1045
+
1046
+ def test_paragraph_with_table_fails(self):
1047
+ text = "Some text.\n| A | B |\n|---|---|\n| 1 | 2 |"
1048
+ violations = validate_output_format(text, "paragraph", "normal")
1049
+ self.assertTrue(any("table" in v.lower() for v in violations))
1050
+
1051
+ def test_paragraph_clean_passes(self):
1052
+ text = "This is a clean paragraph without any lists or headings."
1053
+ violations = validate_output_format(text, "paragraph", "normal")
1054
+ self.assertEqual(violations, [])
1055
+
1056
+ def test_code_without_code_fails(self):
1057
+ text = "Here is an explanation about coding but no actual code."
1058
+ violations = validate_output_format(text, "code", "normal")
1059
+ self.assertTrue(any("code" in v.lower() for v in violations))
1060
+
1061
+ def test_code_with_block_passes(self):
1062
+ text = "```python\nprint('hello')\n```"
1063
+ violations = validate_output_format(text, "code", "normal")
1064
+ self.assertEqual(violations, [])
1065
+
1066
+ def test_code_with_recognisable_code_passes(self):
1067
+ text = "def hello():\n return 'world'"
1068
+ violations = validate_output_format(text, "code", "normal")
1069
+ self.assertEqual(violations, [])
1070
+
1071
+ def test_table_without_table_fails(self):
1072
+ text = "Just a paragraph about tables."
1073
+ violations = validate_output_format(text, "table", "normal")
1074
+ self.assertTrue(any("table" in v.lower() for v in violations))
1075
+
1076
+ def test_table_with_table_passes(self):
1077
+ text = "| Name | Value |\n|------|-------|\n| A | 1 |"
1078
+ violations = validate_output_format(text, "table", "normal")
1079
+ self.assertEqual(violations, [])
1080
+
1081
+ def test_single_choice_too_many_lines_fails(self):
1082
+ text = "\n".join(f"Line {i}" for i in range(10))
1083
+ violations = validate_output_format(text, "single_choice", "normal")
1084
+ self.assertTrue(any("single choice" in v.lower() for v in violations))
1085
+
1086
+ def test_single_choice_short_passes(self):
1087
+ text = "Vegan is the best choice."
1088
+ violations = validate_output_format(text, "single_choice", "normal")
1089
+ self.assertEqual(violations, [])
1090
+
1091
+ def test_minimal_brevity_too_long(self):
1092
+ text = "\n".join(f"Line {i}" for i in range(12))
1093
+ violations = validate_output_format(text, "paragraph", "minimal")
1094
+ self.assertTrue(any("minimal" in v.lower() for v in violations))
1095
+
1096
+ def test_short_brevity_too_long(self):
1097
+ text = "\n".join(f"Line {i}" for i in range(25))
1098
+ violations = validate_output_format(text, "paragraph", "short")
1099
+ self.assertTrue(any("short" in v.lower() for v in violations))
1100
+
1101
+ def test_normal_brevity_no_length_check(self):
1102
+ text = "\n".join(f"Line {i}" for i in range(50))
1103
+ violations = validate_output_format(text, "paragraph", "normal")
1104
+ self.assertEqual(violations, [])
1105
+
1106
+ def test_empty_output(self):
1107
+ violations = validate_output_format("", "paragraph", "normal")
1108
+ self.assertTrue(any("empty" in v.lower() for v in violations))
1109
+
1110
+
1111
+ class TestFormatViolationsInstruction(unittest.TestCase):
1112
+
1113
+ def test_produces_instruction(self):
1114
+ violations = ["Output has bullets.", "Too many lines."]
1115
+ result = format_violations_instruction(violations)
1116
+ self.assertIn("FORMAT VIOLATIONS", result)
1117
+ self.assertIn("Output has bullets.", result)
1118
+ self.assertIn("Too many lines.", result)
1119
+ self.assertIn("Rewrite", result)
1120
+
1121
+ def test_empty_violations(self):
1122
+ result = format_violations_instruction([])
1123
+ self.assertIn("FORMAT VIOLATIONS", result)
1124
+
1125
+
1126
+ # ============================================================
1127
+ # Test: Task Assumptions Parsing
1128
+ # ============================================================
1129
+
1130
+ class TestTaskAssumptions(unittest.TestCase):
1131
+
1132
+ def test_parse_assumptions_basic(self):
1133
+ plan = (
1134
+ "TASK ASSUMPTIONS:\n"
1135
+ "- cost_model: per-unit pricing\n"
1136
+ "- coverage_rate: 95%\n"
1137
+ "- time_frame: 2024 Q4\n"
1138
+ "TASK BREAKDOWN:\n"
1139
+ "1. Do the thing"
1140
+ )
1141
+ result = parse_task_assumptions(plan)
1142
+ self.assertEqual(result["cost_model"], "per-unit pricing")
1143
+ self.assertEqual(result["coverage_rate"], "95%")
1144
+ self.assertEqual(result["time_frame"], "2024 Q4")
1145
+
1146
+ def test_parse_assumptions_missing_section(self):
1147
+ plan = "TASK BREAKDOWN:\n1. Do the thing"
1148
+ result = parse_task_assumptions(plan)
1149
+ self.assertEqual(result, {})
1150
+
1151
+ def test_parse_assumptions_multiple_headers(self):
1152
+ plan = (
1153
+ "TASK ASSUMPTIONS:\n"
1154
+ "units: metric\n"
1155
+ "scope: global\n"
1156
+ "ROLE TO CALL:\n"
1157
+ "Technical Specialist"
1158
+ )
1159
+ result = parse_task_assumptions(plan)
1160
+ self.assertEqual(result["units"], "metric")
1161
+ self.assertEqual(result["scope"], "global")
1162
+ self.assertNotIn("technical_specialist", result)
1163
+
1164
+ def test_parse_assumptions_normalises_keys(self):
1165
+ plan = "TASK ASSUMPTIONS:\nCost Model: expensive\n"
1166
+ result = parse_task_assumptions(plan)
1167
+ self.assertIn("cost_model", result)
1168
+
1169
+ def test_format_assumptions_empty(self):
1170
+ result = format_assumptions_for_prompt({})
1171
+ self.assertEqual(result, "")
1172
+
1173
+ def test_format_assumptions_nonempty(self):
1174
+ result = format_assumptions_for_prompt({"units": "metric", "scope": "global"})
1175
+ self.assertIn("SHARED TASK ASSUMPTIONS", result)
1176
+ self.assertIn("units: metric", result)
1177
+ self.assertIn("scope: global", result)
1178
+ self.assertIn("do NOT invent your own", result)
1179
+
1180
+
1181
+ # ============================================================
1182
+ # Test: PlannerState Assumptions & Revision Instruction
1183
+ # ============================================================
1184
+
1185
+ class TestPlannerStateNewFields(unittest.TestCase):
1186
+
1187
+ def test_task_assumptions_in_state_dict(self):
1188
+ ps = PlannerState(user_request="test")
1189
+ ps.task_assumptions = {"units": "metric", "scope": "global"}
1190
+ d = ps.to_state_dict()
1191
+ self.assertEqual(d["task_assumptions"], {"units": "metric", "scope": "global"})
1192
+
1193
+ def test_revision_instruction_in_state_dict(self):
1194
+ ps = PlannerState(user_request="test")
1195
+ ps.revision_instruction = "Fix the table format."
1196
+ d = ps.to_state_dict()
1197
+ self.assertEqual(d["revision_instruction"], "Fix the table format.")
1198
+
1199
+ def test_task_assumptions_in_context_string(self):
1200
+ ps = PlannerState(user_request="test")
1201
+ ps.task_assumptions = {"rate": "5%"}
1202
+ ctx = ps.to_context_string()
1203
+ self.assertIn("rate: 5%", ctx)
1204
+ self.assertIn("Shared assumptions", ctx)
1205
+
1206
+ def test_revision_instruction_in_context_string(self):
1207
+ ps = PlannerState(user_request="test")
1208
+ ps.revision_instruction = "Shorten the output."
1209
+ ctx = ps.to_context_string()
1210
+ self.assertIn("Shorten the output.", ctx)
1211
+
1212
+ def test_empty_assumptions_not_in_context(self):
1213
+ ps = PlannerState(user_request="test")
1214
+ ctx = ps.to_context_string()
1215
+ self.assertNotIn("Shared assumptions", ctx)
1216
+
1217
+
1218
+ # ============================================================
1219
+ # Test: Task-Aware Scenarios
1220
+ # ============================================================
1221
+
1222
  class TestTaskAwareScenarios(unittest.TestCase):
1223
  """End-to-end scenario tests validating the 4 user-specified cases."""
1224
 
workflow_helpers.py CHANGED
@@ -523,6 +523,11 @@ def select_relevant_roles(
523
  if kw.lower() in lower:
524
  score += 1
525
 
 
 
 
 
 
526
  # Task-category affinity bonus
527
  role_tasks = meta.get("task_types", [])
528
  if task_category in role_tasks:
@@ -751,10 +756,12 @@ class PlannerState:
751
  selected_roles: List[str] = field(default_factory=list)
752
  specialist_outputs: Dict[str, str] = field(default_factory=dict)
753
  evidence: Optional[Dict] = None # serialised EvidenceResult
 
754
  current_draft: str = ""
755
  qa_result: Optional[QAResult] = None
756
  revision_count: int = 0
757
  max_revisions: int = 3
 
758
  failure_history: List[FailureRecord] = field(default_factory=list)
759
  history: List[Dict[str, str]] = field(default_factory=list)
760
  final_answer: str = ""
@@ -829,10 +836,15 @@ class PlannerState:
829
  ]
830
  if self.success_criteria:
831
  lines.append(f"Success criteria: {'; '.join(self.success_criteria)}")
 
 
 
832
  if self.evidence:
833
  conf = self.evidence.get("confidence", "unknown")
834
  n_items = len(self.evidence.get("results", []))
835
  lines.append(f"Evidence: {n_items} items (confidence: {conf})")
 
 
836
  if self.qa_result and not self.qa_result.passed:
837
  lines.append(f"QA status: FAIL — {self.qa_result.reason}")
838
  if self.qa_result.correction_instruction:
@@ -856,9 +868,11 @@ class PlannerState:
856
  "selected_roles": self.selected_roles,
857
  "specialist_outputs": self.specialist_outputs,
858
  "evidence": self.evidence,
 
859
  "current_draft": self.current_draft[:500],
860
  "revision_count": self.revision_count,
861
  "max_revisions": self.max_revisions,
 
862
  "failure_history": [f.to_dict() for f in self.failure_history],
863
  "final_answer": self.final_answer[:500] if self.final_answer else "",
864
  }
@@ -931,3 +945,114 @@ def get_qa_format_instruction(output_format: str, brevity: str) -> str:
931
  if brevity in ("minimal", "short"):
932
  rules.append("FAIL if the output is excessively verbose for a brevity requirement.")
933
  return "\n".join(rules) if rules else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  if kw.lower() in lower:
524
  score += 1
525
 
526
+ # Domain affinity — boost if the request touches a role's domain
527
+ for domain in meta.get("domains", []):
528
+ if domain.lower() in lower:
529
+ score += 1
530
+
531
  # Task-category affinity bonus
532
  role_tasks = meta.get("task_types", [])
533
  if task_category in role_tasks:
 
756
  selected_roles: List[str] = field(default_factory=list)
757
  specialist_outputs: Dict[str, str] = field(default_factory=dict)
758
  evidence: Optional[Dict] = None # serialised EvidenceResult
759
+ task_assumptions: Dict[str, str] = field(default_factory=dict)
760
  current_draft: str = ""
761
  qa_result: Optional[QAResult] = None
762
  revision_count: int = 0
763
  max_revisions: int = 3
764
+ revision_instruction: str = "" # latest revision instruction from planner
765
  failure_history: List[FailureRecord] = field(default_factory=list)
766
  history: List[Dict[str, str]] = field(default_factory=list)
767
  final_answer: str = ""
 
836
  ]
837
  if self.success_criteria:
838
  lines.append(f"Success criteria: {'; '.join(self.success_criteria)}")
839
+ if self.task_assumptions:
840
+ assumptions_str = "; ".join(f"{k}: {v}" for k, v in self.task_assumptions.items())
841
+ lines.append(f"Shared assumptions: {assumptions_str}")
842
  if self.evidence:
843
  conf = self.evidence.get("confidence", "unknown")
844
  n_items = len(self.evidence.get("results", []))
845
  lines.append(f"Evidence: {n_items} items (confidence: {conf})")
846
+ if self.revision_instruction:
847
+ lines.append(f"Revision instruction: {self.revision_instruction}")
848
  if self.qa_result and not self.qa_result.passed:
849
  lines.append(f"QA status: FAIL — {self.qa_result.reason}")
850
  if self.qa_result.correction_instruction:
 
868
  "selected_roles": self.selected_roles,
869
  "specialist_outputs": self.specialist_outputs,
870
  "evidence": self.evidence,
871
+ "task_assumptions": self.task_assumptions,
872
  "current_draft": self.current_draft[:500],
873
  "revision_count": self.revision_count,
874
  "max_revisions": self.max_revisions,
875
+ "revision_instruction": self.revision_instruction,
876
  "failure_history": [f.to_dict() for f in self.failure_history],
877
  "final_answer": self.final_answer[:500] if self.final_answer else "",
878
  }
 
945
  if brevity in ("minimal", "short"):
946
  rules.append("FAIL if the output is excessively verbose for a brevity requirement.")
947
  return "\n".join(rules) if rules else ""
948
+
949
+
950
+ # ============================================================
951
+ # Output Format Validation (pre-QA structural check)
952
+ # ============================================================
953
+
954
+ def validate_output_format(text: str, output_format: str, brevity: str) -> List[str]:
955
+ """Check structural format constraints before QA.
956
+
957
+ Returns a list of violation descriptions. Empty list means the output is valid.
958
+ This catches common structural problems that the synthesizer repeatedly ignores
959
+ (e.g., bullet lists when paragraph-only was requested).
960
+ """
961
+ violations: List[str] = []
962
+ stripped = text.strip()
963
+ if not stripped:
964
+ violations.append("Output is empty.")
965
+ return violations
966
+
967
+ has_bullets = bool(re.search(r"^[\s]*[-•*]\s", stripped, re.MULTILINE))
968
+ has_numbered = bool(re.search(r"^[\s]*\d+[.)]\s", stripped, re.MULTILINE))
969
+ has_headings = bool(re.search(r"^#{1,4}\s", stripped, re.MULTILINE))
970
+ has_table = bool(re.search(r"\|.*\|.*\|", stripped))
971
+ has_code_block = "```" in stripped
972
+ line_count = len([ln for ln in stripped.splitlines() if ln.strip()])
973
+
974
+ if output_format == "paragraph":
975
+ if has_bullets or has_numbered:
976
+ violations.append("Paragraph format requested but output contains bullet/numbered lists.")
977
+ if has_headings:
978
+ violations.append("Paragraph format requested but output contains markdown headings.")
979
+ if has_table:
980
+ violations.append("Paragraph format requested but output contains a table.")
981
+
982
+ elif output_format == "code":
983
+ if not has_code_block and not re.search(r"(?:def |class |import |function |const |let |var )", stripped):
984
+ violations.append("Code format requested but output contains no code block or recognisable code.")
985
+
986
+ elif output_format == "table":
987
+ if not has_table:
988
+ violations.append("Table format requested but output contains no markdown table.")
989
+
990
+ elif output_format == "single_choice":
991
+ if line_count > 5:
992
+ violations.append("Single choice requested but output is multi-section (too many lines).")
993
+ if has_bullets and line_count > 3:
994
+ violations.append("Single choice requested but output contains a bullet list.")
995
+
996
+ # Brevity checks
997
+ if brevity == "minimal" and line_count > 8:
998
+ violations.append(f"Minimal brevity requested but output has {line_count} lines.")
999
+ elif brevity == "short" and line_count > 20:
1000
+ violations.append(f"Short brevity requested but output has {line_count} lines.")
1001
+
1002
+ return violations
1003
+
1004
+
1005
+ def format_violations_instruction(violations: List[str]) -> str:
1006
+ """Turn format violation descriptions into a synthesis rewrite instruction."""
1007
+ return (
1008
+ "FORMAT VIOLATIONS DETECTED — you MUST fix these before QA:\n"
1009
+ + "\n".join(f"- {v}" for v in violations)
1010
+ + "\nRewrite the output to satisfy the required format strictly."
1011
+ )
1012
+
1013
+
1014
+ # ============================================================
1015
+ # Shared Assumptions Parsing
1016
+ # ============================================================
1017
+
1018
+ def parse_task_assumptions(plan_text: str) -> Dict[str, str]:
1019
+ """Extract TASK ASSUMPTIONS from planner output.
1020
+
1021
+ Looks for lines like 'key: value' under a TASK ASSUMPTIONS: header.
1022
+ Returns a dict of assumption key → value.
1023
+ """
1024
+ assumptions: Dict[str, str] = {}
1025
+ if "TASK ASSUMPTIONS:" not in plan_text:
1026
+ return assumptions
1027
+
1028
+ section = plan_text.split("TASK ASSUMPTIONS:", 1)[1]
1029
+ # Section ends at the next header (a line that ends with ':' and starts with caps)
1030
+ for header in (
1031
+ "TASK BREAKDOWN:", "ROLE TO CALL:", "SUCCESS CRITERIA:",
1032
+ "GUIDANCE FOR SPECIALIST:", "REVISED INSTRUCTIONS:",
1033
+ ):
1034
+ if header in section:
1035
+ section = section.split(header, 1)[0]
1036
+ break
1037
+
1038
+ for line in section.strip().splitlines():
1039
+ line = line.strip().lstrip("•-* ")
1040
+ if ":" not in line:
1041
+ continue
1042
+ key, _, value = line.partition(":")
1043
+ key = key.strip().lower().replace(" ", "_")
1044
+ value = value.strip()
1045
+ if key and value:
1046
+ assumptions[key] = value
1047
+
1048
+ return assumptions
1049
+
1050
+
1051
+ def format_assumptions_for_prompt(assumptions: Dict[str, str]) -> str:
1052
+ """Format shared assumptions for injection into specialist prompts."""
1053
+ if not assumptions:
1054
+ return ""
1055
+ lines = ["SHARED TASK ASSUMPTIONS (use these — do NOT invent your own):"]
1056
+ for key, value in assumptions.items():
1057
+ lines.append(f" - {key}: {value}")
1058
+ return "\n".join(lines)