Spaces:
Running
Running
Commit ·
13dd631
1
Parent(s): 28087b9
Simplify UI: merge Extract & Assign into single Classify flow with auto-extract option
Browse files- Remove standalone Extract Categories button
- Remove Extract & Assign button
- Add Auto-extract Categories button within categories section
- Categories can now be manually entered or auto-extracted before classification
- Simplified codebase by ~500 lines
app.py
CHANGED
|
@@ -293,9 +293,7 @@ def generate_methodology_report_pdf(categories, model, column_name, num_rows, mo
|
|
| 293 |
story = []
|
| 294 |
|
| 295 |
# Title based on task type
|
| 296 |
-
if task_type == "
|
| 297 |
-
report_title = "CatLLM Category Extraction Report"
|
| 298 |
-
elif task_type == "extract_and_assign":
|
| 299 |
report_title = "CatLLM Extraction & Classification Report"
|
| 300 |
else:
|
| 301 |
report_title = "CatLLM Classification Report"
|
|
@@ -306,12 +304,7 @@ def generate_methodology_report_pdf(categories, model, column_name, num_rows, mo
|
|
| 306 |
|
| 307 |
story.append(Paragraph("About This Report", heading_style))
|
| 308 |
|
| 309 |
-
if task_type == "
|
| 310 |
-
about_text = """This methodology report documents the category extraction process for reproducibility and transparency. \
|
| 311 |
-
CatLLM uses Large Language Models to automatically discover meaningful categories from your data. The extraction process \
|
| 312 |
-
analyzes your data in chunks, identifies recurring themes, and consolidates them into a final set of categories. \
|
| 313 |
-
This automated approach helps researchers avoid confirmation bias in category selection."""
|
| 314 |
-
elif task_type == "extract_and_assign":
|
| 315 |
about_text = """This methodology report documents the automated category extraction and classification process. \
|
| 316 |
CatLLM first discovers categories from your data using LLMs, then classifies each item into those categories. \
|
| 317 |
This two-phase approach combines exploratory category discovery with systematic classification, ensuring both \
|
|
@@ -458,27 +451,14 @@ consistent and reproducible results."""
|
|
| 458 |
|
| 459 |
# Summary section - adjust title based on task type
|
| 460 |
story.append(PageBreak())
|
| 461 |
-
if task_type == "
|
| 462 |
-
story.append(Paragraph("Extraction Summary", title_style))
|
| 463 |
-
elif task_type == "extract_and_assign":
|
| 464 |
story.append(Paragraph("Processing Summary", title_style))
|
| 465 |
else:
|
| 466 |
story.append(Paragraph("Classification Summary", title_style))
|
| 467 |
story.append(Spacer(1, 15))
|
| 468 |
|
| 469 |
# Build summary data based on task type
|
| 470 |
-
if task_type == "
|
| 471 |
-
story.append(Paragraph("Extraction Details", heading_style))
|
| 472 |
-
summary_data = [
|
| 473 |
-
["Source File", filename],
|
| 474 |
-
["Source Column/Description", column_name],
|
| 475 |
-
["Input Type", input_type],
|
| 476 |
-
["Model Used", model],
|
| 477 |
-
["Model Source", model_source],
|
| 478 |
-
["Max Categories Requested", str(max_categories or "default")],
|
| 479 |
-
["Categories Extracted", str(len(categories)) if categories else "0"],
|
| 480 |
-
]
|
| 481 |
-
else:
|
| 482 |
story.append(Paragraph("Classification Details", heading_style))
|
| 483 |
summary_data = [
|
| 484 |
["Source File", filename],
|
|
@@ -596,28 +576,7 @@ is the key and a 1 if the category is present and a 0 if not.'''
|
|
| 596 |
story.append(PageBreak())
|
| 597 |
story.append(Paragraph("Reproducibility Code", title_style))
|
| 598 |
|
| 599 |
-
if task_type == "
|
| 600 |
-
story.append(Paragraph("Use the following Python code to reproduce this category extraction:", normal_style))
|
| 601 |
-
story.append(Spacer(1, 15))
|
| 602 |
-
|
| 603 |
-
code_text = f'''import catllm
|
| 604 |
-
|
| 605 |
-
# Extract categories from your data
|
| 606 |
-
result = catllm.extract(
|
| 607 |
-
input_data="path/to/your/data", # file path, list of paths, or list of text
|
| 608 |
-
api_key="YOUR_API_KEY",
|
| 609 |
-
input_type="{input_type}",
|
| 610 |
-
description="{description or column_name}",
|
| 611 |
-
user_model="{model}",
|
| 612 |
-
model_source="{model_source}",
|
| 613 |
-
max_categories={max_categories or 12}
|
| 614 |
-
)
|
| 615 |
-
|
| 616 |
-
# View extracted categories
|
| 617 |
-
print(result["top_categories"])
|
| 618 |
-
print(result["counts_df"])'''
|
| 619 |
-
|
| 620 |
-
elif task_type == "extract_and_assign":
|
| 621 |
story.append(Paragraph("Use the following Python code to reproduce this extraction and classification:", normal_style))
|
| 622 |
story.append(Spacer(1, 15))
|
| 623 |
|
|
@@ -742,35 +701,13 @@ def load_columns(file):
|
|
| 742 |
|
| 743 |
def update_task_visibility(task):
|
| 744 |
"""Update visibility of components based on selected task."""
|
| 745 |
-
if task == "
|
| 746 |
-
return (
|
| 747 |
-
gr.update(visible=False), # categories_group
|
| 748 |
-
gr.update(visible=True), # extract_settings_group
|
| 749 |
-
gr.update(visible=True), # model_group
|
| 750 |
-
gr.update(visible=True, value="Extract Categories"), # run_btn
|
| 751 |
-
gr.update(visible=True), # extract_output_group
|
| 752 |
-
gr.update(visible=False), # classify_output_group
|
| 753 |
-
"Ready to extract categories from your data."
|
| 754 |
-
)
|
| 755 |
-
elif task == "assign":
|
| 756 |
return (
|
| 757 |
gr.update(visible=True), # categories_group
|
| 758 |
-
gr.update(visible=False), # extract_settings_group
|
| 759 |
gr.update(visible=True), # model_group
|
| 760 |
gr.update(visible=True, value="Classify Data"), # run_btn
|
| 761 |
-
gr.update(visible=False), # extract_output_group
|
| 762 |
-
gr.update(visible=True), # classify_output_group
|
| 763 |
-
"Enter categories and click Classify."
|
| 764 |
-
)
|
| 765 |
-
elif task == "extract_and_assign":
|
| 766 |
-
return (
|
| 767 |
-
gr.update(visible=False), # categories_group
|
| 768 |
-
gr.update(visible=True), # extract_settings_group
|
| 769 |
-
gr.update(visible=True), # model_group
|
| 770 |
-
gr.update(visible=True, value="Extract & Classify"), # run_btn
|
| 771 |
-
gr.update(visible=True), # extract_output_group (will show extracted cats)
|
| 772 |
gr.update(visible=True), # classify_output_group
|
| 773 |
-
"
|
| 774 |
)
|
| 775 |
else:
|
| 776 |
return (
|
|
@@ -778,76 +715,36 @@ def update_task_visibility(task):
|
|
| 778 |
gr.update(visible=False),
|
| 779 |
gr.update(visible=False),
|
| 780 |
gr.update(visible=False),
|
| 781 |
-
|
| 782 |
-
gr.update(visible=False),
|
| 783 |
-
"Select a task to continue."
|
| 784 |
)
|
| 785 |
|
| 786 |
|
| 787 |
-
def
|
| 788 |
-
|
| 789 |
-
|
| 790 |
-
|
| 791 |
-
|
| 792 |
-
|
| 793 |
-
"""Extract categories from data and
|
| 794 |
if not CATLLM_AVAILABLE:
|
| 795 |
-
|
| 796 |
-
return
|
| 797 |
|
| 798 |
actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
|
| 799 |
if not actual_api_key:
|
| 800 |
-
|
| 801 |
-
return
|
| 802 |
|
| 803 |
if model_source_input == "auto":
|
| 804 |
model_source = get_model_source(model)
|
| 805 |
else:
|
| 806 |
model_source = model_source_input
|
| 807 |
|
| 808 |
-
# Check file size for images and PDFs
|
| 809 |
-
files_to_check = None
|
| 810 |
-
if input_type == "Images":
|
| 811 |
-
files_to_check = image_folder if image_folder else image_file
|
| 812 |
-
elif input_type == "PDF Documents":
|
| 813 |
-
files_to_check = pdf_folder if pdf_folder else pdf_file
|
| 814 |
-
|
| 815 |
-
if files_to_check:
|
| 816 |
-
total_size_mb = calculate_total_file_size(files_to_check)
|
| 817 |
-
if total_size_mb > MAX_FILE_SIZE_MB:
|
| 818 |
-
# Generate the code for the user
|
| 819 |
-
if input_type == "Images":
|
| 820 |
-
code = generate_extract_code("image", image_description or "images", model, model_source, int(max_categories_val))
|
| 821 |
-
else:
|
| 822 |
-
mode_mapping = {"Image (visual documents)": "image", "Text (text-heavy)": "text", "Both (comprehensive)": "both"}
|
| 823 |
-
actual_mode = mode_mapping.get(pdf_mode, "image")
|
| 824 |
-
code = generate_extract_code("pdf", pdf_description or "document", model, model_source, int(max_categories_val), actual_mode)
|
| 825 |
-
|
| 826 |
-
warning_msg = f"""**⚠️ Large Upload Detected ({total_size_mb:.1f} MB)**
|
| 827 |
-
|
| 828 |
-
Uploads over {MAX_FILE_SIZE_MB} MB may experience performance issues or timeouts on this web app.
|
| 829 |
-
|
| 830 |
-
**Recommended:** Run the code locally using the Python package instead. See the code below, or click "See the Code" after this message.
|
| 831 |
-
|
| 832 |
-
```
|
| 833 |
-
pip install cat-llm
|
| 834 |
-
```
|
| 835 |
-
"""
|
| 836 |
-
yield None, None, code, warning_msg
|
| 837 |
-
return
|
| 838 |
-
|
| 839 |
try:
|
| 840 |
-
yield None, None, None, "Extracting categories from your data..."
|
| 841 |
-
|
| 842 |
-
start_time = time.time()
|
| 843 |
-
|
| 844 |
if input_type == "Survey Responses":
|
| 845 |
if not spreadsheet_file:
|
| 846 |
-
|
| 847 |
-
return
|
| 848 |
if not spreadsheet_column:
|
| 849 |
-
|
| 850 |
-
return
|
| 851 |
|
| 852 |
file_path = spreadsheet_file if isinstance(spreadsheet_file, str) else spreadsheet_file.name
|
| 853 |
if file_path.endswith('.csv'):
|
|
@@ -856,130 +753,100 @@ pip install cat-llm
|
|
| 856 |
df = pd.read_excel(file_path)
|
| 857 |
|
| 858 |
input_data = df[spreadsheet_column].tolist()
|
| 859 |
-
|
| 860 |
-
|
| 861 |
-
|
| 862 |
-
api_key=actual_api_key,
|
| 863 |
-
input_type="text",
|
| 864 |
-
description=spreadsheet_column,
|
| 865 |
-
user_model=model,
|
| 866 |
-
model_source=model_source,
|
| 867 |
-
max_categories=int(max_categories_val)
|
| 868 |
-
)
|
| 869 |
|
| 870 |
elif input_type == "PDF Documents":
|
| 871 |
-
# Use folder if provided, otherwise use uploaded files
|
| 872 |
if pdf_folder:
|
| 873 |
if isinstance(pdf_folder, list):
|
| 874 |
-
|
| 875 |
else:
|
| 876 |
-
|
| 877 |
elif pdf_file:
|
| 878 |
if isinstance(pdf_file, list):
|
| 879 |
-
|
| 880 |
else:
|
| 881 |
-
|
| 882 |
else:
|
| 883 |
-
|
| 884 |
-
return
|
| 885 |
-
|
| 886 |
-
mode_mapping = {
|
| 887 |
-
"Image (visual documents)": "image",
|
| 888 |
-
"Text (text-heavy)": "text",
|
| 889 |
-
"Both (comprehensive)": "both"
|
| 890 |
-
}
|
| 891 |
-
actual_mode = mode_mapping.get(pdf_mode, "image")
|
| 892 |
-
|
| 893 |
-
# Calculate sensible divisions based on input size
|
| 894 |
-
num_items = len(pdf_input) if isinstance(pdf_input, list) else 1
|
| 895 |
-
divisions = min(5, max(1, num_items // 3))
|
| 896 |
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
description=pdf_description or "document",
|
| 902 |
-
mode=actual_mode,
|
| 903 |
-
user_model=model,
|
| 904 |
-
model_source=model_source,
|
| 905 |
-
divisions=divisions,
|
| 906 |
-
max_categories=int(max_categories_val)
|
| 907 |
-
)
|
| 908 |
|
| 909 |
elif input_type == "Images":
|
| 910 |
-
# Use folder if provided, otherwise use uploaded files
|
| 911 |
if image_folder:
|
| 912 |
if isinstance(image_folder, list):
|
| 913 |
-
|
| 914 |
else:
|
| 915 |
-
|
| 916 |
elif image_file:
|
| 917 |
if isinstance(image_file, list):
|
| 918 |
-
|
| 919 |
else:
|
| 920 |
-
|
| 921 |
else:
|
| 922 |
-
|
| 923 |
-
return
|
| 924 |
-
|
| 925 |
-
# For images, use fewer divisions since each image can have multiple categories
|
| 926 |
-
num_items = len(image_input) if isinstance(image_input, list) else 1
|
| 927 |
-
# Use 1 division for small sets, max 3 for larger sets
|
| 928 |
-
divisions = min(3, max(1, num_items // 5))
|
| 929 |
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
input_type="image",
|
| 934 |
-
description=image_description or "images",
|
| 935 |
-
user_model=model,
|
| 936 |
-
model_source=model_source,
|
| 937 |
-
divisions=divisions,
|
| 938 |
-
categories_per_chunk=12, # Images often have multiple categories each
|
| 939 |
-
max_categories=int(max_categories_val)
|
| 940 |
-
)
|
| 941 |
|
| 942 |
else:
|
| 943 |
-
|
| 944 |
-
return
|
| 945 |
|
| 946 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 947 |
|
| 948 |
-
|
| 949 |
-
|
| 950 |
-
|
|
|
|
|
|
|
|
|
|
| 951 |
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 958 |
|
| 959 |
-
|
| 960 |
-
|
| 961 |
-
categories_df.to_csv(f.name, index=False)
|
| 962 |
-
csv_path = f.name
|
| 963 |
|
| 964 |
-
|
| 965 |
-
|
| 966 |
-
|
| 967 |
-
|
| 968 |
-
|
| 969 |
-
|
| 970 |
-
|
| 971 |
-
|
| 972 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 973 |
|
| 974 |
-
|
| 975 |
-
|
| 976 |
-
csv_path,
|
| 977 |
-
code,
|
| 978 |
-
f"Extracted {len(top_categories)} categories in {processing_time:.1f}s"
|
| 979 |
-
)
|
| 980 |
|
| 981 |
except Exception as e:
|
| 982 |
-
|
| 983 |
|
| 984 |
|
| 985 |
def run_classify_data(input_type, spreadsheet_file, spreadsheet_column,
|
|
@@ -1261,320 +1128,6 @@ Provide your work in JSON format where the number belonging to each category is
|
|
| 1261 |
yield None, None, None, None, None, f"**Error:** {str(e)}"
|
| 1262 |
|
| 1263 |
|
| 1264 |
-
def run_extract_and_assign(input_type, spreadsheet_file, spreadsheet_column,
|
| 1265 |
-
pdf_file, pdf_folder, pdf_description, pdf_mode,
|
| 1266 |
-
image_file, image_folder, image_description,
|
| 1267 |
-
max_categories_val,
|
| 1268 |
-
model_tier, model, model_source_input, api_key_input,
|
| 1269 |
-
progress=gr.Progress(track_tqdm=True)):
|
| 1270 |
-
"""Extract categories then classify data with them."""
|
| 1271 |
-
if not CATLLM_AVAILABLE:
|
| 1272 |
-
yield None, None, None, None, None, None, None, None, "**Error:** catllm package not available"
|
| 1273 |
-
return
|
| 1274 |
-
|
| 1275 |
-
actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
|
| 1276 |
-
if not actual_api_key:
|
| 1277 |
-
yield None, None, None, None, None, None, None, None, f"**Error:** {provider} API key not configured"
|
| 1278 |
-
return
|
| 1279 |
-
|
| 1280 |
-
if model_source_input == "auto":
|
| 1281 |
-
model_source = get_model_source(model)
|
| 1282 |
-
else:
|
| 1283 |
-
model_source = model_source_input
|
| 1284 |
-
|
| 1285 |
-
# Check file size for images and PDFs
|
| 1286 |
-
files_to_check = None
|
| 1287 |
-
if input_type == "Images":
|
| 1288 |
-
files_to_check = image_folder if image_folder else image_file
|
| 1289 |
-
elif input_type == "PDF Documents":
|
| 1290 |
-
files_to_check = pdf_folder if pdf_folder else pdf_file
|
| 1291 |
-
|
| 1292 |
-
if files_to_check:
|
| 1293 |
-
total_size_mb = calculate_total_file_size(files_to_check)
|
| 1294 |
-
if total_size_mb > MAX_FILE_SIZE_MB:
|
| 1295 |
-
# Generate the code for the user
|
| 1296 |
-
if input_type == "Images":
|
| 1297 |
-
extract_code = generate_extract_code("image", image_description or "images", model, model_source, int(max_categories_val))
|
| 1298 |
-
else:
|
| 1299 |
-
mode_mapping = {"Image (visual documents)": "image", "Text (text-heavy)": "text", "Both (comprehensive)": "both"}
|
| 1300 |
-
actual_mode = mode_mapping.get(pdf_mode, "image")
|
| 1301 |
-
extract_code = generate_extract_code("pdf", pdf_description or "document", model, model_source, int(max_categories_val), actual_mode)
|
| 1302 |
-
|
| 1303 |
-
warning_msg = f"""**⚠️ Large Upload Detected ({total_size_mb:.1f} MB)**
|
| 1304 |
-
|
| 1305 |
-
Uploads over {MAX_FILE_SIZE_MB} MB may experience performance issues or timeouts on this web app.
|
| 1306 |
-
|
| 1307 |
-
**Recommended:** Run the code locally using the Python package instead. See the code below, or click "See the Code" after this message.
|
| 1308 |
-
|
| 1309 |
-
```
|
| 1310 |
-
pip install cat-llm
|
| 1311 |
-
```
|
| 1312 |
-
"""
|
| 1313 |
-
yield None, None, extract_code, None, None, None, None, None, warning_msg
|
| 1314 |
-
return
|
| 1315 |
-
|
| 1316 |
-
try:
|
| 1317 |
-
# Phase 1: Extract categories
|
| 1318 |
-
yield None, None, None, None, None, None, None, None, "Phase 1: Extracting categories..."
|
| 1319 |
-
|
| 1320 |
-
start_time = time.time()
|
| 1321 |
-
|
| 1322 |
-
if input_type == "Survey Responses":
|
| 1323 |
-
if not spreadsheet_file:
|
| 1324 |
-
yield None, None, None, None, None, None, None, None, "**Error:** Please upload a CSV/Excel file"
|
| 1325 |
-
return
|
| 1326 |
-
if not spreadsheet_column:
|
| 1327 |
-
yield None, None, None, None, None, None, None, None, "**Error:** Please select a column"
|
| 1328 |
-
return
|
| 1329 |
-
|
| 1330 |
-
file_path = spreadsheet_file if isinstance(spreadsheet_file, str) else spreadsheet_file.name
|
| 1331 |
-
if file_path.endswith('.csv'):
|
| 1332 |
-
df = pd.read_csv(file_path)
|
| 1333 |
-
else:
|
| 1334 |
-
df = pd.read_excel(file_path)
|
| 1335 |
-
|
| 1336 |
-
input_data = df[spreadsheet_column].tolist()
|
| 1337 |
-
original_filename = file_path.split("/")[-1]
|
| 1338 |
-
column_name = spreadsheet_column
|
| 1339 |
-
input_type_param = "text"
|
| 1340 |
-
description = spreadsheet_column
|
| 1341 |
-
mode_param = None
|
| 1342 |
-
|
| 1343 |
-
elif input_type == "PDF Documents":
|
| 1344 |
-
# Use folder if provided, otherwise use uploaded files
|
| 1345 |
-
if pdf_folder:
|
| 1346 |
-
if isinstance(pdf_folder, list):
|
| 1347 |
-
input_data = [f if isinstance(f, str) else f.name for f in pdf_folder if str(f.name if hasattr(f, 'name') else f).lower().endswith('.pdf')]
|
| 1348 |
-
original_filename = "pdf_folder"
|
| 1349 |
-
else:
|
| 1350 |
-
input_data = pdf_folder if isinstance(pdf_folder, str) else pdf_folder.name
|
| 1351 |
-
original_filename = input_data.split("/")[-1]
|
| 1352 |
-
elif pdf_file:
|
| 1353 |
-
if isinstance(pdf_file, list):
|
| 1354 |
-
input_data = [f if isinstance(f, str) else f.name for f in pdf_file]
|
| 1355 |
-
original_filename = "multiple_pdfs"
|
| 1356 |
-
else:
|
| 1357 |
-
input_data = pdf_file if isinstance(pdf_file, str) else pdf_file.name
|
| 1358 |
-
original_filename = input_data.split("/")[-1]
|
| 1359 |
-
else:
|
| 1360 |
-
yield None, None, None, None, None, None, None, None, "**Error:** Please upload PDF file(s) or a folder"
|
| 1361 |
-
return
|
| 1362 |
-
|
| 1363 |
-
column_name = "PDF Pages"
|
| 1364 |
-
input_type_param = "pdf"
|
| 1365 |
-
description = pdf_description or "document"
|
| 1366 |
-
|
| 1367 |
-
mode_mapping = {
|
| 1368 |
-
"Image (visual documents)": "image",
|
| 1369 |
-
"Text (text-heavy)": "text",
|
| 1370 |
-
"Both (comprehensive)": "both"
|
| 1371 |
-
}
|
| 1372 |
-
mode_param = mode_mapping.get(pdf_mode, "image")
|
| 1373 |
-
|
| 1374 |
-
elif input_type == "Images":
|
| 1375 |
-
# Use folder if provided, otherwise use uploaded files
|
| 1376 |
-
if image_folder:
|
| 1377 |
-
if isinstance(image_folder, list):
|
| 1378 |
-
input_data = [f if isinstance(f, str) else f.name for f in image_folder]
|
| 1379 |
-
original_filename = "image_folder"
|
| 1380 |
-
else:
|
| 1381 |
-
input_data = image_folder if isinstance(image_folder, str) else image_folder.name
|
| 1382 |
-
original_filename = input_data.split("/")[-1]
|
| 1383 |
-
elif image_file:
|
| 1384 |
-
if isinstance(image_file, list):
|
| 1385 |
-
input_data = [f if isinstance(f, str) else f.name for f in image_file]
|
| 1386 |
-
original_filename = "multiple_images"
|
| 1387 |
-
else:
|
| 1388 |
-
input_data = image_file if isinstance(image_file, str) else image_file.name
|
| 1389 |
-
original_filename = input_data.split("/")[-1]
|
| 1390 |
-
else:
|
| 1391 |
-
yield None, None, None, None, None, None, None, None, "**Error:** Please upload image file(s) or a folder"
|
| 1392 |
-
return
|
| 1393 |
-
|
| 1394 |
-
column_name = "Image Files"
|
| 1395 |
-
input_type_param = "image"
|
| 1396 |
-
description = image_description or "images"
|
| 1397 |
-
mode_param = None
|
| 1398 |
-
|
| 1399 |
-
else:
|
| 1400 |
-
yield None, None, None, None, None, None, None, None, f"**Error:** Unknown input type: {input_type}"
|
| 1401 |
-
return
|
| 1402 |
-
|
| 1403 |
-
# Calculate sensible divisions based on input size and type
|
| 1404 |
-
if isinstance(input_data, list):
|
| 1405 |
-
num_items = len(input_data)
|
| 1406 |
-
else:
|
| 1407 |
-
num_items = 1
|
| 1408 |
-
|
| 1409 |
-
# Images can have multiple categories per item, so use fewer divisions
|
| 1410 |
-
if input_type_param == "image":
|
| 1411 |
-
divisions = min(3, max(1, num_items // 5))
|
| 1412 |
-
categories_per_chunk = 12
|
| 1413 |
-
else:
|
| 1414 |
-
divisions = min(5, max(1, num_items // 3))
|
| 1415 |
-
categories_per_chunk = 10
|
| 1416 |
-
|
| 1417 |
-
# Extract categories
|
| 1418 |
-
extract_kwargs = {
|
| 1419 |
-
'input_data': input_data,
|
| 1420 |
-
'api_key': actual_api_key,
|
| 1421 |
-
'input_type': input_type_param,
|
| 1422 |
-
'description': description,
|
| 1423 |
-
'user_model': model,
|
| 1424 |
-
'model_source': model_source,
|
| 1425 |
-
'divisions': divisions,
|
| 1426 |
-
'categories_per_chunk': categories_per_chunk,
|
| 1427 |
-
'max_categories': int(max_categories_val)
|
| 1428 |
-
}
|
| 1429 |
-
if mode_param:
|
| 1430 |
-
extract_kwargs['mode'] = mode_param
|
| 1431 |
-
|
| 1432 |
-
extract_result = catllm.extract(**extract_kwargs)
|
| 1433 |
-
categories = extract_result.get('top_categories', [])
|
| 1434 |
-
categories_df = extract_result.get('counts_df', pd.DataFrame())
|
| 1435 |
-
|
| 1436 |
-
if not categories:
|
| 1437 |
-
yield None, None, None, None, None, None, None, None, "**Error:** No categories were extracted"
|
| 1438 |
-
return
|
| 1439 |
-
|
| 1440 |
-
extract_time = time.time() - start_time
|
| 1441 |
-
|
| 1442 |
-
# Show extracted categories
|
| 1443 |
-
if categories_df.empty and categories:
|
| 1444 |
-
categories_df = pd.DataFrame({
|
| 1445 |
-
'Category': categories,
|
| 1446 |
-
'Count': ['-'] * len(categories)
|
| 1447 |
-
})
|
| 1448 |
-
|
| 1449 |
-
with tempfile.NamedTemporaryFile(mode='w', suffix='_extracted_categories.csv', delete=False) as f:
|
| 1450 |
-
categories_df.to_csv(f.name, index=False)
|
| 1451 |
-
extract_csv_path = f.name
|
| 1452 |
-
|
| 1453 |
-
# Generate extract code
|
| 1454 |
-
extract_code = generate_extract_code(input_type_param, description, model, model_source, int(max_categories_val), mode_param)
|
| 1455 |
-
|
| 1456 |
-
yield (
|
| 1457 |
-
gr.update(value=categories_df, visible=True),
|
| 1458 |
-
extract_csv_path,
|
| 1459 |
-
extract_code,
|
| 1460 |
-
None, None, None, None, None,
|
| 1461 |
-
f"Extracted {len(categories)} categories in {extract_time:.1f}s. Now classifying..."
|
| 1462 |
-
)
|
| 1463 |
-
|
| 1464 |
-
# Phase 2: Classify with extracted categories
|
| 1465 |
-
classify_start = time.time()
|
| 1466 |
-
|
| 1467 |
-
classify_kwargs = {
|
| 1468 |
-
'input_data': input_data,
|
| 1469 |
-
'categories': categories,
|
| 1470 |
-
'api_key': actual_api_key,
|
| 1471 |
-
'input_type': input_type_param,
|
| 1472 |
-
'description': description,
|
| 1473 |
-
'user_model': model,
|
| 1474 |
-
'model_source': model_source
|
| 1475 |
-
}
|
| 1476 |
-
if mode_param:
|
| 1477 |
-
classify_kwargs['mode'] = mode_param
|
| 1478 |
-
|
| 1479 |
-
result = catllm.classify(**classify_kwargs)
|
| 1480 |
-
|
| 1481 |
-
classify_time = time.time() - classify_start
|
| 1482 |
-
total_time = time.time() - start_time
|
| 1483 |
-
num_items = len(result)
|
| 1484 |
-
|
| 1485 |
-
# Save CSV
|
| 1486 |
-
with tempfile.NamedTemporaryFile(mode='w', suffix='_classified.csv', delete=False) as f:
|
| 1487 |
-
result.to_csv(f.name, index=False)
|
| 1488 |
-
classify_csv_path = f.name
|
| 1489 |
-
|
| 1490 |
-
# Calculate success rate
|
| 1491 |
-
if 'processing_status' in result.columns:
|
| 1492 |
-
success_count = (result['processing_status'] == 'success').sum()
|
| 1493 |
-
success_rate = (success_count / len(result)) * 100
|
| 1494 |
-
else:
|
| 1495 |
-
success_rate = 100.0
|
| 1496 |
-
|
| 1497 |
-
# Get version info
|
| 1498 |
-
try:
|
| 1499 |
-
catllm_version = catllm.__version__
|
| 1500 |
-
except AttributeError:
|
| 1501 |
-
catllm_version = "unknown"
|
| 1502 |
-
python_version = sys.version.split()[0]
|
| 1503 |
-
|
| 1504 |
-
# Generate methodology report
|
| 1505 |
-
prompt_template = '''Categorize this survey response "{response}" into the following categories that apply:
|
| 1506 |
-
{categories}
|
| 1507 |
-
|
| 1508 |
-
Let's think step by step:
|
| 1509 |
-
1. First, identify the main themes mentioned in the response
|
| 1510 |
-
2. Then, match each theme to the relevant categories
|
| 1511 |
-
3. Finally, assign 1 to matching categories and 0 to non-matching categories
|
| 1512 |
-
|
| 1513 |
-
Provide your work in JSON format where the number belonging to each category is the key and a 1 if the category is present and a 0 if it is not present as key values.'''
|
| 1514 |
-
|
| 1515 |
-
report_pdf_path = generate_methodology_report_pdf(
|
| 1516 |
-
categories=categories,
|
| 1517 |
-
model=model,
|
| 1518 |
-
column_name=column_name,
|
| 1519 |
-
num_rows=num_items,
|
| 1520 |
-
model_source=model_source,
|
| 1521 |
-
filename=original_filename,
|
| 1522 |
-
success_rate=success_rate,
|
| 1523 |
-
result_df=result,
|
| 1524 |
-
processing_time=total_time,
|
| 1525 |
-
prompt_template=prompt_template,
|
| 1526 |
-
data_quality={'null_count': 0, 'avg_length': 0, 'min_length': 0, 'max_length': 0, 'error_count': 0},
|
| 1527 |
-
catllm_version=catllm_version,
|
| 1528 |
-
python_version=python_version,
|
| 1529 |
-
task_type="extract_and_assign",
|
| 1530 |
-
max_categories=int(max_categories_val),
|
| 1531 |
-
input_type=input_type_param,
|
| 1532 |
-
description=description
|
| 1533 |
-
)
|
| 1534 |
-
|
| 1535 |
-
# Create distribution plot
|
| 1536 |
-
dist_data = []
|
| 1537 |
-
total_rows = len(result)
|
| 1538 |
-
for i, cat in enumerate(categories, 1):
|
| 1539 |
-
col_name = f"category_{i}"
|
| 1540 |
-
if col_name in result.columns:
|
| 1541 |
-
count = int(result[col_name].sum())
|
| 1542 |
-
pct = (count / total_rows) * 100 if total_rows > 0 else 0
|
| 1543 |
-
dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
|
| 1544 |
-
|
| 1545 |
-
fig, ax = plt.subplots(figsize=(10, max(4, len(dist_data) * 0.8)))
|
| 1546 |
-
categories_list = [d["Category"] for d in dist_data][::-1]
|
| 1547 |
-
percentages = [d["Percentage"] for d in dist_data][::-1]
|
| 1548 |
-
|
| 1549 |
-
bars = ax.barh(categories_list, percentages, color='#2563eb')
|
| 1550 |
-
ax.set_xlim(0, 100)
|
| 1551 |
-
ax.set_xlabel('Percentage (%)', fontsize=11)
|
| 1552 |
-
ax.set_title('Category Distribution (%)', fontsize=14, fontweight='bold')
|
| 1553 |
-
|
| 1554 |
-
for bar, pct in zip(bars, percentages):
|
| 1555 |
-
ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
|
| 1556 |
-
f'{pct:.1f}%', va='center', fontsize=10)
|
| 1557 |
-
|
| 1558 |
-
plt.tight_layout()
|
| 1559 |
-
|
| 1560 |
-
# Generate classify code
|
| 1561 |
-
classify_code = generate_classify_code(input_type_param, description, categories, model, model_source, mode_param)
|
| 1562 |
-
|
| 1563 |
-
yield (
|
| 1564 |
-
gr.update(value=categories_df, visible=True),
|
| 1565 |
-
extract_csv_path,
|
| 1566 |
-
extract_code,
|
| 1567 |
-
gr.update(value=fig, visible=True),
|
| 1568 |
-
gr.update(value=result, visible=True),
|
| 1569 |
-
[classify_csv_path, report_pdf_path],
|
| 1570 |
-
classify_code,
|
| 1571 |
-
None,
|
| 1572 |
-
f"Extracted {len(categories)} categories and classified {num_items} items in {total_time:.1f}s"
|
| 1573 |
-
)
|
| 1574 |
-
|
| 1575 |
-
except Exception as e:
|
| 1576 |
-
yield None, None, None, None, None, None, None, None, f"**Error:** {str(e)}"
|
| 1577 |
-
|
| 1578 |
|
| 1579 |
def add_category_field(current_count):
|
| 1580 |
new_count = min(current_count + 1, MAX_CATEGORIES)
|
|
@@ -1612,9 +1165,10 @@ def reset_all():
|
|
| 1612 |
updates.extend([
|
| 1613 |
gr.update(visible=True), # add_category_btn
|
| 1614 |
INITIAL_CATEGORIES, # category_count
|
| 1615 |
-
gr.update(visible=False), #
|
| 1616 |
-
gr.update(visible=False), # extract_settings_group
|
| 1617 |
12, # max_categories (reset to default)
|
|
|
|
|
|
|
| 1618 |
gr.update(visible=False), # model_group
|
| 1619 |
gr.update(visible=False, value="Run"), # run_btn
|
| 1620 |
"Free Models", # model_tier
|
|
@@ -1624,10 +1178,6 @@ def reset_all():
|
|
| 1624 |
gr.update(visible=False), # api_key
|
| 1625 |
"**Free tier** - no API key required!", # api_key_status
|
| 1626 |
"Ready. Upload data and select a task.", # status
|
| 1627 |
-
gr.update(visible=False), # extract_output_group
|
| 1628 |
-
gr.update(value=None, visible=False), # extracted_categories
|
| 1629 |
-
None, # extract_download
|
| 1630 |
-
"# Code will be generated after extraction", # extract_code_display
|
| 1631 |
gr.update(visible=False), # classify_output_group
|
| 1632 |
gr.update(value=None, visible=False), # distribution_plot
|
| 1633 |
gr.update(value=None, visible=False), # results
|
|
@@ -1931,16 +1481,14 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
|
|
| 1931 |
info="Helps the LLM understand context"
|
| 1932 |
)
|
| 1933 |
|
| 1934 |
-
# Task selection
|
| 1935 |
gr.Markdown("### What would you like to do?")
|
| 1936 |
-
|
| 1937 |
-
extract_btn = gr.Button("Extract Categories", variant="secondary", elem_classes="task-btn")
|
| 1938 |
-
assign_btn = gr.Button("Assign Categories", variant="secondary", elem_classes="task-btn")
|
| 1939 |
-
extract_assign_btn = gr.Button("Extract & Assign", variant="secondary", elem_classes="task-btn")
|
| 1940 |
|
| 1941 |
# Categories group (only visible for Assign task)
|
| 1942 |
with gr.Group(visible=False) as categories_group:
|
| 1943 |
gr.Markdown("### Categories")
|
|
|
|
| 1944 |
category_inputs = []
|
| 1945 |
placeholder_examples = [
|
| 1946 |
"e.g., Positive sentiment",
|
|
@@ -1959,19 +1507,20 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
|
|
| 1959 |
visible=visible
|
| 1960 |
)
|
| 1961 |
category_inputs.append(cat_input)
|
| 1962 |
-
|
| 1963 |
-
|
| 1964 |
-
|
| 1965 |
-
|
| 1966 |
-
|
| 1967 |
-
|
| 1968 |
-
|
| 1969 |
-
|
| 1970 |
-
|
| 1971 |
-
|
| 1972 |
-
|
| 1973 |
-
|
| 1974 |
-
|
|
|
|
| 1975 |
|
| 1976 |
# Model selection group
|
| 1977 |
with gr.Group(visible=False) as model_group:
|
|
@@ -2009,23 +1558,6 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
|
|
| 2009 |
with gr.Column():
|
| 2010 |
status = gr.Markdown("Ready. Upload data and select a task.")
|
| 2011 |
|
| 2012 |
-
# Extract output group
|
| 2013 |
-
with gr.Group(visible=False) as extract_output_group:
|
| 2014 |
-
gr.Markdown("### Extracted Categories")
|
| 2015 |
-
extracted_categories = gr.DataFrame(
|
| 2016 |
-
label="Categories",
|
| 2017 |
-
visible=False,
|
| 2018 |
-
wrap=True
|
| 2019 |
-
)
|
| 2020 |
-
extract_download = gr.File(label="Download Categories (CSV)")
|
| 2021 |
-
with gr.Accordion("See the Code", open=False):
|
| 2022 |
-
extract_code_display = gr.Code(
|
| 2023 |
-
label="Python Code",
|
| 2024 |
-
language="python",
|
| 2025 |
-
value="# Code will be generated after extraction",
|
| 2026 |
-
interactive=False
|
| 2027 |
-
)
|
| 2028 |
-
|
| 2029 |
# Classify output group
|
| 2030 |
with gr.Group(visible=False) as classify_output_group:
|
| 2031 |
gr.Markdown("### Classification Results")
|
|
@@ -2122,65 +1654,45 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
|
|
| 2122 |
outputs=category_inputs + [add_category_btn, category_count]
|
| 2123 |
)
|
| 2124 |
|
| 2125 |
-
#
|
| 2126 |
-
def
|
| 2127 |
-
return (
|
| 2128 |
-
|
| 2129 |
-
def select_assign():
|
| 2130 |
-
return ("assign",) + update_task_visibility("assign")
|
| 2131 |
-
|
| 2132 |
-
def select_extract_assign():
|
| 2133 |
-
return ("extract_and_assign",) + update_task_visibility("extract_and_assign")
|
| 2134 |
|
| 2135 |
-
|
| 2136 |
-
fn=
|
| 2137 |
inputs=[],
|
| 2138 |
-
outputs=[
|
| 2139 |
)
|
| 2140 |
|
| 2141 |
-
|
| 2142 |
-
|
| 2143 |
-
|
| 2144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2145 |
)
|
| 2146 |
|
| 2147 |
-
|
| 2148 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2149 |
inputs=[],
|
| 2150 |
-
outputs=[task_mode, categories_group,
|
| 2151 |
)
|
| 2152 |
|
| 2153 |
# Main run button handler - dispatches based on task_mode
|
| 2154 |
def dispatch_run(task, input_type, spreadsheet_file, spreadsheet_column,
|
| 2155 |
pdf_file, pdf_folder_val, pdf_description, pdf_mode,
|
| 2156 |
image_file, image_folder_val, image_description,
|
| 2157 |
-
max_categories_val,
|
| 2158 |
cat1, cat2, cat3, cat4, cat5, cat6, cat7, cat8, cat9, cat10,
|
| 2159 |
model_tier, model, model_source, api_key,
|
| 2160 |
progress=gr.Progress(track_tqdm=True)):
|
| 2161 |
-
"""
|
| 2162 |
-
if task == "
|
| 2163 |
-
# run_extract_categories yields: (categories_df, csv_path, code, status)
|
| 2164 |
-
for update in run_extract_categories(
|
| 2165 |
-
input_type, spreadsheet_file, spreadsheet_column,
|
| 2166 |
-
pdf_file, pdf_folder_val, pdf_description, pdf_mode,
|
| 2167 |
-
image_file, image_folder_val, image_description,
|
| 2168 |
-
max_categories_val,
|
| 2169 |
-
model_tier, model, model_source, api_key,
|
| 2170 |
-
progress
|
| 2171 |
-
):
|
| 2172 |
-
# Map extract outputs to full output list
|
| 2173 |
-
yield (
|
| 2174 |
-
update[0], # extracted_categories
|
| 2175 |
-
update[1], # extract_download
|
| 2176 |
-
update[2], # extract_code_display
|
| 2177 |
-
None, # distribution_plot
|
| 2178 |
-
None, # results
|
| 2179 |
-
None, # download_file
|
| 2180 |
-
None, # classify_code_display
|
| 2181 |
-
update[3] # status
|
| 2182 |
-
)
|
| 2183 |
-
elif task == "assign":
|
| 2184 |
# run_classify_data yields: (plot, df, files, code, unused, status)
|
| 2185 |
for update in run_classify_data(
|
| 2186 |
input_type, spreadsheet_file, spreadsheet_column,
|
|
@@ -2190,47 +1702,22 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
|
|
| 2190 |
model_tier, model, model_source, api_key,
|
| 2191 |
progress
|
| 2192 |
):
|
| 2193 |
-
# Map classify outputs to full output list
|
| 2194 |
yield (
|
| 2195 |
-
None, # extracted_categories
|
| 2196 |
-
None, # extract_download
|
| 2197 |
-
None, # extract_code_display
|
| 2198 |
update[0], # distribution_plot
|
| 2199 |
update[1], # results
|
| 2200 |
update[2], # download_file
|
| 2201 |
update[3], # classify_code_display
|
| 2202 |
update[5] # status
|
| 2203 |
)
|
| 2204 |
-
elif task == "extract_and_assign":
|
| 2205 |
-
# run_extract_and_assign yields: (categories_df, extract_csv, extract_code, plot, df, files, classify_code, unused, status)
|
| 2206 |
-
for update in run_extract_and_assign(
|
| 2207 |
-
input_type, spreadsheet_file, spreadsheet_column,
|
| 2208 |
-
pdf_file, pdf_folder_val, pdf_description, pdf_mode,
|
| 2209 |
-
image_file, image_folder_val, image_description,
|
| 2210 |
-
max_categories_val,
|
| 2211 |
-
model_tier, model, model_source, api_key,
|
| 2212 |
-
progress
|
| 2213 |
-
):
|
| 2214 |
-
yield (
|
| 2215 |
-
update[0], # extracted_categories
|
| 2216 |
-
update[1], # extract_download
|
| 2217 |
-
update[2], # extract_code_display
|
| 2218 |
-
update[3], # distribution_plot
|
| 2219 |
-
update[4], # results
|
| 2220 |
-
update[5], # download_file
|
| 2221 |
-
update[6], # classify_code_display
|
| 2222 |
-
update[8] # status
|
| 2223 |
-
)
|
| 2224 |
else:
|
| 2225 |
-
yield (None, None, None, None,
|
| 2226 |
|
| 2227 |
run_btn.click(
|
| 2228 |
fn=dispatch_run,
|
| 2229 |
inputs=[task_mode, input_type, spreadsheet_file, spreadsheet_column,
|
| 2230 |
pdf_file, pdf_folder, pdf_description, pdf_mode,
|
| 2231 |
-
image_file, image_folder, image_description,
|
| 2232 |
-
|
| 2233 |
-
outputs=[extracted_categories, extract_download, extract_code_display, distribution_plot, results, download_file, classify_code_display, status]
|
| 2234 |
)
|
| 2235 |
|
| 2236 |
reset_btn.click(
|
|
@@ -2244,10 +1731,10 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
|
|
| 2244 |
task_mode
|
| 2245 |
] + category_inputs + [
|
| 2246 |
add_category_btn, category_count,
|
| 2247 |
-
|
|
|
|
| 2248 |
model_tier, model, model_source, api_key, api_key, api_key_status,
|
| 2249 |
status,
|
| 2250 |
-
extract_output_group, extracted_categories, extract_download, extract_code_display,
|
| 2251 |
classify_output_group, distribution_plot, results, download_file, classify_code_display
|
| 2252 |
]
|
| 2253 |
)
|
|
|
|
| 293 |
story = []
|
| 294 |
|
| 295 |
# Title based on task type
|
| 296 |
+
if task_type == "extract_and_assign":
|
|
|
|
|
|
|
| 297 |
report_title = "CatLLM Extraction & Classification Report"
|
| 298 |
else:
|
| 299 |
report_title = "CatLLM Classification Report"
|
|
|
|
| 304 |
|
| 305 |
story.append(Paragraph("About This Report", heading_style))
|
| 306 |
|
| 307 |
+
if task_type == "extract_and_assign":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
about_text = """This methodology report documents the automated category extraction and classification process. \
|
| 309 |
CatLLM first discovers categories from your data using LLMs, then classifies each item into those categories. \
|
| 310 |
This two-phase approach combines exploratory category discovery with systematic classification, ensuring both \
|
|
|
|
| 451 |
|
| 452 |
# Summary section - adjust title based on task type
|
| 453 |
story.append(PageBreak())
|
| 454 |
+
if task_type == "extract_and_assign":
|
|
|
|
|
|
|
| 455 |
story.append(Paragraph("Processing Summary", title_style))
|
| 456 |
else:
|
| 457 |
story.append(Paragraph("Classification Summary", title_style))
|
| 458 |
story.append(Spacer(1, 15))
|
| 459 |
|
| 460 |
# Build summary data based on task type
|
| 461 |
+
if task_type == "assign":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 462 |
story.append(Paragraph("Classification Details", heading_style))
|
| 463 |
summary_data = [
|
| 464 |
["Source File", filename],
|
|
|
|
| 576 |
story.append(PageBreak())
|
| 577 |
story.append(Paragraph("Reproducibility Code", title_style))
|
| 578 |
|
| 579 |
+
if task_type == "extract_and_assign":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
story.append(Paragraph("Use the following Python code to reproduce this extraction and classification:", normal_style))
|
| 581 |
story.append(Spacer(1, 15))
|
| 582 |
|
|
|
|
| 701 |
|
| 702 |
def update_task_visibility(task):
|
| 703 |
"""Update visibility of components based on selected task."""
|
| 704 |
+
if task == "assign":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 705 |
return (
|
| 706 |
gr.update(visible=True), # categories_group
|
|
|
|
| 707 |
gr.update(visible=True), # model_group
|
| 708 |
gr.update(visible=True, value="Classify Data"), # run_btn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 709 |
gr.update(visible=True), # classify_output_group
|
| 710 |
+
"Enter categories (or auto-extract them) and click Classify."
|
| 711 |
)
|
| 712 |
else:
|
| 713 |
return (
|
|
|
|
| 715 |
gr.update(visible=False),
|
| 716 |
gr.update(visible=False),
|
| 717 |
gr.update(visible=False),
|
| 718 |
+
"Click 'Classify Data' to continue."
|
|
|
|
|
|
|
| 719 |
)
|
| 720 |
|
| 721 |
|
| 722 |
+
def run_auto_extract(input_type, spreadsheet_file, spreadsheet_column,
|
| 723 |
+
pdf_file, pdf_folder, pdf_description, pdf_mode,
|
| 724 |
+
image_file, image_folder, image_description,
|
| 725 |
+
max_categories_val,
|
| 726 |
+
model_tier, model, model_source_input, api_key_input,
|
| 727 |
+
progress=gr.Progress(track_tqdm=True)):
|
| 728 |
+
"""Extract categories from data and fill the category textboxes."""
|
| 729 |
if not CATLLM_AVAILABLE:
|
| 730 |
+
# Return empty updates for all category inputs + status
|
| 731 |
+
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** catllm package not available"]
|
| 732 |
|
| 733 |
actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
|
| 734 |
if not actual_api_key:
|
| 735 |
+
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, f"**Error:** {provider} API key not configured"]
|
|
|
|
| 736 |
|
| 737 |
if model_source_input == "auto":
|
| 738 |
model_source = get_model_source(model)
|
| 739 |
else:
|
| 740 |
model_source = model_source_input
|
| 741 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 742 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 743 |
if input_type == "Survey Responses":
|
| 744 |
if not spreadsheet_file:
|
| 745 |
+
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** Please upload a CSV/Excel file first"]
|
|
|
|
| 746 |
if not spreadsheet_column:
|
| 747 |
+
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** Please select a column first"]
|
|
|
|
| 748 |
|
| 749 |
file_path = spreadsheet_file if isinstance(spreadsheet_file, str) else spreadsheet_file.name
|
| 750 |
if file_path.endswith('.csv'):
|
|
|
|
| 753 |
df = pd.read_excel(file_path)
|
| 754 |
|
| 755 |
input_data = df[spreadsheet_column].tolist()
|
| 756 |
+
description = spreadsheet_column
|
| 757 |
+
input_type_param = "text"
|
| 758 |
+
mode_param = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 759 |
|
| 760 |
elif input_type == "PDF Documents":
|
|
|
|
| 761 |
if pdf_folder:
|
| 762 |
if isinstance(pdf_folder, list):
|
| 763 |
+
input_data = [f if isinstance(f, str) else f.name for f in pdf_folder if str(f.name if hasattr(f, 'name') else f).lower().endswith('.pdf')]
|
| 764 |
else:
|
| 765 |
+
input_data = pdf_folder if isinstance(pdf_folder, str) else pdf_folder.name
|
| 766 |
elif pdf_file:
|
| 767 |
if isinstance(pdf_file, list):
|
| 768 |
+
input_data = [f if isinstance(f, str) else f.name for f in pdf_file]
|
| 769 |
else:
|
| 770 |
+
input_data = pdf_file if isinstance(pdf_file, str) else pdf_file.name
|
| 771 |
else:
|
| 772 |
+
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** Please upload PDF file(s) first"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 773 |
|
| 774 |
+
description = pdf_description or "document"
|
| 775 |
+
input_type_param = "pdf"
|
| 776 |
+
mode_mapping = {"Image (visual documents)": "image", "Text (text-heavy)": "text", "Both (comprehensive)": "both"}
|
| 777 |
+
mode_param = mode_mapping.get(pdf_mode, "image")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 778 |
|
| 779 |
elif input_type == "Images":
|
|
|
|
| 780 |
if image_folder:
|
| 781 |
if isinstance(image_folder, list):
|
| 782 |
+
input_data = [f if isinstance(f, str) else f.name for f in image_folder]
|
| 783 |
else:
|
| 784 |
+
input_data = image_folder if isinstance(image_folder, str) else image_folder.name
|
| 785 |
elif image_file:
|
| 786 |
if isinstance(image_file, list):
|
| 787 |
+
input_data = [f if isinstance(f, str) else f.name for f in image_file]
|
| 788 |
else:
|
| 789 |
+
input_data = image_file if isinstance(image_file, str) else image_file.name
|
| 790 |
else:
|
| 791 |
+
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** Please upload image file(s) first"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 792 |
|
| 793 |
+
description = image_description or "images"
|
| 794 |
+
input_type_param = "image"
|
| 795 |
+
mode_param = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
|
| 797 |
else:
|
| 798 |
+
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, f"**Error:** Unknown input type: {input_type}"]
|
|
|
|
| 799 |
|
| 800 |
+
# Calculate divisions based on input size
|
| 801 |
+
if isinstance(input_data, list):
|
| 802 |
+
num_items = len(input_data)
|
| 803 |
+
else:
|
| 804 |
+
num_items = 1
|
| 805 |
|
| 806 |
+
if input_type_param == "image":
|
| 807 |
+
divisions = min(3, max(1, num_items // 5))
|
| 808 |
+
categories_per_chunk = 12
|
| 809 |
+
else:
|
| 810 |
+
divisions = min(5, max(1, num_items // 3))
|
| 811 |
+
categories_per_chunk = 10
|
| 812 |
|
| 813 |
+
# Extract categories
|
| 814 |
+
extract_kwargs = {
|
| 815 |
+
'input_data': input_data,
|
| 816 |
+
'api_key': actual_api_key,
|
| 817 |
+
'input_type': input_type_param,
|
| 818 |
+
'description': description,
|
| 819 |
+
'user_model': model,
|
| 820 |
+
'model_source': model_source,
|
| 821 |
+
'divisions': divisions,
|
| 822 |
+
'categories_per_chunk': categories_per_chunk,
|
| 823 |
+
'max_categories': int(max_categories_val)
|
| 824 |
+
}
|
| 825 |
+
if mode_param:
|
| 826 |
+
extract_kwargs['mode'] = mode_param
|
| 827 |
|
| 828 |
+
extract_result = catllm.extract(**extract_kwargs)
|
| 829 |
+
categories = extract_result.get('top_categories', [])
|
|
|
|
|
|
|
| 830 |
|
| 831 |
+
if not categories:
|
| 832 |
+
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** No categories were extracted"]
|
| 833 |
+
|
| 834 |
+
# Fill the category textboxes
|
| 835 |
+
updates = []
|
| 836 |
+
num_categories = min(len(categories), MAX_CATEGORIES)
|
| 837 |
+
for i in range(MAX_CATEGORIES):
|
| 838 |
+
if i < num_categories:
|
| 839 |
+
updates.append(gr.update(value=categories[i], visible=True))
|
| 840 |
+
elif i < INITIAL_CATEGORIES:
|
| 841 |
+
updates.append(gr.update(value="", visible=True))
|
| 842 |
+
else:
|
| 843 |
+
updates.append(gr.update(value="", visible=False))
|
| 844 |
|
| 845 |
+
# Return updates for category inputs + new category count + status
|
| 846 |
+
return updates + [num_categories, f"Extracted {len(categories)} categories. Review and edit as needed, then click 'Classify Data'."]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
|
| 848 |
except Exception as e:
|
| 849 |
+
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, f"**Error:** {str(e)}"]
|
| 850 |
|
| 851 |
|
| 852 |
def run_classify_data(input_type, spreadsheet_file, spreadsheet_column,
|
|
|
|
| 1128 |
yield None, None, None, None, None, f"**Error:** {str(e)}"
|
| 1129 |
|
| 1130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1131 |
|
| 1132 |
def add_category_field(current_count):
|
| 1133 |
new_count = min(current_count + 1, MAX_CATEGORIES)
|
|
|
|
| 1165 |
updates.extend([
|
| 1166 |
gr.update(visible=True), # add_category_btn
|
| 1167 |
INITIAL_CATEGORIES, # category_count
|
| 1168 |
+
gr.update(visible=False), # auto_extract_settings
|
|
|
|
| 1169 |
12, # max_categories (reset to default)
|
| 1170 |
+
"", # auto_extract_status
|
| 1171 |
+
gr.update(visible=False), # categories_group
|
| 1172 |
gr.update(visible=False), # model_group
|
| 1173 |
gr.update(visible=False, value="Run"), # run_btn
|
| 1174 |
"Free Models", # model_tier
|
|
|
|
| 1178 |
gr.update(visible=False), # api_key
|
| 1179 |
"**Free tier** - no API key required!", # api_key_status
|
| 1180 |
"Ready. Upload data and select a task.", # status
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1181 |
gr.update(visible=False), # classify_output_group
|
| 1182 |
gr.update(value=None, visible=False), # distribution_plot
|
| 1183 |
gr.update(value=None, visible=False), # results
|
|
|
|
| 1481 |
info="Helps the LLM understand context"
|
| 1482 |
)
|
| 1483 |
|
| 1484 |
+
# Task selection button
|
| 1485 |
gr.Markdown("### What would you like to do?")
|
| 1486 |
+
assign_btn = gr.Button("Classify Data", variant="secondary", elem_classes="task-btn")
|
|
|
|
|
|
|
|
|
|
| 1487 |
|
| 1488 |
# Categories group (only visible for Assign task)
|
| 1489 |
with gr.Group(visible=False) as categories_group:
|
| 1490 |
gr.Markdown("### Categories")
|
| 1491 |
+
gr.Markdown("Enter your categories manually, or click 'Auto-extract' to discover them from your data.")
|
| 1492 |
category_inputs = []
|
| 1493 |
placeholder_examples = [
|
| 1494 |
"e.g., Positive sentiment",
|
|
|
|
| 1507 |
visible=visible
|
| 1508 |
)
|
| 1509 |
category_inputs.append(cat_input)
|
| 1510 |
+
with gr.Row():
|
| 1511 |
+
add_category_btn = gr.Button("+ Add More", variant="secondary", size="sm")
|
| 1512 |
+
auto_extract_btn = gr.Button("Auto-extract Categories", variant="secondary", size="sm")
|
| 1513 |
+
with gr.Group(visible=False) as auto_extract_settings:
|
| 1514 |
+
max_categories = gr.Slider(
|
| 1515 |
+
minimum=3,
|
| 1516 |
+
maximum=25,
|
| 1517 |
+
value=12,
|
| 1518 |
+
step=1,
|
| 1519 |
+
label="Number of Categories to Extract",
|
| 1520 |
+
info="How many categories should be identified in your data"
|
| 1521 |
+
)
|
| 1522 |
+
run_auto_extract_btn = gr.Button("Extract Now", variant="primary", size="sm")
|
| 1523 |
+
auto_extract_status = gr.Markdown("")
|
| 1524 |
|
| 1525 |
# Model selection group
|
| 1526 |
with gr.Group(visible=False) as model_group:
|
|
|
|
| 1558 |
with gr.Column():
|
| 1559 |
status = gr.Markdown("Ready. Upload data and select a task.")
|
| 1560 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1561 |
# Classify output group
|
| 1562 |
with gr.Group(visible=False) as classify_output_group:
|
| 1563 |
gr.Markdown("### Classification Results")
|
|
|
|
| 1654 |
outputs=category_inputs + [add_category_btn, category_count]
|
| 1655 |
)
|
| 1656 |
|
| 1657 |
+
# Auto-extract button toggles the settings visibility
|
| 1658 |
+
def toggle_auto_extract_settings():
|
| 1659 |
+
return gr.update(visible=True), "Extracting categories..."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1660 |
|
| 1661 |
+
auto_extract_btn.click(
|
| 1662 |
+
fn=toggle_auto_extract_settings,
|
| 1663 |
inputs=[],
|
| 1664 |
+
outputs=[auto_extract_settings, auto_extract_status]
|
| 1665 |
)
|
| 1666 |
|
| 1667 |
+
# Run auto-extract button
|
| 1668 |
+
run_auto_extract_btn.click(
|
| 1669 |
+
fn=run_auto_extract,
|
| 1670 |
+
inputs=[input_type, spreadsheet_file, spreadsheet_column,
|
| 1671 |
+
pdf_file, pdf_folder, pdf_description, pdf_mode,
|
| 1672 |
+
image_file, image_folder, image_description,
|
| 1673 |
+
max_categories, model_tier, model, model_source, api_key],
|
| 1674 |
+
outputs=category_inputs + [category_count, auto_extract_status]
|
| 1675 |
)
|
| 1676 |
|
| 1677 |
+
# Task button handler
|
| 1678 |
+
def select_assign():
|
| 1679 |
+
return ("assign",) + update_task_visibility("assign")
|
| 1680 |
+
|
| 1681 |
+
assign_btn.click(
|
| 1682 |
+
fn=select_assign,
|
| 1683 |
inputs=[],
|
| 1684 |
+
outputs=[task_mode, categories_group, model_group, run_btn, classify_output_group, status]
|
| 1685 |
)
|
| 1686 |
|
| 1687 |
# Main run button handler - dispatches based on task_mode
|
| 1688 |
def dispatch_run(task, input_type, spreadsheet_file, spreadsheet_column,
|
| 1689 |
pdf_file, pdf_folder_val, pdf_description, pdf_mode,
|
| 1690 |
image_file, image_folder_val, image_description,
|
|
|
|
| 1691 |
cat1, cat2, cat3, cat4, cat5, cat6, cat7, cat8, cat9, cat10,
|
| 1692 |
model_tier, model, model_source, api_key,
|
| 1693 |
progress=gr.Progress(track_tqdm=True)):
|
| 1694 |
+
"""Run classification with user-provided categories."""
|
| 1695 |
+
if task == "assign":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1696 |
# run_classify_data yields: (plot, df, files, code, unused, status)
|
| 1697 |
for update in run_classify_data(
|
| 1698 |
input_type, spreadsheet_file, spreadsheet_column,
|
|
|
|
| 1702 |
model_tier, model, model_source, api_key,
|
| 1703 |
progress
|
| 1704 |
):
|
|
|
|
| 1705 |
yield (
|
|
|
|
|
|
|
|
|
|
| 1706 |
update[0], # distribution_plot
|
| 1707 |
update[1], # results
|
| 1708 |
update[2], # download_file
|
| 1709 |
update[3], # classify_code_display
|
| 1710 |
update[5] # status
|
| 1711 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1712 |
else:
|
| 1713 |
+
yield (None, None, None, None, "Please click 'Classify Data' first.")
|
| 1714 |
|
| 1715 |
run_btn.click(
|
| 1716 |
fn=dispatch_run,
|
| 1717 |
inputs=[task_mode, input_type, spreadsheet_file, spreadsheet_column,
|
| 1718 |
pdf_file, pdf_folder, pdf_description, pdf_mode,
|
| 1719 |
+
image_file, image_folder, image_description] + category_inputs + [model_tier, model, model_source, api_key],
|
| 1720 |
+
outputs=[distribution_plot, results, download_file, classify_code_display, status]
|
|
|
|
| 1721 |
)
|
| 1722 |
|
| 1723 |
reset_btn.click(
|
|
|
|
| 1731 |
task_mode
|
| 1732 |
] + category_inputs + [
|
| 1733 |
add_category_btn, category_count,
|
| 1734 |
+
auto_extract_settings, max_categories, auto_extract_status,
|
| 1735 |
+
categories_group, model_group, run_btn,
|
| 1736 |
model_tier, model, model_source, api_key, api_key, api_key_status,
|
| 1737 |
status,
|
|
|
|
| 1738 |
classify_output_group, distribution_plot, results, download_file, classify_code_display
|
| 1739 |
]
|
| 1740 |
)
|