Spaces:
Running
Running
Commit
·
555ea2f
1
Parent(s):
249e203
Add progress bar with ETA for classification
Browse files
app.py
CHANGED
|
@@ -890,35 +890,127 @@ with col_input:
|
|
| 890 |
elif not categories_entered:
|
| 891 |
st.error("Please enter at least one category")
|
| 892 |
else:
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 908 |
|
| 909 |
-
if result_df is not None:
|
| 910 |
st.session_state.results = {
|
| 911 |
'df': result_df,
|
| 912 |
'csv_path': csv_path,
|
| 913 |
'pdf_path': pdf_path,
|
| 914 |
'code': code,
|
| 915 |
-
'status':
|
| 916 |
'categories': categories_entered
|
| 917 |
}
|
| 918 |
-
st.success(
|
| 919 |
st.rerun()
|
| 920 |
else:
|
| 921 |
-
st.error(
|
| 922 |
|
| 923 |
with col_output:
|
| 924 |
st.markdown("### Results")
|
|
|
|
| 890 |
elif not categories_entered:
|
| 891 |
st.error("Please enter at least one category")
|
| 892 |
else:
|
| 893 |
+
# Set up progress tracking
|
| 894 |
+
mode = None
|
| 895 |
+
if input_type_selected == "pdf":
|
| 896 |
+
mode_mapping = {
|
| 897 |
+
"Image (visual documents)": "image",
|
| 898 |
+
"Text (text-heavy)": "text",
|
| 899 |
+
"Both (comprehensive)": "both"
|
| 900 |
+
}
|
| 901 |
+
mode = mode_mapping.get(pdf_mode, "image")
|
| 902 |
+
|
| 903 |
+
actual_api_key, provider = get_api_key(model, model_tier, api_key)
|
| 904 |
+
if not actual_api_key:
|
| 905 |
+
st.error(f"{provider} API key not configured")
|
| 906 |
+
else:
|
| 907 |
+
model_source = get_model_source(model)
|
| 908 |
+
total_items = len(input_data) if isinstance(input_data, list) else 1
|
| 909 |
+
|
| 910 |
+
# Progress UI
|
| 911 |
+
progress_bar = st.progress(0)
|
| 912 |
+
status_text = st.empty()
|
| 913 |
+
|
| 914 |
+
all_results = []
|
| 915 |
+
start_time = time.time()
|
| 916 |
+
|
| 917 |
+
for i, item in enumerate(input_data if isinstance(input_data, list) else [input_data]):
|
| 918 |
+
# Update progress
|
| 919 |
+
progress = (i / total_items)
|
| 920 |
+
progress_bar.progress(progress)
|
| 921 |
+
|
| 922 |
+
# Calculate ETA
|
| 923 |
+
elapsed = time.time() - start_time
|
| 924 |
+
if i > 0:
|
| 925 |
+
avg_time_per_item = elapsed / i
|
| 926 |
+
remaining_items = total_items - i
|
| 927 |
+
eta_seconds = avg_time_per_item * remaining_items
|
| 928 |
+
eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
|
| 929 |
+
else:
|
| 930 |
+
eta_str = ""
|
| 931 |
+
|
| 932 |
+
status_text.text(f"Processing item {i+1} of {total_items} ({progress*100:.0f}%){eta_str}")
|
| 933 |
+
|
| 934 |
+
try:
|
| 935 |
+
classify_kwargs = {
|
| 936 |
+
'input_data': [item],
|
| 937 |
+
'categories': categories_entered,
|
| 938 |
+
'api_key': actual_api_key,
|
| 939 |
+
'input_type': input_type_selected,
|
| 940 |
+
'description': description,
|
| 941 |
+
'user_model': model,
|
| 942 |
+
'model_source': model_source
|
| 943 |
+
}
|
| 944 |
+
if mode:
|
| 945 |
+
classify_kwargs['mode'] = mode
|
| 946 |
+
|
| 947 |
+
item_result = catllm.classify(**classify_kwargs)
|
| 948 |
+
all_results.append(item_result)
|
| 949 |
+
except Exception as e:
|
| 950 |
+
st.warning(f"Error on item {i+1}: {str(e)}")
|
| 951 |
+
continue
|
| 952 |
+
|
| 953 |
+
# Complete progress
|
| 954 |
+
progress_bar.progress(1.0)
|
| 955 |
+
processing_time = time.time() - start_time
|
| 956 |
+
status_text.text(f"Completed {total_items} items in {processing_time:.1f}s")
|
| 957 |
+
|
| 958 |
+
if all_results:
|
| 959 |
+
# Combine results
|
| 960 |
+
result_df = pd.concat(all_results, ignore_index=True)
|
| 961 |
+
|
| 962 |
+
# Save CSV
|
| 963 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='_classified.csv', delete=False) as f:
|
| 964 |
+
result_df.to_csv(f.name, index=False)
|
| 965 |
+
csv_path = f.name
|
| 966 |
+
|
| 967 |
+
# Calculate success rate
|
| 968 |
+
if 'processing_status' in result_df.columns:
|
| 969 |
+
success_count = (result_df['processing_status'] == 'success').sum()
|
| 970 |
+
success_rate = (success_count / len(result_df)) * 100
|
| 971 |
+
else:
|
| 972 |
+
success_rate = 100.0
|
| 973 |
+
|
| 974 |
+
# Get version info
|
| 975 |
+
try:
|
| 976 |
+
catllm_version = catllm.__version__
|
| 977 |
+
except AttributeError:
|
| 978 |
+
catllm_version = "unknown"
|
| 979 |
+
python_version = sys.version.split()[0]
|
| 980 |
+
|
| 981 |
+
# Generate methodology report
|
| 982 |
+
pdf_path = generate_methodology_report_pdf(
|
| 983 |
+
categories=categories_entered,
|
| 984 |
+
model=model,
|
| 985 |
+
column_name=description,
|
| 986 |
+
num_rows=len(result_df),
|
| 987 |
+
model_source=model_source,
|
| 988 |
+
filename=original_filename,
|
| 989 |
+
success_rate=success_rate,
|
| 990 |
+
result_df=result_df,
|
| 991 |
+
processing_time=processing_time,
|
| 992 |
+
catllm_version=catllm_version,
|
| 993 |
+
python_version=python_version,
|
| 994 |
+
task_type="assign",
|
| 995 |
+
input_type=input_type_selected,
|
| 996 |
+
description=description
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
# Generate code
|
| 1000 |
+
code = generate_classify_code(input_type_selected, description, categories_entered, model, model_source, mode)
|
| 1001 |
|
|
|
|
| 1002 |
st.session_state.results = {
|
| 1003 |
'df': result_df,
|
| 1004 |
'csv_path': csv_path,
|
| 1005 |
'pdf_path': pdf_path,
|
| 1006 |
'code': code,
|
| 1007 |
+
'status': f"Classified {len(result_df)} items in {processing_time:.1f}s",
|
| 1008 |
'categories': categories_entered
|
| 1009 |
}
|
| 1010 |
+
st.success(f"Classified {len(result_df)} items in {processing_time:.1f}s")
|
| 1011 |
st.rerun()
|
| 1012 |
else:
|
| 1013 |
+
st.error("No items were successfully classified")
|
| 1014 |
|
| 1015 |
with col_output:
|
| 1016 |
st.markdown("### Results")
|