Spaces:
Running
Running
Commit
·
9e11d83
1
Parent(s):
52af3d5
Migrate from Gradio to Streamlit for better mobile support
Browse files- Rewrote app.py using Streamlit instead of Gradio
- Updated README.md to use sdk: streamlit
- Updated requirements.txt: added streamlit, removed gradio
- All 3 input modes preserved: Survey Responses, PDF Documents, Images
- Fixes mobile Safari tab switching bug in Gradio
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
- README.md +7 -7
- app.py +523 -1295
- requirements.txt +2 -1
README.md
CHANGED
|
@@ -1,17 +1,17 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk:
|
| 7 |
-
sdk_version: "
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
short_description: Classify survey responses using LLMs
|
| 12 |
---
|
| 13 |
|
| 14 |
-
#
|
| 15 |
|
| 16 |
A web interface for the [catllm](https://github.com/chrissoria/cat-llm) Python package. Classify survey responses into custom categories using various LLM providers.
|
| 17 |
|
|
|
|
| 1 |
---
|
| 2 |
+
title: CatLLM - Survey Response Classifier
|
| 3 |
+
emoji: 🐱
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: yellow
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: "1.32.0"
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
short_description: Classify survey responses using LLMs
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# CatLLM - Survey Response Classifier
|
| 15 |
|
| 16 |
A web interface for the [catllm](https://github.com/chrissoria/cat-llm) Python package. Classify survey responses into custom categories using various LLM providers.
|
| 17 |
|
app.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
| 1 |
"""
|
| 2 |
-
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
-
import
|
| 6 |
import pandas as pd
|
| 7 |
import tempfile
|
| 8 |
import os
|
|
@@ -21,7 +22,89 @@ except ImportError as e:
|
|
| 21 |
|
| 22 |
MAX_CATEGORIES = 10
|
| 23 |
INITIAL_CATEGORIES = 3
|
| 24 |
-
MAX_FILE_SIZE_MB = 100
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def calculate_total_file_size(files):
|
|
@@ -34,11 +117,13 @@ def calculate_total_file_size(files):
|
|
| 34 |
total_bytes = 0
|
| 35 |
for f in files:
|
| 36 |
try:
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
| 39 |
except (OSError, AttributeError):
|
| 40 |
pass
|
| 41 |
-
return total_bytes / (1024 * 1024)
|
| 42 |
|
| 43 |
|
| 44 |
def generate_extract_code(input_type, description, model, model_source, max_categories, mode=None):
|
|
@@ -71,7 +156,7 @@ print(result["counts_df"])
|
|
| 71 |
|
| 72 |
# Extract categories from PDF documents
|
| 73 |
result = catllm.extract(
|
| 74 |
-
input_data="path/to/your/pdfs/",
|
| 75 |
api_key="YOUR_API_KEY",
|
| 76 |
input_type="pdf",
|
| 77 |
description="{description}"{mode_line},
|
|
@@ -89,7 +174,7 @@ print(result["counts_df"])
|
|
| 89 |
|
| 90 |
# Extract categories from images
|
| 91 |
result = catllm.extract(
|
| 92 |
-
input_data="path/to/your/images/",
|
| 93 |
api_key="YOUR_API_KEY",
|
| 94 |
input_type="image",
|
| 95 |
description="{description}",
|
|
@@ -146,7 +231,7 @@ categories = [
|
|
| 146 |
|
| 147 |
# Classify PDF documents
|
| 148 |
result = catllm.classify(
|
| 149 |
-
input_data="path/to/your/pdfs/",
|
| 150 |
categories=categories,
|
| 151 |
api_key="YOUR_API_KEY",
|
| 152 |
input_type="pdf",
|
|
@@ -169,7 +254,7 @@ categories = [
|
|
| 169 |
|
| 170 |
# Classify images
|
| 171 |
result = catllm.classify(
|
| 172 |
-
input_data="path/to/your/images/",
|
| 173 |
categories=categories,
|
| 174 |
api_key="YOUR_API_KEY",
|
| 175 |
input_type="image",
|
|
@@ -183,98 +268,13 @@ print(result)
|
|
| 183 |
result.to_csv("classified_results.csv", index=False)
|
| 184 |
'''
|
| 185 |
|
| 186 |
-
# Free models (uses Space secrets - no user API key needed)
|
| 187 |
-
FREE_MODEL_CHOICES = [
|
| 188 |
-
"Qwen/Qwen3-VL-235B-A22B-Instruct:novita",
|
| 189 |
-
"deepseek-ai/DeepSeek-V3.1:novita",
|
| 190 |
-
"meta-llama/Llama-3.3-70B-Instruct:groq",
|
| 191 |
-
"gemini-2.5-flash",
|
| 192 |
-
"gpt-4o",
|
| 193 |
-
"mistral-medium-2505",
|
| 194 |
-
"claude-3-haiku-20240307",
|
| 195 |
-
"grok-4-fast-non-reasoning",
|
| 196 |
-
]
|
| 197 |
-
|
| 198 |
-
# Paid models (user provides their own API key)
|
| 199 |
-
PAID_MODEL_CHOICES = [
|
| 200 |
-
"gpt-4.1",
|
| 201 |
-
"gpt-4o",
|
| 202 |
-
"gpt-4o-mini",
|
| 203 |
-
"claude-sonnet-4-5-20250929",
|
| 204 |
-
"claude-opus-4-20250514",
|
| 205 |
-
"claude-3-5-haiku-20241022",
|
| 206 |
-
"gemini-2.5-pro",
|
| 207 |
-
"gemini-2.5-flash",
|
| 208 |
-
"mistral-large-latest",
|
| 209 |
-
]
|
| 210 |
-
|
| 211 |
-
# Models routed through HuggingFace
|
| 212 |
-
HF_ROUTED_MODELS = [
|
| 213 |
-
"Qwen/Qwen3-VL-235B-A22B-Instruct:novita",
|
| 214 |
-
"deepseek-ai/DeepSeek-V3.1:novita",
|
| 215 |
-
"meta-llama/Llama-3.3-70B-Instruct:groq",
|
| 216 |
-
]
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
def is_free_model(model, model_tier):
|
| 220 |
-
"""Check if using free tier (Space pays for API)."""
|
| 221 |
-
return model_tier == "Free Models"
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
def get_model_source(model):
|
| 225 |
-
"""Auto-detect model source. All HF router models (novita, groq, etc) use 'huggingface'."""
|
| 226 |
-
model_lower = model.lower()
|
| 227 |
-
if "gpt" in model_lower:
|
| 228 |
-
return "openai"
|
| 229 |
-
elif "claude" in model_lower:
|
| 230 |
-
return "anthropic"
|
| 231 |
-
elif "gemini" in model_lower:
|
| 232 |
-
return "google"
|
| 233 |
-
elif "mistral" in model_lower and ":novita" not in model_lower:
|
| 234 |
-
return "mistral"
|
| 235 |
-
elif any(x in model_lower for x in [":novita", ":groq", "qwen", "llama", "deepseek"]):
|
| 236 |
-
return "huggingface"
|
| 237 |
-
elif "sonar" in model_lower:
|
| 238 |
-
return "perplexity"
|
| 239 |
-
elif "grok" in model_lower:
|
| 240 |
-
return "xai"
|
| 241 |
-
return "huggingface"
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
def get_api_key(model, model_tier, api_key_input):
|
| 245 |
-
"""Get the appropriate API key based on model and tier."""
|
| 246 |
-
if is_free_model(model, model_tier):
|
| 247 |
-
if model in HF_ROUTED_MODELS:
|
| 248 |
-
return os.environ.get("HF_API_KEY", ""), "HuggingFace"
|
| 249 |
-
elif "gpt" in model.lower():
|
| 250 |
-
return os.environ.get("OPENAI_API_KEY", ""), "OpenAI"
|
| 251 |
-
elif "gemini" in model.lower():
|
| 252 |
-
return os.environ.get("GOOGLE_API_KEY", ""), "Google"
|
| 253 |
-
elif "mistral" in model.lower():
|
| 254 |
-
return os.environ.get("MISTRAL_API_KEY", ""), "Mistral"
|
| 255 |
-
elif "claude" in model.lower():
|
| 256 |
-
return os.environ.get("ANTHROPIC_API_KEY", ""), "Anthropic"
|
| 257 |
-
elif "sonar" in model.lower():
|
| 258 |
-
return os.environ.get("PERPLEXITY_API_KEY", ""), "Perplexity"
|
| 259 |
-
elif "grok" in model.lower():
|
| 260 |
-
return os.environ.get("XAI_API_KEY", ""), "xAI"
|
| 261 |
-
else:
|
| 262 |
-
return os.environ.get("HF_API_KEY", ""), "HuggingFace"
|
| 263 |
-
else:
|
| 264 |
-
if api_key_input and api_key_input.strip():
|
| 265 |
-
return api_key_input.strip(), "User"
|
| 266 |
-
return "", "User"
|
| 267 |
-
|
| 268 |
|
| 269 |
def generate_methodology_report_pdf(categories, model, column_name, num_rows, model_source, filename, success_rate,
|
| 270 |
result_df=None, processing_time=None, prompt_template=None,
|
| 271 |
data_quality=None, catllm_version=None, python_version=None,
|
| 272 |
task_type="assign", extracted_categories_df=None, max_categories=None,
|
| 273 |
input_type="text", description=None):
|
| 274 |
-
"""Generate a PDF methodology report
|
| 275 |
-
|
| 276 |
-
task_type: "extract", "assign", or "extract_and_assign"
|
| 277 |
-
"""
|
| 278 |
from reportlab.lib.pagesizes import letter
|
| 279 |
from reportlab.lib import colors
|
| 280 |
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
|
@@ -291,7 +291,6 @@ def generate_methodology_report_pdf(categories, model, column_name, num_rows, mo
|
|
| 291 |
|
| 292 |
story = []
|
| 293 |
|
| 294 |
-
# Title based on task type
|
| 295 |
if task_type == "extract_and_assign":
|
| 296 |
report_title = "CatLLM Extraction & Classification Report"
|
| 297 |
else:
|
|
@@ -305,42 +304,20 @@ def generate_methodology_report_pdf(categories, model, column_name, num_rows, mo
|
|
| 305 |
|
| 306 |
if task_type == "extract_and_assign":
|
| 307 |
about_text = """This methodology report documents the automated category extraction and classification process. \
|
| 308 |
-
CatLLM first discovers categories from your data using LLMs, then classifies each item into those categories.
|
| 309 |
-
This two-phase approach combines exploratory category discovery with systematic classification, ensuring both \
|
| 310 |
-
data-driven category selection and reproducible assignments."""
|
| 311 |
else:
|
| 312 |
about_text = """This methodology report documents the classification process for reproducibility and transparency. \
|
| 313 |
-
CatLLM
|
| 314 |
-
arXiv:2504.14571): researchers could keep modifying prompts to obtain outputs that support desired conclusions, and \
|
| 315 |
-
this variability in pseudo-natural language poses a challenge for reproducibility since each prompt, even if only \
|
| 316 |
-
slightly altered, can yield different outputs, making it impossible to replicate findings reliably. CatLLM restricts \
|
| 317 |
-
the prompt to a standard template that is impartial to the researcher's inclinations, ensuring \
|
| 318 |
consistent and reproducible results."""
|
| 319 |
|
| 320 |
story.append(Paragraph(about_text, normal_style))
|
| 321 |
story.append(Spacer(1, 15))
|
| 322 |
|
| 323 |
-
|
| 324 |
-
if task_type in ["extract", "extract_and_assign"]:
|
| 325 |
-
story.append(Paragraph("Category Extraction Methodology", heading_style))
|
| 326 |
-
extraction_text = f"""Categories were automatically extracted from your data using the following process:
|
| 327 |
-
|
| 328 |
-
1. Data was divided into chunks for analysis
|
| 329 |
-
2. Each chunk was analyzed by the LLM to identify themes and categories
|
| 330 |
-
3. Categories from all chunks were consolidated and deduplicated
|
| 331 |
-
4. Final categories were selected based on frequency and relevance
|
| 332 |
-
5. Maximum categories requested: {max_categories or 'default'}"""
|
| 333 |
-
story.append(Paragraph(extraction_text.replace('\n', '<br/>'), normal_style))
|
| 334 |
-
story.append(Spacer(1, 15))
|
| 335 |
-
|
| 336 |
-
# Category mapping section - for assign and extract_and_assign
|
| 337 |
-
if task_type in ["assign", "extract_and_assign"] and categories:
|
| 338 |
story.append(Paragraph("Category Mapping", heading_style))
|
| 339 |
story.append(Paragraph("Each category column contains binary values: 1 = present, 0 = not present", normal_style))
|
| 340 |
story.append(Spacer(1, 8))
|
| 341 |
|
| 342 |
-
# Category table - show for all task types that have categories
|
| 343 |
-
if categories:
|
| 344 |
category_data = [["Column Name", "Category Description"]]
|
| 345 |
for i, cat in enumerate(categories, 1):
|
| 346 |
category_data.append([f"category_{i}", cat])
|
|
@@ -357,120 +334,26 @@ consistent and reproducible results."""
|
|
| 357 |
story.append(cat_table)
|
| 358 |
story.append(Spacer(1, 15))
|
| 359 |
|
| 360 |
-
# Output columns description - only for classification tasks
|
| 361 |
-
if task_type in ["assign", "extract_and_assign"]:
|
| 362 |
-
story.append(Paragraph("Other Output Columns", heading_style))
|
| 363 |
-
other_cols = [
|
| 364 |
-
["Column Name", "Description"],
|
| 365 |
-
["survey_input", "The original text that was classified"],
|
| 366 |
-
["model_response", "Raw response from the LLM"],
|
| 367 |
-
["json", "Extracted JSON with category assignments"],
|
| 368 |
-
["processing_status", "'success' if classification worked, 'error' if failed"],
|
| 369 |
-
["categories_id", "Comma-separated list of assigned category numbers"],
|
| 370 |
-
]
|
| 371 |
-
other_table = Table(other_cols, colWidths=[120, 330])
|
| 372 |
-
other_table.setStyle(TableStyle([
|
| 373 |
-
('BACKGROUND', (0, 0), (-1, 0), colors.grey),
|
| 374 |
-
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
| 375 |
-
('GRID', (0, 0), (-1, -1), 1, colors.black),
|
| 376 |
-
('PADDING', (0, 0), (-1, -1), 6),
|
| 377 |
-
('BACKGROUND', (0, 1), (0, -1), colors.lightgrey),
|
| 378 |
-
('FONTSIZE', (0, 0), (-1, -1), 9),
|
| 379 |
-
]))
|
| 380 |
-
story.append(other_table)
|
| 381 |
-
|
| 382 |
story.append(Spacer(1, 30))
|
| 383 |
story.append(Paragraph("Citation", heading_style))
|
| 384 |
story.append(Paragraph("If you use CatLLM in your research, please cite:", normal_style))
|
| 385 |
story.append(Spacer(1, 5))
|
| 386 |
story.append(Paragraph("Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DOI: 10.5281/zenodo.15532316", normal_style))
|
| 387 |
|
| 388 |
-
#
|
| 389 |
-
if task_type in ["assign", "extract_and_assign"] and result_df is not None and len(result_df) > 0:
|
| 390 |
-
story.append(PageBreak())
|
| 391 |
-
story.append(Paragraph("Sample Results (First 5 Rows)", title_style))
|
| 392 |
-
story.append(Paragraph("Example classifications showing original text and assigned categories:", normal_style))
|
| 393 |
-
story.append(Spacer(1, 15))
|
| 394 |
-
|
| 395 |
-
sample_data = [["Original Text (truncated)", "Assigned Categories"]]
|
| 396 |
-
sample_df = result_df.head(5)
|
| 397 |
-
|
| 398 |
-
for _, row in sample_df.iterrows():
|
| 399 |
-
original_text = str(row.get('survey_input', ''))[:80]
|
| 400 |
-
if len(str(row.get('survey_input', ''))) > 80:
|
| 401 |
-
original_text += "..."
|
| 402 |
-
assigned = row.get('categories_id', '')
|
| 403 |
-
if pd.isna(assigned) or assigned == '':
|
| 404 |
-
assigned = "None"
|
| 405 |
-
sample_data.append([original_text, str(assigned)])
|
| 406 |
-
|
| 407 |
-
sample_table = Table(sample_data, colWidths=[320, 130])
|
| 408 |
-
sample_table.setStyle(TableStyle([
|
| 409 |
-
('BACKGROUND', (0, 0), (-1, 0), colors.grey),
|
| 410 |
-
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
| 411 |
-
('GRID', (0, 0), (-1, -1), 1, colors.black),
|
| 412 |
-
('PADDING', (0, 0), (-1, -1), 8),
|
| 413 |
-
('FONTSIZE', (0, 0), (-1, -1), 9),
|
| 414 |
-
('VALIGN', (0, 0), (-1, -1), 'TOP'),
|
| 415 |
-
]))
|
| 416 |
-
story.append(sample_table)
|
| 417 |
-
|
| 418 |
-
# Category distribution - only for classification tasks
|
| 419 |
-
if task_type in ["assign", "extract_and_assign"]:
|
| 420 |
-
story.append(PageBreak())
|
| 421 |
-
story.append(Paragraph("Category Distribution", title_style))
|
| 422 |
-
story.append(Paragraph("Count and percentage of responses assigned to each category:", normal_style))
|
| 423 |
-
story.append(Spacer(1, 15))
|
| 424 |
-
|
| 425 |
-
if result_df is not None:
|
| 426 |
-
dist_data = [["Category", "Description", "Count", "Percentage"]]
|
| 427 |
-
total_rows = len(result_df)
|
| 428 |
-
|
| 429 |
-
for i, cat in enumerate(categories, 1):
|
| 430 |
-
col_name = f"category_{i}"
|
| 431 |
-
if col_name in result_df.columns:
|
| 432 |
-
count = int(result_df[col_name].sum())
|
| 433 |
-
pct = (count / total_rows) * 100 if total_rows > 0 else 0
|
| 434 |
-
dist_data.append([col_name, cat[:40], str(count), f"{pct:.1f}%"])
|
| 435 |
-
else:
|
| 436 |
-
dist_data.append([col_name, cat[:40], "N/A", "N/A"])
|
| 437 |
-
|
| 438 |
-
dist_table = Table(dist_data, colWidths=[80, 200, 60, 80])
|
| 439 |
-
dist_table.setStyle(TableStyle([
|
| 440 |
-
('BACKGROUND', (0, 0), (-1, 0), colors.grey),
|
| 441 |
-
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
| 442 |
-
('GRID', (0, 0), (-1, -1), 1, colors.black),
|
| 443 |
-
('PADDING', (0, 0), (-1, -1), 6),
|
| 444 |
-
('FONTSIZE', (0, 0), (-1, -1), 9),
|
| 445 |
-
('ALIGN', (2, 1), (-1, -1), 'CENTER'),
|
| 446 |
-
]))
|
| 447 |
-
story.append(dist_table)
|
| 448 |
-
story.append(Spacer(1, 15))
|
| 449 |
-
story.append(Paragraph(f"<i>Note: Percentages may sum to more than 100% as responses can be assigned to multiple categories.</i>", normal_style))
|
| 450 |
-
|
| 451 |
-
# Summary section - adjust title based on task type
|
| 452 |
story.append(PageBreak())
|
| 453 |
-
|
| 454 |
-
story.append(Paragraph("Processing Summary", title_style))
|
| 455 |
-
else:
|
| 456 |
-
story.append(Paragraph("Classification Summary", title_style))
|
| 457 |
story.append(Spacer(1, 15))
|
| 458 |
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
["Rows Classified", str(num_rows)],
|
| 469 |
-
["Number of Categories", str(len(categories)) if categories else "0"],
|
| 470 |
-
["Success Rate", f"{success_rate:.2f}%"],
|
| 471 |
-
]
|
| 472 |
-
if task_type == "extract_and_assign":
|
| 473 |
-
summary_data.insert(4, ["Categories Auto-Extracted", "Yes"])
|
| 474 |
summary_table = Table(summary_data, colWidths=[150, 300])
|
| 475 |
summary_table.setStyle(TableStyle([
|
| 476 |
('BACKGROUND', (0, 0), (0, -1), colors.lightgrey),
|
|
@@ -499,27 +382,8 @@ consistent and reproducible results."""
|
|
| 499 |
('FONTSIZE', (0, 0), (-1, -1), 9),
|
| 500 |
]))
|
| 501 |
story.append(time_table)
|
| 502 |
-
story.append(Spacer(1, 15))
|
| 503 |
-
|
| 504 |
-
if data_quality is not None:
|
| 505 |
-
story.append(Paragraph("Data Quality Notes", heading_style))
|
| 506 |
-
quality_data = [
|
| 507 |
-
["Empty/Null Inputs Skipped", str(data_quality.get('null_count', 0))],
|
| 508 |
-
["Average Text Length", f"{data_quality.get('avg_length', 0)} characters"],
|
| 509 |
-
["Min Text Length", f"{data_quality.get('min_length', 0)} characters"],
|
| 510 |
-
["Max Text Length", f"{data_quality.get('max_length', 0)} characters"],
|
| 511 |
-
["Responses with Errors", str(data_quality.get('error_count', 0))],
|
| 512 |
-
]
|
| 513 |
-
quality_table = Table(quality_data, colWidths=[180, 270])
|
| 514 |
-
quality_table.setStyle(TableStyle([
|
| 515 |
-
('BACKGROUND', (0, 0), (0, -1), colors.lightgrey),
|
| 516 |
-
('GRID', (0, 0), (-1, -1), 1, colors.black),
|
| 517 |
-
('PADDING', (0, 0), (-1, -1), 6),
|
| 518 |
-
('FONTSIZE', (0, 0), (-1, -1), 9),
|
| 519 |
-
]))
|
| 520 |
-
story.append(quality_table)
|
| 521 |
-
story.append(Spacer(1, 15))
|
| 522 |
|
|
|
|
| 523 |
story.append(Paragraph("Version Information", heading_style))
|
| 524 |
version_data = [
|
| 525 |
["CatLLM Version", catllm_version or "unknown"],
|
|
@@ -535,300 +399,41 @@ consistent and reproducible results."""
|
|
| 535 |
]))
|
| 536 |
story.append(version_table)
|
| 537 |
|
| 538 |
-
# Prompt template - only for classification tasks
|
| 539 |
-
if task_type in ["assign", "extract_and_assign"] and prompt_template:
|
| 540 |
-
story.append(PageBreak())
|
| 541 |
-
story.append(Paragraph("Prompt Template Used", title_style))
|
| 542 |
-
story.append(Paragraph("The following prompt template was sent to the LLM for each classification:", normal_style))
|
| 543 |
-
story.append(Spacer(1, 15))
|
| 544 |
-
|
| 545 |
-
story.append(Paragraph("Template with Placeholders:", heading_style))
|
| 546 |
-
story.append(Spacer(1, 8))
|
| 547 |
-
|
| 548 |
-
for line in prompt_template.split('\n'):
|
| 549 |
-
escaped_line = line.replace('&', '&').replace('<', '<').replace('>', '>')
|
| 550 |
-
if escaped_line.strip():
|
| 551 |
-
story.append(Paragraph(escaped_line, code_style))
|
| 552 |
-
else:
|
| 553 |
-
story.append(Spacer(1, 5))
|
| 554 |
-
|
| 555 |
-
story.append(Spacer(1, 20))
|
| 556 |
-
|
| 557 |
-
if categories:
|
| 558 |
-
story.append(Paragraph("Example with Your Categories:", heading_style))
|
| 559 |
-
story.append(Spacer(1, 8))
|
| 560 |
-
|
| 561 |
-
categories_list = "\n".join([f" {i}. {cat}" for i, cat in enumerate(categories, 1)])
|
| 562 |
-
example_prompt = f'''Categorize this survey response "[YOUR TEXT HERE]" into the following categories:
|
| 563 |
-
{categories_list}
|
| 564 |
-
Provide your work in JSON format where the number belonging to each category
|
| 565 |
-
is the key and a 1 if the category is present and a 0 if not.'''
|
| 566 |
-
|
| 567 |
-
for line in example_prompt.split('\n'):
|
| 568 |
-
escaped_line = line.replace('&', '&').replace('<', '<').replace('>', '>')
|
| 569 |
-
if escaped_line.strip():
|
| 570 |
-
story.append(Paragraph(escaped_line, code_style))
|
| 571 |
-
else:
|
| 572 |
-
story.append(Spacer(1, 5))
|
| 573 |
-
|
| 574 |
-
# Reproducibility code section
|
| 575 |
-
story.append(PageBreak())
|
| 576 |
-
story.append(Paragraph("Reproducibility Code", title_style))
|
| 577 |
-
|
| 578 |
-
if task_type == "extract_and_assign":
|
| 579 |
-
story.append(Paragraph("Use the following Python code to reproduce this extraction and classification:", normal_style))
|
| 580 |
-
story.append(Spacer(1, 15))
|
| 581 |
-
|
| 582 |
-
categories_str = ", ".join([f'"{cat}"' for cat in categories]) if categories else ""
|
| 583 |
-
|
| 584 |
-
code_text = f'''import catllm
|
| 585 |
-
|
| 586 |
-
# Step 1: Extract categories from your data
|
| 587 |
-
extract_result = catllm.extract(
|
| 588 |
-
input_data="path/to/your/data",
|
| 589 |
-
api_key="YOUR_API_KEY",
|
| 590 |
-
input_type="{input_type}",
|
| 591 |
-
description="{description or column_name}",
|
| 592 |
-
user_model="{model}",
|
| 593 |
-
model_source="{model_source}",
|
| 594 |
-
max_categories={max_categories or 12}
|
| 595 |
-
)
|
| 596 |
-
|
| 597 |
-
categories = extract_result["top_categories"]
|
| 598 |
-
print("Extracted categories:", categories)
|
| 599 |
-
|
| 600 |
-
# Step 2: Classify data using extracted categories
|
| 601 |
-
result = catllm.classify(
|
| 602 |
-
input_data="path/to/your/data",
|
| 603 |
-
categories=categories,
|
| 604 |
-
api_key="YOUR_API_KEY",
|
| 605 |
-
input_type="{input_type}",
|
| 606 |
-
description="{description or column_name}",
|
| 607 |
-
user_model="{model}",
|
| 608 |
-
model_source="{model_source}"
|
| 609 |
-
)
|
| 610 |
-
|
| 611 |
-
# View results
|
| 612 |
-
print(result)
|
| 613 |
-
result.to_csv("classified_results.csv", index=False)'''
|
| 614 |
-
|
| 615 |
-
else: # assign
|
| 616 |
-
story.append(Paragraph("Use the following Python code to reproduce this classification:", normal_style))
|
| 617 |
-
story.append(Spacer(1, 15))
|
| 618 |
-
|
| 619 |
-
categories_str = ", ".join([f'"{cat}"' for cat in categories]) if categories else ""
|
| 620 |
-
|
| 621 |
-
code_text = f'''import catllm
|
| 622 |
-
import pandas as pd
|
| 623 |
-
|
| 624 |
-
# Load your survey data
|
| 625 |
-
df = pd.read_csv("{filename}")
|
| 626 |
-
|
| 627 |
-
# Define your categories
|
| 628 |
-
categories = [{categories_str}]
|
| 629 |
-
|
| 630 |
-
# Classify the responses
|
| 631 |
-
result = catllm.classify(
|
| 632 |
-
input_data=df["{column_name}"].tolist(),
|
| 633 |
-
categories=categories,
|
| 634 |
-
api_key="YOUR_API_KEY",
|
| 635 |
-
input_type="{input_type}",
|
| 636 |
-
description="{description or column_name}",
|
| 637 |
-
user_model="{model}",
|
| 638 |
-
model_source="{model_source}"
|
| 639 |
-
)
|
| 640 |
-
|
| 641 |
-
# View results
|
| 642 |
-
print(result)
|
| 643 |
-
|
| 644 |
-
# Save to CSV
|
| 645 |
-
result.to_csv("classified_results.csv", index=False)'''
|
| 646 |
-
|
| 647 |
-
for line in code_text.split('\n'):
|
| 648 |
-
if line.strip() == '':
|
| 649 |
-
story.append(Spacer(1, 5))
|
| 650 |
-
else:
|
| 651 |
-
escaped_line = line.replace('&', '&').replace('<', '<').replace('>', '>')
|
| 652 |
-
story.append(Paragraph(escaped_line, code_style))
|
| 653 |
-
|
| 654 |
doc.build(story)
|
| 655 |
return pdf_file.name
|
| 656 |
|
| 657 |
|
| 658 |
-
def
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
try:
|
| 662 |
-
df = pd.read_csv(example_path)
|
| 663 |
-
columns = df.columns.tolist()
|
| 664 |
-
return (
|
| 665 |
-
example_path,
|
| 666 |
-
gr.update(choices=columns, value=columns[0] if columns else None),
|
| 667 |
-
f"Loaded example dataset ({len(df)} rows). Select column and choose a task."
|
| 668 |
-
)
|
| 669 |
-
except Exception as e:
|
| 670 |
-
return None, gr.update(choices=[], value=None), f"**Error loading example:** {str(e)}"
|
| 671 |
-
|
| 672 |
-
|
| 673 |
-
def load_columns(file):
|
| 674 |
-
if file is None:
|
| 675 |
-
return gr.update(choices=[], value=None), "Please upload a file first"
|
| 676 |
-
|
| 677 |
-
try:
|
| 678 |
-
file_path = file if isinstance(file, str) else file.name
|
| 679 |
-
if file_path.endswith('.csv'):
|
| 680 |
-
df = pd.read_csv(file_path)
|
| 681 |
-
else:
|
| 682 |
-
df = pd.read_excel(file_path)
|
| 683 |
-
|
| 684 |
-
columns = df.columns.tolist()
|
| 685 |
-
num_rows = len(df)
|
| 686 |
-
|
| 687 |
-
if num_rows > 1000:
|
| 688 |
-
est_minutes = round(num_rows * 1.5 / 60)
|
| 689 |
-
status_msg = f"Loaded {num_rows:,} rows. Processing may take ~{est_minutes} minutes."
|
| 690 |
-
else:
|
| 691 |
-
status_msg = f"Loaded {num_rows:,} rows. Choose a task to proceed."
|
| 692 |
-
|
| 693 |
-
return (
|
| 694 |
-
gr.update(choices=columns, value=columns[0] if columns else None),
|
| 695 |
-
status_msg
|
| 696 |
-
)
|
| 697 |
-
except Exception as e:
|
| 698 |
-
return gr.update(choices=[], value=None), f"**Error:** {str(e)}"
|
| 699 |
-
|
| 700 |
-
|
| 701 |
-
def update_task_visibility(task):
|
| 702 |
-
"""Update visibility of components based on selected task."""
|
| 703 |
-
if task == "manual":
|
| 704 |
-
return (
|
| 705 |
-
gr.update(visible=False), # auto_extract_group
|
| 706 |
-
gr.update(visible=True), # categories_group
|
| 707 |
-
gr.update(visible=True), # model_group
|
| 708 |
-
gr.update(visible=True, value="Classify Data"), # run_btn
|
| 709 |
-
gr.update(visible=True), # classify_output_group
|
| 710 |
-
"Enter your categories and click 'Classify Data'."
|
| 711 |
-
)
|
| 712 |
-
elif task == "auto_extract":
|
| 713 |
-
return (
|
| 714 |
-
gr.update(visible=True), # auto_extract_group
|
| 715 |
-
gr.update(visible=False), # categories_group (will show after extraction)
|
| 716 |
-
gr.update(visible=False), # model_group (will show after extraction)
|
| 717 |
-
gr.update(visible=False), # run_btn (will show after extraction)
|
| 718 |
-
gr.update(visible=True), # classify_output_group
|
| 719 |
-
"Select number of categories and click 'Extract Categories'."
|
| 720 |
-
)
|
| 721 |
-
else:
|
| 722 |
-
return (
|
| 723 |
-
gr.update(visible=False),
|
| 724 |
-
gr.update(visible=False),
|
| 725 |
-
gr.update(visible=False),
|
| 726 |
-
gr.update(visible=False),
|
| 727 |
-
gr.update(visible=False),
|
| 728 |
-
"Upload data and select a task to continue."
|
| 729 |
-
)
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
def run_auto_extract(input_type, spreadsheet_file, spreadsheet_column,
|
| 733 |
-
pdf_file, pdf_folder, pdf_description, pdf_mode,
|
| 734 |
-
image_file, image_folder, image_description,
|
| 735 |
-
max_categories_val,
|
| 736 |
-
model_tier, model, model_source_input, api_key_input,
|
| 737 |
-
progress=gr.Progress(track_tqdm=True)):
|
| 738 |
-
"""Extract categories from data and fill the category textboxes."""
|
| 739 |
if not CATLLM_AVAILABLE:
|
| 740 |
-
|
| 741 |
-
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** catllm package not available"]
|
| 742 |
|
| 743 |
actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
|
| 744 |
if not actual_api_key:
|
| 745 |
-
return
|
| 746 |
|
| 747 |
-
|
| 748 |
-
model_source = get_model_source(model)
|
| 749 |
-
else:
|
| 750 |
-
model_source = model_source_input
|
| 751 |
|
| 752 |
try:
|
| 753 |
-
if input_type == "Survey Responses":
|
| 754 |
-
if not spreadsheet_file:
|
| 755 |
-
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** Please upload a CSV/Excel file first"]
|
| 756 |
-
if not spreadsheet_column:
|
| 757 |
-
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** Please select a column first"]
|
| 758 |
-
|
| 759 |
-
file_path = spreadsheet_file if isinstance(spreadsheet_file, str) else spreadsheet_file.name
|
| 760 |
-
if file_path.endswith('.csv'):
|
| 761 |
-
df = pd.read_csv(file_path)
|
| 762 |
-
else:
|
| 763 |
-
df = pd.read_excel(file_path)
|
| 764 |
-
|
| 765 |
-
input_data = df[spreadsheet_column].tolist()
|
| 766 |
-
description = spreadsheet_column
|
| 767 |
-
input_type_param = "text"
|
| 768 |
-
mode_param = None
|
| 769 |
-
|
| 770 |
-
elif input_type == "PDF Documents":
|
| 771 |
-
if pdf_folder:
|
| 772 |
-
if isinstance(pdf_folder, list):
|
| 773 |
-
input_data = [f if isinstance(f, str) else f.name for f in pdf_folder if str(f.name if hasattr(f, 'name') else f).lower().endswith('.pdf')]
|
| 774 |
-
else:
|
| 775 |
-
input_data = pdf_folder if isinstance(pdf_folder, str) else pdf_folder.name
|
| 776 |
-
elif pdf_file:
|
| 777 |
-
if isinstance(pdf_file, list):
|
| 778 |
-
input_data = [f if isinstance(f, str) else f.name for f in pdf_file]
|
| 779 |
-
else:
|
| 780 |
-
input_data = pdf_file if isinstance(pdf_file, str) else pdf_file.name
|
| 781 |
-
else:
|
| 782 |
-
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** Please upload PDF file(s) first"]
|
| 783 |
-
|
| 784 |
-
description = pdf_description or "document"
|
| 785 |
-
input_type_param = "pdf"
|
| 786 |
-
mode_mapping = {"Image (visual documents)": "image", "Text (text-heavy)": "text", "Both (comprehensive)": "both"}
|
| 787 |
-
mode_param = mode_mapping.get(pdf_mode, "image")
|
| 788 |
-
|
| 789 |
-
elif input_type == "Images":
|
| 790 |
-
if image_folder:
|
| 791 |
-
if isinstance(image_folder, list):
|
| 792 |
-
input_data = [f if isinstance(f, str) else f.name for f in image_folder]
|
| 793 |
-
else:
|
| 794 |
-
input_data = image_folder if isinstance(image_folder, str) else image_folder.name
|
| 795 |
-
elif image_file:
|
| 796 |
-
if isinstance(image_file, list):
|
| 797 |
-
input_data = [f if isinstance(f, str) else f.name for f in image_file]
|
| 798 |
-
else:
|
| 799 |
-
input_data = image_file if isinstance(image_file, str) else image_file.name
|
| 800 |
-
else:
|
| 801 |
-
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, "**Error:** Please upload image file(s) first"]
|
| 802 |
-
|
| 803 |
-
description = image_description or "images"
|
| 804 |
-
input_type_param = "image"
|
| 805 |
-
mode_param = None
|
| 806 |
-
|
| 807 |
-
else:
|
| 808 |
-
return [gr.update()] * MAX_CATEGORIES + [MAX_CATEGORIES, f"**Error:** Unknown input type: {input_type}"]
|
| 809 |
-
|
| 810 |
-
# Calculate divisions based on input size
|
| 811 |
if isinstance(input_data, list):
|
| 812 |
num_items = len(input_data)
|
| 813 |
else:
|
| 814 |
num_items = 1
|
| 815 |
|
| 816 |
-
if
|
| 817 |
divisions = min(3, max(1, num_items // 5))
|
| 818 |
categories_per_chunk = 12
|
| 819 |
else:
|
| 820 |
-
|
| 821 |
-
|
| 822 |
-
divisions = max(1, num_items // 15) # Aim for ~15 items per chunk
|
| 823 |
-
divisions = min(divisions, 5) # Cap at 5 divisions
|
| 824 |
chunk_size = num_items // max(1, divisions)
|
| 825 |
-
categories_per_chunk = min(10, chunk_size - 1)
|
| 826 |
|
| 827 |
-
# Extract categories
|
| 828 |
extract_kwargs = {
|
| 829 |
'input_data': input_data,
|
| 830 |
'api_key': actual_api_key,
|
| 831 |
-
'input_type':
|
| 832 |
'description': description,
|
| 833 |
'user_model': model,
|
| 834 |
'model_source': model_source,
|
|
@@ -836,200 +441,54 @@ def run_auto_extract(input_type, spreadsheet_file, spreadsheet_column,
|
|
| 836 |
'categories_per_chunk': categories_per_chunk,
|
| 837 |
'max_categories': int(max_categories_val)
|
| 838 |
}
|
| 839 |
-
if
|
| 840 |
-
extract_kwargs['mode'] =
|
| 841 |
|
| 842 |
extract_result = catllm.extract(**extract_kwargs)
|
| 843 |
categories = extract_result.get('top_categories', [])
|
| 844 |
|
| 845 |
if not categories:
|
| 846 |
-
return
|
| 847 |
-
|
| 848 |
-
# Fill the category textboxes
|
| 849 |
-
updates = []
|
| 850 |
-
num_categories = min(len(categories), MAX_CATEGORIES)
|
| 851 |
-
for i in range(MAX_CATEGORIES):
|
| 852 |
-
if i < num_categories:
|
| 853 |
-
updates.append(gr.update(value=categories[i], visible=True))
|
| 854 |
-
elif i < INITIAL_CATEGORIES:
|
| 855 |
-
updates.append(gr.update(value="", visible=True))
|
| 856 |
-
else:
|
| 857 |
-
updates.append(gr.update(value="", visible=False))
|
| 858 |
|
| 859 |
-
|
| 860 |
-
return updates + [num_categories, f"Extracted {len(categories)} categories. Review and edit as needed, then click 'Classify Data'."]
|
| 861 |
|
| 862 |
except Exception as e:
|
| 863 |
-
return
|
| 864 |
|
| 865 |
|
| 866 |
-
def run_classify_data(input_type,
|
| 867 |
-
|
| 868 |
-
|
| 869 |
-
|
| 870 |
-
model_tier, model, model_source_input, api_key_input,
|
| 871 |
-
progress=gr.Progress(track_tqdm=True)):
|
| 872 |
"""Classify data with user-provided categories."""
|
| 873 |
if not CATLLM_AVAILABLE:
|
| 874 |
-
|
| 875 |
-
return
|
| 876 |
-
|
| 877 |
-
all_cats = [cat1, cat2, cat3, cat4, cat5, cat6, cat7, cat8, cat9, cat10]
|
| 878 |
-
categories = [c.strip() for c in all_cats if c and c.strip()]
|
| 879 |
|
| 880 |
if not categories:
|
| 881 |
-
|
| 882 |
-
return
|
| 883 |
|
| 884 |
actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
|
| 885 |
if not actual_api_key:
|
| 886 |
-
|
| 887 |
-
return
|
| 888 |
-
|
| 889 |
-
if model_source_input == "auto":
|
| 890 |
-
model_source = get_model_source(model)
|
| 891 |
-
else:
|
| 892 |
-
model_source = model_source_input
|
| 893 |
-
|
| 894 |
-
# Check file size for images and PDFs
|
| 895 |
-
files_to_check = None
|
| 896 |
-
if input_type == "Images":
|
| 897 |
-
files_to_check = image_folder if image_folder else image_file
|
| 898 |
-
elif input_type == "PDF Documents":
|
| 899 |
-
files_to_check = pdf_folder if pdf_folder else pdf_file
|
| 900 |
-
|
| 901 |
-
if files_to_check:
|
| 902 |
-
total_size_mb = calculate_total_file_size(files_to_check)
|
| 903 |
-
if total_size_mb > MAX_FILE_SIZE_MB:
|
| 904 |
-
# Generate the code for the user
|
| 905 |
-
if input_type == "Images":
|
| 906 |
-
code = generate_classify_code("image", image_description or "images", categories, model, model_source)
|
| 907 |
-
else:
|
| 908 |
-
mode_mapping = {"Image (visual documents)": "image", "Text (text-heavy)": "text", "Both (comprehensive)": "both"}
|
| 909 |
-
actual_mode = mode_mapping.get(pdf_mode, "image")
|
| 910 |
-
code = generate_classify_code("pdf", pdf_description or "document", categories, model, model_source, actual_mode)
|
| 911 |
|
| 912 |
-
|
| 913 |
-
|
| 914 |
-
Uploads over {MAX_FILE_SIZE_MB} MB may experience performance issues or timeouts on this web app.
|
| 915 |
-
|
| 916 |
-
**Recommended:** Run the code locally using the Python package instead. See the code below, or click "See the Code" after this message.
|
| 917 |
-
|
| 918 |
-
```
|
| 919 |
-
pip install cat-llm
|
| 920 |
-
```
|
| 921 |
-
"""
|
| 922 |
-
yield None, None, None, code, None, warning_msg
|
| 923 |
-
return
|
| 924 |
|
| 925 |
try:
|
| 926 |
-
yield None, None, None, None, None, "Classifying your data..."
|
| 927 |
-
|
| 928 |
start_time = time.time()
|
| 929 |
|
| 930 |
-
|
| 931 |
-
|
| 932 |
-
|
| 933 |
-
|
| 934 |
-
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
-
else:
|
| 942 |
-
df = pd.read_excel(file_path)
|
| 943 |
-
|
| 944 |
-
input_data = df[spreadsheet_column].tolist()
|
| 945 |
-
original_filename = file_path.split("/")[-1]
|
| 946 |
-
column_name = spreadsheet_column
|
| 947 |
-
|
| 948 |
-
result = catllm.classify(
|
| 949 |
-
input_data=input_data,
|
| 950 |
-
categories=categories,
|
| 951 |
-
api_key=actual_api_key,
|
| 952 |
-
input_type="text",
|
| 953 |
-
description=spreadsheet_column,
|
| 954 |
-
user_model=model,
|
| 955 |
-
model_source=model_source
|
| 956 |
-
)
|
| 957 |
-
|
| 958 |
-
elif input_type == "PDF Documents":
|
| 959 |
-
# Use folder if provided, otherwise use uploaded files
|
| 960 |
-
if pdf_folder:
|
| 961 |
-
if isinstance(pdf_folder, list):
|
| 962 |
-
pdf_input = [f if isinstance(f, str) else f.name for f in pdf_folder if str(f.name if hasattr(f, 'name') else f).lower().endswith('.pdf')]
|
| 963 |
-
original_filename = "pdf_folder"
|
| 964 |
-
else:
|
| 965 |
-
pdf_input = pdf_folder if isinstance(pdf_folder, str) else pdf_folder.name
|
| 966 |
-
original_filename = pdf_input.split("/")[-1]
|
| 967 |
-
elif pdf_file:
|
| 968 |
-
if isinstance(pdf_file, list):
|
| 969 |
-
pdf_input = [f if isinstance(f, str) else f.name for f in pdf_file]
|
| 970 |
-
original_filename = "multiple_pdfs"
|
| 971 |
-
else:
|
| 972 |
-
pdf_input = pdf_file if isinstance(pdf_file, str) else pdf_file.name
|
| 973 |
-
original_filename = pdf_input.split("/")[-1]
|
| 974 |
-
else:
|
| 975 |
-
yield None, None, None, None, None, "**Error:** Please upload PDF file(s) or a folder"
|
| 976 |
-
return
|
| 977 |
-
|
| 978 |
-
column_name = "PDF Pages"
|
| 979 |
-
|
| 980 |
-
mode_mapping = {
|
| 981 |
-
"Image (visual documents)": "image",
|
| 982 |
-
"Text (text-heavy)": "text",
|
| 983 |
-
"Both (comprehensive)": "both"
|
| 984 |
-
}
|
| 985 |
-
actual_mode = mode_mapping.get(pdf_mode, "image")
|
| 986 |
-
|
| 987 |
-
result = catllm.classify(
|
| 988 |
-
input_data=pdf_input,
|
| 989 |
-
categories=categories,
|
| 990 |
-
api_key=actual_api_key,
|
| 991 |
-
input_type="pdf",
|
| 992 |
-
description=pdf_description or "document",
|
| 993 |
-
mode=actual_mode,
|
| 994 |
-
user_model=model,
|
| 995 |
-
model_source=model_source
|
| 996 |
-
)
|
| 997 |
-
|
| 998 |
-
elif input_type == "Images":
|
| 999 |
-
# Use folder if provided, otherwise use uploaded files
|
| 1000 |
-
if image_folder:
|
| 1001 |
-
if isinstance(image_folder, list):
|
| 1002 |
-
image_input = [f if isinstance(f, str) else f.name for f in image_folder]
|
| 1003 |
-
original_filename = "image_folder"
|
| 1004 |
-
else:
|
| 1005 |
-
image_input = image_folder if isinstance(image_folder, str) else image_folder.name
|
| 1006 |
-
original_filename = image_input.split("/")[-1]
|
| 1007 |
-
elif image_file:
|
| 1008 |
-
if isinstance(image_file, list):
|
| 1009 |
-
image_input = [f if isinstance(f, str) else f.name for f in image_file]
|
| 1010 |
-
original_filename = "multiple_images"
|
| 1011 |
-
else:
|
| 1012 |
-
image_input = image_file if isinstance(image_file, str) else image_file.name
|
| 1013 |
-
original_filename = image_input.split("/")[-1]
|
| 1014 |
-
else:
|
| 1015 |
-
yield None, None, None, None, None, "**Error:** Please upload image file(s) or a folder"
|
| 1016 |
-
return
|
| 1017 |
-
|
| 1018 |
-
column_name = "Image Files"
|
| 1019 |
-
|
| 1020 |
-
result = catllm.classify(
|
| 1021 |
-
input_data=image_input,
|
| 1022 |
-
categories=categories,
|
| 1023 |
-
api_key=actual_api_key,
|
| 1024 |
-
input_type="image",
|
| 1025 |
-
description=image_description or "images",
|
| 1026 |
-
user_model=model,
|
| 1027 |
-
model_source=model_source
|
| 1028 |
-
)
|
| 1029 |
|
| 1030 |
-
|
| 1031 |
-
yield None, None, None, None, None, f"**Error:** Unknown input type: {input_type}"
|
| 1032 |
-
return
|
| 1033 |
|
| 1034 |
processing_time = time.time() - start_time
|
| 1035 |
num_items = len(result)
|
|
@@ -1054,27 +513,6 @@ pip install cat-llm
|
|
| 1054 |
python_version = sys.version.split()[0]
|
| 1055 |
|
| 1056 |
# Generate methodology report
|
| 1057 |
-
prompt_template = '''Categorize this survey response "{response}" into the following categories that apply:
|
| 1058 |
-
{categories}
|
| 1059 |
-
|
| 1060 |
-
Let's think step by step:
|
| 1061 |
-
1. First, identify the main themes mentioned in the response
|
| 1062 |
-
2. Then, match each theme to the relevant categories
|
| 1063 |
-
3. Finally, assign 1 to matching categories and 0 to non-matching categories
|
| 1064 |
-
|
| 1065 |
-
Provide your work in JSON format where the number belonging to each category is the key and a 1 if the category is present and a 0 if it is not present as key values.'''
|
| 1066 |
-
|
| 1067 |
-
# Determine input_type_param for the report
|
| 1068 |
-
if input_type == "Survey Responses":
|
| 1069 |
-
input_type_param = "text"
|
| 1070 |
-
description_param = spreadsheet_column
|
| 1071 |
-
elif input_type == "PDF Documents":
|
| 1072 |
-
input_type_param = "pdf"
|
| 1073 |
-
description_param = pdf_description or "document"
|
| 1074 |
-
else:
|
| 1075 |
-
input_type_param = "image"
|
| 1076 |
-
description_param = image_description or "images"
|
| 1077 |
-
|
| 1078 |
report_pdf_path = generate_methodology_report_pdf(
|
| 1079 |
categories=categories,
|
| 1080 |
model=model,
|
|
@@ -1085,252 +523,81 @@ Provide your work in JSON format where the number belonging to each category is
|
|
| 1085 |
success_rate=success_rate,
|
| 1086 |
result_df=result,
|
| 1087 |
processing_time=processing_time,
|
| 1088 |
-
prompt_template=prompt_template,
|
| 1089 |
-
data_quality={'null_count': 0, 'avg_length': 0, 'min_length': 0, 'max_length': 0, 'error_count': 0},
|
| 1090 |
catllm_version=catllm_version,
|
| 1091 |
python_version=python_version,
|
| 1092 |
task_type="assign",
|
| 1093 |
-
input_type=
|
| 1094 |
-
description=
|
| 1095 |
)
|
| 1096 |
|
| 1097 |
-
#
|
| 1098 |
-
|
| 1099 |
-
total_rows = len(result)
|
| 1100 |
-
for i, cat in enumerate(categories, 1):
|
| 1101 |
-
col_name = f"category_{i}"
|
| 1102 |
-
if col_name in result.columns:
|
| 1103 |
-
count = int(result[col_name].sum())
|
| 1104 |
-
pct = (count / total_rows) * 100 if total_rows > 0 else 0
|
| 1105 |
-
dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
|
| 1106 |
|
| 1107 |
-
|
| 1108 |
-
categories_list = [d["Category"] for d in dist_data][::-1]
|
| 1109 |
-
percentages = [d["Percentage"] for d in dist_data][::-1]
|
| 1110 |
|
| 1111 |
-
|
| 1112 |
-
|
| 1113 |
-
ax.set_xlabel('Percentage (%)', fontsize=11)
|
| 1114 |
-
ax.set_title('Category Distribution (%)', fontsize=14, fontweight='bold')
|
| 1115 |
|
| 1116 |
-
for bar, pct in zip(bars, percentages):
|
| 1117 |
-
ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
|
| 1118 |
-
f'{pct:.1f}%', va='center', fontsize=10)
|
| 1119 |
|
| 1120 |
-
|
|
|
|
|
|
|
| 1121 |
|
| 1122 |
-
|
| 1123 |
-
|
| 1124 |
-
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
|
| 1130 |
-
code = generate_classify_code("image", image_description or "images", categories, model, model_source)
|
| 1131 |
-
|
| 1132 |
-
yield (
|
| 1133 |
-
gr.update(value=fig, visible=True),
|
| 1134 |
-
gr.update(value=result, visible=True),
|
| 1135 |
-
[csv_path, report_pdf_path],
|
| 1136 |
-
code,
|
| 1137 |
-
None,
|
| 1138 |
-
f"Classified {num_items} items in {processing_time:.1f}s"
|
| 1139 |
-
)
|
| 1140 |
|
| 1141 |
-
|
| 1142 |
-
|
| 1143 |
-
|
| 1144 |
-
|
| 1145 |
-
|
| 1146 |
-
def add_category_field(current_count):
|
| 1147 |
-
new_count = min(current_count + 1, MAX_CATEGORIES)
|
| 1148 |
-
updates = []
|
| 1149 |
-
for i in range(MAX_CATEGORIES):
|
| 1150 |
-
updates.append(gr.update(visible=(i < new_count)))
|
| 1151 |
-
updates.append(gr.update(visible=(new_count < MAX_CATEGORIES)))
|
| 1152 |
-
updates.append(new_count)
|
| 1153 |
-
return updates
|
| 1154 |
-
|
| 1155 |
-
|
| 1156 |
-
def reset_all():
|
| 1157 |
-
"""Reset all inputs and outputs to initial state."""
|
| 1158 |
-
updates = [
|
| 1159 |
-
"Survey Responses", # input_type (State)
|
| 1160 |
-
gr.update(selected="survey"), # input_tabs - reset to first tab
|
| 1161 |
-
None, # spreadsheet_file
|
| 1162 |
-
gr.update(choices=[], value=None), # spreadsheet_column
|
| 1163 |
-
"Upload File(s)", # pdf_upload_type
|
| 1164 |
-
None, # pdf_file
|
| 1165 |
-
None, # pdf_folder
|
| 1166 |
-
"", # pdf_description
|
| 1167 |
-
"Image (visual documents)", # pdf_mode
|
| 1168 |
-
"Upload File(s)", # image_upload_type
|
| 1169 |
-
None, # image_file
|
| 1170 |
-
None, # image_folder
|
| 1171 |
-
"", # image_description
|
| 1172 |
-
None, # task_mode
|
| 1173 |
-
gr.update(variant="secondary"), # manual_btn - reset to unselected
|
| 1174 |
-
gr.update(variant="secondary"), # auto_extract_btn - reset to unselected
|
| 1175 |
-
]
|
| 1176 |
-
# Reset category inputs
|
| 1177 |
-
for i in range(MAX_CATEGORIES):
|
| 1178 |
-
updates.append(gr.update(value="", visible=(i < INITIAL_CATEGORIES)))
|
| 1179 |
-
updates.extend([
|
| 1180 |
-
gr.update(visible=True), # add_category_btn
|
| 1181 |
-
INITIAL_CATEGORIES, # category_count
|
| 1182 |
-
gr.update(visible=False), # auto_extract_group
|
| 1183 |
-
12, # max_categories (reset to default)
|
| 1184 |
-
"", # auto_extract_status
|
| 1185 |
-
gr.update(visible=False), # categories_group
|
| 1186 |
-
gr.update(visible=False), # model_group
|
| 1187 |
-
gr.update(visible=False, value="Classify Data"), # run_btn
|
| 1188 |
-
"Free Models", # model_tier
|
| 1189 |
-
FREE_MODEL_CHOICES[0], # model
|
| 1190 |
-
"auto", # model_source
|
| 1191 |
-
"", # api_key
|
| 1192 |
-
gr.update(visible=False), # api_key
|
| 1193 |
-
"**Free tier** - no API key required!", # api_key_status
|
| 1194 |
-
"Ready. Upload data and select a task.", # status
|
| 1195 |
-
gr.update(visible=False), # classify_output_group
|
| 1196 |
-
gr.update(value=None, visible=False), # distribution_plot
|
| 1197 |
-
gr.update(value=None, visible=False), # results
|
| 1198 |
-
None, # download_file
|
| 1199 |
-
"# Code will be generated after classification", # classify_code_display
|
| 1200 |
-
])
|
| 1201 |
-
return updates
|
| 1202 |
-
|
| 1203 |
-
|
| 1204 |
-
custom_css = """
|
| 1205 |
-
* {
|
| 1206 |
-
font-family: Helvetica, Arial, sans-serif !important;
|
| 1207 |
-
}
|
| 1208 |
-
|
| 1209 |
-
.task-btn {
|
| 1210 |
-
min-width: 150px !important;
|
| 1211 |
-
}
|
| 1212 |
-
|
| 1213 |
-
/* Expandable plot styles */
|
| 1214 |
-
.expandable-plot {
|
| 1215 |
-
min-height: 400px !important;
|
| 1216 |
-
cursor: pointer;
|
| 1217 |
-
transition: transform 0.3s ease;
|
| 1218 |
-
}
|
| 1219 |
-
|
| 1220 |
-
.expandable-plot:hover {
|
| 1221 |
-
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.15);
|
| 1222 |
-
}
|
| 1223 |
-
|
| 1224 |
-
.expandable-plot img {
|
| 1225 |
-
min-height: 350px !important;
|
| 1226 |
-
object-fit: contain;
|
| 1227 |
-
}
|
| 1228 |
-
|
| 1229 |
-
/* Fullscreen modal for plot */
|
| 1230 |
-
.expandable-plot.fullscreen {
|
| 1231 |
-
position: fixed !important;
|
| 1232 |
-
top: 0 !important;
|
| 1233 |
-
left: 0 !important;
|
| 1234 |
-
width: 100vw !important;
|
| 1235 |
-
height: 100vh !important;
|
| 1236 |
-
z-index: 9999 !important;
|
| 1237 |
-
background: rgba(255, 255, 255, 0.98) !important;
|
| 1238 |
-
padding: 20px !important;
|
| 1239 |
-
margin: 0 !important;
|
| 1240 |
-
display: flex !important;
|
| 1241 |
-
align-items: center !important;
|
| 1242 |
-
justify-content: center !important;
|
| 1243 |
-
}
|
| 1244 |
-
|
| 1245 |
-
.expandable-plot.fullscreen img {
|
| 1246 |
-
max-width: 95vw !important;
|
| 1247 |
-
max-height: 90vh !important;
|
| 1248 |
-
min-height: auto !important;
|
| 1249 |
-
}
|
| 1250 |
-
|
| 1251 |
-
/* ===========================================
|
| 1252 |
-
TAB FIXES
|
| 1253 |
-
=========================================== */
|
| 1254 |
-
/* Hide the overflow tab buttons that Gradio creates */
|
| 1255 |
-
.tab-container.visually-hidden {
|
| 1256 |
-
display: none !important;
|
| 1257 |
-
}
|
| 1258 |
-
|
| 1259 |
-
/* Ensure tablist is visible and tabs wrap properly */
|
| 1260 |
-
[role="tablist"] {
|
| 1261 |
-
display: flex !important;
|
| 1262 |
-
flex-wrap: wrap !important;
|
| 1263 |
-
}
|
| 1264 |
-
|
| 1265 |
-
/* Style tab buttons for better mobile tap targets */
|
| 1266 |
-
[role="tab"] {
|
| 1267 |
-
padding: 12px 16px !important;
|
| 1268 |
-
font-size: 14px !important;
|
| 1269 |
-
flex-shrink: 0 !important;
|
| 1270 |
-
min-height: 44px !important;
|
| 1271 |
-
display: flex !important;
|
| 1272 |
-
align-items: center !important;
|
| 1273 |
-
cursor: pointer !important;
|
| 1274 |
-
-webkit-tap-highlight-color: rgba(0, 0, 0, 0.1) !important;
|
| 1275 |
-
touch-action: manipulation !important;
|
| 1276 |
-
}
|
| 1277 |
-
|
| 1278 |
-
/* Desktop-only experience - no mobile responsive CSS */
|
| 1279 |
-
/* Solution 6: Force desktop rendering with CSS transform on mobile */
|
| 1280 |
-
|
| 1281 |
-
/* Force desktop viewport behavior on narrow screens */
|
| 1282 |
-
@media screen and (max-width: 768px) {
|
| 1283 |
-
html, body {
|
| 1284 |
-
width: 1200px !important;
|
| 1285 |
-
overflow-x: scroll !important;
|
| 1286 |
-
}
|
| 1287 |
-
|
| 1288 |
-
body {
|
| 1289 |
-
transform: scale(0.325);
|
| 1290 |
-
transform-origin: top left;
|
| 1291 |
-
min-height: calc(100vh / 0.325);
|
| 1292 |
-
}
|
| 1293 |
-
|
| 1294 |
-
.gradio-container {
|
| 1295 |
-
width: 1200px !important;
|
| 1296 |
-
}
|
| 1297 |
-
}
|
| 1298 |
-
"""
|
| 1299 |
|
| 1300 |
-
|
| 1301 |
-
|
| 1302 |
-
|
| 1303 |
-
|
| 1304 |
-
|
| 1305 |
-
|
| 1306 |
-
|
| 1307 |
-
|
| 1308 |
-
|
| 1309 |
-
|
| 1310 |
-
|
| 1311 |
-
document.addEventListener('keydown', function(e) {
|
| 1312 |
-
if (e.key === 'Escape') {
|
| 1313 |
-
const fullscreenPlot = document.querySelector('.expandable-plot.fullscreen');
|
| 1314 |
-
if (fullscreenPlot) {
|
| 1315 |
-
fullscreenPlot.classList.remove('fullscreen');
|
| 1316 |
-
}
|
| 1317 |
-
}
|
| 1318 |
-
});
|
| 1319 |
-
}
|
| 1320 |
-
"""
|
| 1321 |
|
| 1322 |
-
# Force desktop viewport - 1200px width scaled down to fit mobile
|
| 1323 |
-
custom_head = """
|
| 1324 |
-
<meta name="viewport" content="width=1200, initial-scale=0.33, minimum-scale=0.33, maximum-scale=5.0, user-scalable=yes">
|
| 1325 |
-
"""
|
| 1326 |
|
| 1327 |
-
|
| 1328 |
-
|
| 1329 |
-
|
| 1330 |
-
|
|
|
|
|
|
|
| 1331 |
|
| 1332 |
-
|
| 1333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1334 |
**Privacy Notice:** Your data is sent to third-party LLM APIs for classification. Do not upload sensitive, confidential, or personally identifiable information (PII).
|
| 1335 |
|
| 1336 |
---
|
|
@@ -1359,369 +626,330 @@ Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DO
|
|
| 1359 |
```
|
| 1360 |
""")
|
| 1361 |
|
| 1362 |
-
|
| 1363 |
-
|
| 1364 |
-
task_mode = gr.State(value=None)
|
| 1365 |
-
input_type = gr.State(value="Survey Responses") # Track selected input type
|
| 1366 |
-
|
| 1367 |
-
with gr.Row():
|
| 1368 |
-
with gr.Column():
|
| 1369 |
-
# Input type selector using Tabs (more reliable on mobile than Radio + visibility)
|
| 1370 |
-
with gr.Tabs() as input_tabs:
|
| 1371 |
-
with gr.TabItem("Survey Responses", id="survey"):
|
| 1372 |
-
spreadsheet_file = gr.File(
|
| 1373 |
-
label="Upload Data (CSV or Excel)",
|
| 1374 |
-
file_types=[".csv", ".xlsx", ".xls"]
|
| 1375 |
-
)
|
| 1376 |
-
example_btn = gr.Button("Try Example Dataset", variant="secondary", size="sm")
|
| 1377 |
-
spreadsheet_column = gr.Dropdown(
|
| 1378 |
-
label="Column to Process",
|
| 1379 |
-
choices=[],
|
| 1380 |
-
info="Select the column containing text"
|
| 1381 |
-
)
|
| 1382 |
|
| 1383 |
-
|
| 1384 |
-
|
| 1385 |
-
|
| 1386 |
-
value="Upload File(s)",
|
| 1387 |
-
label="Upload Type"
|
| 1388 |
-
)
|
| 1389 |
-
pdf_file = gr.File(
|
| 1390 |
-
label="Upload PDF Document(s)",
|
| 1391 |
-
file_types=[".pdf"],
|
| 1392 |
-
file_count="multiple"
|
| 1393 |
-
)
|
| 1394 |
-
pdf_folder = gr.File(
|
| 1395 |
-
label="Upload PDF Folder",
|
| 1396 |
-
file_count="directory",
|
| 1397 |
-
visible=False
|
| 1398 |
-
)
|
| 1399 |
-
pdf_description = gr.Textbox(
|
| 1400 |
-
label="Document Description",
|
| 1401 |
-
placeholder="e.g., 'research papers', 'interview transcripts'",
|
| 1402 |
-
info="Helps the LLM understand context"
|
| 1403 |
-
)
|
| 1404 |
-
pdf_mode = gr.Radio(
|
| 1405 |
-
choices=["Image (visual documents)", "Text (text-heavy)", "Both (comprehensive)"],
|
| 1406 |
-
value="Image (visual documents)",
|
| 1407 |
-
label="Processing Mode"
|
| 1408 |
-
)
|
| 1409 |
|
| 1410 |
-
|
| 1411 |
-
|
| 1412 |
-
|
| 1413 |
-
|
| 1414 |
-
|
| 1415 |
-
|
| 1416 |
-
image_file = gr.File(
|
| 1417 |
-
label="Upload Images",
|
| 1418 |
-
file_types=["image"],
|
| 1419 |
-
file_count="multiple"
|
| 1420 |
-
)
|
| 1421 |
-
image_folder = gr.File(
|
| 1422 |
-
label="Upload Image Folder",
|
| 1423 |
-
file_count="directory",
|
| 1424 |
-
visible=False
|
| 1425 |
-
)
|
| 1426 |
-
image_description = gr.Textbox(
|
| 1427 |
-
label="Image Description",
|
| 1428 |
-
placeholder="e.g., 'product photos', 'social media posts'",
|
| 1429 |
-
info="Helps the LLM understand context"
|
| 1430 |
-
)
|
| 1431 |
|
| 1432 |
-
|
| 1433 |
-
|
| 1434 |
-
|
| 1435 |
-
|
| 1436 |
-
|
| 1437 |
-
|
| 1438 |
-
|
| 1439 |
-
|
| 1440 |
-
|
| 1441 |
-
|
| 1442 |
-
|
| 1443 |
-
|
| 1444 |
-
|
| 1445 |
-
|
| 1446 |
-
|
| 1447 |
-
|
| 1448 |
-
|
| 1449 |
-
)
|
| 1450 |
-
|
| 1451 |
-
|
| 1452 |
-
|
| 1453 |
-
|
| 1454 |
-
|
| 1455 |
-
|
| 1456 |
-
|
| 1457 |
-
|
| 1458 |
-
|
| 1459 |
-
|
| 1460 |
-
|
| 1461 |
-
|
| 1462 |
-
|
| 1463 |
-
|
| 1464 |
-
|
| 1465 |
-
|
| 1466 |
-
for i in range(MAX_CATEGORIES):
|
| 1467 |
-
visible = i < INITIAL_CATEGORIES
|
| 1468 |
-
placeholder = placeholder_examples[i] if i < len(placeholder_examples) else "e.g., Custom category"
|
| 1469 |
-
cat_input = gr.Textbox(
|
| 1470 |
-
label=f"Category {i+1}",
|
| 1471 |
-
placeholder=placeholder,
|
| 1472 |
-
visible=visible
|
| 1473 |
-
)
|
| 1474 |
-
category_inputs.append(cat_input)
|
| 1475 |
-
add_category_btn = gr.Button("+ Add More", variant="secondary", size="sm")
|
| 1476 |
-
|
| 1477 |
-
# Model selection group
|
| 1478 |
-
with gr.Group(visible=False) as model_group:
|
| 1479 |
-
gr.Markdown("### Model")
|
| 1480 |
-
model_tier = gr.Radio(
|
| 1481 |
-
choices=["Free Models", "Bring Your Own Key"],
|
| 1482 |
-
value="Free Models",
|
| 1483 |
-
label="Model Tier"
|
| 1484 |
-
)
|
| 1485 |
-
model = gr.Dropdown(
|
| 1486 |
-
choices=FREE_MODEL_CHOICES,
|
| 1487 |
-
value="Qwen/Qwen3-VL-235B-A22B-Instruct:novita",
|
| 1488 |
-
label="Model",
|
| 1489 |
-
allow_custom_value=False, # Only allow custom for "Bring Your Own Key"
|
| 1490 |
-
interactive=True
|
| 1491 |
-
)
|
| 1492 |
-
model_source = gr.Dropdown(
|
| 1493 |
-
choices=["auto", "openai", "anthropic", "google", "mistral", "xai", "huggingface", "perplexity"],
|
| 1494 |
-
value="auto",
|
| 1495 |
-
label="Model Source",
|
| 1496 |
-
visible=False # Hide for free tier, show for BYOK
|
| 1497 |
-
)
|
| 1498 |
-
api_key = gr.Textbox(
|
| 1499 |
-
label="API Key",
|
| 1500 |
-
type="password",
|
| 1501 |
-
placeholder="Enter your API key",
|
| 1502 |
-
visible=False
|
| 1503 |
-
)
|
| 1504 |
-
api_key_status = gr.Markdown("**Free tier** - no API key required!")
|
| 1505 |
-
|
| 1506 |
-
# Run button
|
| 1507 |
-
run_btn = gr.Button("Run", variant="primary", size="lg", visible=False)
|
| 1508 |
-
reset_btn = gr.Button("Reset", variant="stop")
|
| 1509 |
-
|
| 1510 |
-
with gr.Column():
|
| 1511 |
-
status = gr.Markdown("Ready. Upload data and select a task.")
|
| 1512 |
-
|
| 1513 |
-
# Classify output group
|
| 1514 |
-
with gr.Group(visible=False) as classify_output_group:
|
| 1515 |
-
gr.Markdown("### Classification Results")
|
| 1516 |
-
distribution_plot = gr.Plot(label="Category Distribution (%)", visible=False, elem_classes="expandable-plot")
|
| 1517 |
-
results = gr.DataFrame(label="Full Results", visible=False)
|
| 1518 |
-
download_file = gr.File(label="Download Results (CSV + Methodology Report)", file_count="multiple")
|
| 1519 |
-
with gr.Accordion("See the Code", open=False):
|
| 1520 |
-
classify_code_display = gr.Code(
|
| 1521 |
-
label="Python Code",
|
| 1522 |
-
language="python",
|
| 1523 |
-
value="# Code will be generated after classification",
|
| 1524 |
-
interactive=False
|
| 1525 |
-
)
|
| 1526 |
|
| 1527 |
-
|
| 1528 |
-
|
| 1529 |
-
|
| 1530 |
-
|
| 1531 |
-
"survey": "Survey Responses",
|
| 1532 |
-
"pdf": "PDF Documents",
|
| 1533 |
-
"images": "Images"
|
| 1534 |
-
}
|
| 1535 |
-
input_type_val = tab_mapping.get(tab_id, "Survey Responses")
|
| 1536 |
-
return input_type_val, f"Ready to process {input_type_val.lower()}."
|
| 1537 |
-
|
| 1538 |
-
input_tabs.change(
|
| 1539 |
-
fn=on_tab_change,
|
| 1540 |
-
inputs=[input_tabs],
|
| 1541 |
-
outputs=[input_type, status]
|
| 1542 |
-
)
|
| 1543 |
-
|
| 1544 |
-
def update_model_tier(tier):
|
| 1545 |
-
if tier == "Free Models":
|
| 1546 |
-
return (
|
| 1547 |
-
gr.update(choices=FREE_MODEL_CHOICES, value=FREE_MODEL_CHOICES[0], allow_custom_value=False),
|
| 1548 |
-
gr.update(visible=False), # model_source hidden for free
|
| 1549 |
-
gr.update(visible=False), # api_key hidden for free
|
| 1550 |
-
"**Free tier** - no API key required!"
|
| 1551 |
-
)
|
| 1552 |
-
else:
|
| 1553 |
-
return (
|
| 1554 |
-
gr.update(choices=PAID_MODEL_CHOICES, value=PAID_MODEL_CHOICES[0], allow_custom_value=True),
|
| 1555 |
-
gr.update(visible=True), # model_source shown for BYOK
|
| 1556 |
-
gr.update(visible=True), # api_key shown for BYOK
|
| 1557 |
-
"**Bring Your Own Key** - enter your API key below."
|
| 1558 |
-
)
|
| 1559 |
|
| 1560 |
-
|
| 1561 |
-
|
| 1562 |
-
|
| 1563 |
-
|
| 1564 |
-
|
| 1565 |
-
|
| 1566 |
-
|
| 1567 |
-
|
| 1568 |
-
|
| 1569 |
-
|
| 1570 |
-
|
| 1571 |
-
|
| 1572 |
-
|
| 1573 |
-
|
| 1574 |
-
|
| 1575 |
-
|
| 1576 |
-
|
| 1577 |
-
|
| 1578 |
-
|
| 1579 |
-
|
| 1580 |
-
|
| 1581 |
-
|
| 1582 |
-
|
| 1583 |
-
|
| 1584 |
-
|
| 1585 |
-
|
| 1586 |
-
|
| 1587 |
-
|
| 1588 |
-
|
| 1589 |
-
|
| 1590 |
-
|
| 1591 |
-
|
| 1592 |
-
|
| 1593 |
-
|
| 1594 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1595 |
else:
|
| 1596 |
-
|
| 1597 |
-
|
| 1598 |
-
|
| 1599 |
-
|
| 1600 |
-
|
| 1601 |
-
|
| 1602 |
-
|
| 1603 |
-
|
| 1604 |
-
|
| 1605 |
-
|
| 1606 |
-
|
| 1607 |
-
|
| 1608 |
-
|
| 1609 |
-
|
| 1610 |
-
|
| 1611 |
-
|
| 1612 |
-
|
| 1613 |
-
|
| 1614 |
-
|
| 1615 |
-
|
| 1616 |
-
|
| 1617 |
-
|
| 1618 |
-
|
| 1619 |
-
|
| 1620 |
-
|
| 1621 |
-
|
| 1622 |
-
|
| 1623 |
-
|
| 1624 |
-
|
| 1625 |
-
|
| 1626 |
-
|
| 1627 |
-
|
| 1628 |
-
|
| 1629 |
-
|
| 1630 |
-
|
| 1631 |
-
|
| 1632 |
-
|
| 1633 |
-
|
| 1634 |
-
|
| 1635 |
-
|
| 1636 |
-
|
| 1637 |
-
|
| 1638 |
-
|
| 1639 |
-
|
| 1640 |
-
|
| 1641 |
-
|
| 1642 |
-
|
| 1643 |
-
|
| 1644 |
-
|
| 1645 |
-
progress=gr.Progress(track_tqdm=True)):
|
| 1646 |
-
"""Extract categories and then show the categories/model/run sections."""
|
| 1647 |
-
# Run the extraction
|
| 1648 |
-
result = run_auto_extract(input_type, spreadsheet_file, spreadsheet_column,
|
| 1649 |
-
pdf_file, pdf_folder, pdf_description, pdf_mode,
|
| 1650 |
-
image_file, image_folder, image_description,
|
| 1651 |
-
max_categories_val, model_tier, model, model_source, api_key,
|
| 1652 |
-
progress)
|
| 1653 |
-
# result is: category_inputs updates + [category_count, status]
|
| 1654 |
-
# Add visibility updates for categories_group, model_group, run_btn
|
| 1655 |
-
return result + [
|
| 1656 |
-
gr.update(visible=True), # categories_group
|
| 1657 |
-
gr.update(visible=True), # model_group
|
| 1658 |
-
gr.update(visible=True, value="Classify Data"), # run_btn
|
| 1659 |
]
|
| 1660 |
|
| 1661 |
-
|
| 1662 |
-
|
| 1663 |
-
|
| 1664 |
-
|
| 1665 |
-
|
| 1666 |
-
|
| 1667 |
-
|
| 1668 |
-
|
| 1669 |
-
|
| 1670 |
-
|
| 1671 |
-
|
| 1672 |
-
|
| 1673 |
-
|
| 1674 |
-
|
| 1675 |
-
|
| 1676 |
-
|
| 1677 |
-
|
| 1678 |
-
|
| 1679 |
-
|
| 1680 |
-
|
| 1681 |
-
|
| 1682 |
-
|
| 1683 |
-
|
| 1684 |
-
|
| 1685 |
-
|
| 1686 |
-
|
| 1687 |
-
)
|
| 1688 |
-
|
| 1689 |
-
|
| 1690 |
-
update[1], # results
|
| 1691 |
-
update[2], # download_file
|
| 1692 |
-
update[3], # classify_code_display
|
| 1693 |
-
update[5] # status
|
| 1694 |
-
)
|
| 1695 |
else:
|
| 1696 |
-
|
| 1697 |
-
|
| 1698 |
-
|
| 1699 |
-
|
| 1700 |
-
|
| 1701 |
-
|
| 1702 |
-
|
| 1703 |
-
|
| 1704 |
-
|
| 1705 |
-
|
| 1706 |
-
|
| 1707 |
-
|
| 1708 |
-
|
| 1709 |
-
|
| 1710 |
-
|
| 1711 |
-
|
| 1712 |
-
|
| 1713 |
-
|
| 1714 |
-
|
| 1715 |
-
|
| 1716 |
-
|
| 1717 |
-
|
| 1718 |
-
|
| 1719 |
-
|
| 1720 |
-
status,
|
| 1721 |
-
classify_output_group, distribution_plot, results, download_file, classify_code_display
|
| 1722 |
-
]
|
| 1723 |
-
)
|
| 1724 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1725 |
|
| 1726 |
-
|
| 1727 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
"""
|
| 2 |
+
Streamlit app - CatLLM Survey Response Classifier
|
| 3 |
+
Migrated from Gradio for better mobile support
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
import streamlit as st
|
| 7 |
import pandas as pd
|
| 8 |
import tempfile
|
| 9 |
import os
|
|
|
|
| 22 |
|
| 23 |
MAX_CATEGORIES = 10
|
| 24 |
INITIAL_CATEGORIES = 3
|
| 25 |
+
MAX_FILE_SIZE_MB = 100
|
| 26 |
+
|
| 27 |
+
# Free models (uses Space secrets - no user API key needed)
|
| 28 |
+
FREE_MODEL_CHOICES = [
|
| 29 |
+
"Qwen/Qwen3-VL-235B-A22B-Instruct:novita",
|
| 30 |
+
"deepseek-ai/DeepSeek-V3.1:novita",
|
| 31 |
+
"meta-llama/Llama-3.3-70B-Instruct:groq",
|
| 32 |
+
"gemini-2.5-flash",
|
| 33 |
+
"gpt-4o",
|
| 34 |
+
"mistral-medium-2505",
|
| 35 |
+
"claude-3-haiku-20240307",
|
| 36 |
+
"grok-4-fast-non-reasoning",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
# Paid models (user provides their own API key)
|
| 40 |
+
PAID_MODEL_CHOICES = [
|
| 41 |
+
"gpt-4.1",
|
| 42 |
+
"gpt-4o",
|
| 43 |
+
"gpt-4o-mini",
|
| 44 |
+
"claude-sonnet-4-5-20250929",
|
| 45 |
+
"claude-opus-4-20250514",
|
| 46 |
+
"claude-3-5-haiku-20241022",
|
| 47 |
+
"gemini-2.5-pro",
|
| 48 |
+
"gemini-2.5-flash",
|
| 49 |
+
"mistral-large-latest",
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
# Models routed through HuggingFace
|
| 53 |
+
HF_ROUTED_MODELS = [
|
| 54 |
+
"Qwen/Qwen3-VL-235B-A22B-Instruct:novita",
|
| 55 |
+
"deepseek-ai/DeepSeek-V3.1:novita",
|
| 56 |
+
"meta-llama/Llama-3.3-70B-Instruct:groq",
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def is_free_model(model, model_tier):
|
| 61 |
+
"""Check if using free tier (Space pays for API)."""
|
| 62 |
+
return model_tier == "Free Models"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_model_source(model):
|
| 66 |
+
"""Auto-detect model source."""
|
| 67 |
+
model_lower = model.lower()
|
| 68 |
+
if "gpt" in model_lower:
|
| 69 |
+
return "openai"
|
| 70 |
+
elif "claude" in model_lower:
|
| 71 |
+
return "anthropic"
|
| 72 |
+
elif "gemini" in model_lower:
|
| 73 |
+
return "google"
|
| 74 |
+
elif "mistral" in model_lower and ":novita" not in model_lower:
|
| 75 |
+
return "mistral"
|
| 76 |
+
elif any(x in model_lower for x in [":novita", ":groq", "qwen", "llama", "deepseek"]):
|
| 77 |
+
return "huggingface"
|
| 78 |
+
elif "sonar" in model_lower:
|
| 79 |
+
return "perplexity"
|
| 80 |
+
elif "grok" in model_lower:
|
| 81 |
+
return "xai"
|
| 82 |
+
return "huggingface"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def get_api_key(model, model_tier, api_key_input):
|
| 86 |
+
"""Get the appropriate API key based on model and tier."""
|
| 87 |
+
if is_free_model(model, model_tier):
|
| 88 |
+
if model in HF_ROUTED_MODELS:
|
| 89 |
+
return os.environ.get("HF_API_KEY", ""), "HuggingFace"
|
| 90 |
+
elif "gpt" in model.lower():
|
| 91 |
+
return os.environ.get("OPENAI_API_KEY", ""), "OpenAI"
|
| 92 |
+
elif "gemini" in model.lower():
|
| 93 |
+
return os.environ.get("GOOGLE_API_KEY", ""), "Google"
|
| 94 |
+
elif "mistral" in model.lower():
|
| 95 |
+
return os.environ.get("MISTRAL_API_KEY", ""), "Mistral"
|
| 96 |
+
elif "claude" in model.lower():
|
| 97 |
+
return os.environ.get("ANTHROPIC_API_KEY", ""), "Anthropic"
|
| 98 |
+
elif "sonar" in model.lower():
|
| 99 |
+
return os.environ.get("PERPLEXITY_API_KEY", ""), "Perplexity"
|
| 100 |
+
elif "grok" in model.lower():
|
| 101 |
+
return os.environ.get("XAI_API_KEY", ""), "xAI"
|
| 102 |
+
else:
|
| 103 |
+
return os.environ.get("HF_API_KEY", ""), "HuggingFace"
|
| 104 |
+
else:
|
| 105 |
+
if api_key_input and api_key_input.strip():
|
| 106 |
+
return api_key_input.strip(), "User"
|
| 107 |
+
return "", "User"
|
| 108 |
|
| 109 |
|
| 110 |
def calculate_total_file_size(files):
|
|
|
|
| 117 |
total_bytes = 0
|
| 118 |
for f in files:
|
| 119 |
try:
|
| 120 |
+
if hasattr(f, 'size'):
|
| 121 |
+
total_bytes += f.size
|
| 122 |
+
elif hasattr(f, 'name'):
|
| 123 |
+
total_bytes += os.path.getsize(f.name)
|
| 124 |
except (OSError, AttributeError):
|
| 125 |
pass
|
| 126 |
+
return total_bytes / (1024 * 1024)
|
| 127 |
|
| 128 |
|
| 129 |
def generate_extract_code(input_type, description, model, model_source, max_categories, mode=None):
|
|
|
|
| 156 |
|
| 157 |
# Extract categories from PDF documents
|
| 158 |
result = catllm.extract(
|
| 159 |
+
input_data="path/to/your/pdfs/",
|
| 160 |
api_key="YOUR_API_KEY",
|
| 161 |
input_type="pdf",
|
| 162 |
description="{description}"{mode_line},
|
|
|
|
| 174 |
|
| 175 |
# Extract categories from images
|
| 176 |
result = catllm.extract(
|
| 177 |
+
input_data="path/to/your/images/",
|
| 178 |
api_key="YOUR_API_KEY",
|
| 179 |
input_type="image",
|
| 180 |
description="{description}",
|
|
|
|
| 231 |
|
| 232 |
# Classify PDF documents
|
| 233 |
result = catllm.classify(
|
| 234 |
+
input_data="path/to/your/pdfs/",
|
| 235 |
categories=categories,
|
| 236 |
api_key="YOUR_API_KEY",
|
| 237 |
input_type="pdf",
|
|
|
|
| 254 |
|
| 255 |
# Classify images
|
| 256 |
result = catllm.classify(
|
| 257 |
+
input_data="path/to/your/images/",
|
| 258 |
categories=categories,
|
| 259 |
api_key="YOUR_API_KEY",
|
| 260 |
input_type="image",
|
|
|
|
| 268 |
result.to_csv("classified_results.csv", index=False)
|
| 269 |
'''
|
| 270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
def generate_methodology_report_pdf(categories, model, column_name, num_rows, model_source, filename, success_rate,
|
| 273 |
result_df=None, processing_time=None, prompt_template=None,
|
| 274 |
data_quality=None, catllm_version=None, python_version=None,
|
| 275 |
task_type="assign", extracted_categories_df=None, max_categories=None,
|
| 276 |
input_type="text", description=None):
|
| 277 |
+
"""Generate a PDF methodology report."""
|
|
|
|
|
|
|
|
|
|
| 278 |
from reportlab.lib.pagesizes import letter
|
| 279 |
from reportlab.lib import colors
|
| 280 |
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
|
|
|
| 291 |
|
| 292 |
story = []
|
| 293 |
|
|
|
|
| 294 |
if task_type == "extract_and_assign":
|
| 295 |
report_title = "CatLLM Extraction & Classification Report"
|
| 296 |
else:
|
|
|
|
| 304 |
|
| 305 |
if task_type == "extract_and_assign":
|
| 306 |
about_text = """This methodology report documents the automated category extraction and classification process. \
|
| 307 |
+
CatLLM first discovers categories from your data using LLMs, then classifies each item into those categories."""
|
|
|
|
|
|
|
| 308 |
else:
|
| 309 |
about_text = """This methodology report documents the classification process for reproducibility and transparency. \
|
| 310 |
+
CatLLM restricts the prompt to a standard template that is impartial to the researcher's inclinations, ensuring \
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
consistent and reproducible results."""
|
| 312 |
|
| 313 |
story.append(Paragraph(about_text, normal_style))
|
| 314 |
story.append(Spacer(1, 15))
|
| 315 |
|
| 316 |
+
if categories:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
story.append(Paragraph("Category Mapping", heading_style))
|
| 318 |
story.append(Paragraph("Each category column contains binary values: 1 = present, 0 = not present", normal_style))
|
| 319 |
story.append(Spacer(1, 8))
|
| 320 |
|
|
|
|
|
|
|
| 321 |
category_data = [["Column Name", "Category Description"]]
|
| 322 |
for i, cat in enumerate(categories, 1):
|
| 323 |
category_data.append([f"category_{i}", cat])
|
|
|
|
| 334 |
story.append(cat_table)
|
| 335 |
story.append(Spacer(1, 15))
|
| 336 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 337 |
story.append(Spacer(1, 30))
|
| 338 |
story.append(Paragraph("Citation", heading_style))
|
| 339 |
story.append(Paragraph("If you use CatLLM in your research, please cite:", normal_style))
|
| 340 |
story.append(Spacer(1, 5))
|
| 341 |
story.append(Paragraph("Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DOI: 10.5281/zenodo.15532316", normal_style))
|
| 342 |
|
| 343 |
+
# Summary section
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 344 |
story.append(PageBreak())
|
| 345 |
+
story.append(Paragraph("Classification Summary", title_style))
|
|
|
|
|
|
|
|
|
|
| 346 |
story.append(Spacer(1, 15))
|
| 347 |
|
| 348 |
+
summary_data = [
|
| 349 |
+
["Source File", filename],
|
| 350 |
+
["Source Column", column_name],
|
| 351 |
+
["Model Used", model],
|
| 352 |
+
["Model Source", model_source],
|
| 353 |
+
["Rows Classified", str(num_rows)],
|
| 354 |
+
["Number of Categories", str(len(categories)) if categories else "0"],
|
| 355 |
+
["Success Rate", f"{success_rate:.2f}%"],
|
| 356 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
summary_table = Table(summary_data, colWidths=[150, 300])
|
| 358 |
summary_table.setStyle(TableStyle([
|
| 359 |
('BACKGROUND', (0, 0), (0, -1), colors.lightgrey),
|
|
|
|
| 382 |
('FONTSIZE', (0, 0), (-1, -1), 9),
|
| 383 |
]))
|
| 384 |
story.append(time_table)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 385 |
|
| 386 |
+
story.append(Spacer(1, 15))
|
| 387 |
story.append(Paragraph("Version Information", heading_style))
|
| 388 |
version_data = [
|
| 389 |
["CatLLM Version", catllm_version or "unknown"],
|
|
|
|
| 399 |
]))
|
| 400 |
story.append(version_table)
|
| 401 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 402 |
doc.build(story)
|
| 403 |
return pdf_file.name
|
| 404 |
|
| 405 |
|
| 406 |
+
def run_auto_extract(input_type, input_data, description, max_categories_val,
|
| 407 |
+
model_tier, model, api_key_input, mode=None, progress_callback=None):
|
| 408 |
+
"""Extract categories from data."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
if not CATLLM_AVAILABLE:
|
| 410 |
+
return None, "catllm package not available"
|
|
|
|
| 411 |
|
| 412 |
actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
|
| 413 |
if not actual_api_key:
|
| 414 |
+
return None, f"{provider} API key not configured"
|
| 415 |
|
| 416 |
+
model_source = get_model_source(model)
|
|
|
|
|
|
|
|
|
|
| 417 |
|
| 418 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
if isinstance(input_data, list):
|
| 420 |
num_items = len(input_data)
|
| 421 |
else:
|
| 422 |
num_items = 1
|
| 423 |
|
| 424 |
+
if input_type == "image":
|
| 425 |
divisions = min(3, max(1, num_items // 5))
|
| 426 |
categories_per_chunk = 12
|
| 427 |
else:
|
| 428 |
+
divisions = max(1, num_items // 15)
|
| 429 |
+
divisions = min(divisions, 5)
|
|
|
|
|
|
|
| 430 |
chunk_size = num_items // max(1, divisions)
|
| 431 |
+
categories_per_chunk = min(10, chunk_size - 1)
|
| 432 |
|
|
|
|
| 433 |
extract_kwargs = {
|
| 434 |
'input_data': input_data,
|
| 435 |
'api_key': actual_api_key,
|
| 436 |
+
'input_type': input_type,
|
| 437 |
'description': description,
|
| 438 |
'user_model': model,
|
| 439 |
'model_source': model_source,
|
|
|
|
| 441 |
'categories_per_chunk': categories_per_chunk,
|
| 442 |
'max_categories': int(max_categories_val)
|
| 443 |
}
|
| 444 |
+
if mode:
|
| 445 |
+
extract_kwargs['mode'] = mode
|
| 446 |
|
| 447 |
extract_result = catllm.extract(**extract_kwargs)
|
| 448 |
categories = extract_result.get('top_categories', [])
|
| 449 |
|
| 450 |
if not categories:
|
| 451 |
+
return None, "No categories were extracted"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
+
return categories, f"Extracted {len(categories)} categories successfully!"
|
|
|
|
| 454 |
|
| 455 |
except Exception as e:
|
| 456 |
+
return None, f"Error: {str(e)}"
|
| 457 |
|
| 458 |
|
| 459 |
+
def run_classify_data(input_type, input_data, description, categories,
|
| 460 |
+
model_tier, model, api_key_input, mode=None,
|
| 461 |
+
original_filename="data", column_name="text",
|
| 462 |
+
progress_callback=None):
|
|
|
|
|
|
|
| 463 |
"""Classify data with user-provided categories."""
|
| 464 |
if not CATLLM_AVAILABLE:
|
| 465 |
+
return None, None, None, None, "catllm package not available"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
|
| 467 |
if not categories:
|
| 468 |
+
return None, None, None, None, "Please enter at least one category"
|
|
|
|
| 469 |
|
| 470 |
actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
|
| 471 |
if not actual_api_key:
|
| 472 |
+
return None, None, None, None, f"{provider} API key not configured"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
|
| 474 |
+
model_source = get_model_source(model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
|
| 476 |
try:
|
|
|
|
|
|
|
| 477 |
start_time = time.time()
|
| 478 |
|
| 479 |
+
classify_kwargs = {
|
| 480 |
+
'input_data': input_data,
|
| 481 |
+
'categories': categories,
|
| 482 |
+
'api_key': actual_api_key,
|
| 483 |
+
'input_type': input_type,
|
| 484 |
+
'description': description,
|
| 485 |
+
'user_model': model,
|
| 486 |
+
'model_source': model_source
|
| 487 |
+
}
|
| 488 |
+
if mode:
|
| 489 |
+
classify_kwargs['mode'] = mode
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 490 |
|
| 491 |
+
result = catllm.classify(**classify_kwargs)
|
|
|
|
|
|
|
| 492 |
|
| 493 |
processing_time = time.time() - start_time
|
| 494 |
num_items = len(result)
|
|
|
|
| 513 |
python_version = sys.version.split()[0]
|
| 514 |
|
| 515 |
# Generate methodology report
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 516 |
report_pdf_path = generate_methodology_report_pdf(
|
| 517 |
categories=categories,
|
| 518 |
model=model,
|
|
|
|
| 523 |
success_rate=success_rate,
|
| 524 |
result_df=result,
|
| 525 |
processing_time=processing_time,
|
|
|
|
|
|
|
| 526 |
catllm_version=catllm_version,
|
| 527 |
python_version=python_version,
|
| 528 |
task_type="assign",
|
| 529 |
+
input_type=input_type,
|
| 530 |
+
description=description
|
| 531 |
)
|
| 532 |
|
| 533 |
+
# Generate reproducibility code
|
| 534 |
+
code = generate_classify_code(input_type, description, categories, model, model_source, mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 535 |
|
| 536 |
+
return result, csv_path, report_pdf_path, code, f"Classified {num_items} items in {processing_time:.1f}s"
|
|
|
|
|
|
|
| 537 |
|
| 538 |
+
except Exception as e:
|
| 539 |
+
return None, None, None, None, f"Error: {str(e)}"
|
|
|
|
|
|
|
| 540 |
|
|
|
|
|
|
|
|
|
|
| 541 |
|
| 542 |
+
def create_distribution_chart(result_df, categories):
|
| 543 |
+
"""Create a bar chart showing category distribution."""
|
| 544 |
+
fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8)))
|
| 545 |
|
| 546 |
+
dist_data = []
|
| 547 |
+
total_rows = len(result_df)
|
| 548 |
+
for i, cat in enumerate(categories, 1):
|
| 549 |
+
col_name = f"category_{i}"
|
| 550 |
+
if col_name in result_df.columns:
|
| 551 |
+
count = int(result_df[col_name].sum())
|
| 552 |
+
pct = (count / total_rows) * 100 if total_rows > 0 else 0
|
| 553 |
+
dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
+
categories_list = [d["Category"] for d in dist_data][::-1]
|
| 556 |
+
percentages = [d["Percentage"] for d in dist_data][::-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
|
| 558 |
+
bars = ax.barh(categories_list, percentages, color='#2563eb')
|
| 559 |
+
ax.set_xlim(0, 100)
|
| 560 |
+
ax.set_xlabel('Percentage (%)', fontsize=11)
|
| 561 |
+
ax.set_title('Category Distribution (%)', fontsize=14, fontweight='bold')
|
| 562 |
+
|
| 563 |
+
for bar, pct in zip(bars, percentages):
|
| 564 |
+
ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
|
| 565 |
+
f'{pct:.1f}%', va='center', fontsize=10)
|
| 566 |
+
|
| 567 |
+
plt.tight_layout()
|
| 568 |
+
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 570 |
|
| 571 |
+
# Page config
|
| 572 |
+
st.set_page_config(
|
| 573 |
+
page_title="CatLLM - Research Data Classifier",
|
| 574 |
+
page_icon="logo.png",
|
| 575 |
+
layout="wide"
|
| 576 |
+
)
|
| 577 |
|
| 578 |
+
# Initialize session state
|
| 579 |
+
if 'categories' not in st.session_state:
|
| 580 |
+
st.session_state.categories = [''] * MAX_CATEGORIES
|
| 581 |
+
if 'category_count' not in st.session_state:
|
| 582 |
+
st.session_state.category_count = INITIAL_CATEGORIES
|
| 583 |
+
if 'task_mode' not in st.session_state:
|
| 584 |
+
st.session_state.task_mode = None
|
| 585 |
+
if 'extracted_categories' not in st.session_state:
|
| 586 |
+
st.session_state.extracted_categories = None
|
| 587 |
+
if 'results' not in st.session_state:
|
| 588 |
+
st.session_state.results = None
|
| 589 |
+
|
| 590 |
+
# Logo and title
|
| 591 |
+
col_logo, col_title = st.columns([1, 6])
|
| 592 |
+
with col_logo:
|
| 593 |
+
st.image("logo.png", width=100)
|
| 594 |
+
with col_title:
|
| 595 |
+
st.title("CatLLM - Research Data Classifier")
|
| 596 |
+
st.markdown("Extract categories from or classify text data, PDFs, and images using LLMs.")
|
| 597 |
+
|
| 598 |
+
# About section
|
| 599 |
+
with st.expander("About This App"):
|
| 600 |
+
st.markdown("""
|
| 601 |
**Privacy Notice:** Your data is sent to third-party LLM APIs for classification. Do not upload sensitive, confidential, or personally identifiable information (PII).
|
| 602 |
|
| 603 |
---
|
|
|
|
| 626 |
```
|
| 627 |
""")
|
| 628 |
|
| 629 |
+
# Main layout
|
| 630 |
+
col_input, col_output = st.columns([1, 1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
|
| 632 |
+
with col_input:
|
| 633 |
+
# Input type tabs
|
| 634 |
+
tab_survey, tab_pdf, tab_images = st.tabs(["Survey Responses", "PDF Documents", "Images"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
|
| 636 |
+
with tab_survey:
|
| 637 |
+
uploaded_file = st.file_uploader(
|
| 638 |
+
"Upload Data (CSV or Excel)",
|
| 639 |
+
type=['csv', 'xlsx', 'xls'],
|
| 640 |
+
key="survey_file"
|
| 641 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 642 |
|
| 643 |
+
# Load example button
|
| 644 |
+
if st.button("Try Example Dataset", key="example_btn"):
|
| 645 |
+
try:
|
| 646 |
+
st.session_state.example_loaded = True
|
| 647 |
+
except:
|
| 648 |
+
st.error("Could not load example dataset")
|
| 649 |
+
|
| 650 |
+
# Column selector
|
| 651 |
+
columns = []
|
| 652 |
+
df = None
|
| 653 |
+
if uploaded_file is not None:
|
| 654 |
+
try:
|
| 655 |
+
if uploaded_file.name.endswith('.csv'):
|
| 656 |
+
df = pd.read_csv(uploaded_file)
|
| 657 |
+
else:
|
| 658 |
+
df = pd.read_excel(uploaded_file)
|
| 659 |
+
columns = df.columns.tolist()
|
| 660 |
+
st.success(f"Loaded {len(df):,} rows")
|
| 661 |
+
except Exception as e:
|
| 662 |
+
st.error(f"Error loading file: {e}")
|
| 663 |
+
elif hasattr(st.session_state, 'example_loaded') and st.session_state.example_loaded:
|
| 664 |
+
try:
|
| 665 |
+
df = pd.read_csv("example_data.csv")
|
| 666 |
+
columns = df.columns.tolist()
|
| 667 |
+
st.success(f"Loaded example dataset ({len(df)} rows)")
|
| 668 |
+
except:
|
| 669 |
+
pass
|
| 670 |
+
|
| 671 |
+
selected_column = st.selectbox(
|
| 672 |
+
"Column to Process",
|
| 673 |
+
options=columns if columns else ["Upload a file first"],
|
| 674 |
+
disabled=not columns,
|
| 675 |
+
key="survey_column"
|
| 676 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 677 |
|
| 678 |
+
input_type_selected = "text"
|
| 679 |
+
input_data = None
|
| 680 |
+
description = selected_column if columns else ""
|
| 681 |
+
original_filename = uploaded_file.name if uploaded_file else "example_data.csv"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
|
| 683 |
+
if df is not None and columns and selected_column in columns:
|
| 684 |
+
input_data = df[selected_column].tolist()
|
| 685 |
+
|
| 686 |
+
with tab_pdf:
|
| 687 |
+
pdf_files = st.file_uploader(
|
| 688 |
+
"Upload PDF Document(s)",
|
| 689 |
+
type=['pdf'],
|
| 690 |
+
accept_multiple_files=True,
|
| 691 |
+
key="pdf_files"
|
| 692 |
+
)
|
| 693 |
+
|
| 694 |
+
pdf_description = st.text_input(
|
| 695 |
+
"Document Description",
|
| 696 |
+
placeholder="e.g., 'research papers', 'interview transcripts'",
|
| 697 |
+
help="Helps the LLM understand context",
|
| 698 |
+
key="pdf_desc"
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
pdf_mode = st.radio(
|
| 702 |
+
"Processing Mode",
|
| 703 |
+
options=["Image (visual documents)", "Text (text-heavy)", "Both (comprehensive)"],
|
| 704 |
+
key="pdf_mode"
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
input_type_selected = "pdf"
|
| 708 |
+
if pdf_files:
|
| 709 |
+
input_data = []
|
| 710 |
+
for f in pdf_files:
|
| 711 |
+
# Save to temp file
|
| 712 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
|
| 713 |
+
tmp.write(f.read())
|
| 714 |
+
input_data.append(tmp.name)
|
| 715 |
+
description = pdf_description or "document"
|
| 716 |
+
original_filename = "pdf_files"
|
| 717 |
+
st.success(f"Uploaded {len(pdf_files)} PDF file(s)")
|
| 718 |
+
|
| 719 |
+
with tab_images:
|
| 720 |
+
image_files = st.file_uploader(
|
| 721 |
+
"Upload Images",
|
| 722 |
+
type=['png', 'jpg', 'jpeg', 'gif', 'webp'],
|
| 723 |
+
accept_multiple_files=True,
|
| 724 |
+
key="image_files"
|
| 725 |
+
)
|
| 726 |
+
|
| 727 |
+
image_description = st.text_input(
|
| 728 |
+
"Image Description",
|
| 729 |
+
placeholder="e.g., 'product photos', 'social media posts'",
|
| 730 |
+
help="Helps the LLM understand context",
|
| 731 |
+
key="image_desc"
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
input_type_selected = "image"
|
| 735 |
+
if image_files:
|
| 736 |
+
input_data = []
|
| 737 |
+
for f in image_files:
|
| 738 |
+
# Save to temp file
|
| 739 |
+
suffix = '.' + f.name.split('.')[-1]
|
| 740 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
| 741 |
+
tmp.write(f.read())
|
| 742 |
+
input_data.append(tmp.name)
|
| 743 |
+
description = image_description or "images"
|
| 744 |
+
original_filename = "image_files"
|
| 745 |
+
st.success(f"Uploaded {len(image_files)} image file(s)")
|
| 746 |
+
|
| 747 |
+
st.markdown("---")
|
| 748 |
+
|
| 749 |
+
# Task selection
|
| 750 |
+
st.markdown("### What would you like to do?")
|
| 751 |
+
col_btn1, col_btn2 = st.columns(2)
|
| 752 |
+
with col_btn1:
|
| 753 |
+
manual_mode = st.button("Enter Categories Manually", use_container_width=True)
|
| 754 |
+
with col_btn2:
|
| 755 |
+
auto_mode = st.button("Auto-extract Categories", use_container_width=True)
|
| 756 |
+
|
| 757 |
+
if manual_mode:
|
| 758 |
+
st.session_state.task_mode = "manual"
|
| 759 |
+
if auto_mode:
|
| 760 |
+
st.session_state.task_mode = "auto_extract"
|
| 761 |
+
|
| 762 |
+
# Auto-extract settings
|
| 763 |
+
if st.session_state.task_mode == "auto_extract":
|
| 764 |
+
st.markdown("### Auto-extract Categories")
|
| 765 |
+
st.markdown("We'll analyze your data to discover the main categories.")
|
| 766 |
+
|
| 767 |
+
max_categories = st.slider(
|
| 768 |
+
"Number of Categories to Extract",
|
| 769 |
+
min_value=3,
|
| 770 |
+
max_value=25,
|
| 771 |
+
value=12,
|
| 772 |
+
help="How many categories should be identified in your data"
|
| 773 |
+
)
|
| 774 |
+
|
| 775 |
+
# Model selection for extraction
|
| 776 |
+
st.markdown("### Model Selection")
|
| 777 |
+
model_tier = st.radio(
|
| 778 |
+
"Model Tier",
|
| 779 |
+
options=["Free Models", "Bring Your Own Key"],
|
| 780 |
+
key="extract_model_tier"
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
if model_tier == "Free Models":
|
| 784 |
+
model = st.selectbox("Model", options=FREE_MODEL_CHOICES, key="extract_model")
|
| 785 |
+
api_key = ""
|
| 786 |
+
st.info("**Free tier** - no API key required!")
|
| 787 |
else:
|
| 788 |
+
model = st.selectbox("Model", options=PAID_MODEL_CHOICES, key="extract_model_paid")
|
| 789 |
+
api_key = st.text_input("API Key", type="password", key="extract_api_key")
|
| 790 |
+
|
| 791 |
+
if st.button("Extract Categories", type="primary"):
|
| 792 |
+
if input_data is None:
|
| 793 |
+
st.error("Please upload data first")
|
| 794 |
+
else:
|
| 795 |
+
with st.spinner("Extracting categories..."):
|
| 796 |
+
mode = None
|
| 797 |
+
if input_type_selected == "pdf":
|
| 798 |
+
mode_mapping = {
|
| 799 |
+
"Image (visual documents)": "image",
|
| 800 |
+
"Text (text-heavy)": "text",
|
| 801 |
+
"Both (comprehensive)": "both"
|
| 802 |
+
}
|
| 803 |
+
mode = mode_mapping.get(pdf_mode, "image")
|
| 804 |
+
|
| 805 |
+
categories, status = run_auto_extract(
|
| 806 |
+
input_type_selected, input_data, description,
|
| 807 |
+
max_categories, model_tier, model, api_key, mode
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
if categories:
|
| 811 |
+
st.session_state.extracted_categories = categories
|
| 812 |
+
st.session_state.task_mode = "manual" # Switch to manual to show categories
|
| 813 |
+
st.success(status)
|
| 814 |
+
st.rerun()
|
| 815 |
+
else:
|
| 816 |
+
st.error(status)
|
| 817 |
+
|
| 818 |
+
# Category inputs (shown for manual mode or after extraction)
|
| 819 |
+
if st.session_state.task_mode == "manual":
|
| 820 |
+
st.markdown("### Categories")
|
| 821 |
+
st.markdown("Enter your classification categories below.")
|
| 822 |
+
|
| 823 |
+
# Pre-fill with extracted categories if available
|
| 824 |
+
if st.session_state.extracted_categories:
|
| 825 |
+
for i, cat in enumerate(st.session_state.extracted_categories[:MAX_CATEGORIES]):
|
| 826 |
+
st.session_state.categories[i] = cat
|
| 827 |
+
st.session_state.category_count = min(len(st.session_state.extracted_categories), MAX_CATEGORIES)
|
| 828 |
+
st.session_state.extracted_categories = None # Clear after use
|
| 829 |
+
|
| 830 |
+
placeholder_examples = [
|
| 831 |
+
"e.g., Positive sentiment",
|
| 832 |
+
"e.g., Negative sentiment",
|
| 833 |
+
"e.g., Product feedback",
|
| 834 |
+
"e.g., Service complaint",
|
| 835 |
+
"e.g., Feature request",
|
| 836 |
+
"e.g., Custom category"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 837 |
]
|
| 838 |
|
| 839 |
+
categories_entered = []
|
| 840 |
+
for i in range(st.session_state.category_count):
|
| 841 |
+
placeholder = placeholder_examples[i] if i < len(placeholder_examples) else "e.g., Custom category"
|
| 842 |
+
cat_value = st.text_input(
|
| 843 |
+
f"Category {i+1}",
|
| 844 |
+
value=st.session_state.categories[i],
|
| 845 |
+
placeholder=placeholder,
|
| 846 |
+
key=f"cat_{i}"
|
| 847 |
+
)
|
| 848 |
+
st.session_state.categories[i] = cat_value
|
| 849 |
+
if cat_value.strip():
|
| 850 |
+
categories_entered.append(cat_value.strip())
|
| 851 |
+
|
| 852 |
+
if st.session_state.category_count < MAX_CATEGORIES:
|
| 853 |
+
if st.button("+ Add More"):
|
| 854 |
+
st.session_state.category_count += 1
|
| 855 |
+
st.rerun()
|
| 856 |
+
|
| 857 |
+
st.markdown("### Model Selection")
|
| 858 |
+
model_tier = st.radio(
|
| 859 |
+
"Model Tier",
|
| 860 |
+
options=["Free Models", "Bring Your Own Key"],
|
| 861 |
+
key="classify_model_tier"
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
if model_tier == "Free Models":
|
| 865 |
+
model = st.selectbox("Model", options=FREE_MODEL_CHOICES, key="classify_model")
|
| 866 |
+
api_key = ""
|
| 867 |
+
st.info("**Free tier** - no API key required!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 868 |
else:
|
| 869 |
+
model = st.selectbox("Model", options=PAID_MODEL_CHOICES, key="classify_model_paid")
|
| 870 |
+
api_key = st.text_input("API Key", type="password", key="classify_api_key")
|
| 871 |
+
|
| 872 |
+
if st.button("Classify Data", type="primary", use_container_width=True):
|
| 873 |
+
if input_data is None:
|
| 874 |
+
st.error("Please upload data first")
|
| 875 |
+
elif not categories_entered:
|
| 876 |
+
st.error("Please enter at least one category")
|
| 877 |
+
else:
|
| 878 |
+
with st.spinner("Classifying data... This may take a few minutes."):
|
| 879 |
+
mode = None
|
| 880 |
+
if input_type_selected == "pdf":
|
| 881 |
+
mode_mapping = {
|
| 882 |
+
"Image (visual documents)": "image",
|
| 883 |
+
"Text (text-heavy)": "text",
|
| 884 |
+
"Both (comprehensive)": "both"
|
| 885 |
+
}
|
| 886 |
+
mode = mode_mapping.get(pdf_mode, "image")
|
| 887 |
+
|
| 888 |
+
result_df, csv_path, pdf_path, code, status = run_classify_data(
|
| 889 |
+
input_type_selected, input_data, description,
|
| 890 |
+
categories_entered, model_tier, model, api_key, mode,
|
| 891 |
+
original_filename, description
|
| 892 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 893 |
|
| 894 |
+
if result_df is not None:
|
| 895 |
+
st.session_state.results = {
|
| 896 |
+
'df': result_df,
|
| 897 |
+
'csv_path': csv_path,
|
| 898 |
+
'pdf_path': pdf_path,
|
| 899 |
+
'code': code,
|
| 900 |
+
'status': status,
|
| 901 |
+
'categories': categories_entered
|
| 902 |
+
}
|
| 903 |
+
st.success(status)
|
| 904 |
+
st.rerun()
|
| 905 |
+
else:
|
| 906 |
+
st.error(status)
|
| 907 |
+
|
| 908 |
+
with col_output:
|
| 909 |
+
st.markdown("### Results")
|
| 910 |
+
|
| 911 |
+
if st.session_state.results:
|
| 912 |
+
results = st.session_state.results
|
| 913 |
+
|
| 914 |
+
# Distribution chart
|
| 915 |
+
fig = create_distribution_chart(results['df'], results['categories'])
|
| 916 |
+
st.pyplot(fig)
|
| 917 |
+
|
| 918 |
+
# Results dataframe
|
| 919 |
+
st.dataframe(results['df'], use_container_width=True)
|
| 920 |
+
|
| 921 |
+
# Downloads
|
| 922 |
+
col_dl1, col_dl2 = st.columns(2)
|
| 923 |
+
with col_dl1:
|
| 924 |
+
with open(results['csv_path'], 'rb') as f:
|
| 925 |
+
st.download_button(
|
| 926 |
+
"Download Results (CSV)",
|
| 927 |
+
data=f,
|
| 928 |
+
file_name="classified_results.csv",
|
| 929 |
+
mime="text/csv"
|
| 930 |
+
)
|
| 931 |
+
with col_dl2:
|
| 932 |
+
with open(results['pdf_path'], 'rb') as f:
|
| 933 |
+
st.download_button(
|
| 934 |
+
"Download Methodology Report (PDF)",
|
| 935 |
+
data=f,
|
| 936 |
+
file_name="methodology_report.pdf",
|
| 937 |
+
mime="application/pdf"
|
| 938 |
+
)
|
| 939 |
|
| 940 |
+
# Code
|
| 941 |
+
with st.expander("See the Code"):
|
| 942 |
+
st.code(results['code'], language='python')
|
| 943 |
+
else:
|
| 944 |
+
st.info("Upload data, select categories, and click 'Classify Data' to see results here.")
|
| 945 |
+
|
| 946 |
+
# Reset button
|
| 947 |
+
if st.button("Reset", type="secondary"):
|
| 948 |
+
st.session_state.categories = [''] * MAX_CATEGORIES
|
| 949 |
+
st.session_state.category_count = INITIAL_CATEGORIES
|
| 950 |
+
st.session_state.task_mode = None
|
| 951 |
+
st.session_state.extracted_categories = None
|
| 952 |
+
st.session_state.results = None
|
| 953 |
+
if hasattr(st.session_state, 'example_loaded'):
|
| 954 |
+
del st.session_state.example_loaded
|
| 955 |
+
st.rerun()
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
|
|
| 1 |
cat-llm[pdf]>=0.1.6
|
| 2 |
-
gradio==5.6.0
|
| 3 |
mistralai
|
| 4 |
pydantic==2.10.6
|
| 5 |
huggingface_hub<0.27.0
|
|
@@ -9,3 +9,4 @@ requests
|
|
| 9 |
regex
|
| 10 |
reportlab
|
| 11 |
matplotlib
|
|
|
|
|
|
| 1 |
+
streamlit>=1.32.0
|
| 2 |
cat-llm[pdf]>=0.1.6
|
|
|
|
| 3 |
mistralai
|
| 4 |
pydantic==2.10.6
|
| 5 |
huggingface_hub<0.27.0
|
|
|
|
| 9 |
regex
|
| 10 |
reportlab
|
| 11 |
matplotlib
|
| 12 |
+
Pillow
|