Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -998,25 +998,71 @@ with col_input:
|
|
| 998 |
st.rerun()
|
| 999 |
|
| 1000 |
st.markdown("### Model Selection")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1001 |
model_tier = st.radio(
|
| 1002 |
"Model Tier",
|
| 1003 |
options=["Free Models", "Bring Your Own Key"],
|
| 1004 |
key="classify_model_tier"
|
| 1005 |
)
|
| 1006 |
|
|
|
|
|
|
|
|
|
|
| 1007 |
if model_tier == "Free Models":
|
| 1008 |
-
|
| 1009 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1010 |
api_key = ""
|
| 1011 |
else:
|
| 1012 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1013 |
api_key = st.text_input("API Key", type="password", key="classify_api_key")
|
| 1014 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1015 |
if st.button("Categorize Data", type="primary", use_container_width=True):
|
| 1016 |
if input_data is None:
|
| 1017 |
st.error("Please upload data first")
|
| 1018 |
elif not categories_entered:
|
| 1019 |
st.error("Please enter at least one category")
|
|
|
|
|
|
|
| 1020 |
else:
|
| 1021 |
# Set up progress tracking
|
| 1022 |
mode = None
|
|
@@ -1028,11 +1074,20 @@ with col_input:
|
|
| 1028 |
}
|
| 1029 |
mode = mode_mapping.get(pdf_mode, "image")
|
| 1030 |
|
| 1031 |
-
|
| 1032 |
-
|
| 1033 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1034 |
else:
|
| 1035 |
-
model_source = get_model_source(model)
|
| 1036 |
items_list = input_data if isinstance(input_data, list) else [input_data]
|
| 1037 |
|
| 1038 |
# Progress UI
|
|
@@ -1058,14 +1113,20 @@ with col_input:
|
|
| 1058 |
status_text.text(f"Processing page {current_idx+1} of {total_pages} ({page_label}) ({progress*100:.0f}%){eta_str}")
|
| 1059 |
|
| 1060 |
try:
|
| 1061 |
-
|
| 1062 |
-
|
| 1063 |
-
|
| 1064 |
-
|
| 1065 |
-
|
| 1066 |
-
|
| 1067 |
-
|
| 1068 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1069 |
|
| 1070 |
processing_time = time.time() - start_time
|
| 1071 |
total_items = len(result_df)
|
|
@@ -1113,13 +1174,19 @@ with col_input:
|
|
| 1113 |
status_text.text(f"Processing item {current_idx+1} of {total} ({progress*100:.0f}%){eta_str}")
|
| 1114 |
|
| 1115 |
try:
|
| 1116 |
-
|
| 1117 |
-
|
| 1118 |
-
|
| 1119 |
-
|
| 1120 |
-
|
| 1121 |
-
|
| 1122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1123 |
all_results = [result_df]
|
| 1124 |
|
| 1125 |
processing_time = time.time() - start_time
|
|
@@ -1154,13 +1221,21 @@ with col_input:
|
|
| 1154 |
catllm_version = "unknown"
|
| 1155 |
python_version = sys.version.split()[0]
|
| 1156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1157 |
# Generate methodology report
|
| 1158 |
pdf_path = generate_methodology_report_pdf(
|
| 1159 |
categories=categories_entered,
|
| 1160 |
-
model=
|
| 1161 |
column_name=description,
|
| 1162 |
num_rows=len(result_df),
|
| 1163 |
-
model_source=
|
| 1164 |
filename=original_filename,
|
| 1165 |
success_rate=success_rate,
|
| 1166 |
result_df=result_df,
|
|
@@ -1173,7 +1248,7 @@ with col_input:
|
|
| 1173 |
)
|
| 1174 |
|
| 1175 |
# Generate code
|
| 1176 |
-
code = generate_classify_code(input_type_selected, description, categories_entered,
|
| 1177 |
|
| 1178 |
st.session_state.results = {
|
| 1179 |
'df': result_df,
|
|
|
|
| 998 |
st.rerun()
|
| 999 |
|
| 1000 |
st.markdown("### Model Selection")
|
| 1001 |
+
|
| 1002 |
+
# Classification mode selector
|
| 1003 |
+
classify_mode = st.radio(
|
| 1004 |
+
"Classification Mode",
|
| 1005 |
+
options=["Single Model", "Model Comparison", "Ensemble"],
|
| 1006 |
+
horizontal=True,
|
| 1007 |
+
key="classify_mode",
|
| 1008 |
+
help="Single: one model. Comparison: see results from multiple models side-by-side. Ensemble: multiple models vote for consensus."
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
model_tier = st.radio(
|
| 1012 |
"Model Tier",
|
| 1013 |
options=["Free Models", "Bring Your Own Key"],
|
| 1014 |
key="classify_model_tier"
|
| 1015 |
)
|
| 1016 |
|
| 1017 |
+
# Multi-model mode uses multiselect
|
| 1018 |
+
is_multi_model = classify_mode in ["Model Comparison", "Ensemble"]
|
| 1019 |
+
|
| 1020 |
if model_tier == "Free Models":
|
| 1021 |
+
if is_multi_model:
|
| 1022 |
+
model_displays = st.multiselect(
|
| 1023 |
+
"Models (select 2+)",
|
| 1024 |
+
options=FREE_MODEL_DISPLAY_NAMES,
|
| 1025 |
+
default=[FREE_MODEL_DISPLAY_NAMES[0], FREE_MODEL_DISPLAY_NAMES[1]] if len(FREE_MODEL_DISPLAY_NAMES) >= 2 else FREE_MODEL_DISPLAY_NAMES[:1],
|
| 1026 |
+
key="classify_models_multi"
|
| 1027 |
+
)
|
| 1028 |
+
models_list = [FREE_MODELS_MAP[d] for d in model_displays]
|
| 1029 |
+
else:
|
| 1030 |
+
model_display = st.selectbox("Model", options=FREE_MODEL_DISPLAY_NAMES, key="classify_model")
|
| 1031 |
+
model = FREE_MODELS_MAP[model_display] # Convert to actual model name
|
| 1032 |
+
models_list = [model]
|
| 1033 |
api_key = ""
|
| 1034 |
else:
|
| 1035 |
+
if is_multi_model:
|
| 1036 |
+
models_list = st.multiselect(
|
| 1037 |
+
"Models (select 2+)",
|
| 1038 |
+
options=PAID_MODEL_CHOICES,
|
| 1039 |
+
default=[PAID_MODEL_CHOICES[0], PAID_MODEL_CHOICES[1]] if len(PAID_MODEL_CHOICES) >= 2 else PAID_MODEL_CHOICES[:1],
|
| 1040 |
+
key="classify_models_multi_paid"
|
| 1041 |
+
)
|
| 1042 |
+
else:
|
| 1043 |
+
model = st.selectbox("Model", options=PAID_MODEL_CHOICES, key="classify_model_paid")
|
| 1044 |
+
models_list = [model]
|
| 1045 |
api_key = st.text_input("API Key", type="password", key="classify_api_key")
|
| 1046 |
|
| 1047 |
+
# Ensemble-specific options
|
| 1048 |
+
if classify_mode == "Ensemble":
|
| 1049 |
+
consensus_threshold = st.slider(
|
| 1050 |
+
"Consensus Threshold",
|
| 1051 |
+
min_value=0.0,
|
| 1052 |
+
max_value=1.0,
|
| 1053 |
+
value=0.5,
|
| 1054 |
+
step=0.1,
|
| 1055 |
+
key="consensus_threshold",
|
| 1056 |
+
help="Minimum agreement ratio needed for consensus (0.5 = majority vote)"
|
| 1057 |
+
)
|
| 1058 |
+
|
| 1059 |
if st.button("Categorize Data", type="primary", use_container_width=True):
|
| 1060 |
if input_data is None:
|
| 1061 |
st.error("Please upload data first")
|
| 1062 |
elif not categories_entered:
|
| 1063 |
st.error("Please enter at least one category")
|
| 1064 |
+
elif is_multi_model and len(models_list) < 2:
|
| 1065 |
+
st.error("Please select at least 2 models for comparison/ensemble mode")
|
| 1066 |
else:
|
| 1067 |
# Set up progress tracking
|
| 1068 |
mode = None
|
|
|
|
| 1074 |
}
|
| 1075 |
mode = mode_mapping.get(pdf_mode, "image")
|
| 1076 |
|
| 1077 |
+
# Build models tuples list: [(model, source, api_key), ...]
|
| 1078 |
+
models_tuples = []
|
| 1079 |
+
api_key_error = None
|
| 1080 |
+
for m in models_list:
|
| 1081 |
+
actual_key, provider = get_api_key(m, model_tier, api_key)
|
| 1082 |
+
if not actual_key:
|
| 1083 |
+
api_key_error = f"{provider} API key not configured for {m}"
|
| 1084 |
+
break
|
| 1085 |
+
m_source = get_model_source(m)
|
| 1086 |
+
models_tuples.append((m, m_source, actual_key))
|
| 1087 |
+
|
| 1088 |
+
if api_key_error:
|
| 1089 |
+
st.error(api_key_error)
|
| 1090 |
else:
|
|
|
|
| 1091 |
items_list = input_data if isinstance(input_data, list) else [input_data]
|
| 1092 |
|
| 1093 |
# Progress UI
|
|
|
|
| 1113 |
status_text.text(f"Processing page {current_idx+1} of {total_pages} ({page_label}) ({progress*100:.0f}%){eta_str}")
|
| 1114 |
|
| 1115 |
try:
|
| 1116 |
+
# Build kwargs for classify_ensemble
|
| 1117 |
+
classify_kwargs = {
|
| 1118 |
+
"survey_input": items_list,
|
| 1119 |
+
"categories": categories_entered,
|
| 1120 |
+
"models": models_tuples,
|
| 1121 |
+
"input_description": description,
|
| 1122 |
+
"pdf_mode": mode,
|
| 1123 |
+
"progress_callback": pdf_progress_callback,
|
| 1124 |
+
}
|
| 1125 |
+
# Add consensus_threshold for ensemble mode
|
| 1126 |
+
if classify_mode == "Ensemble":
|
| 1127 |
+
classify_kwargs["consensus_threshold"] = consensus_threshold
|
| 1128 |
+
|
| 1129 |
+
result_df = catllm.classify_ensemble(**classify_kwargs)
|
| 1130 |
|
| 1131 |
processing_time = time.time() - start_time
|
| 1132 |
total_items = len(result_df)
|
|
|
|
| 1174 |
status_text.text(f"Processing item {current_idx+1} of {total} ({progress*100:.0f}%){eta_str}")
|
| 1175 |
|
| 1176 |
try:
|
| 1177 |
+
# Build kwargs for classify_ensemble
|
| 1178 |
+
classify_kwargs = {
|
| 1179 |
+
"survey_input": items_list,
|
| 1180 |
+
"categories": categories_entered,
|
| 1181 |
+
"models": models_tuples,
|
| 1182 |
+
"input_description": description,
|
| 1183 |
+
"progress_callback": item_progress_callback,
|
| 1184 |
+
}
|
| 1185 |
+
# Add consensus_threshold for ensemble mode
|
| 1186 |
+
if classify_mode == "Ensemble":
|
| 1187 |
+
classify_kwargs["consensus_threshold"] = consensus_threshold
|
| 1188 |
+
|
| 1189 |
+
result_df = catllm.classify_ensemble(**classify_kwargs)
|
| 1190 |
all_results = [result_df]
|
| 1191 |
|
| 1192 |
processing_time = time.time() - start_time
|
|
|
|
| 1221 |
catllm_version = "unknown"
|
| 1222 |
python_version = sys.version.split()[0]
|
| 1223 |
|
| 1224 |
+
# For reports: create model string (single or list)
|
| 1225 |
+
if len(models_list) == 1:
|
| 1226 |
+
report_model = models_list[0]
|
| 1227 |
+
report_model_source = models_tuples[0][1]
|
| 1228 |
+
else:
|
| 1229 |
+
report_model = ", ".join(models_list)
|
| 1230 |
+
report_model_source = f"{classify_mode} ({len(models_list)} models)"
|
| 1231 |
+
|
| 1232 |
# Generate methodology report
|
| 1233 |
pdf_path = generate_methodology_report_pdf(
|
| 1234 |
categories=categories_entered,
|
| 1235 |
+
model=report_model,
|
| 1236 |
column_name=description,
|
| 1237 |
num_rows=len(result_df),
|
| 1238 |
+
model_source=report_model_source,
|
| 1239 |
filename=original_filename,
|
| 1240 |
success_rate=success_rate,
|
| 1241 |
result_df=result_df,
|
|
|
|
| 1248 |
)
|
| 1249 |
|
| 1250 |
# Generate code
|
| 1251 |
+
code = generate_classify_code(input_type_selected, description, categories_entered, report_model, report_model_source, mode)
|
| 1252 |
|
| 1253 |
st.session_state.results = {
|
| 1254 |
'df': result_df,
|