Spaces:
Running
Running
Commit
·
67ce2bc
1
Parent(s):
5542b9b
Simplify extraction: just call catllm.extract() instead of reimplementing logic
Browse files
app.py
CHANGED
|
@@ -849,172 +849,37 @@ with col_input:
|
|
| 849 |
st.error(f"{provider} API key not configured")
|
| 850 |
else:
|
| 851 |
model_source = get_model_source(model)
|
| 852 |
-
items_list = input_data if isinstance(input_data, list) else [input_data]
|
| 853 |
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
status_text = st.empty()
|
| 857 |
-
status_text.text("Extracting text from PDFs...")
|
| 858 |
|
| 859 |
-
|
| 860 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 861 |
|
| 862 |
-
|
| 863 |
-
|
| 864 |
-
|
| 865 |
-
# Now process as text corpus with random chunks
|
| 866 |
-
import random
|
| 867 |
-
random.shuffle(extracted_texts) # Randomize order
|
| 868 |
-
|
| 869 |
-
num_items = len(extracted_texts)
|
| 870 |
-
divisions = max(1, num_items // 15)
|
| 871 |
-
divisions = min(divisions, 5)
|
| 872 |
-
chunk_size_calc = num_items // max(1, divisions)
|
| 873 |
-
categories_per_chunk = min(10, chunk_size_calc - 1) if chunk_size_calc > 1 else 5
|
| 874 |
-
|
| 875 |
-
# Progress UI
|
| 876 |
-
progress_bar = st.progress(0)
|
| 877 |
-
start_time = time.time()
|
| 878 |
-
|
| 879 |
-
all_categories = []
|
| 880 |
-
chunk_size = len(extracted_texts) // divisions if divisions > 0 else len(extracted_texts)
|
| 881 |
-
|
| 882 |
-
for div_idx in range(divisions):
|
| 883 |
-
progress = div_idx / divisions
|
| 884 |
-
progress_bar.progress(progress)
|
| 885 |
-
|
| 886 |
-
elapsed = time.time() - start_time
|
| 887 |
-
if div_idx > 0:
|
| 888 |
-
avg_time = elapsed / div_idx
|
| 889 |
-
eta_seconds = avg_time * (divisions - div_idx)
|
| 890 |
-
eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
|
| 891 |
-
else:
|
| 892 |
-
eta_str = ""
|
| 893 |
-
|
| 894 |
-
status_text.text(f"Analyzing text chunk {div_idx + 1} of {divisions} ({progress*100:.0f}%){eta_str}")
|
| 895 |
-
|
| 896 |
-
start_idx = div_idx * chunk_size
|
| 897 |
-
end_idx = start_idx + chunk_size if div_idx < divisions - 1 else len(extracted_texts)
|
| 898 |
-
chunk_data = extracted_texts[start_idx:end_idx]
|
| 899 |
-
|
| 900 |
-
if not chunk_data:
|
| 901 |
-
continue
|
| 902 |
-
|
| 903 |
-
try:
|
| 904 |
-
# Process as text type since we extracted the text
|
| 905 |
-
extract_kwargs = {
|
| 906 |
-
'input_data': chunk_data,
|
| 907 |
-
'api_key': actual_api_key,
|
| 908 |
-
'input_type': 'text', # Treat as text now
|
| 909 |
-
'description': description,
|
| 910 |
-
'user_model': model,
|
| 911 |
-
'model_source': model_source,
|
| 912 |
-
'divisions': 1,
|
| 913 |
-
'categories_per_chunk': categories_per_chunk,
|
| 914 |
-
'max_categories': int(max_categories)
|
| 915 |
-
}
|
| 916 |
-
|
| 917 |
-
chunk_result = catllm.extract(**extract_kwargs)
|
| 918 |
-
chunk_cats = chunk_result.get('top_categories', [])
|
| 919 |
-
all_categories.extend(chunk_cats)
|
| 920 |
-
except Exception as e:
|
| 921 |
-
st.warning(f"Error on chunk {div_idx + 1}: {str(e)}")
|
| 922 |
-
continue
|
| 923 |
-
|
| 924 |
-
progress_bar.progress(1.0)
|
| 925 |
-
processing_time = time.time() - start_time
|
| 926 |
-
status_text.text(f"Extraction complete in {processing_time:.1f}s")
|
| 927 |
|
| 928 |
-
|
| 929 |
-
from collections import Counter
|
| 930 |
-
cat_counts = Counter(all_categories)
|
| 931 |
-
top_cats = [cat for cat, _ in cat_counts.most_common(int(max_categories))]
|
| 932 |
|
| 933 |
-
|
|
|
|
| 934 |
st.session_state.task_mode = "manual"
|
| 935 |
-
st.success(f"Extracted {len(
|
| 936 |
st.rerun()
|
| 937 |
else:
|
| 938 |
st.error("No categories were extracted")
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
num_items = len(items_list)
|
| 942 |
-
|
| 943 |
-
if input_type_selected == "image":
|
| 944 |
-
divisions = min(3, max(1, num_items // 5))
|
| 945 |
-
categories_per_chunk = 12
|
| 946 |
-
else:
|
| 947 |
-
divisions = max(1, num_items // 15)
|
| 948 |
-
divisions = min(divisions, 5)
|
| 949 |
-
chunk_size_calc = num_items // max(1, divisions)
|
| 950 |
-
categories_per_chunk = min(10, chunk_size_calc - 1) if chunk_size_calc > 1 else 5
|
| 951 |
-
|
| 952 |
-
# Progress UI
|
| 953 |
-
progress_bar = st.progress(0)
|
| 954 |
-
status_text = st.empty()
|
| 955 |
-
start_time = time.time()
|
| 956 |
-
|
| 957 |
-
all_categories = []
|
| 958 |
-
chunk_size = len(items_list) // divisions if divisions > 0 else len(items_list)
|
| 959 |
-
|
| 960 |
-
for div_idx in range(divisions):
|
| 961 |
-
progress = div_idx / divisions
|
| 962 |
-
progress_bar.progress(progress)
|
| 963 |
-
|
| 964 |
-
elapsed = time.time() - start_time
|
| 965 |
-
if div_idx > 0:
|
| 966 |
-
avg_time = elapsed / div_idx
|
| 967 |
-
eta_seconds = avg_time * (divisions - div_idx)
|
| 968 |
-
eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
|
| 969 |
-
else:
|
| 970 |
-
eta_str = ""
|
| 971 |
-
|
| 972 |
-
status_text.text(f"Analyzing chunk {div_idx + 1} of {divisions} ({progress*100:.0f}%){eta_str}")
|
| 973 |
-
|
| 974 |
-
start_idx = div_idx * chunk_size
|
| 975 |
-
end_idx = start_idx + chunk_size if div_idx < divisions - 1 else len(items_list)
|
| 976 |
-
chunk_data = items_list[start_idx:end_idx]
|
| 977 |
-
|
| 978 |
-
if not chunk_data:
|
| 979 |
-
continue
|
| 980 |
-
|
| 981 |
-
try:
|
| 982 |
-
extract_kwargs = {
|
| 983 |
-
'input_data': chunk_data,
|
| 984 |
-
'api_key': actual_api_key,
|
| 985 |
-
'input_type': input_type_selected,
|
| 986 |
-
'description': description,
|
| 987 |
-
'user_model': model,
|
| 988 |
-
'model_source': model_source,
|
| 989 |
-
'divisions': 1,
|
| 990 |
-
'categories_per_chunk': categories_per_chunk,
|
| 991 |
-
'max_categories': int(max_categories)
|
| 992 |
-
}
|
| 993 |
-
if mode:
|
| 994 |
-
extract_kwargs['mode'] = mode
|
| 995 |
-
|
| 996 |
-
chunk_result = catllm.extract(**extract_kwargs)
|
| 997 |
-
chunk_cats = chunk_result.get('top_categories', [])
|
| 998 |
-
all_categories.extend(chunk_cats)
|
| 999 |
-
except Exception as e:
|
| 1000 |
-
st.warning(f"Error on chunk {div_idx + 1}: {str(e)}")
|
| 1001 |
-
continue
|
| 1002 |
-
|
| 1003 |
-
progress_bar.progress(1.0)
|
| 1004 |
-
processing_time = time.time() - start_time
|
| 1005 |
-
status_text.text(f"Extraction complete in {processing_time:.1f}s")
|
| 1006 |
-
|
| 1007 |
-
if all_categories:
|
| 1008 |
-
from collections import Counter
|
| 1009 |
-
cat_counts = Counter(all_categories)
|
| 1010 |
-
top_cats = [cat for cat, _ in cat_counts.most_common(int(max_categories))]
|
| 1011 |
-
|
| 1012 |
-
st.session_state.extracted_categories = top_cats
|
| 1013 |
-
st.session_state.task_mode = "manual"
|
| 1014 |
-
st.success(f"Extracted {len(top_cats)} categories in {processing_time:.1f}s")
|
| 1015 |
-
st.rerun()
|
| 1016 |
-
else:
|
| 1017 |
-
st.error("No categories were extracted")
|
| 1018 |
|
| 1019 |
# Category inputs (shown for manual mode or after extraction)
|
| 1020 |
if st.session_state.task_mode == "manual":
|
|
|
|
| 849 |
st.error(f"{provider} API key not configured")
|
| 850 |
else:
|
| 851 |
model_source = get_model_source(model)
|
|
|
|
| 852 |
|
| 853 |
+
with st.spinner("Extracting categories..."):
|
| 854 |
+
start_time = time.time()
|
|
|
|
|
|
|
| 855 |
|
| 856 |
+
extract_kwargs = {
|
| 857 |
+
'input_data': input_data,
|
| 858 |
+
'api_key': actual_api_key,
|
| 859 |
+
'input_type': input_type_selected,
|
| 860 |
+
'description': description,
|
| 861 |
+
'user_model': model,
|
| 862 |
+
'model_source': model_source,
|
| 863 |
+
'max_categories': int(max_categories)
|
| 864 |
+
}
|
| 865 |
+
if mode:
|
| 866 |
+
extract_kwargs['mode'] = mode
|
| 867 |
|
| 868 |
+
try:
|
| 869 |
+
extract_result = catllm.extract(**extract_kwargs)
|
| 870 |
+
categories = extract_result.get('top_categories', [])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 871 |
|
| 872 |
+
processing_time = time.time() - start_time
|
|
|
|
|
|
|
|
|
|
| 873 |
|
| 874 |
+
if categories:
|
| 875 |
+
st.session_state.extracted_categories = categories
|
| 876 |
st.session_state.task_mode = "manual"
|
| 877 |
+
st.success(f"Extracted {len(categories)} categories in {processing_time:.1f}s")
|
| 878 |
st.rerun()
|
| 879 |
else:
|
| 880 |
st.error("No categories were extracted")
|
| 881 |
+
except Exception as e:
|
| 882 |
+
st.error(f"Error: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 883 |
|
| 884 |
# Category inputs (shown for manual mode or after extraction)
|
| 885 |
if st.session_state.task_mode == "manual":
|