chrissoria commited on
Commit
555ea2f
·
1 Parent(s): 249e203

Add progress bar with ETA for classification

Browse files
Files changed (1) hide show
  1. app.py +111 -19
app.py CHANGED
@@ -890,35 +890,127 @@ with col_input:
890
  elif not categories_entered:
891
  st.error("Please enter at least one category")
892
  else:
893
- with st.spinner("Classifying data... This may take a few minutes."):
894
- mode = None
895
- if input_type_selected == "pdf":
896
- mode_mapping = {
897
- "Image (visual documents)": "image",
898
- "Text (text-heavy)": "text",
899
- "Both (comprehensive)": "both"
900
- }
901
- mode = mode_mapping.get(pdf_mode, "image")
902
-
903
- result_df, csv_path, pdf_path, code, status = run_classify_data(
904
- input_type_selected, input_data, description,
905
- categories_entered, model_tier, model, api_key, mode,
906
- original_filename, description
907
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
908
 
909
- if result_df is not None:
910
  st.session_state.results = {
911
  'df': result_df,
912
  'csv_path': csv_path,
913
  'pdf_path': pdf_path,
914
  'code': code,
915
- 'status': status,
916
  'categories': categories_entered
917
  }
918
- st.success(status)
919
  st.rerun()
920
  else:
921
- st.error(status)
922
 
923
  with col_output:
924
  st.markdown("### Results")
 
890
  elif not categories_entered:
891
  st.error("Please enter at least one category")
892
  else:
893
+ # Set up progress tracking
894
+ mode = None
895
+ if input_type_selected == "pdf":
896
+ mode_mapping = {
897
+ "Image (visual documents)": "image",
898
+ "Text (text-heavy)": "text",
899
+ "Both (comprehensive)": "both"
900
+ }
901
+ mode = mode_mapping.get(pdf_mode, "image")
902
+
903
+ actual_api_key, provider = get_api_key(model, model_tier, api_key)
904
+ if not actual_api_key:
905
+ st.error(f"{provider} API key not configured")
906
+ else:
907
+ model_source = get_model_source(model)
908
+ total_items = len(input_data) if isinstance(input_data, list) else 1
909
+
910
+ # Progress UI
911
+ progress_bar = st.progress(0)
912
+ status_text = st.empty()
913
+
914
+ all_results = []
915
+ start_time = time.time()
916
+
917
+ for i, item in enumerate(input_data if isinstance(input_data, list) else [input_data]):
918
+ # Update progress
919
+ progress = (i / total_items)
920
+ progress_bar.progress(progress)
921
+
922
+ # Calculate ETA
923
+ elapsed = time.time() - start_time
924
+ if i > 0:
925
+ avg_time_per_item = elapsed / i
926
+ remaining_items = total_items - i
927
+ eta_seconds = avg_time_per_item * remaining_items
928
+ eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
929
+ else:
930
+ eta_str = ""
931
+
932
+ status_text.text(f"Processing item {i+1} of {total_items} ({progress*100:.0f}%){eta_str}")
933
+
934
+ try:
935
+ classify_kwargs = {
936
+ 'input_data': [item],
937
+ 'categories': categories_entered,
938
+ 'api_key': actual_api_key,
939
+ 'input_type': input_type_selected,
940
+ 'description': description,
941
+ 'user_model': model,
942
+ 'model_source': model_source
943
+ }
944
+ if mode:
945
+ classify_kwargs['mode'] = mode
946
+
947
+ item_result = catllm.classify(**classify_kwargs)
948
+ all_results.append(item_result)
949
+ except Exception as e:
950
+ st.warning(f"Error on item {i+1}: {str(e)}")
951
+ continue
952
+
953
+ # Complete progress
954
+ progress_bar.progress(1.0)
955
+ processing_time = time.time() - start_time
956
+ status_text.text(f"Completed {total_items} items in {processing_time:.1f}s")
957
+
958
+ if all_results:
959
+ # Combine results
960
+ result_df = pd.concat(all_results, ignore_index=True)
961
+
962
+ # Save CSV
963
+ with tempfile.NamedTemporaryFile(mode='w', suffix='_classified.csv', delete=False) as f:
964
+ result_df.to_csv(f.name, index=False)
965
+ csv_path = f.name
966
+
967
+ # Calculate success rate
968
+ if 'processing_status' in result_df.columns:
969
+ success_count = (result_df['processing_status'] == 'success').sum()
970
+ success_rate = (success_count / len(result_df)) * 100
971
+ else:
972
+ success_rate = 100.0
973
+
974
+ # Get version info
975
+ try:
976
+ catllm_version = catllm.__version__
977
+ except AttributeError:
978
+ catllm_version = "unknown"
979
+ python_version = sys.version.split()[0]
980
+
981
+ # Generate methodology report
982
+ pdf_path = generate_methodology_report_pdf(
983
+ categories=categories_entered,
984
+ model=model,
985
+ column_name=description,
986
+ num_rows=len(result_df),
987
+ model_source=model_source,
988
+ filename=original_filename,
989
+ success_rate=success_rate,
990
+ result_df=result_df,
991
+ processing_time=processing_time,
992
+ catllm_version=catllm_version,
993
+ python_version=python_version,
994
+ task_type="assign",
995
+ input_type=input_type_selected,
996
+ description=description
997
+ )
998
+
999
+ # Generate code
1000
+ code = generate_classify_code(input_type_selected, description, categories_entered, model, model_source, mode)
1001
 
 
1002
  st.session_state.results = {
1003
  'df': result_df,
1004
  'csv_path': csv_path,
1005
  'pdf_path': pdf_path,
1006
  'code': code,
1007
+ 'status': f"Classified {len(result_df)} items in {processing_time:.1f}s",
1008
  'categories': categories_entered
1009
  }
1010
+ st.success(f"Classified {len(result_df)} items in {processing_time:.1f}s")
1011
  st.rerun()
1012
  else:
1013
+ st.error("No items were successfully classified")
1014
 
1015
  with col_output:
1016
  st.markdown("### Results")