chrissoria commited on
Commit
03b483d
·
verified ·
1 Parent(s): 0905db0

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +100 -25
app.py CHANGED
@@ -998,25 +998,71 @@ with col_input:
998
  st.rerun()
999
 
1000
  st.markdown("### Model Selection")
 
 
 
 
 
 
 
 
 
 
1001
  model_tier = st.radio(
1002
  "Model Tier",
1003
  options=["Free Models", "Bring Your Own Key"],
1004
  key="classify_model_tier"
1005
  )
1006
 
 
 
 
1007
  if model_tier == "Free Models":
1008
- model_display = st.selectbox("Model", options=FREE_MODEL_DISPLAY_NAMES, key="classify_model")
1009
- model = FREE_MODELS_MAP[model_display] # Convert to actual model name
 
 
 
 
 
 
 
 
 
 
1010
  api_key = ""
1011
  else:
1012
- model = st.selectbox("Model", options=PAID_MODEL_CHOICES, key="classify_model_paid")
 
 
 
 
 
 
 
 
 
1013
  api_key = st.text_input("API Key", type="password", key="classify_api_key")
1014
 
 
 
 
 
 
 
 
 
 
 
 
 
1015
  if st.button("Categorize Data", type="primary", use_container_width=True):
1016
  if input_data is None:
1017
  st.error("Please upload data first")
1018
  elif not categories_entered:
1019
  st.error("Please enter at least one category")
 
 
1020
  else:
1021
  # Set up progress tracking
1022
  mode = None
@@ -1028,11 +1074,20 @@ with col_input:
1028
  }
1029
  mode = mode_mapping.get(pdf_mode, "image")
1030
 
1031
- actual_api_key, provider = get_api_key(model, model_tier, api_key)
1032
- if not actual_api_key:
1033
- st.error(f"{provider} API key not configured")
 
 
 
 
 
 
 
 
 
 
1034
  else:
1035
- model_source = get_model_source(model)
1036
  items_list = input_data if isinstance(input_data, list) else [input_data]
1037
 
1038
  # Progress UI
@@ -1058,14 +1113,20 @@ with col_input:
1058
  status_text.text(f"Processing page {current_idx+1} of {total_pages} ({page_label}) ({progress*100:.0f}%){eta_str}")
1059
 
1060
  try:
1061
- result_df = catllm.classify_ensemble(
1062
- survey_input=items_list,
1063
- categories=categories_entered,
1064
- models=[(model, model_source, actual_api_key)],
1065
- input_description=description,
1066
- pdf_mode=mode,
1067
- progress_callback=pdf_progress_callback
1068
- )
 
 
 
 
 
 
1069
 
1070
  processing_time = time.time() - start_time
1071
  total_items = len(result_df)
@@ -1113,13 +1174,19 @@ with col_input:
1113
  status_text.text(f"Processing item {current_idx+1} of {total} ({progress*100:.0f}%){eta_str}")
1114
 
1115
  try:
1116
- result_df = catllm.classify_ensemble(
1117
- survey_input=items_list,
1118
- categories=categories_entered,
1119
- models=[(model, model_source, actual_api_key)],
1120
- input_description=description,
1121
- progress_callback=item_progress_callback,
1122
- )
 
 
 
 
 
 
1123
  all_results = [result_df]
1124
 
1125
  processing_time = time.time() - start_time
@@ -1154,13 +1221,21 @@ with col_input:
1154
  catllm_version = "unknown"
1155
  python_version = sys.version.split()[0]
1156
 
 
 
 
 
 
 
 
 
1157
  # Generate methodology report
1158
  pdf_path = generate_methodology_report_pdf(
1159
  categories=categories_entered,
1160
- model=model,
1161
  column_name=description,
1162
  num_rows=len(result_df),
1163
- model_source=model_source,
1164
  filename=original_filename,
1165
  success_rate=success_rate,
1166
  result_df=result_df,
@@ -1173,7 +1248,7 @@ with col_input:
1173
  )
1174
 
1175
  # Generate code
1176
- code = generate_classify_code(input_type_selected, description, categories_entered, model, model_source, mode)
1177
 
1178
  st.session_state.results = {
1179
  'df': result_df,
 
998
  st.rerun()
999
 
1000
  st.markdown("### Model Selection")
1001
+
1002
+ # Classification mode selector
1003
+ classify_mode = st.radio(
1004
+ "Classification Mode",
1005
+ options=["Single Model", "Model Comparison", "Ensemble"],
1006
+ horizontal=True,
1007
+ key="classify_mode",
1008
+ help="Single: one model. Comparison: see results from multiple models side-by-side. Ensemble: multiple models vote for consensus."
1009
+ )
1010
+
1011
  model_tier = st.radio(
1012
  "Model Tier",
1013
  options=["Free Models", "Bring Your Own Key"],
1014
  key="classify_model_tier"
1015
  )
1016
 
1017
+ # Multi-model mode uses multiselect
1018
+ is_multi_model = classify_mode in ["Model Comparison", "Ensemble"]
1019
+
1020
  if model_tier == "Free Models":
1021
+ if is_multi_model:
1022
+ model_displays = st.multiselect(
1023
+ "Models (select 2+)",
1024
+ options=FREE_MODEL_DISPLAY_NAMES,
1025
+ default=[FREE_MODEL_DISPLAY_NAMES[0], FREE_MODEL_DISPLAY_NAMES[1]] if len(FREE_MODEL_DISPLAY_NAMES) >= 2 else FREE_MODEL_DISPLAY_NAMES[:1],
1026
+ key="classify_models_multi"
1027
+ )
1028
+ models_list = [FREE_MODELS_MAP[d] for d in model_displays]
1029
+ else:
1030
+ model_display = st.selectbox("Model", options=FREE_MODEL_DISPLAY_NAMES, key="classify_model")
1031
+ model = FREE_MODELS_MAP[model_display] # Convert to actual model name
1032
+ models_list = [model]
1033
  api_key = ""
1034
  else:
1035
+ if is_multi_model:
1036
+ models_list = st.multiselect(
1037
+ "Models (select 2+)",
1038
+ options=PAID_MODEL_CHOICES,
1039
+ default=[PAID_MODEL_CHOICES[0], PAID_MODEL_CHOICES[1]] if len(PAID_MODEL_CHOICES) >= 2 else PAID_MODEL_CHOICES[:1],
1040
+ key="classify_models_multi_paid"
1041
+ )
1042
+ else:
1043
+ model = st.selectbox("Model", options=PAID_MODEL_CHOICES, key="classify_model_paid")
1044
+ models_list = [model]
1045
  api_key = st.text_input("API Key", type="password", key="classify_api_key")
1046
 
1047
+ # Ensemble-specific options
1048
+ if classify_mode == "Ensemble":
1049
+ consensus_threshold = st.slider(
1050
+ "Consensus Threshold",
1051
+ min_value=0.0,
1052
+ max_value=1.0,
1053
+ value=0.5,
1054
+ step=0.1,
1055
+ key="consensus_threshold",
1056
+ help="Minimum agreement ratio needed for consensus (0.5 = majority vote)"
1057
+ )
1058
+
1059
  if st.button("Categorize Data", type="primary", use_container_width=True):
1060
  if input_data is None:
1061
  st.error("Please upload data first")
1062
  elif not categories_entered:
1063
  st.error("Please enter at least one category")
1064
+ elif is_multi_model and len(models_list) < 2:
1065
+ st.error("Please select at least 2 models for comparison/ensemble mode")
1066
  else:
1067
  # Set up progress tracking
1068
  mode = None
 
1074
  }
1075
  mode = mode_mapping.get(pdf_mode, "image")
1076
 
1077
+ # Build models tuples list: [(model, source, api_key), ...]
1078
+ models_tuples = []
1079
+ api_key_error = None
1080
+ for m in models_list:
1081
+ actual_key, provider = get_api_key(m, model_tier, api_key)
1082
+ if not actual_key:
1083
+ api_key_error = f"{provider} API key not configured for {m}"
1084
+ break
1085
+ m_source = get_model_source(m)
1086
+ models_tuples.append((m, m_source, actual_key))
1087
+
1088
+ if api_key_error:
1089
+ st.error(api_key_error)
1090
  else:
 
1091
  items_list = input_data if isinstance(input_data, list) else [input_data]
1092
 
1093
  # Progress UI
 
1113
  status_text.text(f"Processing page {current_idx+1} of {total_pages} ({page_label}) ({progress*100:.0f}%){eta_str}")
1114
 
1115
  try:
1116
+ # Build kwargs for classify_ensemble
1117
+ classify_kwargs = {
1118
+ "survey_input": items_list,
1119
+ "categories": categories_entered,
1120
+ "models": models_tuples,
1121
+ "input_description": description,
1122
+ "pdf_mode": mode,
1123
+ "progress_callback": pdf_progress_callback,
1124
+ }
1125
+ # Add consensus_threshold for ensemble mode
1126
+ if classify_mode == "Ensemble":
1127
+ classify_kwargs["consensus_threshold"] = consensus_threshold
1128
+
1129
+ result_df = catllm.classify_ensemble(**classify_kwargs)
1130
 
1131
  processing_time = time.time() - start_time
1132
  total_items = len(result_df)
 
1174
  status_text.text(f"Processing item {current_idx+1} of {total} ({progress*100:.0f}%){eta_str}")
1175
 
1176
  try:
1177
+ # Build kwargs for classify_ensemble
1178
+ classify_kwargs = {
1179
+ "survey_input": items_list,
1180
+ "categories": categories_entered,
1181
+ "models": models_tuples,
1182
+ "input_description": description,
1183
+ "progress_callback": item_progress_callback,
1184
+ }
1185
+ # Add consensus_threshold for ensemble mode
1186
+ if classify_mode == "Ensemble":
1187
+ classify_kwargs["consensus_threshold"] = consensus_threshold
1188
+
1189
+ result_df = catllm.classify_ensemble(**classify_kwargs)
1190
  all_results = [result_df]
1191
 
1192
  processing_time = time.time() - start_time
 
1221
  catllm_version = "unknown"
1222
  python_version = sys.version.split()[0]
1223
 
1224
+ # For reports: create model string (single or list)
1225
+ if len(models_list) == 1:
1226
+ report_model = models_list[0]
1227
+ report_model_source = models_tuples[0][1]
1228
+ else:
1229
+ report_model = ", ".join(models_list)
1230
+ report_model_source = f"{classify_mode} ({len(models_list)} models)"
1231
+
1232
  # Generate methodology report
1233
  pdf_path = generate_methodology_report_pdf(
1234
  categories=categories_entered,
1235
+ model=report_model,
1236
  column_name=description,
1237
  num_rows=len(result_df),
1238
+ model_source=report_model_source,
1239
  filename=original_filename,
1240
  success_rate=success_rate,
1241
  result_df=result_df,
 
1248
  )
1249
 
1250
  # Generate code
1251
+ code = generate_classify_code(input_type_selected, description, categories_entered, report_model, report_model_source, mode)
1252
 
1253
  st.session_state.results = {
1254
  'df': result_df,