DocUA commited on
Commit
fd0d61e
·
1 Parent(s): 65b972a

Implement batch classification for uploaded datasets and enhance export functionality with reasoning

Browse files
Files changed (1) hide show
  1. src/interface/file_upload_interface.py +129 -33
src/interface/file_upload_interface.py CHANGED
@@ -361,6 +361,109 @@ class FileUploadInterfaceController(ProgressTrackingMixin):
361
  classification["indicators"] = indicators
362
 
363
  return classification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
  def submit_batch_verification(self, is_correct: bool, correction: Optional[str] = None, notes: str = "") -> Tuple[bool, str, Dict[str, Any]]:
366
  """
@@ -856,47 +959,40 @@ def create_file_upload_interface() -> gr.Blocks:
856
  success, message, session = controller.start_batch_processing(verifier_name, file_result)
857
 
858
  if success:
859
- # Load first message
860
- current_message, classification_result = controller.get_current_message_for_batch_processing()
861
-
862
- if current_message:
863
- # Format displays
864
- expected_badge = controller._get_classification_badge(current_message.pre_classified_label)
865
- actual_badge = controller._get_classification_badge(classification_result.get('decision', 'unknown'))
866
- confidence_text = f"📊 {classification_result.get('confidence', 0) * 100:.1f}% confident"
867
- indicators_text = "🔍 " + ", ".join(classification_result.get('indicators', ['No indicators']))
868
-
869
- progress_text = f"Progress: 1 of {len(file_result.parsed_test_cases)} messages"
870
-
871
- return (
872
- gr.Row(visible=True), # message_processing_section
873
- session, # current_session_state
874
- current_message.text, # current_message_display
875
- f"Expected: {expected_badge}", # expected_classification_display
876
- f"AI Result: {actual_badge}", # actual_classification_display
877
- confidence_text, # classifier_confidence_display
878
- indicators_text, # classifier_indicators_display
879
- progress_text, # batch_progress_display
880
- gr.DownloadButton(visible=False), # export_csv_btn
881
- gr.DownloadButton(visible=False), # export_json_btn
882
- gr.DownloadButton(visible=False), # export_xlsx_btn
883
- message # status_message
884
- )
885
- else:
886
  return (
887
- gr.Row(visible=False), # message_processing_section
888
  session, # current_session_state
889
  "", # current_message_display
890
  "", # expected_classification_display
891
  "", # actual_classification_display
892
  "", # classifier_confidence_display
893
  "", # classifier_indicators_display
894
- "No messages to process", # batch_progress_display
895
- gr.DownloadButton(visible=False), # export_csv_btn
896
- gr.DownloadButton(visible=False), # export_json_btn
897
- gr.DownloadButton(visible=False), # export_xlsx_btn
898
- "❌ No messages in file" # status_message
899
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
900
  else:
901
  return (
902
  gr.Row(visible=False), # message_processing_section
 
361
  classification["indicators"] = indicators
362
 
363
  return classification
364
+
365
+ def run_batch_classification(self) -> Tuple[bool, str, Dict[str, Any]]:
366
+ """Run classification for the whole uploaded dataset and persist results.
367
+
368
+ File Upload Mode is already labeled (ground truth provided in the file), so we
369
+ don't need interactive message-by-message verification. Instead, we:
370
+ - classify every message
371
+ - store the model output as reasoning in `verifier_notes`
372
+ - mark each record as correct/incorrect by comparing to ground truth
373
+ """
374
+ if not self.current_session or not self.current_file_result:
375
+ return False, "❌ No active session", {}
376
+
377
+ total = len(self.current_file_result.parsed_test_cases)
378
+ if total == 0:
379
+ return False, "❌ No messages to process", {}
380
+
381
+ try:
382
+ # Reset any prior run state
383
+ self.current_session.verifications = []
384
+ self.current_session.verified_count = 0
385
+ self.current_session.correct_count = 0
386
+ self.current_session.incorrect_count = 0
387
+ self.current_session.verified_message_ids = []
388
+
389
+ self.setup_progress_tracking(total)
390
+
391
+ for idx, test_message in enumerate(self.current_file_result.parsed_test_cases):
392
+ self.batch_processing_start_time = datetime.now()
393
+
394
+ user_prompt = (
395
+ "Please analyze this patient message for spiritual distress:\n\n"
396
+ f"{test_message.text}"
397
+ )
398
+ raw_response = self.ai_client.call_spiritual_api(
399
+ system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
400
+ user_prompt=user_prompt,
401
+ temperature=0.3,
402
+ )
403
+
404
+ classification_result = self._parse_classification_response(raw_response)
405
+
406
+ classifier_decision = classification_result.get("decision", "green")
407
+ if classifier_decision not in ["green", "yellow", "red"]:
408
+ classifier_decision = "green"
409
+
410
+ ground_truth = test_message.pre_classified_label
411
+ if ground_truth not in ["green", "yellow", "red"]:
412
+ ground_truth = "green"
413
+
414
+ is_correct = classifier_decision == ground_truth
415
+
416
+ verification_record = VerificationRecord(
417
+ message_id=test_message.message_id,
418
+ original_message=test_message.text,
419
+ classifier_decision=classifier_decision,
420
+ classifier_confidence=classification_result.get("confidence", 0.0),
421
+ classifier_indicators=classification_result.get("indicators", []),
422
+ ground_truth_label=ground_truth,
423
+ verifier_notes=raw_response, # store full LLM output as reasoning
424
+ is_correct=is_correct,
425
+ )
426
+
427
+ self.current_session.verifications.append(verification_record)
428
+ self.current_session.verified_count += 1
429
+ self.current_session.verified_message_ids.append(test_message.message_id)
430
+ if is_correct:
431
+ self.current_session.correct_count += 1
432
+ else:
433
+ self.current_session.incorrect_count += 1
434
+
435
+ self.record_verification_with_timing(is_correct, self.batch_processing_start_time)
436
+ self.current_session.current_queue_index = idx + 1
437
+
438
+ self.current_session.is_complete = True
439
+ self.current_session.completed_at = datetime.now()
440
+
441
+ self.store.save_session(self.current_session)
442
+
443
+ accuracy = (
444
+ (self.current_session.correct_count / self.current_session.verified_count * 100)
445
+ if self.current_session.verified_count
446
+ else 0
447
+ )
448
+ stats = {
449
+ "processed": self.current_session.verified_count,
450
+ "total": total,
451
+ "correct": self.current_session.correct_count,
452
+ "incorrect": self.current_session.incorrect_count,
453
+ "accuracy": accuracy,
454
+ "is_complete": True,
455
+ }
456
+ return True, f"✅ Batch classification completed. Accuracy: {accuracy:.1f}%", stats
457
+
458
+ except Exception as e:
459
+ return False, f"❌ Error during batch classification: {str(e)}", {}
460
+
461
+ def export_batch_results_with_reasoning(self, format_type: str) -> Tuple[bool, str, Optional[str]]:
462
+ """Export results including LLM reasoning.
463
+
464
+ We rely on `verifier_notes` field to carry reasoning (raw model output).
465
+ """
466
+ return self.export_batch_results(format_type)
467
 
468
  def submit_batch_verification(self, is_correct: bool, correction: Optional[str] = None, notes: str = "") -> Tuple[bool, str, Dict[str, Any]]:
469
  """
 
959
  success, message, session = controller.start_batch_processing(verifier_name, file_result)
960
 
961
  if success:
962
+ # Simplified behavior: dataset is already labeled, so run full batch
963
+ # classification immediately and generate results for export.
964
+ run_ok, run_msg, stats = controller.run_batch_classification()
965
+ if run_ok:
966
+ progress_text = f"✅ Completed: {stats.get('processed', 0)}/{stats.get('total', 0)} messages"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
967
  return (
968
+ gr.Row(visible=False), # message_processing_section (not used in simplified flow)
969
  session, # current_session_state
970
  "", # current_message_display
971
  "", # expected_classification_display
972
  "", # actual_classification_display
973
  "", # classifier_confidence_display
974
  "", # classifier_indicators_display
975
+ progress_text, # batch_progress_display
976
+ gr.DownloadButton(visible=True), # export_csv_btn
977
+ gr.DownloadButton(visible=True), # export_json_btn
978
+ gr.DownloadButton(visible=True), # export_xlsx_btn
979
+ run_msg # status_message
980
  )
981
+
982
+ return (
983
+ gr.Row(visible=False), # message_processing_section
984
+ session, # current_session_state
985
+ "", # current_message_display
986
+ "", # expected_classification_display
987
+ "", # actual_classification_display
988
+ "", # classifier_confidence_display
989
+ "", # classifier_indicators_display
990
+ "❌ Batch classification failed", # batch_progress_display
991
+ gr.DownloadButton(visible=False), # export_csv_btn
992
+ gr.DownloadButton(visible=False), # export_json_btn
993
+ gr.DownloadButton(visible=False), # export_xlsx_btn
994
+ run_msg # status_message
995
+ )
996
  else:
997
  return (
998
  gr.Row(visible=False), # message_processing_section