Spaces:
Running
Running
| """ | |
| Gradio app - Step 5g: Add actual catllm classification | |
| """ | |
| import gradio as gr | |
| import pandas as pd | |
| import tempfile | |
| import os | |
| import time | |
| import sys | |
| from datetime import datetime | |
| import matplotlib.pyplot as plt | |
| # Import catllm | |
| try: | |
| import catllm | |
| CATLLM_AVAILABLE = True | |
| except ImportError as e: | |
| print(f"Warning: Could not import catllm: {e}") | |
| CATLLM_AVAILABLE = False | |
| MAX_CATEGORIES = 10 | |
| INITIAL_CATEGORIES = 3 | |
| # Free models (uses Space secrets - no user API key needed) | |
| FREE_MODEL_CHOICES = [ | |
| "Qwen/Qwen3-VL-235B-A22B-Instruct:novita", | |
| "deepseek-ai/DeepSeek-V3.1:novita", | |
| "meta-llama/Llama-4.1-8B-Instruct:novita", | |
| "gemini-2.5-flash", | |
| "gpt-4o", | |
| "mistral-medium-2505", | |
| "claude-3-haiku-20240307", | |
| "sonar", | |
| "grok-4-fast-non-reasoning", | |
| ] | |
| # Paid models (user provides their own API key) | |
| PAID_MODEL_CHOICES = [ | |
| "gpt-4.1", | |
| "gpt-4o", | |
| "gpt-4o-mini", | |
| "claude-sonnet-4-5-20250929", | |
| "claude-opus-4-20250514", | |
| "claude-3-5-haiku-20241022", | |
| "gemini-2.5-pro", | |
| "gemini-2.5-flash", | |
| "mistral-large-latest", | |
| ] | |
| # Models routed through HuggingFace | |
| HF_ROUTED_MODELS = [ | |
| "Qwen/Qwen3-VL-235B-A22B-Instruct:novita", | |
| "deepseek-ai/DeepSeek-V3.1:novita", | |
| "meta-llama/Llama-4.1-8B-Instruct:novita", | |
| ] | |
| def is_free_model(model, model_tier): | |
| """Check if using free tier (Space pays for API).""" | |
| return model_tier == "Free Models" | |
| def generate_methodology_report_pdf(categories, model, column_name, num_rows, model_source, filename, success_rate, | |
| result_df=None, processing_time=None, prompt_template=None, | |
| data_quality=None, catllm_version=None, python_version=None): | |
| """Generate a PDF methodology report for reproducibility and transparency.""" | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.lib import colors | |
| from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle | |
| from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak | |
| # Create temp file for PDF | |
| pdf_file = tempfile.NamedTemporaryFile(mode='wb', suffix='_methodology_report.pdf', delete=False) | |
| doc = SimpleDocTemplate(pdf_file.name, pagesize=letter) | |
| styles = getSampleStyleSheet() | |
| # Custom styles | |
| title_style = ParagraphStyle('Title', parent=styles['Heading1'], fontSize=18, spaceAfter=20) | |
| heading_style = ParagraphStyle('Heading', parent=styles['Heading2'], fontSize=14, spaceAfter=10, spaceBefore=15) | |
| normal_style = styles['Normal'] | |
| code_style = ParagraphStyle('Code', parent=styles['Normal'], fontName='Courier', fontSize=9, leftIndent=20, spaceAfter=3) | |
| story = [] | |
| # === PAGE 1: Title, About, Category Mapping === | |
| story.append(Paragraph("CatLLM Methodology Report", title_style)) | |
| story.append(Paragraph(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", normal_style)) | |
| story.append(Spacer(1, 15)) | |
| # About CatLLM - addressing prompt hacking | |
| story.append(Paragraph("About This Report", heading_style)) | |
| about_text = """This methodology report documents the classification process for reproducibility and transparency. \ | |
| CatLLM addresses an issue identified by researchers in "Prompt-Hacking: The New p-Hacking?" (Kosch & Feger, 2025; \ | |
| arXiv:2504.14571): researchers could keep modifying prompts to obtain outputs that support desired conclusions, and \ | |
| this variability in pseudo-natural language poses a challenge for reproducibility since each prompt, even if only \ | |
| slightly altered, can yield different outputs, making it impossible to replicate findings reliably. CatLLM restricts \ | |
| the prompt to a standard template that is impartial to the researcher's inclinations, ensuring \ | |
| consistent and reproducible results.""" | |
| story.append(Paragraph(about_text, normal_style)) | |
| story.append(Spacer(1, 15)) | |
| # Category mapping | |
| story.append(Paragraph("Category Mapping", heading_style)) | |
| story.append(Paragraph("Each category column contains binary values: 1 = present, 0 = not present", normal_style)) | |
| story.append(Spacer(1, 8)) | |
| category_data = [["Column Name", "Category Description"]] | |
| for i, cat in enumerate(categories, 1): | |
| category_data.append([f"category_{i}", cat]) | |
| cat_table = Table(category_data, colWidths=[120, 330]) | |
| cat_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (-1, 0), colors.grey), | |
| ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('PADDING', (0, 0), (-1, -1), 6), | |
| ('BACKGROUND', (0, 1), (0, -1), colors.lightgrey), | |
| ('FONTSIZE', (0, 0), (-1, -1), 9), | |
| ])) | |
| story.append(cat_table) | |
| story.append(Spacer(1, 15)) | |
| # Other columns | |
| story.append(Paragraph("Other Output Columns", heading_style)) | |
| other_cols = [ | |
| ["Column Name", "Description"], | |
| ["survey_input", "The original text that was classified"], | |
| ["model_response", "Raw response from the LLM"], | |
| ["json", "Extracted JSON with category assignments"], | |
| ["processing_status", "'success' if classification worked, 'error' if failed"], | |
| ["categories_id", "Comma-separated list of assigned category numbers"], | |
| ] | |
| other_table = Table(other_cols, colWidths=[120, 330]) | |
| other_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (-1, 0), colors.grey), | |
| ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('PADDING', (0, 0), (-1, -1), 6), | |
| ('BACKGROUND', (0, 1), (0, -1), colors.lightgrey), | |
| ('FONTSIZE', (0, 0), (-1, -1), 9), | |
| ])) | |
| story.append(other_table) | |
| # Citation at end of page 1 | |
| story.append(Spacer(1, 30)) | |
| story.append(Paragraph("Citation", heading_style)) | |
| story.append(Paragraph("If you use CatLLM in your research, please cite:", normal_style)) | |
| story.append(Spacer(1, 5)) | |
| story.append(Paragraph("Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DOI: 10.5281/zenodo.15532316", normal_style)) | |
| # === PAGE 2: Sample Results === | |
| if result_df is not None and len(result_df) > 0: | |
| story.append(PageBreak()) | |
| story.append(Paragraph("Sample Results (First 5 Rows)", title_style)) | |
| story.append(Paragraph("Example classifications showing original text and assigned categories:", normal_style)) | |
| story.append(Spacer(1, 15)) | |
| sample_data = [["Original Text (truncated)", "Assigned Categories"]] | |
| sample_df = result_df.head(5) | |
| for _, row in sample_df.iterrows(): | |
| # Get original text, truncate to 80 chars | |
| original_text = str(row.get('survey_input', ''))[:80] | |
| if len(str(row.get('survey_input', ''))) > 80: | |
| original_text += "..." | |
| # Get assigned categories | |
| assigned = row.get('categories_id', '') | |
| if pd.isna(assigned) or assigned == '': | |
| assigned = "None" | |
| sample_data.append([original_text, str(assigned)]) | |
| sample_table = Table(sample_data, colWidths=[320, 130]) | |
| sample_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (-1, 0), colors.grey), | |
| ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('PADDING', (0, 0), (-1, -1), 8), | |
| ('FONTSIZE', (0, 0), (-1, -1), 9), | |
| ('VALIGN', (0, 0), (-1, -1), 'TOP'), | |
| ])) | |
| story.append(sample_table) | |
| # === PAGE 3: Category Distribution === | |
| story.append(PageBreak()) | |
| story.append(Paragraph("Category Distribution", title_style)) | |
| story.append(Paragraph("Count and percentage of responses assigned to each category:", normal_style)) | |
| story.append(Spacer(1, 15)) | |
| if result_df is not None: | |
| dist_data = [["Category", "Description", "Count", "Percentage"]] | |
| total_rows = len(result_df) | |
| for i, cat in enumerate(categories, 1): | |
| col_name = f"category_{i}" | |
| if col_name in result_df.columns: | |
| count = int(result_df[col_name].sum()) | |
| pct = (count / total_rows) * 100 if total_rows > 0 else 0 | |
| dist_data.append([col_name, cat[:40], str(count), f"{pct:.1f}%"]) | |
| else: | |
| dist_data.append([col_name, cat[:40], "N/A", "N/A"]) | |
| dist_table = Table(dist_data, colWidths=[80, 200, 60, 80]) | |
| dist_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (-1, 0), colors.grey), | |
| ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke), | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('PADDING', (0, 0), (-1, -1), 6), | |
| ('FONTSIZE', (0, 0), (-1, -1), 9), | |
| ('ALIGN', (2, 1), (-1, -1), 'CENTER'), | |
| ])) | |
| story.append(dist_table) | |
| story.append(Spacer(1, 15)) | |
| story.append(Paragraph(f"<i>Note: Percentages may sum to more than 100% as responses can be assigned to multiple categories.</i>", normal_style)) | |
| # === PAGE 4: Classification Summary (Expanded) === | |
| story.append(PageBreak()) | |
| story.append(Paragraph("Classification Summary", title_style)) | |
| story.append(Spacer(1, 15)) | |
| # Basic summary | |
| story.append(Paragraph("Classification Details", heading_style)) | |
| summary_data = [ | |
| ["Source File", filename], | |
| ["Source Column", column_name], | |
| ["Model Used", model], | |
| ["Model Source", model_source], | |
| ["Temperature", "default"], | |
| ["Rows Classified", str(num_rows)], | |
| ["Number of Categories", str(len(categories))], | |
| ["Success Rate", f"{success_rate:.2f}%"], | |
| ] | |
| summary_table = Table(summary_data, colWidths=[150, 300]) | |
| summary_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (0, -1), colors.lightgrey), | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('PADDING', (0, 0), (-1, -1), 6), | |
| ('FONTSIZE', (0, 0), (-1, -1), 9), | |
| ])) | |
| story.append(summary_table) | |
| story.append(Spacer(1, 15)) | |
| # Processing Time | |
| if processing_time is not None: | |
| story.append(Paragraph("Processing Time", heading_style)) | |
| rows_per_min = (num_rows / processing_time) * 60 if processing_time > 0 else 0 | |
| avg_time = processing_time / num_rows if num_rows > 0 else 0 | |
| time_data = [ | |
| ["Total Processing Time", f"{processing_time:.1f} seconds"], | |
| ["Average Time per Response", f"{avg_time:.2f} seconds"], | |
| ["Processing Rate", f"{rows_per_min:.1f} rows/minute"], | |
| ] | |
| time_table = Table(time_data, colWidths=[180, 270]) | |
| time_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (0, -1), colors.lightgrey), | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('PADDING', (0, 0), (-1, -1), 6), | |
| ('FONTSIZE', (0, 0), (-1, -1), 9), | |
| ])) | |
| story.append(time_table) | |
| story.append(Spacer(1, 15)) | |
| # Data Quality Notes | |
| if data_quality is not None: | |
| story.append(Paragraph("Data Quality Notes", heading_style)) | |
| quality_data = [ | |
| ["Empty/Null Inputs Skipped", str(data_quality.get('null_count', 0))], | |
| ["Average Text Length", f"{data_quality.get('avg_length', 0)} characters"], | |
| ["Min Text Length", f"{data_quality.get('min_length', 0)} characters"], | |
| ["Max Text Length", f"{data_quality.get('max_length', 0)} characters"], | |
| ["Responses with Errors", str(data_quality.get('error_count', 0))], | |
| ] | |
| quality_table = Table(quality_data, colWidths=[180, 270]) | |
| quality_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (0, -1), colors.lightgrey), | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('PADDING', (0, 0), (-1, -1), 6), | |
| ('FONTSIZE', (0, 0), (-1, -1), 9), | |
| ])) | |
| story.append(quality_table) | |
| story.append(Spacer(1, 15)) | |
| # Version Information | |
| story.append(Paragraph("Version Information", heading_style)) | |
| version_data = [ | |
| ["CatLLM Version", catllm_version or "unknown"], | |
| ["Python Version", python_version or "unknown"], | |
| ["Timestamp", datetime.now().strftime('%Y-%m-%d %H:%M:%S')], | |
| ] | |
| version_table = Table(version_data, colWidths=[180, 270]) | |
| version_table.setStyle(TableStyle([ | |
| ('BACKGROUND', (0, 0), (0, -1), colors.lightgrey), | |
| ('GRID', (0, 0), (-1, -1), 1, colors.black), | |
| ('PADDING', (0, 0), (-1, -1), 6), | |
| ('FONTSIZE', (0, 0), (-1, -1), 9), | |
| ])) | |
| story.append(version_table) | |
| # === PAGE 5: Prompt Template === | |
| story.append(PageBreak()) | |
| story.append(Paragraph("Prompt Template Used", title_style)) | |
| story.append(Paragraph("The following prompt template was sent to the LLM for each classification:", normal_style)) | |
| story.append(Spacer(1, 15)) | |
| if prompt_template: | |
| # Show the template with placeholders | |
| story.append(Paragraph("Template with Placeholders:", heading_style)) | |
| story.append(Spacer(1, 8)) | |
| for line in prompt_template.split('\n'): | |
| escaped_line = line.replace('&', '&').replace('<', '<').replace('>', '>') | |
| if escaped_line.strip(): | |
| story.append(Paragraph(escaped_line, code_style)) | |
| else: | |
| story.append(Spacer(1, 5)) | |
| story.append(Spacer(1, 20)) | |
| # Show example with actual categories | |
| story.append(Paragraph("Example with Your Categories:", heading_style)) | |
| story.append(Spacer(1, 8)) | |
| categories_list = "\n".join([f" {i}. {cat}" for i, cat in enumerate(categories, 1)]) | |
| example_prompt = f'''Categorize this survey response "[YOUR TEXT HERE]" into the following categories: | |
| {categories_list} | |
| Provide your work in JSON format where the number belonging to each category | |
| is the key and a 1 if the category is present and a 0 if not.''' | |
| for line in example_prompt.split('\n'): | |
| escaped_line = line.replace('&', '&').replace('<', '<').replace('>', '>') | |
| if escaped_line.strip(): | |
| story.append(Paragraph(escaped_line, code_style)) | |
| else: | |
| story.append(Spacer(1, 5)) | |
| # === PAGE 6: Reproducibility Code === | |
| story.append(PageBreak()) | |
| story.append(Paragraph("Reproducibility Code", title_style)) | |
| story.append(Paragraph("Use the following Python code to reproduce this classification:", normal_style)) | |
| story.append(Spacer(1, 15)) | |
| # Build categories list string | |
| categories_str = ", ".join([f'"{cat}"' for cat in categories]) | |
| code_text = f'''import catllm | |
| import pandas as pd | |
| # Load your survey data | |
| df = pd.read_csv("{filename}") | |
| # Define your categories | |
| categories = [{categories_str}] | |
| # Classify the responses | |
| result = catllm.multi_class( | |
| survey_input=df["{column_name}"].tolist(), | |
| categories=categories, | |
| api_key="YOUR_API_KEY", | |
| user_model="{model}", | |
| model_source="{model_source}" | |
| ) | |
| # View results | |
| print(result) | |
| # Save to CSV | |
| result.to_csv("classified_results.csv", index=False)''' | |
| # Split code into lines and add each as a paragraph | |
| for line in code_text.split('\n'): | |
| if line.strip() == '': | |
| story.append(Spacer(1, 5)) | |
| else: | |
| # Escape special characters for PDF | |
| escaped_line = line.replace('&', '&').replace('<', '<').replace('>', '>') | |
| story.append(Paragraph(escaped_line, code_style)) | |
| doc.build(story) | |
| return pdf_file.name | |
| def get_model_source(model): | |
| """Auto-detect model source. All HF router models (novita, groq, etc) use 'huggingface'.""" | |
| model_lower = model.lower() | |
| if "gpt" in model_lower: | |
| return "openai" | |
| elif "claude" in model_lower: | |
| return "anthropic" | |
| elif "gemini" in model_lower: | |
| return "google" | |
| elif "mistral" in model_lower and ":novita" not in model_lower: | |
| return "mistral" | |
| # All models routed through HuggingFace (including novita, groq variants) | |
| elif any(x in model_lower for x in [":novita", ":groq", "qwen", "llama", "deepseek"]): | |
| return "huggingface" | |
| return "huggingface" | |
| def load_example_dataset(): | |
| """Load the example dataset for users to try the app.""" | |
| example_path = "example_data.csv" | |
| try: | |
| df = pd.read_csv(example_path) | |
| columns = df.columns.tolist() | |
| return ( | |
| example_path, # file path | |
| gr.update(choices=columns, value=columns[0] if columns else None), # column dropdown | |
| f"Loaded example dataset ({len(df)} rows). Select column and click Classify." # status | |
| ) | |
| except Exception as e: | |
| return None, gr.update(choices=[], value=None), f"**Error loading example:** {str(e)}" | |
| def load_columns(file): | |
| if file is None: | |
| return gr.update(choices=[], value=None), "Please upload a file first" | |
| try: | |
| file_path = file if isinstance(file, str) else file.name | |
| if file_path.endswith('.csv'): | |
| df = pd.read_csv(file_path) | |
| else: | |
| df = pd.read_excel(file_path) | |
| columns = df.columns.tolist() | |
| num_rows = len(df) | |
| # Warning for large datasets | |
| if num_rows > 1000: | |
| est_minutes = round(num_rows * 1.5 / 60) # ~1.5 seconds per row estimate | |
| status_msg = f"⚠️ **Large dataset** ({num_rows:,} rows). Classification may take ~{est_minutes} minutes. Select column and click Classify." | |
| else: | |
| status_msg = f"Loaded {num_rows:,} rows. Select column and click Classify." | |
| return ( | |
| gr.update(choices=columns, value=columns[0] if columns else None), | |
| status_msg | |
| ) | |
| except Exception as e: | |
| return gr.update(choices=[], value=None), f"**Error:** {str(e)}" | |
| def classify_data(spreadsheet_file, spreadsheet_column, | |
| cat1, cat2, cat3, cat4, cat5, cat6, cat7, cat8, cat9, cat10, | |
| model_tier, model, model_source_input, api_key_input): | |
| """Main classification function with progress updates. Yields status updates then final results.""" | |
| if not CATLLM_AVAILABLE: | |
| yield None, None, None, None, "**Error:** catllm package not available" | |
| return | |
| all_cats = [cat1, cat2, cat3, cat4, cat5, cat6, cat7, cat8, cat9, cat10] | |
| categories = [c.strip() for c in all_cats if c and c.strip()] | |
| if not categories: | |
| yield None, None, None, None, "**Error:** Please enter at least one category" | |
| return | |
| actual_model = model | |
| # Get API key based on tier | |
| if is_free_model(model, model_tier): | |
| # Free tier - use Space secrets | |
| if model in HF_ROUTED_MODELS: | |
| actual_api_key = os.environ.get("HF_API_KEY", "") | |
| if not actual_api_key: | |
| yield None, None, None, None, "**Error:** HuggingFace API key not configured in Space secrets" | |
| return | |
| elif "gpt" in model.lower(): | |
| actual_api_key = os.environ.get("OPENAI_API_KEY", "") | |
| if not actual_api_key: | |
| yield None, None, None, None, "**Error:** OpenAI API key not configured in Space secrets" | |
| return | |
| elif "gemini" in model.lower(): | |
| actual_api_key = os.environ.get("GOOGLE_API_KEY", "") | |
| if not actual_api_key: | |
| yield None, None, None, None, "**Error:** Google API key not configured in Space secrets" | |
| return | |
| elif "mistral" in model.lower(): | |
| actual_api_key = os.environ.get("MISTRAL_API_KEY", "") | |
| if not actual_api_key: | |
| yield None, None, None, None, "**Error:** Mistral API key not configured in Space secrets" | |
| return | |
| elif "claude" in model.lower(): | |
| actual_api_key = os.environ.get("ANTHROPIC_API_KEY", "") | |
| if not actual_api_key: | |
| yield None, None, None, None, "**Error:** Anthropic API key not configured in Space secrets" | |
| return | |
| elif "sonar" in model.lower(): | |
| actual_api_key = os.environ.get("PERPLEXITY_API_KEY", "") | |
| if not actual_api_key: | |
| yield None, None, None, None, "**Error:** Perplexity API key not configured in Space secrets" | |
| return | |
| elif "grok" in model.lower(): | |
| actual_api_key = os.environ.get("XAI_API_KEY", "") | |
| if not actual_api_key: | |
| yield None, None, None, None, "**Error:** xAI API key not configured in Space secrets" | |
| return | |
| else: | |
| actual_api_key = os.environ.get("HF_API_KEY", "") | |
| else: | |
| # Paid tier - user provides their own API key | |
| if api_key_input and api_key_input.strip(): | |
| actual_api_key = api_key_input.strip() | |
| else: | |
| yield None, None, None, None, f"**Error:** Please provide your API key for {model}" | |
| return | |
| # Use user-selected model_source, or auto-detect if "auto" | |
| if model_source_input == "auto": | |
| model_source = get_model_source(actual_model) | |
| else: | |
| model_source = model_source_input | |
| try: | |
| if not spreadsheet_file: | |
| yield None, None, None, None, "**Error:** Please upload a file" | |
| return | |
| if not spreadsheet_column: | |
| yield None, None, None, None, "**Error:** Please select a column to classify" | |
| return | |
| file_path = spreadsheet_file if isinstance(spreadsheet_file, str) else spreadsheet_file.name | |
| if file_path.endswith('.csv'): | |
| df = pd.read_csv(file_path) | |
| else: | |
| df = pd.read_excel(file_path) | |
| if spreadsheet_column not in df.columns: | |
| yield None, None, None, None, f"**Error:** Column '{spreadsheet_column}' not found" | |
| return | |
| input_data = df[spreadsheet_column].tolist() | |
| # Progress update: data loaded | |
| yield None, None, None, None, f"⏳ **Loading data...** Found {len(input_data)} responses to classify." | |
| # Calculate data quality metrics before classification | |
| text_series = df[spreadsheet_column].dropna().astype(str) | |
| data_quality = { | |
| 'null_count': int(df[spreadsheet_column].isna().sum()), | |
| 'avg_length': round(text_series.str.len().mean(), 1) if len(text_series) > 0 else 0, | |
| 'min_length': int(text_series.str.len().min()) if len(text_series) > 0 else 0, | |
| 'max_length': int(text_series.str.len().max()) if len(text_series) > 0 else 0, | |
| 'error_count': 0 # Will be updated after classification | |
| } | |
| # Progress update: starting classification | |
| yield None, None, None, None, f"🔄 **Classifying {len(input_data)} responses...** This may take a moment." | |
| # Capture timing | |
| start_time = time.time() | |
| result = catllm.multi_class( | |
| survey_input=input_data, | |
| categories=categories, | |
| api_key=actual_api_key, | |
| user_model=actual_model, | |
| model_source=model_source | |
| ) | |
| processing_time = time.time() - start_time | |
| # Update error count from results | |
| if 'processing_status' in result.columns: | |
| data_quality['error_count'] = int((result['processing_status'] == 'error').sum()) | |
| # Save CSV for download | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='_classified.csv', delete=False) as f: | |
| result.to_csv(f.name, index=False) | |
| csv_path = f.name | |
| # Get original filename for methodology report | |
| original_filename = file_path.split("/")[-1] | |
| # Calculate success rate | |
| if 'processing_status' in result.columns: | |
| success_count = (result['processing_status'] == 'success').sum() | |
| success_rate = (success_count / len(result)) * 100 | |
| else: | |
| success_rate = 100.0 | |
| # Build prompt template for documentation (chain of thought - default) | |
| prompt_template = '''Categorize this survey response "{response}" into the following categories that apply: | |
| {categories} | |
| Let's think step by step: | |
| 1. First, identify the main themes mentioned in the response | |
| 2. Then, match each theme to the relevant categories | |
| 3. Finally, assign 1 to matching categories and 0 to non-matching categories | |
| Provide your work in JSON format where the number belonging to each category is the key and a 1 if the category is present and a 0 if it is not present as key values.''' | |
| # Get version info | |
| try: | |
| catllm_version = catllm.__version__ | |
| except AttributeError: | |
| catllm_version = "unknown" | |
| python_version = sys.version.split()[0] | |
| # Progress update: generating report | |
| yield None, None, None, None, f"📄 **Generating methodology report...** Classification complete in {processing_time:.1f}s." | |
| # Generate PDF methodology report with all new data | |
| pdf_path = generate_methodology_report_pdf( | |
| categories=categories, | |
| model=actual_model, | |
| column_name=spreadsheet_column, | |
| num_rows=len(input_data), | |
| model_source=model_source, | |
| filename=original_filename, | |
| success_rate=success_rate, | |
| result_df=result, | |
| processing_time=processing_time, | |
| prompt_template=prompt_template, | |
| data_quality=data_quality, | |
| catllm_version=catllm_version, | |
| python_version=python_version | |
| ) | |
| # Build distribution data and create matplotlib plot | |
| dist_data = [] | |
| total_rows = len(result) | |
| for i, cat in enumerate(categories, 1): | |
| col_name = f"category_{i}" | |
| if col_name in result.columns: | |
| count = int(result[col_name].sum()) | |
| pct = (count / total_rows) * 100 if total_rows > 0 else 0 | |
| dist_data.append({ | |
| "Category": cat, | |
| "Percentage": round(pct, 1) | |
| }) | |
| # Create matplotlib horizontal bar chart | |
| fig, ax = plt.subplots(figsize=(10, max(4, len(dist_data) * 0.8))) | |
| categories_list = [d["Category"] for d in dist_data] | |
| percentages = [d["Percentage"] for d in dist_data] | |
| # Reverse order so first category is at top | |
| categories_list = categories_list[::-1] | |
| percentages = percentages[::-1] | |
| bars = ax.barh(categories_list, percentages, color='#2563eb') | |
| ax.set_xlim(0, 100) | |
| ax.set_xlabel('Percentage (%)', fontsize=11) | |
| ax.set_title('Category Distribution (%)', fontsize=14, fontweight='bold') | |
| # Add percentage labels on bars | |
| for bar, pct in zip(bars, percentages): | |
| ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2, | |
| f'{pct:.1f}%', va='center', fontsize=10) | |
| plt.tight_layout() | |
| distribution_fig = fig | |
| # Build sample results DataFrame (first 5 rows) | |
| sample_data = [] | |
| for _, row in result.head(5).iterrows(): | |
| original_text = str(row.get('survey_input', ''))[:100] | |
| if len(str(row.get('survey_input', ''))) > 100: | |
| original_text += "..." | |
| assigned = row.get('categories_id', '') | |
| if pd.isna(assigned) or assigned == '': | |
| assigned = "None" | |
| sample_data.append({ | |
| "Original Text": original_text, | |
| "Assigned Categories": str(assigned) | |
| }) | |
| sample_df = pd.DataFrame(sample_data) | |
| # Final yield: distribution plot (visible), samples (visible), full results (visible), files, status | |
| yield ( | |
| gr.update(value=distribution_fig, visible=True), | |
| gr.update(value=sample_df, visible=True), | |
| gr.update(value=result, visible=True), | |
| [csv_path, pdf_path], | |
| f"✅ **Success!** Classified {len(input_data)} responses in {processing_time:.1f}s" | |
| ) | |
| except Exception as e: | |
| yield None, None, None, None, f"**Error:** {str(e)}" | |
| def add_category_field(current_count): | |
| new_count = min(current_count + 1, MAX_CATEGORIES) | |
| updates = [] | |
| for i in range(MAX_CATEGORIES): | |
| updates.append(gr.update(visible=(i < new_count))) | |
| updates.append(gr.update(visible=(new_count < MAX_CATEGORIES))) | |
| updates.append(new_count) | |
| return updates | |
| def reset_all(): | |
| """Reset all inputs and outputs to initial state.""" | |
| updates = [ | |
| None, # spreadsheet_file | |
| gr.update(choices=[], value=None), # spreadsheet_column | |
| ] | |
| # Reset category inputs (first 3 visible, rest hidden, all empty) | |
| for i in range(MAX_CATEGORIES): | |
| updates.append(gr.update(value="", visible=(i < INITIAL_CATEGORIES))) | |
| updates.extend([ | |
| gr.update(visible=True), # add_category_btn | |
| INITIAL_CATEGORIES, # category_count | |
| "Free Models", # model_tier | |
| FREE_MODEL_CHOICES[0], # model | |
| "auto", # model_source | |
| "", # api_key | |
| "**Free tier** - no API key required! We cover the cost while CatLLM is in review.", # api_key_status | |
| "Ready to classify", # status | |
| gr.update(value=None, visible=False), # distribution_plot | |
| gr.update(value=None, visible=False), # sample_results | |
| gr.update(value=None, visible=False), # results | |
| None, # download_file | |
| gr.update(value="", visible=False), # code_output | |
| ]) | |
| return updates | |
| def generate_code(spreadsheet_file, spreadsheet_column, | |
| cat1, cat2, cat3, cat4, cat5, cat6, cat7, cat8, cat9, cat10, | |
| model_tier, model, model_source_input): | |
| """Generate Python code snippet based on user inputs.""" | |
| all_cats = [cat1, cat2, cat3, cat4, cat5, cat6, cat7, cat8, cat9, cat10] | |
| categories = [c.strip() for c in all_cats if c and c.strip()] | |
| actual_model = model | |
| # Determine model source | |
| if model_source_input == "auto": | |
| model_source = get_model_source(actual_model) | |
| else: | |
| model_source = model_source_input | |
| # Get filename from uploaded file | |
| if spreadsheet_file: | |
| file_path = spreadsheet_file if isinstance(spreadsheet_file, str) else spreadsheet_file.name | |
| filename = file_path.split("/")[-1] | |
| else: | |
| filename = "your_survey_data.csv" | |
| # Build categories list string | |
| categories_str = ", ".join([f'"{cat}"' for cat in categories]) if categories else '"Category1", "Category2"' | |
| # Build the code snippet | |
| code = f'''import catllm | |
| import pandas as pd | |
| # Load your survey data | |
| df = pd.read_csv("{filename}") | |
| # Define your categories | |
| categories = [{categories_str}] | |
| # Classify the responses | |
| result = catllm.multi_class( | |
| survey_input=df["{spreadsheet_column or 'your_column'}"].tolist(), | |
| categories=categories, | |
| api_key="YOUR_API_KEY", # Replace with your actual API key | |
| user_model="{actual_model}", | |
| model_source="{model_source}" | |
| ) | |
| # View results | |
| print(result) | |
| # Save to CSV | |
| result.to_csv("classified_results.csv", index=False) | |
| ''' | |
| return gr.update(value=code, visible=True) | |
| custom_css = """ | |
| * { | |
| font-family: Helvetica, Arial, sans-serif !important; | |
| } | |
| """ | |
| with gr.Blocks(title="CatLLM - Survey Response Classifier", theme=gr.themes.Soft(), css=custom_css) as demo: | |
| gr.Image("logo.png", show_label=False, show_download_button=False, height=100, container=False) | |
| gr.Markdown("# CatLLM - Survey Response Classifier") | |
| gr.Markdown("Classify survey responses into custom categories using LLMs.") | |
| with gr.Accordion("About This App", open=False): | |
| gr.Markdown(""" | |
| **CatLLM** is an open-source Python package for classifying text data using Large Language Models. | |
| ### What It Does | |
| - Classifies survey responses, open-ended text, and other unstructured data into custom categories | |
| - Supports multiple LLM providers: OpenAI, Anthropic, Google, HuggingFace, and more | |
| - Returns structured results with category assignments for each response | |
| - Tested on over 40,000 rows of data with a 100% structured output rate (actual output rate ~99.98% due to occasional server errors) | |
| ### Beta Test - We Want Your Feedback! | |
| This app is currently in **beta** and **free to use** while CatLLM is under review for publication. We're actively accepting feedback to improve the tool for researchers. | |
| - Found a bug? Have a feature request? Please open an issue or submit a PR on [GitHub](https://github.com/chrissoria/cat-llm) so other researchers can benefit! | |
| - **Open to collaboration** on research projects involving LLM-based classification | |
| - Reach out directly: [chrissoria@berkeley.edu](mailto:chrissoria@berkeley.edu) | |
| - Connect on [LinkedIn](https://www.linkedin.com/in/chris-soria-9340931a/) | |
| ### Links | |
| - 📦 **PyPI**: [pip install cat-llm](https://pypi.org/project/cat-llm/) | |
| - 💻 **GitHub**: [github.com/chrissoria/cat-llm](https://github.com/chrissoria/cat-llm) | |
| - 🌐 **Author**: [christophersoria.com](https://christophersoria.com) | |
| ### Citation | |
| If you use CatLLM in your research, please cite: | |
| ``` | |
| Soria, C. (2025). CatLLM: A Python package for LLM-based text classification. DOI: 10.5281/zenodo.15532316 | |
| ``` | |
| """) | |
| category_count = gr.State(value=INITIAL_CATEGORIES) | |
| with gr.Row(): | |
| with gr.Column(): | |
| spreadsheet_file = gr.File( | |
| label="Upload Survey Data (CSV or Excel)", | |
| file_types=[".csv", ".xlsx", ".xls"] | |
| ) | |
| example_btn = gr.Button("📋 Try Example Dataset", variant="secondary", size="sm") | |
| spreadsheet_column = gr.Dropdown( | |
| label="Column to Classify", | |
| choices=[], | |
| info="Select the column containing text to classify" | |
| ) | |
| gr.Markdown("### Categories") | |
| category_inputs = [] | |
| placeholder_examples = [ | |
| "e.g., Positive sentiment", | |
| "e.g., Negative sentiment", | |
| "e.g., Product feedback", | |
| "e.g., Service complaint", | |
| "e.g., Feature request", | |
| "e.g., Custom category" | |
| ] | |
| for i in range(MAX_CATEGORIES): | |
| visible = i < INITIAL_CATEGORIES | |
| placeholder = placeholder_examples[i] if i < len(placeholder_examples) else "e.g., Custom category" | |
| cat_input = gr.Textbox( | |
| label=f"Category {i+1}", | |
| placeholder=placeholder, | |
| visible=visible | |
| ) | |
| category_inputs.append(cat_input) | |
| add_category_btn = gr.Button("+ Add More Categories", variant="secondary", size="sm") | |
| gr.Markdown("### Model") | |
| model_tier = gr.Radio( | |
| choices=["Free Models", "Bring Your Own Key"], | |
| value="Free Models", | |
| label="Model Tier", | |
| info="Free models use our API keys. 'Bring Your Own Key' lets you use your own API key." | |
| ) | |
| model = gr.Dropdown( | |
| choices=FREE_MODEL_CHOICES, | |
| value="Qwen/Qwen3-VL-235B-A22B-Instruct:novita", | |
| label="Model", | |
| allow_custom_value=True | |
| ) | |
| model_source = gr.Dropdown( | |
| choices=["auto", "openai", "anthropic", "google", "mistral", "xai", "huggingface", "perplexity"], | |
| value="auto", | |
| label="Model Source", | |
| info="Auto-detects from model name, or select manually." | |
| ) | |
| api_key = gr.Textbox( | |
| label="API Key", | |
| type="password", | |
| placeholder="Enter your API key", | |
| info="Required for 'Bring Your Own Key' tier", | |
| visible=False | |
| ) | |
| api_key_status = gr.Markdown("**Free tier** - no API key required! We cover the cost while CatLLM is in review.") | |
| classify_btn = gr.Button("Classify", variant="primary", size="lg") | |
| with gr.Row(): | |
| see_code_btn = gr.Button("See the Code", variant="secondary") | |
| reset_btn = gr.Button("Reset", variant="stop") | |
| with gr.Column(): | |
| status = gr.Markdown("Ready to classify") | |
| distribution_plot = gr.Plot( | |
| label="Category Distribution (%)", | |
| visible=False | |
| ) | |
| sample_results = gr.DataFrame(label="Sample Results (First 5 Rows)", visible=False) | |
| results = gr.DataFrame(label="Full Classification Results", visible=False) | |
| download_file = gr.File(label="Download Results (CSV + Methodology Report)", file_count="multiple") | |
| code_output = gr.Code( | |
| label="Python Code", | |
| language="python", | |
| visible=False | |
| ) | |
| # Event handlers | |
| def update_model_tier(tier): | |
| """Update model choices and API key visibility based on tier.""" | |
| if tier == "Free Models": | |
| return ( | |
| gr.update(choices=FREE_MODEL_CHOICES, value=FREE_MODEL_CHOICES[0]), | |
| gr.update(visible=False), | |
| "**Free tier** - no API key required! We cover the cost while CatLLM is in review." | |
| ) | |
| else: | |
| return ( | |
| gr.update(choices=PAID_MODEL_CHOICES, value=PAID_MODEL_CHOICES[0]), | |
| gr.update(visible=True), | |
| "**Bring Your Own Key** - enter your API key below." | |
| ) | |
| model_tier.change( | |
| fn=update_model_tier, | |
| inputs=[model_tier], | |
| outputs=[model, api_key, api_key_status] | |
| ) | |
| spreadsheet_file.change( | |
| fn=load_columns, | |
| inputs=[spreadsheet_file], | |
| outputs=[spreadsheet_column, status] | |
| ) | |
| example_btn.click( | |
| fn=load_example_dataset, | |
| inputs=[], | |
| outputs=[spreadsheet_file, spreadsheet_column, status] | |
| ) | |
| add_category_btn.click( | |
| fn=add_category_field, | |
| inputs=[category_count], | |
| outputs=category_inputs + [add_category_btn, category_count] | |
| ) | |
| classify_btn.click( | |
| fn=classify_data, | |
| inputs=[spreadsheet_file, spreadsheet_column] + category_inputs + [model_tier, model, model_source, api_key], | |
| outputs=[distribution_plot, sample_results, results, download_file, status] | |
| ) | |
| see_code_btn.click( | |
| fn=generate_code, | |
| inputs=[spreadsheet_file, spreadsheet_column] + category_inputs + [model_tier, model, model_source], | |
| outputs=[code_output] | |
| ) | |
| reset_btn.click( | |
| fn=reset_all, | |
| inputs=[], | |
| outputs=[spreadsheet_file, spreadsheet_column] + category_inputs + [add_category_btn, category_count, model_tier, model, model_source, api_key, api_key_status, status, distribution_plot, sample_results, results, download_file, code_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |