chrissoria Claude commited on
Commit
7779e40
·
1 Parent(s): 02239e1

Add max_categories slider for Extract and Extract & Assign tasks

Browse files

- New "Extraction Settings" group with slider (3-25, default 12)
- Shows for Extract and Extract & Assign tasks, hidden for Assign
- Passes max_categories to catllm.extract() calls
- Updated reset function to include new components

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +37 -9
app.py CHANGED
@@ -450,6 +450,7 @@ def update_task_visibility(task):
450
  if task == "extract":
451
  return (
452
  gr.update(visible=False), # categories_group
 
453
  gr.update(visible=True), # model_group
454
  gr.update(visible=True, value="Extract Categories"), # run_btn
455
  gr.update(visible=True), # extract_output_group
@@ -459,6 +460,7 @@ def update_task_visibility(task):
459
  elif task == "assign":
460
  return (
461
  gr.update(visible=True), # categories_group
 
462
  gr.update(visible=True), # model_group
463
  gr.update(visible=True, value="Classify Data"), # run_btn
464
  gr.update(visible=False), # extract_output_group
@@ -468,6 +470,7 @@ def update_task_visibility(task):
468
  elif task == "extract_and_assign":
469
  return (
470
  gr.update(visible=False), # categories_group
 
471
  gr.update(visible=True), # model_group
472
  gr.update(visible=True, value="Extract & Classify"), # run_btn
473
  gr.update(visible=True), # extract_output_group (will show extracted cats)
@@ -481,6 +484,7 @@ def update_task_visibility(task):
481
  gr.update(visible=False),
482
  gr.update(visible=False),
483
  gr.update(visible=False),
 
484
  "Select a task to continue."
485
  )
486
 
@@ -488,6 +492,7 @@ def update_task_visibility(task):
488
  def run_extract_categories(input_type, spreadsheet_file, spreadsheet_column,
489
  pdf_file, pdf_folder, pdf_description, pdf_mode,
490
  image_file, image_folder, image_description,
 
491
  model_tier, model, model_source_input, api_key_input,
492
  progress=gr.Progress(track_tqdm=True)):
493
  """Extract categories from data and display them in a table."""
@@ -532,7 +537,8 @@ def run_extract_categories(input_type, spreadsheet_file, spreadsheet_column,
532
  input_type="text",
533
  description=spreadsheet_column,
534
  user_model=model,
535
- model_source=model_source
 
536
  )
537
 
538
  elif input_type == "PDF Documents":
@@ -570,7 +576,8 @@ def run_extract_categories(input_type, spreadsheet_file, spreadsheet_column,
570
  mode=actual_mode,
571
  user_model=model,
572
  model_source=model_source,
573
- divisions=divisions
 
574
  )
575
 
576
  elif input_type == "Images":
@@ -602,7 +609,8 @@ def run_extract_categories(input_type, spreadsheet_file, spreadsheet_column,
602
  user_model=model,
603
  model_source=model_source,
604
  divisions=divisions,
605
- categories_per_chunk=12 # Images often have multiple categories each
 
606
  )
607
 
608
  else:
@@ -863,6 +871,7 @@ Provide your work in JSON format where the number belonging to each category is
863
  def run_extract_and_assign(input_type, spreadsheet_file, spreadsheet_column,
864
  pdf_file, pdf_folder, pdf_description, pdf_mode,
865
  image_file, image_folder, image_description,
 
866
  model_tier, model, model_source_input, api_key_input,
867
  progress=gr.Progress(track_tqdm=True)):
868
  """Extract categories then classify data with them."""
@@ -990,7 +999,8 @@ def run_extract_and_assign(input_type, spreadsheet_file, spreadsheet_column,
990
  'user_model': model,
991
  'model_source': model_source,
992
  'divisions': divisions,
993
- 'categories_per_chunk': categories_per_chunk
 
994
  }
995
  if mode_param:
996
  extract_kwargs['mode'] = mode_param
@@ -1166,6 +1176,8 @@ def reset_all():
1166
  gr.update(visible=True), # add_category_btn
1167
  INITIAL_CATEGORIES, # category_count
1168
  gr.update(visible=False), # categories_group
 
 
1169
  gr.update(visible=False), # model_group
1170
  gr.update(visible=False, value="Run"), # run_btn
1171
  "Free Models", # model_tier
@@ -1337,6 +1349,18 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
1337
  category_inputs.append(cat_input)
1338
  add_category_btn = gr.Button("+ Add More Categories", variant="secondary", size="sm")
1339
 
 
 
 
 
 
 
 
 
 
 
 
 
1340
  # Model selection group
1341
  with gr.Group(visible=False) as model_group:
1342
  gr.Markdown("### Model")
@@ -1481,25 +1505,26 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
1481
  extract_btn.click(
1482
  fn=select_extract,
1483
  inputs=[],
1484
- outputs=[task_mode, categories_group, model_group, run_btn, extract_output_group, classify_output_group, status]
1485
  )
1486
 
1487
  assign_btn.click(
1488
  fn=select_assign,
1489
  inputs=[],
1490
- outputs=[task_mode, categories_group, model_group, run_btn, extract_output_group, classify_output_group, status]
1491
  )
1492
 
1493
  extract_assign_btn.click(
1494
  fn=select_extract_assign,
1495
  inputs=[],
1496
- outputs=[task_mode, categories_group, model_group, run_btn, extract_output_group, classify_output_group, status]
1497
  )
1498
 
1499
  # Main run button handler - dispatches based on task_mode
1500
  def dispatch_run(task, input_type, spreadsheet_file, spreadsheet_column,
1501
  pdf_file, pdf_folder_val, pdf_description, pdf_mode,
1502
  image_file, image_folder_val, image_description,
 
1503
  cat1, cat2, cat3, cat4, cat5, cat6, cat7, cat8, cat9, cat10,
1504
  model_tier, model, model_source, api_key,
1505
  progress=gr.Progress(track_tqdm=True)):
@@ -1509,6 +1534,7 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
1509
  input_type, spreadsheet_file, spreadsheet_column,
1510
  pdf_file, pdf_folder_val, pdf_description, pdf_mode,
1511
  image_file, image_folder_val, image_description,
 
1512
  model_tier, model, model_source, api_key,
1513
  progress
1514
  ):
@@ -1544,6 +1570,7 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
1544
  input_type, spreadsheet_file, spreadsheet_column,
1545
  pdf_file, pdf_folder_val, pdf_description, pdf_mode,
1546
  image_file, image_folder_val, image_description,
 
1547
  model_tier, model, model_source, api_key,
1548
  progress
1549
  ):
@@ -1562,7 +1589,8 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
1562
  fn=dispatch_run,
1563
  inputs=[task_mode, input_type, spreadsheet_file, spreadsheet_column,
1564
  pdf_file, pdf_folder, pdf_description, pdf_mode,
1565
- image_file, image_folder, image_description] + category_inputs + [model_tier, model, model_source, api_key],
 
1566
  outputs=[extracted_categories, extract_download, distribution_plot, results, download_file, status]
1567
  )
1568
 
@@ -1577,7 +1605,7 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
1577
  task_mode
1578
  ] + category_inputs + [
1579
  add_category_btn, category_count,
1580
- categories_group, model_group, run_btn,
1581
  model_tier, model, model_source, api_key, api_key, api_key_status,
1582
  status,
1583
  extract_output_group, extracted_categories, extract_download,
 
450
  if task == "extract":
451
  return (
452
  gr.update(visible=False), # categories_group
453
+ gr.update(visible=True), # extract_settings_group
454
  gr.update(visible=True), # model_group
455
  gr.update(visible=True, value="Extract Categories"), # run_btn
456
  gr.update(visible=True), # extract_output_group
 
460
  elif task == "assign":
461
  return (
462
  gr.update(visible=True), # categories_group
463
+ gr.update(visible=False), # extract_settings_group
464
  gr.update(visible=True), # model_group
465
  gr.update(visible=True, value="Classify Data"), # run_btn
466
  gr.update(visible=False), # extract_output_group
 
470
  elif task == "extract_and_assign":
471
  return (
472
  gr.update(visible=False), # categories_group
473
+ gr.update(visible=True), # extract_settings_group
474
  gr.update(visible=True), # model_group
475
  gr.update(visible=True, value="Extract & Classify"), # run_btn
476
  gr.update(visible=True), # extract_output_group (will show extracted cats)
 
484
  gr.update(visible=False),
485
  gr.update(visible=False),
486
  gr.update(visible=False),
487
+ gr.update(visible=False),
488
  "Select a task to continue."
489
  )
490
 
 
492
  def run_extract_categories(input_type, spreadsheet_file, spreadsheet_column,
493
  pdf_file, pdf_folder, pdf_description, pdf_mode,
494
  image_file, image_folder, image_description,
495
+ max_categories_val,
496
  model_tier, model, model_source_input, api_key_input,
497
  progress=gr.Progress(track_tqdm=True)):
498
  """Extract categories from data and display them in a table."""
 
537
  input_type="text",
538
  description=spreadsheet_column,
539
  user_model=model,
540
+ model_source=model_source,
541
+ max_categories=int(max_categories_val)
542
  )
543
 
544
  elif input_type == "PDF Documents":
 
576
  mode=actual_mode,
577
  user_model=model,
578
  model_source=model_source,
579
+ divisions=divisions,
580
+ max_categories=int(max_categories_val)
581
  )
582
 
583
  elif input_type == "Images":
 
609
  user_model=model,
610
  model_source=model_source,
611
  divisions=divisions,
612
+ categories_per_chunk=12, # Images often have multiple categories each
613
+ max_categories=int(max_categories_val)
614
  )
615
 
616
  else:
 
871
  def run_extract_and_assign(input_type, spreadsheet_file, spreadsheet_column,
872
  pdf_file, pdf_folder, pdf_description, pdf_mode,
873
  image_file, image_folder, image_description,
874
+ max_categories_val,
875
  model_tier, model, model_source_input, api_key_input,
876
  progress=gr.Progress(track_tqdm=True)):
877
  """Extract categories then classify data with them."""
 
999
  'user_model': model,
1000
  'model_source': model_source,
1001
  'divisions': divisions,
1002
+ 'categories_per_chunk': categories_per_chunk,
1003
+ 'max_categories': int(max_categories_val)
1004
  }
1005
  if mode_param:
1006
  extract_kwargs['mode'] = mode_param
 
1176
  gr.update(visible=True), # add_category_btn
1177
  INITIAL_CATEGORIES, # category_count
1178
  gr.update(visible=False), # categories_group
1179
+ gr.update(visible=False), # extract_settings_group
1180
+ 12, # max_categories (reset to default)
1181
  gr.update(visible=False), # model_group
1182
  gr.update(visible=False, value="Run"), # run_btn
1183
  "Free Models", # model_tier
 
1349
  category_inputs.append(cat_input)
1350
  add_category_btn = gr.Button("+ Add More Categories", variant="secondary", size="sm")
1351
 
1352
+ # Extraction settings group (only visible for Extract and Extract & Assign)
1353
+ with gr.Group(visible=False) as extract_settings_group:
1354
+ gr.Markdown("### Extraction Settings")
1355
+ max_categories = gr.Slider(
1356
+ minimum=3,
1357
+ maximum=25,
1358
+ value=12,
1359
+ step=1,
1360
+ label="Number of Categories to Extract",
1361
+ info="How many categories should be identified in your data"
1362
+ )
1363
+
1364
  # Model selection group
1365
  with gr.Group(visible=False) as model_group:
1366
  gr.Markdown("### Model")
 
1505
  extract_btn.click(
1506
  fn=select_extract,
1507
  inputs=[],
1508
+ outputs=[task_mode, categories_group, extract_settings_group, model_group, run_btn, extract_output_group, classify_output_group, status]
1509
  )
1510
 
1511
  assign_btn.click(
1512
  fn=select_assign,
1513
  inputs=[],
1514
+ outputs=[task_mode, categories_group, extract_settings_group, model_group, run_btn, extract_output_group, classify_output_group, status]
1515
  )
1516
 
1517
  extract_assign_btn.click(
1518
  fn=select_extract_assign,
1519
  inputs=[],
1520
+ outputs=[task_mode, categories_group, extract_settings_group, model_group, run_btn, extract_output_group, classify_output_group, status]
1521
  )
1522
 
1523
  # Main run button handler - dispatches based on task_mode
1524
  def dispatch_run(task, input_type, spreadsheet_file, spreadsheet_column,
1525
  pdf_file, pdf_folder_val, pdf_description, pdf_mode,
1526
  image_file, image_folder_val, image_description,
1527
+ max_categories_val,
1528
  cat1, cat2, cat3, cat4, cat5, cat6, cat7, cat8, cat9, cat10,
1529
  model_tier, model, model_source, api_key,
1530
  progress=gr.Progress(track_tqdm=True)):
 
1534
  input_type, spreadsheet_file, spreadsheet_column,
1535
  pdf_file, pdf_folder_val, pdf_description, pdf_mode,
1536
  image_file, image_folder_val, image_description,
1537
+ max_categories_val,
1538
  model_tier, model, model_source, api_key,
1539
  progress
1540
  ):
 
1570
  input_type, spreadsheet_file, spreadsheet_column,
1571
  pdf_file, pdf_folder_val, pdf_description, pdf_mode,
1572
  image_file, image_folder_val, image_description,
1573
+ max_categories_val,
1574
  model_tier, model, model_source, api_key,
1575
  progress
1576
  ):
 
1589
  fn=dispatch_run,
1590
  inputs=[task_mode, input_type, spreadsheet_file, spreadsheet_column,
1591
  pdf_file, pdf_folder, pdf_description, pdf_mode,
1592
+ image_file, image_folder, image_description,
1593
+ max_categories] + category_inputs + [model_tier, model, model_source, api_key],
1594
  outputs=[extracted_categories, extract_download, distribution_plot, results, download_file, status]
1595
  )
1596
 
 
1605
  task_mode
1606
  ] + category_inputs + [
1607
  add_category_btn, category_count,
1608
+ categories_group, extract_settings_group, max_categories, model_group, run_btn,
1609
  model_tier, model, model_source, api_key, api_key, api_key_status,
1610
  status,
1611
  extract_output_group, extracted_categories, extract_download,