Spaces:
Sleeping
Sleeping
Implement batch classification for uploaded datasets and enhance export functionality with reasoning
Browse files
src/interface/file_upload_interface.py
CHANGED
|
@@ -361,6 +361,109 @@ class FileUploadInterfaceController(ProgressTrackingMixin):
|
|
| 361 |
classification["indicators"] = indicators
|
| 362 |
|
| 363 |
return classification
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
def submit_batch_verification(self, is_correct: bool, correction: Optional[str] = None, notes: str = "") -> Tuple[bool, str, Dict[str, Any]]:
|
| 366 |
"""
|
|
@@ -856,47 +959,40 @@ def create_file_upload_interface() -> gr.Blocks:
|
|
| 856 |
success, message, session = controller.start_batch_processing(verifier_name, file_result)
|
| 857 |
|
| 858 |
if success:
|
| 859 |
-
#
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
if
|
| 863 |
-
|
| 864 |
-
expected_badge = controller._get_classification_badge(current_message.pre_classified_label)
|
| 865 |
-
actual_badge = controller._get_classification_badge(classification_result.get('decision', 'unknown'))
|
| 866 |
-
confidence_text = f"📊 {classification_result.get('confidence', 0) * 100:.1f}% confident"
|
| 867 |
-
indicators_text = "🔍 " + ", ".join(classification_result.get('indicators', ['No indicators']))
|
| 868 |
-
|
| 869 |
-
progress_text = f"Progress: 1 of {len(file_result.parsed_test_cases)} messages"
|
| 870 |
-
|
| 871 |
-
return (
|
| 872 |
-
gr.Row(visible=True), # message_processing_section
|
| 873 |
-
session, # current_session_state
|
| 874 |
-
current_message.text, # current_message_display
|
| 875 |
-
f"Expected: {expected_badge}", # expected_classification_display
|
| 876 |
-
f"AI Result: {actual_badge}", # actual_classification_display
|
| 877 |
-
confidence_text, # classifier_confidence_display
|
| 878 |
-
indicators_text, # classifier_indicators_display
|
| 879 |
-
progress_text, # batch_progress_display
|
| 880 |
-
gr.DownloadButton(visible=False), # export_csv_btn
|
| 881 |
-
gr.DownloadButton(visible=False), # export_json_btn
|
| 882 |
-
gr.DownloadButton(visible=False), # export_xlsx_btn
|
| 883 |
-
message # status_message
|
| 884 |
-
)
|
| 885 |
-
else:
|
| 886 |
return (
|
| 887 |
-
gr.Row(visible=False), # message_processing_section
|
| 888 |
session, # current_session_state
|
| 889 |
"", # current_message_display
|
| 890 |
"", # expected_classification_display
|
| 891 |
"", # actual_classification_display
|
| 892 |
"", # classifier_confidence_display
|
| 893 |
"", # classifier_indicators_display
|
| 894 |
-
|
| 895 |
-
gr.DownloadButton(visible=
|
| 896 |
-
gr.DownloadButton(visible=
|
| 897 |
-
gr.DownloadButton(visible=
|
| 898 |
-
|
| 899 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 900 |
else:
|
| 901 |
return (
|
| 902 |
gr.Row(visible=False), # message_processing_section
|
|
|
|
| 361 |
classification["indicators"] = indicators
|
| 362 |
|
| 363 |
return classification
|
| 364 |
+
|
| 365 |
+
def run_batch_classification(self) -> Tuple[bool, str, Dict[str, Any]]:
|
| 366 |
+
"""Run classification for the whole uploaded dataset and persist results.
|
| 367 |
+
|
| 368 |
+
File Upload Mode is already labeled (ground truth provided in the file), so we
|
| 369 |
+
don't need interactive message-by-message verification. Instead, we:
|
| 370 |
+
- classify every message
|
| 371 |
+
- store the model output as reasoning in `verifier_notes`
|
| 372 |
+
- mark each record as correct/incorrect by comparing to ground truth
|
| 373 |
+
"""
|
| 374 |
+
if not self.current_session or not self.current_file_result:
|
| 375 |
+
return False, "❌ No active session", {}
|
| 376 |
+
|
| 377 |
+
total = len(self.current_file_result.parsed_test_cases)
|
| 378 |
+
if total == 0:
|
| 379 |
+
return False, "❌ No messages to process", {}
|
| 380 |
+
|
| 381 |
+
try:
|
| 382 |
+
# Reset any prior run state
|
| 383 |
+
self.current_session.verifications = []
|
| 384 |
+
self.current_session.verified_count = 0
|
| 385 |
+
self.current_session.correct_count = 0
|
| 386 |
+
self.current_session.incorrect_count = 0
|
| 387 |
+
self.current_session.verified_message_ids = []
|
| 388 |
+
|
| 389 |
+
self.setup_progress_tracking(total)
|
| 390 |
+
|
| 391 |
+
for idx, test_message in enumerate(self.current_file_result.parsed_test_cases):
|
| 392 |
+
self.batch_processing_start_time = datetime.now()
|
| 393 |
+
|
| 394 |
+
user_prompt = (
|
| 395 |
+
"Please analyze this patient message for spiritual distress:\n\n"
|
| 396 |
+
f"{test_message.text}"
|
| 397 |
+
)
|
| 398 |
+
raw_response = self.ai_client.call_spiritual_api(
|
| 399 |
+
system_prompt=SYSTEM_PROMPT_ENTRY_CLASSIFIER,
|
| 400 |
+
user_prompt=user_prompt,
|
| 401 |
+
temperature=0.3,
|
| 402 |
+
)
|
| 403 |
+
|
| 404 |
+
classification_result = self._parse_classification_response(raw_response)
|
| 405 |
+
|
| 406 |
+
classifier_decision = classification_result.get("decision", "green")
|
| 407 |
+
if classifier_decision not in ["green", "yellow", "red"]:
|
| 408 |
+
classifier_decision = "green"
|
| 409 |
+
|
| 410 |
+
ground_truth = test_message.pre_classified_label
|
| 411 |
+
if ground_truth not in ["green", "yellow", "red"]:
|
| 412 |
+
ground_truth = "green"
|
| 413 |
+
|
| 414 |
+
is_correct = classifier_decision == ground_truth
|
| 415 |
+
|
| 416 |
+
verification_record = VerificationRecord(
|
| 417 |
+
message_id=test_message.message_id,
|
| 418 |
+
original_message=test_message.text,
|
| 419 |
+
classifier_decision=classifier_decision,
|
| 420 |
+
classifier_confidence=classification_result.get("confidence", 0.0),
|
| 421 |
+
classifier_indicators=classification_result.get("indicators", []),
|
| 422 |
+
ground_truth_label=ground_truth,
|
| 423 |
+
verifier_notes=raw_response, # store full LLM output as reasoning
|
| 424 |
+
is_correct=is_correct,
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
self.current_session.verifications.append(verification_record)
|
| 428 |
+
self.current_session.verified_count += 1
|
| 429 |
+
self.current_session.verified_message_ids.append(test_message.message_id)
|
| 430 |
+
if is_correct:
|
| 431 |
+
self.current_session.correct_count += 1
|
| 432 |
+
else:
|
| 433 |
+
self.current_session.incorrect_count += 1
|
| 434 |
+
|
| 435 |
+
self.record_verification_with_timing(is_correct, self.batch_processing_start_time)
|
| 436 |
+
self.current_session.current_queue_index = idx + 1
|
| 437 |
+
|
| 438 |
+
self.current_session.is_complete = True
|
| 439 |
+
self.current_session.completed_at = datetime.now()
|
| 440 |
+
|
| 441 |
+
self.store.save_session(self.current_session)
|
| 442 |
+
|
| 443 |
+
accuracy = (
|
| 444 |
+
(self.current_session.correct_count / self.current_session.verified_count * 100)
|
| 445 |
+
if self.current_session.verified_count
|
| 446 |
+
else 0
|
| 447 |
+
)
|
| 448 |
+
stats = {
|
| 449 |
+
"processed": self.current_session.verified_count,
|
| 450 |
+
"total": total,
|
| 451 |
+
"correct": self.current_session.correct_count,
|
| 452 |
+
"incorrect": self.current_session.incorrect_count,
|
| 453 |
+
"accuracy": accuracy,
|
| 454 |
+
"is_complete": True,
|
| 455 |
+
}
|
| 456 |
+
return True, f"✅ Batch classification completed. Accuracy: {accuracy:.1f}%", stats
|
| 457 |
+
|
| 458 |
+
except Exception as e:
|
| 459 |
+
return False, f"❌ Error during batch classification: {str(e)}", {}
|
| 460 |
+
|
| 461 |
+
def export_batch_results_with_reasoning(self, format_type: str) -> Tuple[bool, str, Optional[str]]:
|
| 462 |
+
"""Export results including LLM reasoning.
|
| 463 |
+
|
| 464 |
+
We rely on `verifier_notes` field to carry reasoning (raw model output).
|
| 465 |
+
"""
|
| 466 |
+
return self.export_batch_results(format_type)
|
| 467 |
|
| 468 |
def submit_batch_verification(self, is_correct: bool, correction: Optional[str] = None, notes: str = "") -> Tuple[bool, str, Dict[str, Any]]:
|
| 469 |
"""
|
|
|
|
| 959 |
success, message, session = controller.start_batch_processing(verifier_name, file_result)
|
| 960 |
|
| 961 |
if success:
|
| 962 |
+
# Simplified behavior: dataset is already labeled, so run full batch
|
| 963 |
+
# classification immediately and generate results for export.
|
| 964 |
+
run_ok, run_msg, stats = controller.run_batch_classification()
|
| 965 |
+
if run_ok:
|
| 966 |
+
progress_text = f"✅ Completed: {stats.get('processed', 0)}/{stats.get('total', 0)} messages"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 967 |
return (
|
| 968 |
+
gr.Row(visible=False), # message_processing_section (not used in simplified flow)
|
| 969 |
session, # current_session_state
|
| 970 |
"", # current_message_display
|
| 971 |
"", # expected_classification_display
|
| 972 |
"", # actual_classification_display
|
| 973 |
"", # classifier_confidence_display
|
| 974 |
"", # classifier_indicators_display
|
| 975 |
+
progress_text, # batch_progress_display
|
| 976 |
+
gr.DownloadButton(visible=True), # export_csv_btn
|
| 977 |
+
gr.DownloadButton(visible=True), # export_json_btn
|
| 978 |
+
gr.DownloadButton(visible=True), # export_xlsx_btn
|
| 979 |
+
run_msg # status_message
|
| 980 |
)
|
| 981 |
+
|
| 982 |
+
return (
|
| 983 |
+
gr.Row(visible=False), # message_processing_section
|
| 984 |
+
session, # current_session_state
|
| 985 |
+
"", # current_message_display
|
| 986 |
+
"", # expected_classification_display
|
| 987 |
+
"", # actual_classification_display
|
| 988 |
+
"", # classifier_confidence_display
|
| 989 |
+
"", # classifier_indicators_display
|
| 990 |
+
"❌ Batch classification failed", # batch_progress_display
|
| 991 |
+
gr.DownloadButton(visible=False), # export_csv_btn
|
| 992 |
+
gr.DownloadButton(visible=False), # export_json_btn
|
| 993 |
+
gr.DownloadButton(visible=False), # export_xlsx_btn
|
| 994 |
+
run_msg # status_message
|
| 995 |
+
)
|
| 996 |
else:
|
| 997 |
return (
|
| 998 |
gr.Row(visible=False), # message_processing_section
|