KeenWoo commited on
Commit
b008ff6
·
verified ·
1 Parent(s): 11de8fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -6
app.py CHANGED
@@ -665,7 +665,79 @@ def run_nlu_test(test_title: str):
665
  print("--- TEST COMPLETE ---\n")
666
  return status, comparison_data
667
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668
 
 
 
 
 
 
 
669
 
670
  # In app.py, inside the Gradio Callbacks section for debugging
671
  def test_save_file():
@@ -741,17 +813,20 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo:
741
 
742
  with gr.Tab("Testing"):
743
  gr.Markdown("## NLU Context Detection Tests")
744
- gr.Markdown("Select a test case from `conversation_test_fixtures.jsonl` to run it through the NLU classifier and see the results.")
 
745
  with gr.Row():
746
- test_case_dropdown = gr.Dropdown(label="Select Test Case", scale=3)
747
- run_test_btn = gr.Button("Load & Run Test", variant="primary", scale=1)
748
- test_status_md = gr.Markdown("### Please select and run a test case.")
 
749
  test_results_df = gr.DataFrame(
750
  label="Test Results Comparison",
751
- headers=["Category", "Expected", "Actual", "Result"],
 
752
  interactive=False
753
  )
754
-
755
  with gr.Tab("Settings"):
756
  with gr.Group():
757
  gr.Markdown("## Conversation & Persona Settings")
@@ -823,6 +898,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo:
823
  test_save_btn.click(fn=test_save_file, inputs=None, outputs=[test_status])
824
  check_save_btn.click(fn=check_test_file, inputs=None, outputs=[test_status])
825
 
 
826
  # --- ADD WIRING FOR THE TESTING TAB ---
827
  demo.load(load_test_fixtures, outputs=[test_case_dropdown])
828
  run_test_btn.click(
@@ -830,7 +906,10 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CSS) as demo:
830
  inputs=[test_case_dropdown],
831
  outputs=[test_status_md, test_results_df]
832
  )
 
 
833
 
 
834
  # --- Startup Logic ---
835
  def pre_load_indexes():
836
  global personal_vectorstore
 
665
  print("--- TEST COMPLETE ---\n")
666
  return status, comparison_data
667
 
668
+ # add the new function that will run when the "Run All Tests" button is clicked.
669
+ def run_all_nlu_tests():
670
+ """Runs all test fixtures in a batch and provides a summary."""
671
+ if not test_fixtures:
672
+ load_test_fixtures()
673
+ if not test_fixtures:
674
+ return "## Batch Test Summary: No test fixtures found. Please ensure `conversation_test_fixtures.jsonl` is present.", []
675
+
676
+ print("\n--- RUNNING ALL NLU TESTS ---")
677
+
678
+ behavior_options = CONFIG.get("behavior_tags", [])
679
+ emotion_options = CONFIG.get("emotion_tags", [])
680
+ topic_options = CONFIG.get("topic_tags", [])
681
+ context_options = CONFIG.get("context_tags", [])
682
+
683
+ total_tests = len(test_fixtures)
684
+ passed_tests = 0
685
+ all_results_data = []
686
+
687
+ for fixture in test_fixtures:
688
+ user_query = fixture["turns"][0]["text"]
689
+ expected_results = fixture["expected"]
690
+
691
+ actual_results_raw = detect_tags_from_query(
692
+ user_query,
693
+ behavior_options=behavior_options,
694
+ emotion_options=emotion_options,
695
+ topic_options=topic_options,
696
+ context_options=context_options
697
+ )
698
+
699
+ actual_results = {
700
+ "emotion": [actual_results_raw.get("detected_emotion")],
701
+ "behaviors": actual_results_raw.get("detected_behaviors", []),
702
+ "topic_tags": [actual_results_raw.get("detected_topic")],
703
+ "context_tags": actual_results_raw.get("detected_contexts", [])
704
+ }
705
+
706
+ pass_count = 0
707
+ total_count = 0
708
+
709
+ all_keys = set(expected_results.keys())
710
+ for key in sorted(list(all_keys)):
711
+ expected_set = set(expected_results.get(key, []))
712
+ if not expected_set: continue
713
+
714
+ total_count += 1
715
+ actual_set = set(a for a in actual_results.get(key, []) if a and a != "None")
716
+
717
+ # Flexible pass logic: passes if there is any overlap
718
+ is_pass = len(expected_set.intersection(actual_set)) > 0
719
+ if is_pass:
720
+ pass_count += 1
721
+
722
+ # A test case passes if all its expected categories pass
723
+ if total_count > 0 and pass_count == total_count:
724
+ passed_tests += 1
725
+ overall_result = "✅ Pass"
726
+ else:
727
+ overall_result = "❌ Fail"
728
+
729
+ all_results_data.append([
730
+ fixture["title"],
731
+ overall_result,
732
+ f"{pass_count} / {total_count}"
733
+ ])
734
 
735
+ pass_rate = (passed_tests / total_tests) * 100 if total_tests > 0 else 0
736
+ summary_md = f"## Batch Test Summary: {passed_tests} / {total_tests} Tests Passed ({pass_rate:.1f}%)"
737
+ print(f"--- BATCH TEST COMPLETE: {summary_md} ---")
738
+
739
+ return summary_md, all_results_data
740
+
741
 
742
  # In app.py, inside the Gradio Callbacks section for debugging
743
  def test_save_file():
 
813
 
814
  with gr.Tab("Testing"):
815
  gr.Markdown("## NLU Context Detection Tests")
816
+ gr.Markdown("Select a single test case to run, or run the entire batch of fixtures to get a summary of the NLU's performance.")
817
+ batch_summary_md = gr.Markdown("### Batch Test Summary: Not yet run.") # <-- ADD THIS
818
  with gr.Row():
819
+ test_case_dropdown = gr.Dropdown(label="Select Single Test Case", scale=2)
820
+ run_test_btn = gr.Button("Run Single Test", variant="secondary", scale=1)
821
+ run_all_btn = gr.Button("Run All Tests", variant="primary", scale=1) # <-- ADD THIS
822
+ test_status_md = gr.Markdown("### Test Results")
823
  test_results_df = gr.DataFrame(
824
  label="Test Results Comparison",
825
+ # UPDATE these headers for the batch summary
826
+ headers=["Test Case Title", "Overall Result", "Categories Passed"],
827
  interactive=False
828
  )
829
+
830
  with gr.Tab("Settings"):
831
  with gr.Group():
832
  gr.Markdown("## Conversation & Persona Settings")
 
898
  test_save_btn.click(fn=test_save_file, inputs=None, outputs=[test_status])
899
  check_save_btn.click(fn=check_test_file, inputs=None, outputs=[test_status])
900
 
901
+
902
  # --- ADD WIRING FOR THE TESTING TAB ---
903
  demo.load(load_test_fixtures, outputs=[test_case_dropdown])
904
  run_test_btn.click(
 
906
  inputs=[test_case_dropdown],
907
  outputs=[test_status_md, test_results_df]
908
  )
909
+ # --- ADD THE LINE BELOW ---
910
+ run_all_btn.click(fn=run_all_nlu_tests, outputs=[batch_summary_md, test_results_df])
911
 
912
+
913
  # --- Startup Logic ---
914
  def pre_load_indexes():
915
  global personal_vectorstore