chrissoria's picture
Fix: Use Gemini Flash as default, fix unavailable HF models
4934097
"""
Streamlit app - CatLLM Survey Response Classifier
Migrated from Gradio for better mobile support
"""
import streamlit as st
import pandas as pd
import tempfile
import os
import time
import sys
from datetime import datetime
import matplotlib.pyplot as plt
# Import catllm
try:
import catllm
CATLLM_AVAILABLE = True
except ImportError as e:
print(f"Warning: Could not import catllm: {e}")
CATLLM_AVAILABLE = False
MAX_CATEGORIES = 10
INITIAL_CATEGORIES = 3
MAX_FILE_SIZE_MB = 100
def count_pdf_pages(pdf_path):
"""Count the number of pages in a PDF file."""
try:
import fitz # PyMuPDF
doc = fitz.open(pdf_path)
page_count = len(doc)
doc.close()
return page_count
except Exception:
return 1 # Default to 1 if can't read
def extract_text_from_pdfs(pdf_paths):
"""Extract text from all pages of all PDFs, returning list of page texts."""
import fitz # PyMuPDF
all_texts = []
for pdf_path in pdf_paths:
try:
doc = fitz.open(pdf_path)
for page in doc:
text = page.get_text().strip()
if text: # Only add non-empty pages
all_texts.append(text)
doc.close()
except Exception as e:
print(f"Error extracting text from {pdf_path}: {e}")
return all_texts
def extract_pdf_pages(pdf_paths, pdf_name_map, mode="image"):
"""
Extract individual pages from PDFs.
Returns list of (page_data, page_label) tuples.
For image mode: page_data is path to temp image file
For text mode: page_data is extracted text
"""
import fitz # PyMuPDF
pages = []
for pdf_path in pdf_paths:
orig_name = pdf_name_map.get(pdf_path, os.path.basename(pdf_path).replace('.pdf', ''))
try:
doc = fitz.open(pdf_path)
for page_num, page in enumerate(doc, 1):
page_label = f"{orig_name}_p{page_num}"
if mode == "text":
# Extract text
text = page.get_text().strip()
if text:
pages.append((text, page_label, "text"))
else:
# Render as image (for image or both mode)
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x zoom for better quality
img_path = tempfile.NamedTemporaryFile(delete=False, suffix='.png').name
pix.save(img_path)
if mode == "both":
text = page.get_text().strip()
pages.append((img_path, page_label, "image", text))
else:
pages.append((img_path, page_label, "image"))
doc.close()
except Exception as e:
print(f"Error extracting pages from {pdf_path}: {e}")
return pages
# Free models - display name -> actual API model name
FREE_MODELS_MAP = {
"Gemini 2.5 Flash": "gemini-2.5-flash",
"GPT-4o Mini": "gpt-4o-mini",
"Claude 3 Haiku": "claude-3-haiku-20240307",
"Llama 3.3 70B": "meta-llama/Llama-3.3-70B-Instruct:groq",
"DeepSeek R1": "deepseek-ai/DeepSeek-R1:novita",
"Mistral Medium": "mistral-medium-2505",
"Grok 4 Fast": "grok-4-fast-non-reasoning",
}
FREE_MODEL_DISPLAY_NAMES = list(FREE_MODELS_MAP.keys())
FREE_MODEL_CHOICES = list(FREE_MODELS_MAP.values()) # Keep for backward compat
# Paid models (user provides their own API key)
PAID_MODEL_CHOICES = [
"gemini-2.5-flash",
"gemini-2.5-pro",
"gpt-4.1",
"gpt-4o",
"gpt-4o-mini",
"claude-sonnet-4-5-20250929",
"claude-opus-4-20250514",
"claude-3-5-haiku-20241022",
"mistral-large-latest",
]
# Models routed through HuggingFace
HF_ROUTED_MODELS = [
"meta-llama/Llama-3.3-70B-Instruct:groq",
"deepseek-ai/DeepSeek-R1:novita",
]
def is_free_model(model, model_tier):
"""Check if using free tier (Space pays for API)."""
return model_tier == "Free Models"
def get_model_source(model):
"""Auto-detect model source."""
model_lower = model.lower()
if "gpt" in model_lower:
return "openai"
elif "claude" in model_lower:
return "anthropic"
elif "gemini" in model_lower:
return "google"
elif "mistral" in model_lower and ":novita" not in model_lower:
return "mistral"
elif any(x in model_lower for x in [":novita", ":groq", "qwen", "llama", "deepseek"]):
return "huggingface"
elif "sonar" in model_lower:
return "perplexity"
elif "grok" in model_lower:
return "xai"
return "huggingface"
def get_api_key(model, model_tier, api_key_input):
"""Get the appropriate API key based on model and tier."""
if is_free_model(model, model_tier):
if model in HF_ROUTED_MODELS:
return os.environ.get("HF_API_KEY", ""), "HuggingFace"
elif "gpt" in model.lower():
return os.environ.get("OPENAI_API_KEY", ""), "OpenAI"
elif "gemini" in model.lower():
return os.environ.get("GOOGLE_API_KEY", ""), "Google"
elif "mistral" in model.lower():
return os.environ.get("MISTRAL_API_KEY", ""), "Mistral"
elif "claude" in model.lower():
return os.environ.get("ANTHROPIC_API_KEY", ""), "Anthropic"
elif "sonar" in model.lower():
return os.environ.get("PERPLEXITY_API_KEY", ""), "Perplexity"
elif "grok" in model.lower():
return os.environ.get("XAI_API_KEY", ""), "xAI"
else:
return os.environ.get("HF_API_KEY", ""), "HuggingFace"
else:
if api_key_input and api_key_input.strip():
return api_key_input.strip(), "User"
return "", "User"
def calculate_total_file_size(files):
"""Calculate total size of uploaded files in MB."""
if files is None:
return 0
if not isinstance(files, list):
files = [files]
total_bytes = 0
for f in files:
try:
if hasattr(f, 'size'):
total_bytes += f.size
elif hasattr(f, 'name'):
total_bytes += os.path.getsize(f.name)
except (OSError, AttributeError):
pass
return total_bytes / (1024 * 1024)
def generate_extract_code(input_type, description, model, model_source, max_categories, mode=None):
"""Generate Python code for category extraction."""
if input_type == "text":
return f'''import catllm
import pandas as pd
# Load your data
df = pd.read_csv("your_data.csv")
# Extract categories from the text column
result = catllm.extract(
input_data=df["{description}"].tolist(),
api_key="YOUR_API_KEY",
input_type="text",
description="{description}",
user_model="{model}",
model_source="{model_source}",
max_categories={max_categories}
)
# View extracted categories
print(result["top_categories"])
print(result["counts_df"])
'''
elif input_type == "pdf":
mode_line = f',\n mode="{mode}"' if mode else ''
return f'''import catllm
# Extract categories from PDF documents
result = catllm.extract(
input_data="path/to/your/pdfs/",
api_key="YOUR_API_KEY",
input_type="pdf",
description="{description}"{mode_line},
user_model="{model}",
model_source="{model_source}",
max_categories={max_categories}
)
# View extracted categories
print(result["top_categories"])
print(result["counts_df"])
'''
else: # image
return f'''import catllm
# Extract categories from images
result = catllm.extract(
input_data="path/to/your/images/",
api_key="YOUR_API_KEY",
input_type="image",
description="{description}",
user_model="{model}",
model_source="{model_source}",
max_categories={max_categories}
)
# View extracted categories
print(result["top_categories"])
print(result["counts_df"])
'''
def generate_classify_code(input_type, description, categories, model, model_source, mode=None, classify_mode="Single Model", models_list=None):
"""Generate Python code for classification."""
categories_str = ",\n ".join([f'"{cat}"' for cat in categories])
# Determine input data placeholder based on type
if input_type == "text":
input_placeholder = 'df["your_column"].tolist()'
load_data = '''import pandas as pd
# Load your data
df = pd.read_csv("your_data.csv")
'''
elif input_type == "pdf":
input_placeholder = '"path/to/your/pdfs/"'
load_data = ''
else: # image
input_placeholder = '"path/to/your/images/"'
load_data = ''
# Generate code based on classification mode
if classify_mode == "Single Model":
# Single model mode
mode_param = f',\n mode="{mode}"' if mode and input_type == "pdf" else ''
return f'''import catllm
{load_data}
# Define categories
categories = [
{categories_str}
]
# Classify data (input type is auto-detected)
result = catllm.classify(
input_data={input_placeholder},
categories=categories,
api_key="YOUR_API_KEY",
description="{description}",
user_model="{model}"{mode_param}
)
# View results
print(result)
result.to_csv("classified_results.csv", index=False)
'''
else:
# Multi-model mode (Comparison or Ensemble)
if models_list:
models_str = ",\n ".join([f'("{m}", "auto", "YOUR_API_KEY")' for m in models_list])
else:
models_str = '("gpt-4o", "auto", "YOUR_API_KEY"),\n ("claude-sonnet-4-5-20250929", "auto", "YOUR_API_KEY")'
mode_param = f',\n mode="{mode}"' if mode and input_type == "pdf" else ''
consensus_param = ',\n consensus_threshold=0.5' if classify_mode == "Ensemble" else ''
return f'''import catllm
{load_data}
# Define categories
categories = [
{categories_str}
]
# Define models for {"ensemble voting" if classify_mode == "Ensemble" else "comparison"}
models = [
{models_str}
]
# Classify with multiple models
result = catllm.classify(
input_data={input_placeholder},
categories=categories,
models=models,
description="{description}"{mode_param}{consensus_param}
)
# View results
print(result)
result.to_csv("classified_results.csv", index=False)
'''
def generate_methodology_report_pdf(categories, model, column_name, num_rows, model_source, filename, success_rate,
result_df=None, processing_time=None, prompt_template=None,
data_quality=None, catllm_version=None, python_version=None,
task_type="assign", extracted_categories_df=None, max_categories=None,
input_type="text", description=None):
"""Generate a PDF methodology report."""
from reportlab.lib.pagesizes import letter
from reportlab.lib import colors
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak
pdf_file = tempfile.NamedTemporaryFile(mode='wb', suffix='_methodology_report.pdf', delete=False)
doc = SimpleDocTemplate(pdf_file.name, pagesize=letter)
styles = getSampleStyleSheet()
title_style = ParagraphStyle('Title', parent=styles['Heading1'], fontSize=18, spaceAfter=20)
heading_style = ParagraphStyle('Heading', parent=styles['Heading2'], fontSize=14, spaceAfter=10, spaceBefore=15)
normal_style = styles['Normal']
code_style = ParagraphStyle('Code', parent=styles['Normal'], fontName='Courier', fontSize=9, leftIndent=20, spaceAfter=3)
story = []
if task_type == "extract_and_assign":
report_title = "CatLLM Extraction & Classification Report"
else:
report_title = "CatLLM Classification Report"
story.append(Paragraph(report_title, title_style))
story.append(Paragraph(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", normal_style))
story.append(Spacer(1, 15))
story.append(Paragraph("About This Report", heading_style))
if task_type == "extract_and_assign":
about_text = """This methodology report documents the automated category extraction and classification process. \
CatLLM first discovers categories from your data using LLMs, then classifies each item into those categories."""
else:
about_text = """This methodology report documents the classification process for reproducibility and transparency. \
CatLLM restricts the prompt to a standard template that is impartial to the researcher's inclinations, ensuring \
consistent and reproducible results."""
story.append(Paragraph(about_text, normal_style))
story.append(Spacer(1, 15))
if categories:
story.append(Paragraph("Category Mapping", heading_style))
story.append(Paragraph("Each category column contains binary values: 1 = present, 0 = not present", normal_style))
story.append(Spacer(1, 8))
category_data = [["Column Name", "Category Description"]]
for i, cat in enumerate(categories, 1):
category_data.append([f"category_{i}", cat])
cat_table = Table(category_data, colWidths=[120, 330])
cat_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (-1, 0), colors.grey),
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
('GRID', (0, 0), (-1, -1), 1, colors.black),
('PADDING', (0, 0), (-1, -1), 6),
('BACKGROUND', (0, 1), (0, -1), colors.lightgrey),
('FONTSIZE', (0, 0), (-1, -1), 9),
]))
story.append(cat_table)
story.append(Spacer(1, 15))
story.append(Spacer(1, 30))
story.append(Paragraph("Citation", heading_style))
story.append(Paragraph("If you use CatLLM in your research, please cite:", normal_style))
story.append(Spacer(1, 5))
story.append(Paragraph("Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DOI: 10.5281/zenodo.15532316", normal_style))
# Summary section
story.append(PageBreak())
story.append(Paragraph("Classification Summary", title_style))
story.append(Spacer(1, 15))
summary_data = [
["Source File", filename],
["Source Column", column_name],
["Model Used", model],
["Model Source", model_source],
["Rows Classified", str(num_rows)],
["Number of Categories", str(len(categories)) if categories else "0"],
["Success Rate", f"{success_rate:.2f}%"],
]
summary_table = Table(summary_data, colWidths=[150, 300])
summary_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (0, -1), colors.lightgrey),
('GRID', (0, 0), (-1, -1), 1, colors.black),
('PADDING', (0, 0), (-1, -1), 6),
('FONTSIZE', (0, 0), (-1, -1), 9),
]))
story.append(summary_table)
story.append(Spacer(1, 15))
if processing_time is not None:
story.append(Paragraph("Processing Time", heading_style))
rows_per_min = (num_rows / processing_time) * 60 if processing_time > 0 else 0
avg_time = processing_time / num_rows if num_rows > 0 else 0
time_data = [
["Total Processing Time", f"{processing_time:.1f} seconds"],
["Average Time per Response", f"{avg_time:.2f} seconds"],
["Processing Rate", f"{rows_per_min:.1f} rows/minute"],
]
time_table = Table(time_data, colWidths=[180, 270])
time_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (0, -1), colors.lightgrey),
('GRID', (0, 0), (-1, -1), 1, colors.black),
('PADDING', (0, 0), (-1, -1), 6),
('FONTSIZE', (0, 0), (-1, -1), 9),
]))
story.append(time_table)
story.append(Spacer(1, 15))
story.append(Paragraph("Version Information", heading_style))
version_data = [
["CatLLM Version", catllm_version or "unknown"],
["Python Version", python_version or "unknown"],
["Timestamp", datetime.now().strftime('%Y-%m-%d %H:%M:%S')],
]
version_table = Table(version_data, colWidths=[180, 270])
version_table.setStyle(TableStyle([
('BACKGROUND', (0, 0), (0, -1), colors.lightgrey),
('GRID', (0, 0), (-1, -1), 1, colors.black),
('PADDING', (0, 0), (-1, -1), 6),
('FONTSIZE', (0, 0), (-1, -1), 9),
]))
story.append(version_table)
doc.build(story)
return pdf_file.name
def run_auto_extract(input_type, input_data, description, max_categories_val,
model_tier, model, api_key_input, mode=None, progress_callback=None):
"""Extract categories from data."""
if not CATLLM_AVAILABLE:
return None, "catllm package not available"
actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
if not actual_api_key:
return None, f"{provider} API key not configured"
model_source = get_model_source(model)
try:
if isinstance(input_data, list):
num_items = len(input_data)
else:
num_items = 1
if input_type == "image":
divisions = min(3, max(1, num_items // 5))
categories_per_chunk = 12
else:
divisions = max(1, num_items // 15)
divisions = min(divisions, 5)
chunk_size = num_items // max(1, divisions)
categories_per_chunk = min(10, chunk_size - 1)
extract_kwargs = {
'input_data': input_data,
'api_key': actual_api_key,
'input_type': input_type,
'description': description,
'user_model': model,
'model_source': model_source,
'divisions': divisions,
'categories_per_chunk': categories_per_chunk,
'max_categories': int(max_categories_val)
}
if mode:
extract_kwargs['mode'] = mode
extract_result = catllm.extract(**extract_kwargs)
categories = extract_result.get('top_categories', [])
if not categories:
return None, "No categories were extracted"
return categories, f"Extracted {len(categories)} categories successfully!"
except Exception as e:
return None, f"Error: {str(e)}"
def run_classify_data(input_type, input_data, description, categories,
model_tier, model, api_key_input, mode=None,
original_filename="data", column_name="text",
progress_callback=None):
"""Classify data with user-provided categories."""
if not CATLLM_AVAILABLE:
return None, None, None, None, "catllm package not available"
if not categories:
return None, None, None, None, "Please enter at least one category"
actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
if not actual_api_key:
return None, None, None, None, f"{provider} API key not configured"
model_source = get_model_source(model)
try:
start_time = time.time()
classify_kwargs = {
'survey_input': input_data,
'categories': categories,
'models': [(model, model_source, actual_api_key)],
'input_description': description,
}
if mode:
classify_kwargs['pdf_mode'] = mode
result = catllm.classify_ensemble(**classify_kwargs)
processing_time = time.time() - start_time
num_items = len(result)
# Save CSV
with tempfile.NamedTemporaryFile(mode='w', suffix='_classified.csv', delete=False) as f:
result.to_csv(f.name, index=False)
csv_path = f.name
# Calculate success rate
if 'processing_status' in result.columns:
success_count = (result['processing_status'] == 'success').sum()
success_rate = (success_count / len(result)) * 100
else:
success_rate = 100.0
# Get version info
try:
catllm_version = catllm.__version__
except AttributeError:
catllm_version = "unknown"
python_version = sys.version.split()[0]
# Generate methodology report
report_pdf_path = generate_methodology_report_pdf(
categories=categories,
model=model,
column_name=column_name,
num_rows=num_items,
model_source=model_source,
filename=original_filename,
success_rate=success_rate,
result_df=result,
processing_time=processing_time,
catllm_version=catllm_version,
python_version=python_version,
task_type="assign",
input_type=input_type,
description=description
)
# Generate reproducibility code
code = generate_classify_code(input_type, description, categories, model, model_source, mode)
return result, csv_path, report_pdf_path, code, f"Classified {num_items} items in {processing_time:.1f}s"
except Exception as e:
return None, None, None, None, f"Error: {str(e)}"
def sanitize_model_name(model: str) -> str:
"""Convert model name to column-safe suffix (matches catllm logic)."""
import re
sanitized = re.sub(r'[^a-zA-Z0-9]', '_', model)
sanitized = re.sub(r'_+', '_', sanitized)
sanitized = sanitized.strip('_').lower()
return sanitized[:40]
def create_distribution_chart(result_df, categories, classify_mode="Single Model", models_list=None):
"""Create a bar chart showing category distribution.
Args:
result_df: DataFrame with classification results
categories: List of category names
classify_mode: "Single Model", "Model Comparison", or "Ensemble"
models_list: List of model names (for multi-model modes)
"""
import numpy as np
total_rows = len(result_df)
if total_rows == 0:
fig, ax = plt.subplots(figsize=(10, 4))
ax.text(0.5, 0.5, 'No data to display', ha='center', va='center', fontsize=14)
ax.axis('off')
return fig
# Define colors for different models
model_colors = ['#2563eb', '#dc2626', '#16a34a', '#ca8a04', '#9333ea', '#0891b2', '#be185d', '#65a30d']
if classify_mode == "Single Model":
# Single model: use category_1, category_2, etc.
fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8)))
dist_data = []
for i, cat in enumerate(categories, 1):
col_name = f"category_{i}"
if col_name in result_df.columns:
count = int(result_df[col_name].sum())
pct = (count / total_rows) * 100
dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
categories_list = [d["Category"] for d in dist_data][::-1]
percentages = [d["Percentage"] for d in dist_data][::-1]
bars = ax.barh(categories_list, percentages, color='#2563eb')
ax.set_xlim(0, 100)
ax.set_xlabel('Percentage (%)', fontsize=11)
ax.set_title('Category Distribution (%)', fontsize=14, fontweight='bold')
for bar, pct in zip(bars, percentages):
ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
f'{pct:.1f}%', va='center', fontsize=10)
elif classify_mode == "Ensemble":
# Ensemble: use category_1_consensus, category_2_consensus, etc.
fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8)))
dist_data = []
for i, cat in enumerate(categories, 1):
col_name = f"category_{i}_consensus"
if col_name in result_df.columns:
count = int(result_df[col_name].sum())
pct = (count / total_rows) * 100
dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
categories_list = [d["Category"] for d in dist_data][::-1]
percentages = [d["Percentage"] for d in dist_data][::-1]
bars = ax.barh(categories_list, percentages, color='#16a34a')
ax.set_xlim(0, 100)
ax.set_xlabel('Percentage (%)', fontsize=11)
ax.set_title('Ensemble Consensus Distribution (%)', fontsize=14, fontweight='bold')
for bar, pct in zip(bars, percentages):
ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
f'{pct:.1f}%', va='center', fontsize=10)
else: # Model Comparison
# Model Comparison: grouped bars for each model
if not models_list:
models_list = []
sanitized_names = [sanitize_model_name(m) for m in models_list]
n_models = len(sanitized_names)
n_categories = len(categories)
fig, ax = plt.subplots(figsize=(12, max(5, n_categories * 1.2)))
# Gather data for each model
bar_height = 0.8 / n_models
y_positions = np.arange(n_categories)
for model_idx, (model_name, sanitized) in enumerate(zip(models_list, sanitized_names)):
model_pcts = []
for i in range(1, n_categories + 1):
col_name = f"category_{i}_{sanitized}"
if col_name in result_df.columns:
count = int(result_df[col_name].sum())
pct = (count / total_rows) * 100
else:
pct = 0
model_pcts.append(pct)
# Reverse for horizontal bar chart
model_pcts = model_pcts[::-1]
offset = (model_idx - n_models / 2 + 0.5) * bar_height
color = model_colors[model_idx % len(model_colors)]
# Use shorter display name
display_name = model_name.split('/')[-1].split(':')[0][:20]
bars = ax.barh(y_positions + offset, model_pcts, bar_height * 0.9,
label=display_name, color=color, alpha=0.85)
ax.set_yticks(y_positions)
ax.set_yticklabels(categories[::-1])
ax.set_xlim(0, 100)
ax.set_xlabel('Percentage (%)', fontsize=11)
ax.set_title('Category Distribution by Model (%)', fontsize=14, fontweight='bold')
ax.legend(loc='lower right', fontsize=9)
plt.tight_layout()
return fig
# Page config
st.set_page_config(
page_title="CatLLM - Research Data Classifier",
page_icon="🐱",
layout="wide"
)
# Initialize session state
if 'categories' not in st.session_state:
st.session_state.categories = [''] * MAX_CATEGORIES
if 'category_count' not in st.session_state:
st.session_state.category_count = INITIAL_CATEGORIES
if 'task_mode' not in st.session_state:
st.session_state.task_mode = None
if 'extracted_categories' not in st.session_state:
st.session_state.extracted_categories = None
if 'results' not in st.session_state:
st.session_state.results = None
if 'active_tab' not in st.session_state:
st.session_state.active_tab = "survey"
if 'survey_data' not in st.session_state:
st.session_state.survey_data = None
if 'pdf_data' not in st.session_state:
st.session_state.pdf_data = None
if 'image_data' not in st.session_state:
st.session_state.image_data = None
# Logo and title
col_logo, col_title = st.columns([1, 6])
with col_logo:
st.image("logo.png", width=100)
with col_title:
st.title("CatLLM - Research Data Classifier")
st.markdown("Research-grade categorization of survey responses, PDFs, and images using LLMs.")
# About section
with st.expander("About This App"):
st.markdown("""
**Privacy Notice:** Your data is sent to third-party LLM APIs for classification. Do not upload sensitive, confidential, or personally identifiable information (PII).
---
**CatLLM** is an open-source Python package for classifying text and document data using Large Language Models.
### What It Does
- **Extract Categories**: Discover themes and categories in your data automatically
- **Assign Categories**: Classify data into your predefined categories
- **Extract & Assign**: Let CatLLM discover categories, then classify all your data
### Beta Test - We Want Your Feedback!
This app is currently in **beta** and **free to use** while CatLLM is under review for publication, made possible by **Bashir Ahmed's generous fellowship support**.
- Found a bug? Have a feature request? Please open an issue on [GitHub](https://github.com/chrissoria/cat-llm)
- Reach out directly: [chrissoria@berkeley.edu](mailto:chrissoria@berkeley.edu)
### Links
- **PyPI**: [pip install cat-llm](https://pypi.org/project/cat-llm/)
- **GitHub**: [github.com/chrissoria/cat-llm](https://github.com/chrissoria/cat-llm)
### Citation
If you use CatLLM in your research, please cite:
```
Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DOI: 10.5281/zenodo.15532316
```
""")
# Main layout
col_input, col_output = st.columns([1, 1])
with col_input:
# Input type selector
input_type_choice = st.radio(
"Input Type",
options=["Survey Responses", "PDF Documents", "Images"],
horizontal=True,
key="input_type_radio"
)
# Initialize variables
input_data = None
input_type_selected = "text"
description = ""
original_filename = "data"
pdf_mode = "Image (visual documents)"
if input_type_choice == "Survey Responses":
input_type_selected = "text"
uploaded_file = st.file_uploader(
"Upload Data (CSV or Excel)",
type=['csv', 'xlsx', 'xls'],
key="survey_file"
)
if st.button("Try Example Dataset", key="example_btn"):
st.session_state.example_loaded = True
columns = []
df = None
if uploaded_file is not None:
try:
if uploaded_file.name.endswith('.csv'):
df = pd.read_csv(uploaded_file)
else:
df = pd.read_excel(uploaded_file)
columns = df.columns.tolist()
st.success(f"Loaded {len(df):,} rows")
except Exception as e:
st.error(f"Error loading file: {e}")
elif hasattr(st.session_state, 'example_loaded') and st.session_state.example_loaded:
try:
df = pd.read_csv("example_data.csv")
columns = df.columns.tolist()
st.success(f"Loaded example dataset ({len(df)} rows)")
except:
pass
selected_column = st.selectbox(
"Column to Process",
options=columns if columns else ["Upload a file first"],
disabled=not columns,
key="survey_column"
)
description = selected_column if columns else ""
original_filename = uploaded_file.name if uploaded_file else "example_data.csv"
if df is not None and columns and selected_column in columns:
input_data = df[selected_column].tolist()
elif input_type_choice == "PDF Documents":
input_type_selected = "pdf"
pdf_files = st.file_uploader(
"Upload PDF Document(s)",
type=['pdf'],
accept_multiple_files=True,
key="pdf_files"
)
pdf_description = st.text_input(
"Document Description",
placeholder="e.g., 'research papers', 'interview transcripts'",
help="Helps the LLM understand context",
key="pdf_desc"
)
pdf_mode = st.radio(
"Processing Mode",
options=["Image (visual documents)", "Text (text-heavy)", "Both (comprehensive)"],
key="pdf_mode"
)
if pdf_files:
input_data = []
pdf_name_map = {} # Map temp paths to original filenames
for f in pdf_files:
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
tmp.write(f.read())
input_data.append(tmp.name)
pdf_name_map[tmp.name] = f.name.replace('.pdf', '') # Store original name without extension
st.session_state.pdf_name_map = pdf_name_map
description = pdf_description or "document"
original_filename = "pdf_files"
st.success(f"Uploaded {len(pdf_files)} PDF file(s)")
else: # Images
input_type_selected = "image"
image_files = st.file_uploader(
"Upload Images",
type=['png', 'jpg', 'jpeg', 'gif', 'webp'],
accept_multiple_files=True,
key="image_files"
)
image_description = st.text_input(
"Image Description",
placeholder="e.g., 'product photos', 'social media posts'",
help="Helps the LLM understand context",
key="image_desc"
)
if image_files:
input_data = []
for f in image_files:
suffix = '.' + f.name.split('.')[-1]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
tmp.write(f.read())
input_data.append(tmp.name)
description = image_description or "images"
original_filename = "image_files"
st.success(f"Uploaded {len(image_files)} image file(s)")
st.markdown("---")
# Task selection
st.markdown("### What would you like to do?")
col_btn1, col_btn2 = st.columns(2)
with col_btn1:
manual_mode = st.button("Enter Categories Manually", use_container_width=True)
with col_btn2:
auto_mode = st.button("Auto-extract Categories", use_container_width=True)
if manual_mode:
st.session_state.task_mode = "manual"
if auto_mode:
st.session_state.task_mode = "auto_extract"
# Auto-extract settings
if st.session_state.task_mode == "auto_extract":
st.markdown("### Auto-extract Categories")
st.markdown("We'll analyze your data to discover the main categories.")
max_categories = st.slider(
"Number of Categories to Extract",
min_value=3,
max_value=25,
value=12,
help="How many categories should be identified in your data"
)
specificity = st.selectbox(
"How specific should categories be?",
options=["Broad", "Moderate", "Narrow"],
index=0,
help="Broad = general themes, Moderate = balanced detail, Narrow = highly specific categories"
)
focus = st.text_input(
"What should categories be focused around? (optional)",
placeholder="e.g., 'decisions to move', 'emotional responses', 'financial factors'",
help="Guide the model to prioritize extracting categories related to this focus"
)
# Model selection for extraction
st.markdown("### Model Selection")
model_tier = st.radio(
"Model Tier",
options=["Free Models", "Bring Your Own Key"],
key="extract_model_tier"
)
if model_tier == "Free Models":
model_display = st.selectbox("Model", options=FREE_MODEL_DISPLAY_NAMES, key="extract_model")
model = FREE_MODELS_MAP[model_display] # Convert to actual model name
api_key = ""
else:
model = st.selectbox("Model", options=PAID_MODEL_CHOICES, key="extract_model_paid")
api_key = st.text_input("API Key", type="password", key="extract_api_key")
if st.button("Extract Categories", type="primary"):
if input_data is None:
st.error("Please upload data first")
else:
mode = None
if input_type_selected == "pdf":
mode_mapping = {
"Image (visual documents)": "image",
"Text (text-heavy)": "text",
"Both (comprehensive)": "both"
}
mode = mode_mapping.get(pdf_mode, "image")
actual_api_key, provider = get_api_key(model, model_tier, api_key)
if not actual_api_key:
st.error(f"{provider} API key not configured")
else:
model_source = get_model_source(model)
# Calculate estimated time based on input size
num_items = len(input_data) if isinstance(input_data, list) else 1
if input_type_selected == "pdf":
# PDFs take longer - estimate ~5s per page
total_pages = sum(count_pdf_pages(p) for p in (input_data if isinstance(input_data, list) else [input_data]))
est_seconds = total_pages * 5
elif input_type_selected == "image":
# Images ~4s each
est_seconds = num_items * 4
else:
# Text ~2s per item, but batched
est_seconds = max(10, num_items * 0.5)
est_time_str = f"{est_seconds:.0f}s" if est_seconds < 60 else f"{est_seconds/60:.1f}m"
# Animated status indicator
with st.status(f"Extracting categories (est. {est_time_str})...", expanded=True) as status:
st.write("Analyzing your data to discover categories...")
start_time = time.time()
extract_kwargs = {
'input_data': input_data,
'api_key': actual_api_key,
'input_type': input_type_selected,
'description': description,
'user_model': model,
'model_source': model_source,
'max_categories': int(max_categories),
'specificity': specificity.lower()
}
if mode:
extract_kwargs['mode'] = mode
if focus and focus.strip():
extract_kwargs['focus'] = focus.strip()
try:
extract_result = catllm.extract(**extract_kwargs)
categories = extract_result.get('top_categories', [])
processing_time = time.time() - start_time
if categories:
status.update(label=f"Extracted {len(categories)} categories in {processing_time:.1f}s", state="complete", expanded=False)
st.session_state.extracted_categories = categories
st.session_state.task_mode = "manual"
st.rerun()
else:
status.update(label="No categories extracted", state="error")
st.error("No categories were extracted from the data")
except Exception as e:
status.update(label="Extraction failed", state="error")
st.error(f"Error: {str(e)}")
# Category inputs (shown for manual mode or after extraction)
if st.session_state.task_mode == "manual":
st.markdown("### Categories")
st.markdown("Enter your classification categories below.")
# Pre-fill with extracted categories if available
if st.session_state.extracted_categories:
for i, cat in enumerate(st.session_state.extracted_categories[:MAX_CATEGORIES]):
st.session_state.categories[i] = cat
st.session_state.category_count = min(len(st.session_state.extracted_categories), MAX_CATEGORIES)
st.session_state.extracted_categories = None # Clear after use
placeholder_examples = [
"e.g., Positive sentiment",
"e.g., Negative sentiment",
"e.g., Product feedback",
"e.g., Service complaint",
"e.g., Feature request",
"e.g., Custom category"
]
categories_entered = []
for i in range(st.session_state.category_count):
placeholder = placeholder_examples[i] if i < len(placeholder_examples) else "e.g., Custom category"
cat_value = st.text_input(
f"Category {i+1}",
value=st.session_state.categories[i],
placeholder=placeholder,
key=f"cat_{i}"
)
st.session_state.categories[i] = cat_value
if cat_value.strip():
categories_entered.append(cat_value.strip())
if st.session_state.category_count < MAX_CATEGORIES:
if st.button("+ Add More"):
st.session_state.category_count += 1
st.rerun()
st.markdown("### Model Selection")
# Classification mode selector
classify_mode = st.radio(
"Classification Mode",
options=["Single Model", "Model Comparison", "Ensemble"],
horizontal=True,
key="classify_mode",
help="Single: one model. Comparison: see results from multiple models side-by-side. Ensemble: multiple models vote for consensus."
)
model_tier = st.radio(
"Model Tier",
options=["Free Models", "Bring Your Own Key"],
key="classify_model_tier"
)
# Multi-model mode uses multiselect
is_multi_model = classify_mode in ["Model Comparison", "Ensemble"]
if model_tier == "Free Models":
if is_multi_model:
model_displays = st.multiselect(
"Models (select 2+)",
options=FREE_MODEL_DISPLAY_NAMES,
default=[FREE_MODEL_DISPLAY_NAMES[0], FREE_MODEL_DISPLAY_NAMES[1]] if len(FREE_MODEL_DISPLAY_NAMES) >= 2 else FREE_MODEL_DISPLAY_NAMES[:1],
key="classify_models_multi"
)
models_list = [FREE_MODELS_MAP[d] for d in model_displays]
else:
model_display = st.selectbox("Model", options=FREE_MODEL_DISPLAY_NAMES, key="classify_model")
model = FREE_MODELS_MAP[model_display] # Convert to actual model name
models_list = [model]
api_key = ""
else:
if is_multi_model:
models_list = st.multiselect(
"Models (select 2+)",
options=PAID_MODEL_CHOICES,
default=[PAID_MODEL_CHOICES[0], PAID_MODEL_CHOICES[1]] if len(PAID_MODEL_CHOICES) >= 2 else PAID_MODEL_CHOICES[:1],
key="classify_models_multi_paid"
)
else:
model = st.selectbox("Model", options=PAID_MODEL_CHOICES, key="classify_model_paid")
models_list = [model]
api_key = st.text_input("API Key", type="password", key="classify_api_key")
# Ensemble-specific options
if classify_mode == "Ensemble":
consensus_threshold = st.slider(
"Consensus Threshold",
min_value=0.0,
max_value=1.0,
value=0.5,
step=0.1,
key="consensus_threshold",
help="Minimum agreement ratio needed for consensus (0.5 = majority vote)"
)
if st.button("Categorize Data", type="primary", use_container_width=True):
if input_data is None:
st.error("Please upload data first")
elif not categories_entered:
st.error("Please enter at least one category")
elif is_multi_model and len(models_list) < 2:
st.error("Please select at least 2 models for comparison/ensemble mode")
else:
# Set up progress tracking
mode = None
if input_type_selected == "pdf":
mode_mapping = {
"Image (visual documents)": "image",
"Text (text-heavy)": "text",
"Both (comprehensive)": "both"
}
mode = mode_mapping.get(pdf_mode, "image")
# Build models tuples list: [(model, source, api_key), ...]
models_tuples = []
api_key_error = None
for m in models_list:
actual_key, provider = get_api_key(m, model_tier, api_key)
if not actual_key:
api_key_error = f"{provider} API key not configured for {m}"
break
m_source = get_model_source(m)
models_tuples.append((m, m_source, actual_key))
if api_key_error:
st.error(api_key_error)
else:
items_list = input_data if isinstance(input_data, list) else [input_data]
# Progress UI
progress_bar = st.progress(0)
status_text = st.empty()
start_time = time.time()
# For PDFs, use progress callback
if input_type_selected == "pdf":
# Progress callback for PDF page-by-page updates
def pdf_progress_callback(current_idx, total_pages, page_label):
progress = current_idx / total_pages if total_pages > 0 else 0
progress_bar.progress(min(progress, 1.0))
elapsed = time.time() - start_time
if current_idx > 0:
avg_time = elapsed / current_idx
eta_seconds = avg_time * (total_pages - current_idx)
eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
else:
eta_str = ""
status_text.text(f"Processing page {current_idx+1} of {total_pages} ({page_label}) ({progress*100:.0f}%){eta_str}")
try:
# Build kwargs for classify_ensemble
classify_kwargs = {
"survey_input": items_list,
"categories": categories_entered,
"models": models_tuples,
"input_description": description,
"pdf_mode": mode,
"progress_callback": pdf_progress_callback,
}
# Add consensus_threshold for ensemble mode
if classify_mode == "Ensemble":
classify_kwargs["consensus_threshold"] = consensus_threshold
result_df = catllm.classify_ensemble(**classify_kwargs)
processing_time = time.time() - start_time
total_items = len(result_df)
progress_bar.progress(1.0)
status_text.text(f"Completed {total_items} pages in {processing_time:.1f}s")
# Replace temp paths with original filenames in pdf_input column
if 'pdf_input' in result_df.columns:
pdf_name_map = st.session_state.get('pdf_name_map', {})
def replace_temp_path(val):
if pd.isna(val):
return val
val_str = str(val)
for temp_path, orig_name in pdf_name_map.items():
# Check if the temp path's filename (without extension) is in the value
temp_name = os.path.basename(temp_path).replace('.pdf', '')
if temp_name in val_str:
return val_str.replace(temp_name, orig_name)
return val_str
result_df['pdf_input'] = result_df['pdf_input'].apply(replace_temp_path)
all_results = [result_df]
except Exception as e:
st.error(f"Error: {str(e)}")
all_results = []
else:
# Non-PDF processing (text, images) - process all at once
total_items = len(items_list)
# Progress callback for item-by-item updates
def item_progress_callback(current_idx, total, item_label):
progress = current_idx / total if total > 0 else 0
progress_bar.progress(min(progress, 1.0))
elapsed = time.time() - start_time
if current_idx > 0:
avg_time = elapsed / current_idx
eta_seconds = avg_time * (total - current_idx)
eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
else:
eta_str = ""
status_text.text(f"Processing item {current_idx+1} of {total} ({progress*100:.0f}%){eta_str}")
try:
# Build kwargs for classify_ensemble
classify_kwargs = {
"survey_input": items_list,
"categories": categories_entered,
"models": models_tuples,
"input_description": description,
"progress_callback": item_progress_callback,
}
# Add consensus_threshold for ensemble mode
if classify_mode == "Ensemble":
classify_kwargs["consensus_threshold"] = consensus_threshold
result_df = catllm.classify_ensemble(**classify_kwargs)
all_results = [result_df]
processing_time = time.time() - start_time
progress_bar.progress(1.0)
status_text.text(f"Completed {total_items} items in {processing_time:.1f}s")
except Exception as e:
st.error(f"Error: {str(e)}")
all_results = []
processing_time = time.time() - start_time
if all_results:
# Combine results
result_df = pd.concat(all_results, ignore_index=True)
# Save CSV
with tempfile.NamedTemporaryFile(mode='w', suffix='_classified.csv', delete=False) as f:
result_df.to_csv(f.name, index=False)
csv_path = f.name
# Calculate success rate
if 'processing_status' in result_df.columns:
success_count = (result_df['processing_status'] == 'success').sum()
success_rate = (success_count / len(result_df)) * 100
else:
success_rate = 100.0
# Get version info
try:
catllm_version = catllm.__version__
except AttributeError:
catllm_version = "unknown"
python_version = sys.version.split()[0]
# For reports: create model string (single or list)
if len(models_list) == 1:
report_model = models_list[0]
report_model_source = models_tuples[0][1]
else:
report_model = ", ".join(models_list)
report_model_source = f"{classify_mode} ({len(models_list)} models)"
# Generate methodology report
pdf_path = generate_methodology_report_pdf(
categories=categories_entered,
model=report_model,
column_name=description,
num_rows=len(result_df),
model_source=report_model_source,
filename=original_filename,
success_rate=success_rate,
result_df=result_df,
processing_time=processing_time,
catllm_version=catllm_version,
python_version=python_version,
task_type="assign",
input_type=input_type_selected,
description=description
)
# Generate code
code = generate_classify_code(
input_type_selected, description, categories_entered,
report_model, report_model_source, mode,
classify_mode=classify_mode, models_list=models_list
)
st.session_state.results = {
'df': result_df,
'csv_path': csv_path,
'pdf_path': pdf_path,
'code': code,
'status': f"Classified {len(result_df)} items in {processing_time:.1f}s",
'categories': categories_entered,
'classify_mode': classify_mode,
'models_list': models_list,
}
st.success(f"Classified {len(result_df)} items in {processing_time:.1f}s")
st.rerun()
else:
st.error("No items were successfully classified")
with col_output:
st.markdown("### Results")
if st.session_state.results:
results = st.session_state.results
# Distribution chart
fig = create_distribution_chart(
results['df'],
results['categories'],
classify_mode=results.get('classify_mode', 'Single Model'),
models_list=results.get('models_list', [])
)
st.pyplot(fig)
st.caption("Note: Categories are not mutually exclusive—each item can belong to multiple categories.")
# Results dataframe (hide technical columns from display)
display_df = results['df'].copy()
cols_to_hide = ['model_response', 'json', 'raw_response', 'raw_json']
display_df = display_df.drop(columns=[c for c in cols_to_hide if c in display_df.columns])
st.dataframe(display_df, use_container_width=True)
# Downloads
col_dl1, col_dl2 = st.columns(2)
with col_dl1:
with open(results['csv_path'], 'rb') as f:
st.download_button(
"Download Results (CSV)",
data=f,
file_name="classified_results.csv",
mime="text/csv"
)
with col_dl2:
with open(results['pdf_path'], 'rb') as f:
st.download_button(
"Download Methodology Report (PDF)",
data=f,
file_name="methodology_report.pdf",
mime="application/pdf"
)
# Code
with st.expander("See the Code"):
st.code(results['code'], language='python')
else:
st.info("Upload data, select categories, and click 'Categorize Data' to see results here.")
# Bottom buttons
col_reset, col_code = st.columns(2)
with col_reset:
if st.button("Reset", type="secondary", use_container_width=True):
st.session_state.categories = [''] * MAX_CATEGORIES
st.session_state.category_count = INITIAL_CATEGORIES
st.session_state.task_mode = None
st.session_state.extracted_categories = None
st.session_state.results = None
if hasattr(st.session_state, 'example_loaded'):
del st.session_state.example_loaded
st.rerun()
with col_code:
if st.session_state.results:
if st.button("See in Code", use_container_width=True):
st.session_state.show_code_modal = True
# Code modal/dialog
if st.session_state.get('show_code_modal') and st.session_state.results:
st.markdown("---")
st.markdown("### Reproducibility Code")
st.markdown("Use this code to reproduce the classification with the CatLLM Python package:")
st.code(st.session_state.results['code'], language='python')
if st.button("Close"):
st.session_state.show_code_modal = False
st.rerun()