chrissoria commited on
Commit
13dd631
·
1 Parent(s): 28087b9

Simplify UI: merge Extract & Assign into single Classify flow with auto-extract option

Browse files

- Remove standalone Extract Categories button
- Remove Extract & Assign button
- Add Auto-extract Categories button within categories section
- Categories can now be manually entered or auto-extracted before classification
- Simplified codebase by ~500 lines

Files changed (1) hide show
  1. app.py +132 -645
app.py CHANGED
@@ -293,9 +293,7 @@ def generate_methodology_report_pdf(categories, model, column_name, num_rows, mo
293
  story = []
294
 
295
  # Title based on task type
296
- if task_type == "extract":
297
- report_title = "CatLLM Category Extraction Report"
298
- elif task_type == "extract_and_assign":
299
  report_title = "CatLLM Extraction & Classification Report"
300
  else:
301
  report_title = "CatLLM Classification Report"
@@ -306,12 +304,7 @@ def generate_methodology_report_pdf(categories, model, column_name, num_rows, mo
306
 
307
  story.append(Paragraph("About This Report", heading_style))
308
 
309
- if task_type == "extract":
310
- about_text = """This methodology report documents the category extraction process for reproducibility and transparency. \
311
- CatLLM uses Large Language Models to automatically discover meaningful categories from your data. The extraction process \
312
- analyzes your data in chunks, identifies recurring themes, and consolidates them into a final set of categories. \
313
- This automated approach helps researchers avoid confirmation bias in category selection."""
314
- elif task_type == "extract_and_assign":
315
  about_text = """This methodology report documents the automated category extraction and classification process. \
316
  CatLLM first discovers categories from your data using LLMs, then classifies each item into those categories. \
317
  This two-phase approach combines exploratory category discovery with systematic classification, ensuring both \
@@ -458,27 +451,14 @@ consistent and reproducible results."""
458
 
459
  # Summary section - adjust title based on task type
460
  story.append(PageBreak())
461
- if task_type == "extract":
462
- story.append(Paragraph("Extraction Summary", title_style))
463
- elif task_type == "extract_and_assign":
464
  story.append(Paragraph("Processing Summary", title_style))
465
  else:
466
  story.append(Paragraph("Classification Summary", title_style))
467
  story.append(Spacer(1, 15))
468
 
469
  # Build summary data based on task type
470
- if task_type == "extract":
471
- story.append(Paragraph("Extraction Details", heading_style))
472
- summary_data = [
473
- ["Source File", filename],
474
- ["Source Column/Description", column_name],
475
- ["Input Type", input_type],
476
- ["Model Used", model],
477
- ["Model Source", model_source],
478
- ["Max Categories Requested", str(max_categories or "default")],
479
- ["Categories Extracted", str(len(categories)) if categories else "0"],
480
- ]
481
- else:
482
  story.append(Paragraph("Classification Details", heading_style))
483
  summary_data = [
484
  ["Source File", filename],
@@ -596,28 +576,7 @@ is the key and a 1 if the category is present and a 0 if not.'''
596
  story.append(PageBreak())
597
  story.append(Paragraph("Reproducibility Code", title_style))
598
 
599
- if task_type == "extract":
600
- story.append(Paragraph("Use the following Python code to reproduce this category extraction:", normal_style))
601
- story.append(Spacer(1, 15))
602
-
603
- code_text = f'''import catllm
604
-
605
- # Extract categories from your data
606
- result = catllm.extract(
607
- input_data="path/to/your/data", # file path, list of paths, or list of text
608
- api_key="YOUR_API_KEY",
609
- input_type="{input_type}",
610
- description="{description or column_name}",
611
- user_model="{model}",
612
- model_source="{model_source}",
613
- max_categories={max_categories or 12}
614
- )
615
-
616
- # View extracted categories
617
- print(result["top_categories"])
618
- print(result["counts_df"])'''
619
-
620
- elif task_type == "extract_and_assign":
621
  story.append(Paragraph("Use the following Python code to reproduce this extraction and classification:", normal_style))
622
  story.append(Spacer(1, 15))
623
 
@@ -742,35 +701,13 @@ def load_columns(file):
742
 
743
  def update_task_visibility(task):
744
  """Update visibility of components based on selected task."""
745
- if task == "extract":
746
- return (
747
- gr.update(visible=False), # categories_group
748
- gr.update(visible=True), # extract_settings_group
749
- gr.update(visible=True), # model_group
750
- gr.update(visible=True, value="Extract Categories"), # run_btn
751
- gr.update(visible=True), # extract_output_group
752
- gr.update(visible=False), # classify_output_group
753
- "Ready to extract categories from your data."
754
- )
755
- elif task == "assign":
756
  return (
757
  gr.update(visible=True), # categories_group
758
- gr.update(visible=False), # extract_settings_group
759
  gr.update(visible=True), # model_group
760
  gr.update(visible=True, value="Classify Data"), # run_btn
761
- gr.update(visible=False), # extract_output_group
762
- gr.update(visible=True), # classify_output_group
763
- "Enter categories and click Classify."
764
- )
765
- elif task == "extract_and_assign":
766
- return (
767
- gr.update(visible=False), # categories_group
768
- gr.update(visible=True), # extract_settings_group
769
- gr.update(visible=True), # model_group
770
- gr.update(visible=True, value="Extract & Classify"), # run_btn
771
- gr.update(visible=True), # extract_output_group (will show extracted cats)
772
  gr.update(visible=True), # classify_output_group
773
- "Categories will be auto-extracted, then data will be classified."
774
  )
775
  else:
776
  return (
@@ -778,76 +715,36 @@ def update_task_visibility(task):
778
  gr.update(visible=False),
779
  gr.update(visible=False),
780
  gr.update(visible=False),
781
- gr.update(visible=False),
782
- gr.update(visible=False),
783
- "Select a task to continue."
784
  )
785
 
786
 
787
- def run_extract_categories(input_type, spreadsheet_file, spreadsheet_column,
788
- pdf_file, pdf_folder, pdf_description, pdf_mode,
789
- image_file, image_folder, image_description,
790
- max_categories_val,
791
- model_tier, model, model_source_input, api_key_input,
792
- progress=gr.Progress(track_tqdm=True)):
793
- """Extract categories from data and display them in a table."""
794
  if not CATLLM_AVAILABLE:
795
- yield None, None, None, "**Error:** catllm package not available"
796
- return
797
 
798
  actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
799
  if not actual_api_key:
800
- yield None, None, None, f"**Error:** {provider} API key not configured"
801
- return
802
 
803
  if model_source_input == "auto":
804
  model_source = get_model_source(model)
805
  else:
806
  model_source = model_source_input
807
 
808
- # Check file size for images and PDFs
809
- files_to_check = None
810
- if input_type == "Images":
811
- files_to_check = image_folder if image_folder else image_file
812
- elif input_type == "PDF Documents":
813
- files_to_check = pdf_folder if pdf_folder else pdf_file
814
-
815
- if files_to_check:
816
- total_size_mb = calculate_total_file_size(files_to_check)
817
- if total_size_mb > MAX_FILE_SIZE_MB:
818
- # Generate the code for the user
819
- if input_type == "Images":
820
- code = generate_extract_code("image", image_description or "images", model, model_source, int(max_categories_val))
821
- else:
822
- mode_mapping = {"Image (visual documents)": "image", "Text (text-heavy)": "text", "Both (comprehensive)": "both"}
823
- actual_mode = mode_mapping.get(pdf_mode, "image")
824
- code = generate_extract_code("pdf", pdf_description or "document", model, model_source, int(max_categories_val), actual_mode)
825
-
826
- warning_msg = f"""**⚠️ Large Upload Detected ({total_size_mb:.1f} MB)**
827
-
828
- Uploads over {MAX_FILE_SIZE_MB} MB may experience performance issues or timeouts on this web app.
829
-
830
- **Recommended:** Run the code locally using the Python package instead. See the code below, or click "See the Code" after this message.
831
-
832
- ```
833
- pip install cat-llm
834
- ```
835
- """
836
- yield None, None, code, warning_msg
837
- return
838
-
839
  try:
840
- yield None, None, None, "Extracting categories from your data..."
841
-
842
- start_time = time.time()
843
-
844
  if input_type == "Survey Responses":
845
  if not spreadsheet_file:
846
- yield None, None, None, "**Error:** Please upload a CSV/Excel file"
847
- return
848
  if not spreadsheet_column:
849
- yield None, None, None, "**Error:** Please select a column"
850
- return
851
 
852
  file_path = spreadsheet_file if isinstance(spreadsheet_file, str) else spreadsheet_file.name
853
  if file_path.endswith('.csv'):
@@ -856,130 +753,100 @@ pip install cat-llm
856
  df = pd.read_excel(file_path)
857
 
858
  input_data = df[spreadsheet_column].tolist()
859
-
860
- result = catllm.extract(
861
- input_data=input_data,
862
- api_key=actual_api_key,
863
- input_type="text",
864
- description=spreadsheet_column,
865
- user_model=model,
866
- model_source=model_source,
867
- max_categories=int(max_categories_val)
868
- )
869
 
870
  elif input_type == "PDF Documents":
871
- # Use folder if provided, otherwise use uploaded files
872
  if pdf_folder:
873
  if isinstance(pdf_folder, list):
874
- pdf_input = [f if isinstance(f, str) else f.name for f in pdf_folder if str(f.name if hasattr(f, 'name') else f).lower().endswith('.pdf')]
875
  else:
876
- pdf_input = pdf_folder if isinstance(pdf_folder, str) else pdf_folder.name
877
  elif pdf_file:
878
  if isinstance(pdf_file, list):
879
- pdf_input = [f if isinstance(f, str) else f.name for f in pdf_file]
880
  else:
881
- pdf_input = pdf_file if isinstance(pdf_file, str) else pdf_file.name
882
  else:
883
- yield None, None, None, "**Error:** Please upload PDF file(s) or a folder"
884
- return
885
-
886
- mode_mapping = {
887
- "Image (visual documents)": "image",
888
- "Text (text-heavy)": "text",
889
- "Both (comprehensive)": "both"
890
- }
891
- actual_mode = mode_mapping.get(pdf_mode, "image")
892
-
893
- # Calculate sensible divisions based on input size
894
- num_items = len(pdf_input) if isinstance(pdf_input, list) else 1
895
- divisions = min(5, max(1, num_items // 3))
896
 
897
- result = catllm.extract(
898
- input_data=pdf_input,
899
- api_key=actual_api_key,
900
- input_type="pdf",
901
- description=pdf_description or "document",
902
- mode=actual_mode,
903
- user_model=model,
904
- model_source=model_source,
905
- divisions=divisions,
906
- max_categories=int(max_categories_val)
907
- )
908
 
909
  elif input_type == "Images":
910
- # Use folder if provided, otherwise use uploaded files
911
  if image_folder:
912
  if isinstance(image_folder, list):
913
- image_input = [f if isinstance(f, str) else f.name for f in image_folder]
914
  else:
915
- image_input = image_folder if isinstance(image_folder, str) else image_folder.name
916
  elif image_file:
917
  if isinstance(image_file, list):
918
- image_input = [f if isinstance(f, str) else f.name for f in image_file]
919
  else:
920
- image_input = image_file if isinstance(image_file, str) else image_file.name
921
  else:
922
- yield None, None, None, "**Error:** Please upload image file(s) or a folder"
923
- return
924
-
925
- # For images, use fewer divisions since each image can have multiple categories
926
- num_items = len(image_input) if isinstance(image_input, list) else 1
927
- # Use 1 division for small sets, max 3 for larger sets
928
- divisions = min(3, max(1, num_items // 5))
929
 
930
- result = catllm.extract(
931
- input_data=image_input,
932
- api_key=actual_api_key,
933
- input_type="image",
934
- description=image_description or "images",
935
- user_model=model,
936
- model_source=model_source,
937
- divisions=divisions,
938
- categories_per_chunk=12, # Images often have multiple categories each
939
- max_categories=int(max_categories_val)
940
- )
941
 
942
  else:
943
- yield None, None, None, f"**Error:** Unknown input type: {input_type}"
944
- return
945
 
946
- processing_time = time.time() - start_time
 
 
 
 
947
 
948
- # Extract the categories and counts
949
- categories_df = result.get('counts_df', pd.DataFrame())
950
- top_categories = result.get('top_categories', [])
 
 
 
951
 
952
- if categories_df.empty and top_categories:
953
- # Create a simple DataFrame from top_categories
954
- categories_df = pd.DataFrame({
955
- 'Category': top_categories,
956
- 'Count': ['-'] * len(top_categories)
957
- })
 
 
 
 
 
 
 
 
958
 
959
- # Save to CSV for download
960
- with tempfile.NamedTemporaryFile(mode='w', suffix='_extracted_categories.csv', delete=False) as f:
961
- categories_df.to_csv(f.name, index=False)
962
- csv_path = f.name
963
 
964
- # Generate reproducibility code
965
- if input_type == "Survey Responses":
966
- code = generate_extract_code("text", spreadsheet_column, model, model_source, int(max_categories_val))
967
- elif input_type == "PDF Documents":
968
- mode_mapping = {"Image (visual documents)": "image", "Text (text-heavy)": "text", "Both (comprehensive)": "both"}
969
- actual_mode = mode_mapping.get(pdf_mode, "image")
970
- code = generate_extract_code("pdf", pdf_description or "document", model, model_source, int(max_categories_val), actual_mode)
971
- else: # Images
972
- code = generate_extract_code("image", image_description or "images", model, model_source, int(max_categories_val))
 
 
 
 
973
 
974
- yield (
975
- gr.update(value=categories_df, visible=True),
976
- csv_path,
977
- code,
978
- f"Extracted {len(top_categories)} categories in {processing_time:.1f}s"
979
- )
980
 
981
  except Exception as e:
982
- yield None, None, None, f"**Error:** {str(e)}"
983
 
984
 
985
  def run_classify_data(input_type, spreadsheet_file, spreadsheet_column,
@@ -1261,320 +1128,6 @@ Provide your work in JSON format where the number belonging to each category is
1261
  yield None, None, None, None, None, f"**Error:** {str(e)}"
1262
 
1263
 
1264
- def run_extract_and_assign(input_type, spreadsheet_file, spreadsheet_column,
1265
- pdf_file, pdf_folder, pdf_description, pdf_mode,
1266
- image_file, image_folder, image_description,
1267
- max_categories_val,
1268
- model_tier, model, model_source_input, api_key_input,
1269
- progress=gr.Progress(track_tqdm=True)):
1270
- """Extract categories then classify data with them."""
1271
- if not CATLLM_AVAILABLE:
1272
- yield None, None, None, None, None, None, None, None, "**Error:** catllm package not available"
1273
- return
1274
-
1275
- actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
1276
- if not actual_api_key:
1277
- yield None, None, None, None, None, None, None, None, f"**Error:** {provider} API key not configured"
1278
- return
1279
-
1280
- if model_source_input == "auto":
1281
- model_source = get_model_source(model)
1282
- else:
1283
- model_source = model_source_input
1284
-
1285
- # Check file size for images and PDFs
1286
- files_to_check = None
1287
- if input_type == "Images":
1288
- files_to_check = image_folder if image_folder else image_file
1289
- elif input_type == "PDF Documents":
1290
- files_to_check = pdf_folder if pdf_folder else pdf_file
1291
-
1292
- if files_to_check:
1293
- total_size_mb = calculate_total_file_size(files_to_check)
1294
- if total_size_mb > MAX_FILE_SIZE_MB:
1295
- # Generate the code for the user
1296
- if input_type == "Images":
1297
- extract_code = generate_extract_code("image", image_description or "images", model, model_source, int(max_categories_val))
1298
- else:
1299
- mode_mapping = {"Image (visual documents)": "image", "Text (text-heavy)": "text", "Both (comprehensive)": "both"}
1300
- actual_mode = mode_mapping.get(pdf_mode, "image")
1301
- extract_code = generate_extract_code("pdf", pdf_description or "document", model, model_source, int(max_categories_val), actual_mode)
1302
-
1303
- warning_msg = f"""**⚠️ Large Upload Detected ({total_size_mb:.1f} MB)**
1304
-
1305
- Uploads over {MAX_FILE_SIZE_MB} MB may experience performance issues or timeouts on this web app.
1306
-
1307
- **Recommended:** Run the code locally using the Python package instead. See the code below, or click "See the Code" after this message.
1308
-
1309
- ```
1310
- pip install cat-llm
1311
- ```
1312
- """
1313
- yield None, None, extract_code, None, None, None, None, None, warning_msg
1314
- return
1315
-
1316
- try:
1317
- # Phase 1: Extract categories
1318
- yield None, None, None, None, None, None, None, None, "Phase 1: Extracting categories..."
1319
-
1320
- start_time = time.time()
1321
-
1322
- if input_type == "Survey Responses":
1323
- if not spreadsheet_file:
1324
- yield None, None, None, None, None, None, None, None, "**Error:** Please upload a CSV/Excel file"
1325
- return
1326
- if not spreadsheet_column:
1327
- yield None, None, None, None, None, None, None, None, "**Error:** Please select a column"
1328
- return
1329
-
1330
- file_path = spreadsheet_file if isinstance(spreadsheet_file, str) else spreadsheet_file.name
1331
- if file_path.endswith('.csv'):
1332
- df = pd.read_csv(file_path)
1333
- else:
1334
- df = pd.read_excel(file_path)
1335
-
1336
- input_data = df[spreadsheet_column].tolist()
1337
- original_filename = file_path.split("/")[-1]
1338
- column_name = spreadsheet_column
1339
- input_type_param = "text"
1340
- description = spreadsheet_column
1341
- mode_param = None
1342
-
1343
- elif input_type == "PDF Documents":
1344
- # Use folder if provided, otherwise use uploaded files
1345
- if pdf_folder:
1346
- if isinstance(pdf_folder, list):
1347
- input_data = [f if isinstance(f, str) else f.name for f in pdf_folder if str(f.name if hasattr(f, 'name') else f).lower().endswith('.pdf')]
1348
- original_filename = "pdf_folder"
1349
- else:
1350
- input_data = pdf_folder if isinstance(pdf_folder, str) else pdf_folder.name
1351
- original_filename = input_data.split("/")[-1]
1352
- elif pdf_file:
1353
- if isinstance(pdf_file, list):
1354
- input_data = [f if isinstance(f, str) else f.name for f in pdf_file]
1355
- original_filename = "multiple_pdfs"
1356
- else:
1357
- input_data = pdf_file if isinstance(pdf_file, str) else pdf_file.name
1358
- original_filename = input_data.split("/")[-1]
1359
- else:
1360
- yield None, None, None, None, None, None, None, None, "**Error:** Please upload PDF file(s) or a folder"
1361
- return
1362
-
1363
- column_name = "PDF Pages"
1364
- input_type_param = "pdf"
1365
- description = pdf_description or "document"
1366
-
1367
- mode_mapping = {
1368
- "Image (visual documents)": "image",
1369
- "Text (text-heavy)": "text",
1370
- "Both (comprehensive)": "both"
1371
- }
1372
- mode_param = mode_mapping.get(pdf_mode, "image")
1373
-
1374
- elif input_type == "Images":
1375
- # Use folder if provided, otherwise use uploaded files
1376
- if image_folder:
1377
- if isinstance(image_folder, list):
1378
- input_data = [f if isinstance(f, str) else f.name for f in image_folder]
1379
- original_filename = "image_folder"
1380
- else:
1381
- input_data = image_folder if isinstance(image_folder, str) else image_folder.name
1382
- original_filename = input_data.split("/")[-1]
1383
- elif image_file:
1384
- if isinstance(image_file, list):
1385
- input_data = [f if isinstance(f, str) else f.name for f in image_file]
1386
- original_filename = "multiple_images"
1387
- else:
1388
- input_data = image_file if isinstance(image_file, str) else image_file.name
1389
- original_filename = input_data.split("/")[-1]
1390
- else:
1391
- yield None, None, None, None, None, None, None, None, "**Error:** Please upload image file(s) or a folder"
1392
- return
1393
-
1394
- column_name = "Image Files"
1395
- input_type_param = "image"
1396
- description = image_description or "images"
1397
- mode_param = None
1398
-
1399
- else:
1400
- yield None, None, None, None, None, None, None, None, f"**Error:** Unknown input type: {input_type}"
1401
- return
1402
-
1403
- # Calculate sensible divisions based on input size and type
1404
- if isinstance(input_data, list):
1405
- num_items = len(input_data)
1406
- else:
1407
- num_items = 1
1408
-
1409
- # Images can have multiple categories per item, so use fewer divisions
1410
- if input_type_param == "image":
1411
- divisions = min(3, max(1, num_items // 5))
1412
- categories_per_chunk = 12
1413
- else:
1414
- divisions = min(5, max(1, num_items // 3))
1415
- categories_per_chunk = 10
1416
-
1417
- # Extract categories
1418
- extract_kwargs = {
1419
- 'input_data': input_data,
1420
- 'api_key': actual_api_key,
1421
- 'input_type': input_type_param,
1422
- 'description': description,
1423
- 'user_model': model,
1424
- 'model_source': model_source,
1425
- 'divisions': divisions,
1426
- 'categories_per_chunk': categories_per_chunk,
1427
- 'max_categories': int(max_categories_val)
1428
- }
1429
- if mode_param:
1430
- extract_kwargs['mode'] = mode_param
1431
-
1432
- extract_result = catllm.extract(**extract_kwargs)
1433
- categories = extract_result.get('top_categories', [])
1434
- categories_df = extract_result.get('counts_df', pd.DataFrame())
1435
-
1436
- if not categories:
1437
- yield None, None, None, None, None, None, None, None, "**Error:** No categories were extracted"
1438
- return
1439
-
1440
- extract_time = time.time() - start_time
1441
-
1442
- # Show extracted categories
1443
- if categories_df.empty and categories:
1444
- categories_df = pd.DataFrame({
1445
- 'Category': categories,
1446
- 'Count': ['-'] * len(categories)
1447
- })
1448
-
1449
- with tempfile.NamedTemporaryFile(mode='w', suffix='_extracted_categories.csv', delete=False) as f:
1450
- categories_df.to_csv(f.name, index=False)
1451
- extract_csv_path = f.name
1452
-
1453
- # Generate extract code
1454
- extract_code = generate_extract_code(input_type_param, description, model, model_source, int(max_categories_val), mode_param)
1455
-
1456
- yield (
1457
- gr.update(value=categories_df, visible=True),
1458
- extract_csv_path,
1459
- extract_code,
1460
- None, None, None, None, None,
1461
- f"Extracted {len(categories)} categories in {extract_time:.1f}s. Now classifying..."
1462
- )
1463
-
1464
- # Phase 2: Classify with extracted categories
1465
- classify_start = time.time()
1466
-
1467
- classify_kwargs = {
1468
- 'input_data': input_data,
1469
- 'categories': categories,
1470
- 'api_key': actual_api_key,
1471
- 'input_type': input_type_param,
1472
- 'description': description,
1473
- 'user_model': model,
1474
- 'model_source': model_source
1475
- }
1476
- if mode_param:
1477
- classify_kwargs['mode'] = mode_param
1478
-
1479
- result = catllm.classify(**classify_kwargs)
1480
-
1481
- classify_time = time.time() - classify_start
1482
- total_time = time.time() - start_time
1483
- num_items = len(result)
1484
-
1485
- # Save CSV
1486
- with tempfile.NamedTemporaryFile(mode='w', suffix='_classified.csv', delete=False) as f:
1487
- result.to_csv(f.name, index=False)
1488
- classify_csv_path = f.name
1489
-
1490
- # Calculate success rate
1491
- if 'processing_status' in result.columns:
1492
- success_count = (result['processing_status'] == 'success').sum()
1493
- success_rate = (success_count / len(result)) * 100
1494
- else:
1495
- success_rate = 100.0
1496
-
1497
- # Get version info
1498
- try:
1499
- catllm_version = catllm.__version__
1500
- except AttributeError:
1501
- catllm_version = "unknown"
1502
- python_version = sys.version.split()[0]
1503
-
1504
- # Generate methodology report
1505
- prompt_template = '''Categorize this survey response "{response}" into the following categories that apply:
1506
- {categories}
1507
-
1508
- Let's think step by step:
1509
- 1. First, identify the main themes mentioned in the response
1510
- 2. Then, match each theme to the relevant categories
1511
- 3. Finally, assign 1 to matching categories and 0 to non-matching categories
1512
-
1513
- Provide your work in JSON format where the number belonging to each category is the key and a 1 if the category is present and a 0 if it is not present as key values.'''
1514
-
1515
- report_pdf_path = generate_methodology_report_pdf(
1516
- categories=categories,
1517
- model=model,
1518
- column_name=column_name,
1519
- num_rows=num_items,
1520
- model_source=model_source,
1521
- filename=original_filename,
1522
- success_rate=success_rate,
1523
- result_df=result,
1524
- processing_time=total_time,
1525
- prompt_template=prompt_template,
1526
- data_quality={'null_count': 0, 'avg_length': 0, 'min_length': 0, 'max_length': 0, 'error_count': 0},
1527
- catllm_version=catllm_version,
1528
- python_version=python_version,
1529
- task_type="extract_and_assign",
1530
- max_categories=int(max_categories_val),
1531
- input_type=input_type_param,
1532
- description=description
1533
- )
1534
-
1535
- # Create distribution plot
1536
- dist_data = []
1537
- total_rows = len(result)
1538
- for i, cat in enumerate(categories, 1):
1539
- col_name = f"category_{i}"
1540
- if col_name in result.columns:
1541
- count = int(result[col_name].sum())
1542
- pct = (count / total_rows) * 100 if total_rows > 0 else 0
1543
- dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
1544
-
1545
- fig, ax = plt.subplots(figsize=(10, max(4, len(dist_data) * 0.8)))
1546
- categories_list = [d["Category"] for d in dist_data][::-1]
1547
- percentages = [d["Percentage"] for d in dist_data][::-1]
1548
-
1549
- bars = ax.barh(categories_list, percentages, color='#2563eb')
1550
- ax.set_xlim(0, 100)
1551
- ax.set_xlabel('Percentage (%)', fontsize=11)
1552
- ax.set_title('Category Distribution (%)', fontsize=14, fontweight='bold')
1553
-
1554
- for bar, pct in zip(bars, percentages):
1555
- ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
1556
- f'{pct:.1f}%', va='center', fontsize=10)
1557
-
1558
- plt.tight_layout()
1559
-
1560
- # Generate classify code
1561
- classify_code = generate_classify_code(input_type_param, description, categories, model, model_source, mode_param)
1562
-
1563
- yield (
1564
- gr.update(value=categories_df, visible=True),
1565
- extract_csv_path,
1566
- extract_code,
1567
- gr.update(value=fig, visible=True),
1568
- gr.update(value=result, visible=True),
1569
- [classify_csv_path, report_pdf_path],
1570
- classify_code,
1571
- None,
1572
- f"Extracted {len(categories)} categories and classified {num_items} items in {total_time:.1f}s"
1573
- )
1574
-
1575
- except Exception as e:
1576
- yield None, None, None, None, None, None, None, None, f"**Error:** {str(e)}"
1577
-
1578
 
1579
  def add_category_field(current_count):
1580
  new_count = min(current_count + 1, MAX_CATEGORIES)
@@ -1612,9 +1165,10 @@ def reset_all():
1612
  updates.extend([
1613
  gr.update(visible=True), # add_category_btn
1614
  INITIAL_CATEGORIES, # category_count
1615
- gr.update(visible=False), # categories_group
1616
- gr.update(visible=False), # extract_settings_group
1617
  12, # max_categories (reset to default)
 
 
1618
  gr.update(visible=False), # model_group
1619
  gr.update(visible=False, value="Run"), # run_btn
1620
  "Free Models", # model_tier
@@ -1624,10 +1178,6 @@ def reset_all():
1624
  gr.update(visible=False), # api_key
1625
  "**Free tier** - no API key required!", # api_key_status
1626
  "Ready. Upload data and select a task.", # status
1627
- gr.update(visible=False), # extract_output_group
1628
- gr.update(value=None, visible=False), # extracted_categories
1629
- None, # extract_download
1630
- "# Code will be generated after extraction", # extract_code_display
1631
  gr.update(visible=False), # classify_output_group
1632
  gr.update(value=None, visible=False), # distribution_plot
1633
  gr.update(value=None, visible=False), # results
@@ -1931,16 +1481,14 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
1931
  info="Helps the LLM understand context"
1932
  )
1933
 
1934
- # Task selection buttons
1935
  gr.Markdown("### What would you like to do?")
1936
- with gr.Row():
1937
- extract_btn = gr.Button("Extract Categories", variant="secondary", elem_classes="task-btn")
1938
- assign_btn = gr.Button("Assign Categories", variant="secondary", elem_classes="task-btn")
1939
- extract_assign_btn = gr.Button("Extract & Assign", variant="secondary", elem_classes="task-btn")
1940
 
1941
  # Categories group (only visible for Assign task)
1942
  with gr.Group(visible=False) as categories_group:
1943
  gr.Markdown("### Categories")
 
1944
  category_inputs = []
1945
  placeholder_examples = [
1946
  "e.g., Positive sentiment",
@@ -1959,19 +1507,20 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
1959
  visible=visible
1960
  )
1961
  category_inputs.append(cat_input)
1962
- add_category_btn = gr.Button("+ Add More Categories", variant="secondary", size="sm")
1963
-
1964
- # Extraction settings group (only visible for Extract and Extract & Assign)
1965
- with gr.Group(visible=False) as extract_settings_group:
1966
- gr.Markdown("### Extraction Settings")
1967
- max_categories = gr.Slider(
1968
- minimum=3,
1969
- maximum=25,
1970
- value=12,
1971
- step=1,
1972
- label="Number of Categories to Extract",
1973
- info="How many categories should be identified in your data"
1974
- )
 
1975
 
1976
  # Model selection group
1977
  with gr.Group(visible=False) as model_group:
@@ -2009,23 +1558,6 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
2009
  with gr.Column():
2010
  status = gr.Markdown("Ready. Upload data and select a task.")
2011
 
2012
- # Extract output group
2013
- with gr.Group(visible=False) as extract_output_group:
2014
- gr.Markdown("### Extracted Categories")
2015
- extracted_categories = gr.DataFrame(
2016
- label="Categories",
2017
- visible=False,
2018
- wrap=True
2019
- )
2020
- extract_download = gr.File(label="Download Categories (CSV)")
2021
- with gr.Accordion("See the Code", open=False):
2022
- extract_code_display = gr.Code(
2023
- label="Python Code",
2024
- language="python",
2025
- value="# Code will be generated after extraction",
2026
- interactive=False
2027
- )
2028
-
2029
  # Classify output group
2030
  with gr.Group(visible=False) as classify_output_group:
2031
  gr.Markdown("### Classification Results")
@@ -2122,65 +1654,45 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
2122
  outputs=category_inputs + [add_category_btn, category_count]
2123
  )
2124
 
2125
- # Task button handlers
2126
- def select_extract():
2127
- return ("extract",) + update_task_visibility("extract")
2128
-
2129
- def select_assign():
2130
- return ("assign",) + update_task_visibility("assign")
2131
-
2132
- def select_extract_assign():
2133
- return ("extract_and_assign",) + update_task_visibility("extract_and_assign")
2134
 
2135
- extract_btn.click(
2136
- fn=select_extract,
2137
  inputs=[],
2138
- outputs=[task_mode, categories_group, extract_settings_group, model_group, run_btn, extract_output_group, classify_output_group, status]
2139
  )
2140
 
2141
- assign_btn.click(
2142
- fn=select_assign,
2143
- inputs=[],
2144
- outputs=[task_mode, categories_group, extract_settings_group, model_group, run_btn, extract_output_group, classify_output_group, status]
 
 
 
 
2145
  )
2146
 
2147
- extract_assign_btn.click(
2148
- fn=select_extract_assign,
 
 
 
 
2149
  inputs=[],
2150
- outputs=[task_mode, categories_group, extract_settings_group, model_group, run_btn, extract_output_group, classify_output_group, status]
2151
  )
2152
 
2153
  # Main run button handler - dispatches based on task_mode
2154
  def dispatch_run(task, input_type, spreadsheet_file, spreadsheet_column,
2155
  pdf_file, pdf_folder_val, pdf_description, pdf_mode,
2156
  image_file, image_folder_val, image_description,
2157
- max_categories_val,
2158
  cat1, cat2, cat3, cat4, cat5, cat6, cat7, cat8, cat9, cat10,
2159
  model_tier, model, model_source, api_key,
2160
  progress=gr.Progress(track_tqdm=True)):
2161
- """Dispatch to appropriate function based on task mode."""
2162
- if task == "extract":
2163
- # run_extract_categories yields: (categories_df, csv_path, code, status)
2164
- for update in run_extract_categories(
2165
- input_type, spreadsheet_file, spreadsheet_column,
2166
- pdf_file, pdf_folder_val, pdf_description, pdf_mode,
2167
- image_file, image_folder_val, image_description,
2168
- max_categories_val,
2169
- model_tier, model, model_source, api_key,
2170
- progress
2171
- ):
2172
- # Map extract outputs to full output list
2173
- yield (
2174
- update[0], # extracted_categories
2175
- update[1], # extract_download
2176
- update[2], # extract_code_display
2177
- None, # distribution_plot
2178
- None, # results
2179
- None, # download_file
2180
- None, # classify_code_display
2181
- update[3] # status
2182
- )
2183
- elif task == "assign":
2184
  # run_classify_data yields: (plot, df, files, code, unused, status)
2185
  for update in run_classify_data(
2186
  input_type, spreadsheet_file, spreadsheet_column,
@@ -2190,47 +1702,22 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
2190
  model_tier, model, model_source, api_key,
2191
  progress
2192
  ):
2193
- # Map classify outputs to full output list
2194
  yield (
2195
- None, # extracted_categories
2196
- None, # extract_download
2197
- None, # extract_code_display
2198
  update[0], # distribution_plot
2199
  update[1], # results
2200
  update[2], # download_file
2201
  update[3], # classify_code_display
2202
  update[5] # status
2203
  )
2204
- elif task == "extract_and_assign":
2205
- # run_extract_and_assign yields: (categories_df, extract_csv, extract_code, plot, df, files, classify_code, unused, status)
2206
- for update in run_extract_and_assign(
2207
- input_type, spreadsheet_file, spreadsheet_column,
2208
- pdf_file, pdf_folder_val, pdf_description, pdf_mode,
2209
- image_file, image_folder_val, image_description,
2210
- max_categories_val,
2211
- model_tier, model, model_source, api_key,
2212
- progress
2213
- ):
2214
- yield (
2215
- update[0], # extracted_categories
2216
- update[1], # extract_download
2217
- update[2], # extract_code_display
2218
- update[3], # distribution_plot
2219
- update[4], # results
2220
- update[5], # download_file
2221
- update[6], # classify_code_display
2222
- update[8] # status
2223
- )
2224
  else:
2225
- yield (None, None, None, None, None, None, None, "Please select a task first.")
2226
 
2227
  run_btn.click(
2228
  fn=dispatch_run,
2229
  inputs=[task_mode, input_type, spreadsheet_file, spreadsheet_column,
2230
  pdf_file, pdf_folder, pdf_description, pdf_mode,
2231
- image_file, image_folder, image_description,
2232
- max_categories] + category_inputs + [model_tier, model, model_source, api_key],
2233
- outputs=[extracted_categories, extract_download, extract_code_display, distribution_plot, results, download_file, classify_code_display, status]
2234
  )
2235
 
2236
  reset_btn.click(
@@ -2244,10 +1731,10 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
2244
  task_mode
2245
  ] + category_inputs + [
2246
  add_category_btn, category_count,
2247
- categories_group, extract_settings_group, max_categories, model_group, run_btn,
 
2248
  model_tier, model, model_source, api_key, api_key, api_key_status,
2249
  status,
2250
- extract_output_group, extracted_categories, extract_download, extract_code_display,
2251
  classify_output_group, distribution_plot, results, download_file, classify_code_display
2252
  ]
2253
  )
 
293
  story = []
294
 
295
  # Title based on task type
296
+ if task_type == "extract_and_assign":
 
 
297
  report_title = "CatLLM Extraction & Classification Report"
298
  else:
299
  report_title = "CatLLM Classification Report"
 
304
 
305
  story.append(Paragraph("About This Report", heading_style))
306
 
307
+ if task_type == "extract_and_assign":
 
 
 
 
 
308
  about_text = """This methodology report documents the automated category extraction and classification process. \
309
  CatLLM first discovers categories from your data using LLMs, then classifies each item into those categories. \
310
  This two-phase approach combines exploratory category discovery with systematic classification, ensuring both \
 
451
 
452
  # Summary section - adjust title based on task type
453
  story.append(PageBreak())
454
+ if task_type == "extract_and_assign":
 
 
455
  story.append(Paragraph("Processing Summary", title_style))
456
  else:
457
  story.append(Paragraph("Classification Summary", title_style))
458
  story.append(Spacer(1, 15))
459
 
460
  # Build summary data based on task type
461
+ if task_type == "assign":
 
 
 
 
 
 
 
 
 
 
 
462
  story.append(Paragraph("Classification Details", heading_style))
463
  summary_data = [
464
  ["Source File", filename],
 
576
  story.append(PageBreak())
577
  story.append(Paragraph("Reproducibility Code", title_style))
578
 
579
+ if task_type == "extract_and_assign":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
  story.append(Paragraph("Use the following Python code to reproduce this extraction and classification:", normal_style))
581
  story.append(Spacer(1, 15))
582
 
 
701
 
702
  def update_task_visibility(task):
703
  """Update visibility of components based on selected task."""
704
+ if task == "assign":
 
 
 
 
 
 
 
 
 
 
705
  return (
706
  gr.update(visible=True), # categories_group
 
707
  gr.update(visible=True), # model_group
708
  gr.update(visible=True, value="Classify Data"), # run_btn
 
 
 
 
 
 
 
 
 
 
 
709
  gr.update(visible=True), # classify_output_group
710
+ "Enter categories (or auto-extract them) and click Classify."
711
  )
712
  else:
713
  return (
 
715
  gr.update(visible=False),
716
  gr.update(visible=False),
717
  gr.update(visible=False),
718
+ "Click 'Classify Data' to continue."
 
 
719
  )
720
 
721
 
722
+ def run_auto_extract(input_type, spreadsheet_file, spreadsheet_column,
723
+ pdf_file, pdf_folder, pdf_description, pdf_mode,
724
+ image_file, image_folder, image_description,
725
+ max_categories_val,
726
+ model_tier, model, model_source_input, api_key_input,
727
+ progress=gr.Progress(track_tqdm=True)):
728
+ """Extract categories from data and fill the category textboxes."""
729
  if not CATLLM_AVAILABLE:
730
+ # Return empty updates for all category inputs + status
731
+ return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** catllm package not available"]
732
 
733
  actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
734
  if not actual_api_key:
735
+ return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, f"**Error:** {provider} API key not configured"]
 
736
 
737
  if model_source_input == "auto":
738
  model_source = get_model_source(model)
739
  else:
740
  model_source = model_source_input
741
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
742
  try:
 
 
 
 
743
  if input_type == "Survey Responses":
744
  if not spreadsheet_file:
745
+ return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** Please upload a CSV/Excel file first"]
 
746
  if not spreadsheet_column:
747
+ return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** Please select a column first"]
 
748
 
749
  file_path = spreadsheet_file if isinstance(spreadsheet_file, str) else spreadsheet_file.name
750
  if file_path.endswith('.csv'):
 
753
  df = pd.read_excel(file_path)
754
 
755
  input_data = df[spreadsheet_column].tolist()
756
+ description = spreadsheet_column
757
+ input_type_param = "text"
758
+ mode_param = None
 
 
 
 
 
 
 
759
 
760
  elif input_type == "PDF Documents":
 
761
  if pdf_folder:
762
  if isinstance(pdf_folder, list):
763
+ input_data = [f if isinstance(f, str) else f.name for f in pdf_folder if str(f.name if hasattr(f, 'name') else f).lower().endswith('.pdf')]
764
  else:
765
+ input_data = pdf_folder if isinstance(pdf_folder, str) else pdf_folder.name
766
  elif pdf_file:
767
  if isinstance(pdf_file, list):
768
+ input_data = [f if isinstance(f, str) else f.name for f in pdf_file]
769
  else:
770
+ input_data = pdf_file if isinstance(pdf_file, str) else pdf_file.name
771
  else:
772
+ return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** Please upload PDF file(s) first"]
 
 
 
 
 
 
 
 
 
 
 
 
773
 
774
+ description = pdf_description or "document"
775
+ input_type_param = "pdf"
776
+ mode_mapping = {"Image (visual documents)": "image", "Text (text-heavy)": "text", "Both (comprehensive)": "both"}
777
+ mode_param = mode_mapping.get(pdf_mode, "image")
 
 
 
 
 
 
 
778
 
779
  elif input_type == "Images":
 
780
  if image_folder:
781
  if isinstance(image_folder, list):
782
+ input_data = [f if isinstance(f, str) else f.name for f in image_folder]
783
  else:
784
+ input_data = image_folder if isinstance(image_folder, str) else image_folder.name
785
  elif image_file:
786
  if isinstance(image_file, list):
787
+ input_data = [f if isinstance(f, str) else f.name for f in image_file]
788
  else:
789
+ input_data = image_file if isinstance(image_file, str) else image_file.name
790
  else:
791
+ return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** Please upload image file(s) first"]
 
 
 
 
 
 
792
 
793
+ description = image_description or "images"
794
+ input_type_param = "image"
795
+ mode_param = None
 
 
 
 
 
 
 
 
796
 
797
  else:
798
+ return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, f"**Error:** Unknown input type: {input_type}"]
 
799
 
800
+ # Calculate divisions based on input size
801
+ if isinstance(input_data, list):
802
+ num_items = len(input_data)
803
+ else:
804
+ num_items = 1
805
 
806
+ if input_type_param == "image":
807
+ divisions = min(3, max(1, num_items // 5))
808
+ categories_per_chunk = 12
809
+ else:
810
+ divisions = min(5, max(1, num_items // 3))
811
+ categories_per_chunk = 10
812
 
813
+ # Extract categories
814
+ extract_kwargs = {
815
+ 'input_data': input_data,
816
+ 'api_key': actual_api_key,
817
+ 'input_type': input_type_param,
818
+ 'description': description,
819
+ 'user_model': model,
820
+ 'model_source': model_source,
821
+ 'divisions': divisions,
822
+ 'categories_per_chunk': categories_per_chunk,
823
+ 'max_categories': int(max_categories_val)
824
+ }
825
+ if mode_param:
826
+ extract_kwargs['mode'] = mode_param
827
 
828
+ extract_result = catllm.extract(**extract_kwargs)
829
+ categories = extract_result.get('top_categories', [])
 
 
830
 
831
+ if not categories:
832
+ return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** No categories were extracted"]
833
+
834
+ # Fill the category textboxes
835
+ updates = []
836
+ num_categories = min(len(categories), MAX_CATEGORIES)
837
+ for i in range(MAX_CATEGORIES):
838
+ if i < num_categories:
839
+ updates.append(gr.update(value=categories[i], visible=True))
840
+ elif i < INITIAL_CATEGORIES:
841
+ updates.append(gr.update(value="", visible=True))
842
+ else:
843
+ updates.append(gr.update(value="", visible=False))
844
 
845
+ # Return updates for category inputs + new category count + status
846
+ return updates + [num_categories, f"Extracted {len(categories)} categories. Review and edit as needed, then click 'Classify Data'."]
 
 
 
 
847
 
848
  except Exception as e:
849
+ return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, f"**Error:** {str(e)}"]
850
 
851
 
852
  def run_classify_data(input_type, spreadsheet_file, spreadsheet_column,
 
1128
  yield None, None, None, None, None, f"**Error:** {str(e)}"
1129
 
1130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1131
 
1132
  def add_category_field(current_count):
1133
  new_count = min(current_count + 1, MAX_CATEGORIES)
 
1165
  updates.extend([
1166
  gr.update(visible=True), # add_category_btn
1167
  INITIAL_CATEGORIES, # category_count
1168
+ gr.update(visible=False), # auto_extract_settings
 
1169
  12, # max_categories (reset to default)
1170
+ "", # auto_extract_status
1171
+ gr.update(visible=False), # categories_group
1172
  gr.update(visible=False), # model_group
1173
  gr.update(visible=False, value="Run"), # run_btn
1174
  "Free Models", # model_tier
 
1178
  gr.update(visible=False), # api_key
1179
  "**Free tier** - no API key required!", # api_key_status
1180
  "Ready. Upload data and select a task.", # status
 
 
 
 
1181
  gr.update(visible=False), # classify_output_group
1182
  gr.update(value=None, visible=False), # distribution_plot
1183
  gr.update(value=None, visible=False), # results
 
1481
  info="Helps the LLM understand context"
1482
  )
1483
 
1484
+ # Task selection button
1485
  gr.Markdown("### What would you like to do?")
1486
+ assign_btn = gr.Button("Classify Data", variant="secondary", elem_classes="task-btn")
 
 
 
1487
 
1488
  # Categories group (only visible for Assign task)
1489
  with gr.Group(visible=False) as categories_group:
1490
  gr.Markdown("### Categories")
1491
+ gr.Markdown("Enter your categories manually, or click 'Auto-extract' to discover them from your data.")
1492
  category_inputs = []
1493
  placeholder_examples = [
1494
  "e.g., Positive sentiment",
 
1507
  visible=visible
1508
  )
1509
  category_inputs.append(cat_input)
1510
+ with gr.Row():
1511
+ add_category_btn = gr.Button("+ Add More", variant="secondary", size="sm")
1512
+ auto_extract_btn = gr.Button("Auto-extract Categories", variant="secondary", size="sm")
1513
+ with gr.Group(visible=False) as auto_extract_settings:
1514
+ max_categories = gr.Slider(
1515
+ minimum=3,
1516
+ maximum=25,
1517
+ value=12,
1518
+ step=1,
1519
+ label="Number of Categories to Extract",
1520
+ info="How many categories should be identified in your data"
1521
+ )
1522
+ run_auto_extract_btn = gr.Button("Extract Now", variant="primary", size="sm")
1523
+ auto_extract_status = gr.Markdown("")
1524
 
1525
  # Model selection group
1526
  with gr.Group(visible=False) as model_group:
 
1558
  with gr.Column():
1559
  status = gr.Markdown("Ready. Upload data and select a task.")
1560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1561
  # Classify output group
1562
  with gr.Group(visible=False) as classify_output_group:
1563
  gr.Markdown("### Classification Results")
 
1654
  outputs=category_inputs + [add_category_btn, category_count]
1655
  )
1656
 
1657
+ # Auto-extract button toggles the settings visibility
1658
+ def toggle_auto_extract_settings():
1659
+ return gr.update(visible=True), "Extracting categories..."
 
 
 
 
 
 
1660
 
1661
+ auto_extract_btn.click(
1662
+ fn=toggle_auto_extract_settings,
1663
  inputs=[],
1664
+ outputs=[auto_extract_settings, auto_extract_status]
1665
  )
1666
 
1667
+ # Run auto-extract button
1668
+ run_auto_extract_btn.click(
1669
+ fn=run_auto_extract,
1670
+ inputs=[input_type, spreadsheet_file, spreadsheet_column,
1671
+ pdf_file, pdf_folder, pdf_description, pdf_mode,
1672
+ image_file, image_folder, image_description,
1673
+ max_categories, model_tier, model, model_source, api_key],
1674
+ outputs=category_inputs + [category_count, auto_extract_status]
1675
  )
1676
 
1677
+ # Task button handler
1678
+ def select_assign():
1679
+ return ("assign",) + update_task_visibility("assign")
1680
+
1681
+ assign_btn.click(
1682
+ fn=select_assign,
1683
  inputs=[],
1684
+ outputs=[task_mode, categories_group, model_group, run_btn, classify_output_group, status]
1685
  )
1686
 
1687
  # Main run button handler - dispatches based on task_mode
1688
  def dispatch_run(task, input_type, spreadsheet_file, spreadsheet_column,
1689
  pdf_file, pdf_folder_val, pdf_description, pdf_mode,
1690
  image_file, image_folder_val, image_description,
 
1691
  cat1, cat2, cat3, cat4, cat5, cat6, cat7, cat8, cat9, cat10,
1692
  model_tier, model, model_source, api_key,
1693
  progress=gr.Progress(track_tqdm=True)):
1694
+ """Run classification with user-provided categories."""
1695
+ if task == "assign":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1696
  # run_classify_data yields: (plot, df, files, code, unused, status)
1697
  for update in run_classify_data(
1698
  input_type, spreadsheet_file, spreadsheet_column,
 
1702
  model_tier, model, model_source, api_key,
1703
  progress
1704
  ):
 
1705
  yield (
 
 
 
1706
  update[0], # distribution_plot
1707
  update[1], # results
1708
  update[2], # download_file
1709
  update[3], # classify_code_display
1710
  update[5] # status
1711
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1712
  else:
1713
+ yield (None, None, None, None, "Please click 'Classify Data' first.")
1714
 
1715
  run_btn.click(
1716
  fn=dispatch_run,
1717
  inputs=[task_mode, input_type, spreadsheet_file, spreadsheet_column,
1718
  pdf_file, pdf_folder, pdf_description, pdf_mode,
1719
+ image_file, image_folder, image_description] + category_inputs + [model_tier, model, model_source, api_key],
1720
+ outputs=[distribution_plot, results, download_file, classify_code_display, status]
 
1721
  )
1722
 
1723
  reset_btn.click(
 
1731
  task_mode
1732
  ] + category_inputs + [
1733
  add_category_btn, category_count,
1734
+ auto_extract_settings, max_categories, auto_extract_status,
1735
+ categories_group, model_group, run_btn,
1736
  model_tier, model, model_source, api_key, api_key, api_key_status,
1737
  status,
 
1738
  classify_output_group, distribution_plot, results, download_file, classify_code_display
1739
  ]
1740
  )