chrissoria commited on
Commit
df05499
·
1 Parent(s): af04675

Use progress callback for page-by-page PDF progress; install catllm from GitHub

Browse files
Files changed (2) hide show
  1. app.py +50 -65
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1027,77 +1027,67 @@ with col_input:
1027
  # Progress UI
1028
  progress_bar = st.progress(0)
1029
  status_text = st.empty()
1030
-
1031
- all_results = []
1032
  start_time = time.time()
1033
 
1034
- # For PDFs, extract pages and process one at a time
1035
  if input_type_selected == "pdf":
1036
- pdf_name_map = st.session_state.get('pdf_name_map', {})
1037
- status_text.text("Extracting PDF pages...")
1038
-
1039
- pages = extract_pdf_pages(items_list, pdf_name_map, mode)
1040
- total_pages = len(pages)
1041
-
1042
- for i, page_data in enumerate(pages):
1043
- # Update progress
1044
- progress = i / total_pages if total_pages > 0 else 0
1045
  progress_bar.progress(min(progress, 1.0))
1046
 
1047
- # Calculate ETA
1048
  elapsed = time.time() - start_time
1049
- if i > 0:
1050
- avg_time = elapsed / i
1051
- eta_seconds = avg_time * (total_pages - i)
1052
  eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
1053
  else:
1054
  eta_str = ""
1055
 
1056
- page_label = page_data[1]
1057
- status_text.text(f"Processing page {i+1} of {total_pages} ({page_label}) ({progress*100:.0f}%){eta_str}")
1058
 
1059
- try:
1060
- if page_data[2] == "text":
1061
- # Text mode - classify as text
1062
- result = catllm.classify(
1063
- input_data=[page_data[0]],
1064
- categories=categories_entered,
1065
- api_key=actual_api_key,
1066
- input_type="text",
1067
- description=description,
1068
- user_model=model,
1069
- model_source=model_source
1070
- )
1071
- else:
1072
- # Image mode - classify as image
1073
- result = catllm.classify(
1074
- input_data=[page_data[0]],
1075
- categories=categories_entered,
1076
- api_key=actual_api_key,
1077
- input_type="image",
1078
- description=description,
1079
- user_model=model,
1080
- model_source=model_source
1081
- )
1082
-
1083
- # Replace the input column with the page label
1084
- if 'image_input' in result.columns:
1085
- result['pdf_input'] = page_label
1086
- result = result.drop(columns=['image_input'])
1087
- elif 'text_input' in result.columns:
1088
- result['pdf_input'] = page_label
1089
- result = result.drop(columns=['text_input'])
1090
- else:
1091
- result['pdf_input'] = page_label
1092
-
1093
- all_results.append(result)
1094
- except Exception as e:
1095
- st.warning(f"Error on {page_label}: {str(e)}")
1096
- continue
1097
 
1098
- total_items = total_pages
1099
  else:
1100
- # Non-PDF processing (text, images)
 
1101
  total_items = len(items_list)
1102
 
1103
  for i, item in enumerate(items_list):
@@ -1126,7 +1116,6 @@ with col_input:
1126
  )
1127
  all_results.append(item_result)
1128
 
1129
- # Update progress after processing
1130
  progress = (i + 1) / total_items if total_items > 0 else 1.0
1131
  progress_bar.progress(min(progress, 1.0))
1132
 
@@ -1134,12 +1123,8 @@ with col_input:
1134
  st.warning(f"Error on item {i+1}: {str(e)}")
1135
  continue
1136
 
1137
- # Complete progress
1138
- progress_bar.progress(1.0)
1139
- processing_time = time.time() - start_time
1140
- if input_type_selected == "pdf":
1141
- status_text.text(f"Completed {total_items} pages in {processing_time:.1f}s")
1142
- else:
1143
  status_text.text(f"Completed {total_items} items in {processing_time:.1f}s")
1144
 
1145
  if all_results:
 
1027
  # Progress UI
1028
  progress_bar = st.progress(0)
1029
  status_text = st.empty()
 
 
1030
  start_time = time.time()
1031
 
1032
+ # For PDFs, use progress callback
1033
  if input_type_selected == "pdf":
1034
+ # Progress callback for PDF page-by-page updates
1035
+ def pdf_progress_callback(current_idx, total_pages, page_label):
1036
+ progress = current_idx / total_pages if total_pages > 0 else 0
 
 
 
 
 
 
1037
  progress_bar.progress(min(progress, 1.0))
1038
 
 
1039
  elapsed = time.time() - start_time
1040
+ if current_idx > 0:
1041
+ avg_time = elapsed / current_idx
1042
+ eta_seconds = avg_time * (total_pages - current_idx)
1043
  eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
1044
  else:
1045
  eta_str = ""
1046
 
1047
+ status_text.text(f"Processing page {current_idx+1} of {total_pages} ({page_label}) ({progress*100:.0f}%){eta_str}")
 
1048
 
1049
+ try:
1050
+ result_df = catllm.classify(
1051
+ input_data=items_list,
1052
+ categories=categories_entered,
1053
+ api_key=actual_api_key,
1054
+ input_type="pdf",
1055
+ description=description,
1056
+ user_model=model,
1057
+ model_source=model_source,
1058
+ mode=mode,
1059
+ progress_callback=pdf_progress_callback
1060
+ )
1061
+
1062
+ processing_time = time.time() - start_time
1063
+ total_items = len(result_df)
1064
+ progress_bar.progress(1.0)
1065
+ status_text.text(f"Completed {total_items} pages in {processing_time:.1f}s")
1066
+
1067
+ # Replace temp paths with original filenames in pdf_input column
1068
+ if 'pdf_input' in result_df.columns:
1069
+ pdf_name_map = st.session_state.get('pdf_name_map', {})
1070
+ def replace_temp_path(val):
1071
+ if pd.isna(val):
1072
+ return val
1073
+ val_str = str(val)
1074
+ for temp_path, orig_name in pdf_name_map.items():
1075
+ # Check if the temp path's filename (without extension) is in the value
1076
+ temp_name = os.path.basename(temp_path).replace('.pdf', '')
1077
+ if temp_name in val_str:
1078
+ return val_str.replace(temp_name, orig_name)
1079
+ return val_str
1080
+ result_df['pdf_input'] = result_df['pdf_input'].apply(replace_temp_path)
1081
+
1082
+ all_results = [result_df]
1083
+
1084
+ except Exception as e:
1085
+ st.error(f"Error: {str(e)}")
1086
+ all_results = []
1087
 
 
1088
  else:
1089
+ # Non-PDF processing (text, images) - item by item
1090
+ all_results = []
1091
  total_items = len(items_list)
1092
 
1093
  for i, item in enumerate(items_list):
 
1116
  )
1117
  all_results.append(item_result)
1118
 
 
1119
  progress = (i + 1) / total_items if total_items > 0 else 1.0
1120
  progress_bar.progress(min(progress, 1.0))
1121
 
 
1123
  st.warning(f"Error on item {i+1}: {str(e)}")
1124
  continue
1125
 
1126
+ processing_time = time.time() - start_time
1127
+ progress_bar.progress(1.0)
 
 
 
 
1128
  status_text.text(f"Completed {total_items} items in {processing_time:.1f}s")
1129
 
1130
  if all_results:
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  streamlit>=1.32.0
2
- cat-llm[pdf]>=0.1.6
3
  mistralai
4
  pydantic==2.10.6
5
  huggingface_hub<0.27.0
 
1
  streamlit>=1.32.0
2
+ cat-llm[pdf] @ git+https://github.com/chrissoria/cat-llm.git@main
3
  mistralai
4
  pydantic==2.10.6
5
  huggingface_hub<0.27.0