chrissoria commited on
Commit
67ce2bc
·
1 Parent(s): 5542b9b

Simplify extraction: just call catllm.extract() instead of reimplementing logic

Browse files
Files changed (1) hide show
  1. app.py +22 -157
app.py CHANGED
@@ -849,172 +849,37 @@ with col_input:
849
  st.error(f"{provider} API key not configured")
850
  else:
851
  model_source = get_model_source(model)
852
- items_list = input_data if isinstance(input_data, list) else [input_data]
853
 
854
- # For PDF text mode, extract text and process like text corpus
855
- if input_type_selected == "pdf" and mode == "text":
856
- status_text = st.empty()
857
- status_text.text("Extracting text from PDFs...")
858
 
859
- # Extract text from all PDF pages
860
- extracted_texts = extract_text_from_pdfs(items_list)
 
 
 
 
 
 
 
 
 
861
 
862
- if not extracted_texts:
863
- st.error("No text could be extracted from the PDFs")
864
- else:
865
- # Now process as text corpus with random chunks
866
- import random
867
- random.shuffle(extracted_texts) # Randomize order
868
-
869
- num_items = len(extracted_texts)
870
- divisions = max(1, num_items // 15)
871
- divisions = min(divisions, 5)
872
- chunk_size_calc = num_items // max(1, divisions)
873
- categories_per_chunk = min(10, chunk_size_calc - 1) if chunk_size_calc > 1 else 5
874
-
875
- # Progress UI
876
- progress_bar = st.progress(0)
877
- start_time = time.time()
878
-
879
- all_categories = []
880
- chunk_size = len(extracted_texts) // divisions if divisions > 0 else len(extracted_texts)
881
-
882
- for div_idx in range(divisions):
883
- progress = div_idx / divisions
884
- progress_bar.progress(progress)
885
-
886
- elapsed = time.time() - start_time
887
- if div_idx > 0:
888
- avg_time = elapsed / div_idx
889
- eta_seconds = avg_time * (divisions - div_idx)
890
- eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
891
- else:
892
- eta_str = ""
893
-
894
- status_text.text(f"Analyzing text chunk {div_idx + 1} of {divisions} ({progress*100:.0f}%){eta_str}")
895
-
896
- start_idx = div_idx * chunk_size
897
- end_idx = start_idx + chunk_size if div_idx < divisions - 1 else len(extracted_texts)
898
- chunk_data = extracted_texts[start_idx:end_idx]
899
-
900
- if not chunk_data:
901
- continue
902
-
903
- try:
904
- # Process as text type since we extracted the text
905
- extract_kwargs = {
906
- 'input_data': chunk_data,
907
- 'api_key': actual_api_key,
908
- 'input_type': 'text', # Treat as text now
909
- 'description': description,
910
- 'user_model': model,
911
- 'model_source': model_source,
912
- 'divisions': 1,
913
- 'categories_per_chunk': categories_per_chunk,
914
- 'max_categories': int(max_categories)
915
- }
916
-
917
- chunk_result = catllm.extract(**extract_kwargs)
918
- chunk_cats = chunk_result.get('top_categories', [])
919
- all_categories.extend(chunk_cats)
920
- except Exception as e:
921
- st.warning(f"Error on chunk {div_idx + 1}: {str(e)}")
922
- continue
923
-
924
- progress_bar.progress(1.0)
925
- processing_time = time.time() - start_time
926
- status_text.text(f"Extraction complete in {processing_time:.1f}s")
927
 
928
- if all_categories:
929
- from collections import Counter
930
- cat_counts = Counter(all_categories)
931
- top_cats = [cat for cat, _ in cat_counts.most_common(int(max_categories))]
932
 
933
- st.session_state.extracted_categories = top_cats
 
934
  st.session_state.task_mode = "manual"
935
- st.success(f"Extracted {len(top_cats)} categories from {len(extracted_texts)} pages in {processing_time:.1f}s")
936
  st.rerun()
937
  else:
938
  st.error("No categories were extracted")
939
- else:
940
- # Original logic for images, PDFs in image/both mode, and text
941
- num_items = len(items_list)
942
-
943
- if input_type_selected == "image":
944
- divisions = min(3, max(1, num_items // 5))
945
- categories_per_chunk = 12
946
- else:
947
- divisions = max(1, num_items // 15)
948
- divisions = min(divisions, 5)
949
- chunk_size_calc = num_items // max(1, divisions)
950
- categories_per_chunk = min(10, chunk_size_calc - 1) if chunk_size_calc > 1 else 5
951
-
952
- # Progress UI
953
- progress_bar = st.progress(0)
954
- status_text = st.empty()
955
- start_time = time.time()
956
-
957
- all_categories = []
958
- chunk_size = len(items_list) // divisions if divisions > 0 else len(items_list)
959
-
960
- for div_idx in range(divisions):
961
- progress = div_idx / divisions
962
- progress_bar.progress(progress)
963
-
964
- elapsed = time.time() - start_time
965
- if div_idx > 0:
966
- avg_time = elapsed / div_idx
967
- eta_seconds = avg_time * (divisions - div_idx)
968
- eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
969
- else:
970
- eta_str = ""
971
-
972
- status_text.text(f"Analyzing chunk {div_idx + 1} of {divisions} ({progress*100:.0f}%){eta_str}")
973
-
974
- start_idx = div_idx * chunk_size
975
- end_idx = start_idx + chunk_size if div_idx < divisions - 1 else len(items_list)
976
- chunk_data = items_list[start_idx:end_idx]
977
-
978
- if not chunk_data:
979
- continue
980
-
981
- try:
982
- extract_kwargs = {
983
- 'input_data': chunk_data,
984
- 'api_key': actual_api_key,
985
- 'input_type': input_type_selected,
986
- 'description': description,
987
- 'user_model': model,
988
- 'model_source': model_source,
989
- 'divisions': 1,
990
- 'categories_per_chunk': categories_per_chunk,
991
- 'max_categories': int(max_categories)
992
- }
993
- if mode:
994
- extract_kwargs['mode'] = mode
995
-
996
- chunk_result = catllm.extract(**extract_kwargs)
997
- chunk_cats = chunk_result.get('top_categories', [])
998
- all_categories.extend(chunk_cats)
999
- except Exception as e:
1000
- st.warning(f"Error on chunk {div_idx + 1}: {str(e)}")
1001
- continue
1002
-
1003
- progress_bar.progress(1.0)
1004
- processing_time = time.time() - start_time
1005
- status_text.text(f"Extraction complete in {processing_time:.1f}s")
1006
-
1007
- if all_categories:
1008
- from collections import Counter
1009
- cat_counts = Counter(all_categories)
1010
- top_cats = [cat for cat, _ in cat_counts.most_common(int(max_categories))]
1011
-
1012
- st.session_state.extracted_categories = top_cats
1013
- st.session_state.task_mode = "manual"
1014
- st.success(f"Extracted {len(top_cats)} categories in {processing_time:.1f}s")
1015
- st.rerun()
1016
- else:
1017
- st.error("No categories were extracted")
1018
 
1019
  # Category inputs (shown for manual mode or after extraction)
1020
  if st.session_state.task_mode == "manual":
 
849
  st.error(f"{provider} API key not configured")
850
  else:
851
  model_source = get_model_source(model)
 
852
 
853
+ with st.spinner("Extracting categories..."):
854
+ start_time = time.time()
 
 
855
 
856
+ extract_kwargs = {
857
+ 'input_data': input_data,
858
+ 'api_key': actual_api_key,
859
+ 'input_type': input_type_selected,
860
+ 'description': description,
861
+ 'user_model': model,
862
+ 'model_source': model_source,
863
+ 'max_categories': int(max_categories)
864
+ }
865
+ if mode:
866
+ extract_kwargs['mode'] = mode
867
 
868
+ try:
869
+ extract_result = catllm.extract(**extract_kwargs)
870
+ categories = extract_result.get('top_categories', [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
871
 
872
+ processing_time = time.time() - start_time
 
 
 
873
 
874
+ if categories:
875
+ st.session_state.extracted_categories = categories
876
  st.session_state.task_mode = "manual"
877
+ st.success(f"Extracted {len(categories)} categories in {processing_time:.1f}s")
878
  st.rerun()
879
  else:
880
  st.error("No categories were extracted")
881
+ except Exception as e:
882
+ st.error(f"Error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
883
 
884
  # Category inputs (shown for manual mode or after extraction)
885
  if st.session_state.task_mode == "manual":