chrissoria commited on
Commit
a5c1ab5
·
verified ·
1 Parent(s): dfbf34e

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +2368 -0
app.py ADDED
@@ -0,0 +1,2368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Streamlit app - CatVader Social Media Classifier
3
+ Migrated from Gradio for better mobile support
4
+ """
5
+
6
+ import streamlit as st
7
+ import pandas as pd
8
+ import tempfile
9
+ import os
10
+ import time
11
+ import sys
12
+ from datetime import datetime
13
+ import matplotlib.pyplot as plt
14
+
15
+ # Import catvader
16
+ try:
17
+ import catvader
18
+ CATVADER_AVAILABLE = True
19
+ except ImportError as e:
20
+ print(f"Warning: Could not import catvader: {e}")
21
+ CATVADER_AVAILABLE = False
22
+
23
+ MAX_CATEGORIES = 10
24
+ INITIAL_CATEGORIES = 3
25
+ MAX_FILE_SIZE_MB = 100
26
+
27
+ def count_pdf_pages(pdf_path):
28
+ """Count the number of pages in a PDF file."""
29
+ try:
30
+ import fitz # PyMuPDF
31
+ doc = fitz.open(pdf_path)
32
+ page_count = len(doc)
33
+ doc.close()
34
+ return page_count
35
+ except Exception:
36
+ return 1 # Default to 1 if can't read
37
+
38
+
39
+ def extract_text_from_pdfs(pdf_paths):
40
+ """Extract text from all pages of all PDFs, returning list of page texts."""
41
+ import fitz # PyMuPDF
42
+ all_texts = []
43
+ for pdf_path in pdf_paths:
44
+ try:
45
+ doc = fitz.open(pdf_path)
46
+ for page in doc:
47
+ text = page.get_text().strip()
48
+ if text: # Only add non-empty pages
49
+ all_texts.append(text)
50
+ doc.close()
51
+ except Exception as e:
52
+ print(f"Error extracting text from {pdf_path}: {e}")
53
+ return all_texts
54
+
55
+
56
+ def extract_pdf_pages(pdf_paths, pdf_name_map, mode="image"):
57
+ """
58
+ Extract individual pages from PDFs.
59
+ Returns list of (page_data, page_label) tuples.
60
+ For image mode: page_data is path to temp image file
61
+ For text mode: page_data is extracted text
62
+ """
63
+ import fitz # PyMuPDF
64
+ pages = []
65
+
66
+ for pdf_path in pdf_paths:
67
+ orig_name = pdf_name_map.get(pdf_path, os.path.basename(pdf_path).replace('.pdf', ''))
68
+ try:
69
+ doc = fitz.open(pdf_path)
70
+ for page_num, page in enumerate(doc, 1):
71
+ page_label = f"{orig_name}_p{page_num}"
72
+
73
+ if mode == "text":
74
+ # Extract text
75
+ text = page.get_text().strip()
76
+ if text:
77
+ pages.append((text, page_label, "text"))
78
+ else:
79
+ # Render as image (for image or both mode)
80
+ pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x zoom for better quality
81
+ img_path = tempfile.NamedTemporaryFile(delete=False, suffix='.png').name
82
+ pix.save(img_path)
83
+
84
+ if mode == "both":
85
+ text = page.get_text().strip()
86
+ pages.append((img_path, page_label, "image", text))
87
+ else:
88
+ pages.append((img_path, page_label, "image"))
89
+ doc.close()
90
+ except Exception as e:
91
+ print(f"Error extracting pages from {pdf_path}: {e}")
92
+
93
+ return pages
94
+
95
+ # Free models - display name -> actual API model name
96
+ FREE_MODELS_MAP = {
97
+ "GPT-4o Mini": "gpt-4o-mini",
98
+ "Gemini 2.5 Flash": "gemini-2.5-flash",
99
+ "Claude 3 Haiku": "claude-3-haiku-20240307",
100
+ "Llama 3.3 70B": "meta-llama/Llama-3.3-70B-Instruct:groq",
101
+ "Qwen 2.5": "Qwen/Qwen2.5-72B-Instruct",
102
+ "DeepSeek R1": "deepseek-ai/DeepSeek-R1:novita",
103
+ "Mistral Medium": "mistral-medium-2505",
104
+ "Grok 4 Fast": "grok-4-fast-non-reasoning",
105
+ }
106
+ FREE_MODEL_DISPLAY_NAMES = list(FREE_MODELS_MAP.keys())
107
+ FREE_MODEL_CHOICES = list(FREE_MODELS_MAP.values()) # Keep for backward compat
108
+
109
+ # Paid models (user provides their own API key)
110
+ PAID_MODEL_CHOICES = [
111
+ "gemini-2.5-flash",
112
+ "gemini-2.5-pro",
113
+ "gpt-4.1",
114
+ "gpt-4o",
115
+ "gpt-4o-mini",
116
+ "claude-sonnet-4-5-20250929",
117
+ "claude-opus-4-20250514",
118
+ "claude-3-5-haiku-20241022",
119
+ "mistral-large-latest",
120
+ ]
121
+
122
+ # Models routed through HuggingFace
123
+ HF_ROUTED_MODELS = [
124
+ "meta-llama/Llama-3.3-70B-Instruct:groq",
125
+ "deepseek-ai/DeepSeek-R1:novita",
126
+ ]
127
+
128
+
129
+ def is_free_model(model, model_tier):
130
+ """Check if using free tier (Space pays for API)."""
131
+ return model_tier == "Free Models"
132
+
133
+
134
+ def get_model_source(model):
135
+ """Auto-detect model source."""
136
+ model_lower = model.lower()
137
+ if "gpt" in model_lower:
138
+ return "openai"
139
+ elif "claude" in model_lower:
140
+ return "anthropic"
141
+ elif "gemini" in model_lower:
142
+ return "google"
143
+ elif "mistral" in model_lower and ":novita" not in model_lower:
144
+ return "mistral"
145
+ elif any(x in model_lower for x in [":novita", ":groq", "qwen", "llama", "deepseek"]):
146
+ return "huggingface"
147
+ elif "sonar" in model_lower:
148
+ return "perplexity"
149
+ elif "grok" in model_lower:
150
+ return "xai"
151
+ return "huggingface"
152
+
153
+
154
+ def get_api_key(model, model_tier, api_key_input):
155
+ """Get the appropriate API key based on model and tier."""
156
+ if is_free_model(model, model_tier):
157
+ if model in HF_ROUTED_MODELS:
158
+ return os.environ.get("HF_API_KEY", ""), "HuggingFace"
159
+ elif "gpt" in model.lower():
160
+ return os.environ.get("OPENAI_API_KEY", ""), "OpenAI"
161
+ elif "gemini" in model.lower():
162
+ return os.environ.get("GOOGLE_API_KEY", ""), "Google"
163
+ elif "mistral" in model.lower():
164
+ return os.environ.get("MISTRAL_API_KEY", ""), "Mistral"
165
+ elif "claude" in model.lower():
166
+ return os.environ.get("ANTHROPIC_API_KEY", ""), "Anthropic"
167
+ elif "sonar" in model.lower():
168
+ return os.environ.get("PERPLEXITY_API_KEY", ""), "Perplexity"
169
+ elif "grok" in model.lower():
170
+ return os.environ.get("XAI_API_KEY", ""), "xAI"
171
+ else:
172
+ return os.environ.get("HF_API_KEY", ""), "HuggingFace"
173
+ else:
174
+ if api_key_input and api_key_input.strip():
175
+ return api_key_input.strip(), "User"
176
+ return "", "User"
177
+
178
+
179
+ def calculate_total_file_size(files):
180
+ """Calculate total size of uploaded files in MB."""
181
+ if files is None:
182
+ return 0
183
+ if not isinstance(files, list):
184
+ files = [files]
185
+
186
+ total_bytes = 0
187
+ for f in files:
188
+ try:
189
+ if hasattr(f, 'size'):
190
+ total_bytes += f.size
191
+ elif hasattr(f, 'name'):
192
+ total_bytes += os.path.getsize(f.name)
193
+ except (OSError, AttributeError):
194
+ pass
195
+ return total_bytes / (1024 * 1024)
196
+
197
+
198
+ def generate_extract_code(input_type, description, model, model_source, max_categories, mode=None):
199
+ """Generate Python code for category extraction."""
200
+ if input_type == "text":
201
+ return f'''import catvader
202
+ import pandas as pd
203
+
204
+ # Load your data
205
+ df = pd.read_csv("your_data.csv")
206
+
207
+ # Extract categories from the text column
208
+ result = catvader.extract(
209
+ input_data=df["{description}"].tolist(),
210
+ api_key="YOUR_API_KEY",
211
+ input_type="text",
212
+ description="{description}",
213
+ user_model="{model}",
214
+ model_source="{model_source}",
215
+ max_categories={max_categories}
216
+ )
217
+
218
+ # View extracted categories
219
+ print(result["top_categories"])
220
+ print(result["counts_df"])
221
+ '''
222
+ elif input_type == "pdf":
223
+ mode_line = f',\n mode="{mode}"' if mode else ''
224
+ return f'''import catvader
225
+
226
+ # Extract categories from PDF documents
227
+ result = catvader.extract(
228
+ input_data="path/to/your/pdfs/",
229
+ api_key="YOUR_API_KEY",
230
+ input_type="pdf",
231
+ description="{description}"{mode_line},
232
+ user_model="{model}",
233
+ model_source="{model_source}",
234
+ max_categories={max_categories}
235
+ )
236
+
237
+ # View extracted categories
238
+ print(result["top_categories"])
239
+ print(result["counts_df"])
240
+ '''
241
+ else: # image
242
+ return f'''import catvader
243
+
244
+ # Extract categories from images
245
+ result = catvader.extract(
246
+ input_data="path/to/your/images/",
247
+ api_key="YOUR_API_KEY",
248
+ input_type="image",
249
+ description="{description}",
250
+ user_model="{model}",
251
+ model_source="{model_source}",
252
+ max_categories={max_categories}
253
+ )
254
+
255
+ # View extracted categories
256
+ print(result["top_categories"])
257
+ print(result["counts_df"])
258
+ '''
259
+
260
+
261
+ def generate_full_code(extraction_params, classify_params):
262
+ """Generate combined extract + classify code when categories were auto-extracted."""
263
+ ext = extraction_params
264
+ cls = classify_params
265
+
266
+ # Determine input data placeholder
267
+ if ext['input_type'] == "text":
268
+ input_placeholder = 'df["your_column"].tolist()'
269
+ load_data = '''import pandas as pd
270
+
271
+ # Load your data
272
+ df = pd.read_csv("your_data.csv")
273
+ '''
274
+ elif ext['input_type'] == "pdf":
275
+ input_placeholder = '"path/to/your/pdfs/"'
276
+ load_data = ''
277
+ else:
278
+ input_placeholder = '"path/to/your/images/"'
279
+ load_data = ''
280
+
281
+ mode_param = f',\n mode="{ext["mode"]}"' if ext.get('mode') else ''
282
+
283
+ # Build extract code
284
+ extract_code = f'''# Step 1: Extract categories from your data
285
+ extract_result = catvader.extract(
286
+ input_data={input_placeholder},
287
+ api_key="YOUR_API_KEY",
288
+ description="{ext['description']}",
289
+ user_model="{ext['model']}",
290
+ max_categories={ext['max_categories']}{mode_param}
291
+ )
292
+
293
+ categories = extract_result["top_categories"]
294
+ print(f"Extracted {{len(categories)}} categories: {{categories}}")
295
+ '''
296
+
297
+ # Build classify code based on mode
298
+ if cls['classify_mode'] == "Single Model":
299
+ classify_mode_param = f',\n mode="{cls["mode"]}"' if cls.get('mode') and ext['input_type'] == "pdf" else ''
300
+ classify_code = f'''
301
+ # Step 2: Classify data using extracted categories
302
+ result = catvader.classify(
303
+ input_data={input_placeholder},
304
+ categories=categories,
305
+ api_key="YOUR_API_KEY",
306
+ description="{cls['description']}",
307
+ user_model="{cls['model']}"{classify_mode_param}
308
+ )'''
309
+ else:
310
+ # Multi-model mode — include per-model temperatures when set
311
+ ens_runs = cls.get('ensemble_runs')
312
+ model_lines = []
313
+ if ens_runs:
314
+ for m, temp in ens_runs:
315
+ model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})')
316
+ else:
317
+ model_temps = cls.get('model_temperatures', {})
318
+ for m in cls['models_list']:
319
+ temp = model_temps.get(m) if model_temps else None
320
+ if temp is not None:
321
+ model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})')
322
+ else:
323
+ model_lines.append(f'("{m}", "auto", "YOUR_API_KEY")')
324
+ models_str = ",\n ".join(model_lines)
325
+
326
+ classify_mode_param = f',\n mode="{cls["mode"]}"' if cls.get('mode') and ext['input_type'] == "pdf" else ''
327
+ threshold_str = "majority" if cls['consensus_threshold'] == 0.5 else "two-thirds" if cls['consensus_threshold'] == 0.67 else "unanimous"
328
+ consensus_param = f',\n consensus_threshold="{threshold_str}"' if cls['classify_mode'] == "Ensemble" else ''
329
+
330
+ classify_code = f'''
331
+ # Step 2: Classify data using extracted categories with {"ensemble voting" if cls['classify_mode'] == "Ensemble" else "model comparison"}
332
+ models = [
333
+ {models_str}
334
+ ]
335
+
336
+ result = catvader.classify(
337
+ input_data={input_placeholder},
338
+ categories=categories,
339
+ models=models,
340
+ description="{cls['description']}"{classify_mode_param}{consensus_param}
341
+ )'''
342
+
343
+ return f'''import catvader
344
+ {load_data}
345
+ {extract_code}
346
+ {classify_code}
347
+
348
+ # View results
349
+ print(result)
350
+ result.to_csv("classified_results.csv", index=False)
351
+ '''
352
+
353
+
354
+ def generate_classify_code(input_type, description, categories, model, model_source, mode=None,
355
+ classify_mode="Single Model", models_list=None, consensus_threshold=0.5,
356
+ model_temperatures=None, ensemble_runs=None):
357
+ """Generate Python code for classification."""
358
+ categories_str = ",\n ".join([f'"{cat}"' for cat in categories])
359
+
360
+ # Determine input data placeholder based on type
361
+ if input_type == "text":
362
+ input_placeholder = 'df["your_column"].tolist()'
363
+ load_data = '''import pandas as pd
364
+
365
+ # Load your data
366
+ df = pd.read_csv("your_data.csv")
367
+ '''
368
+ elif input_type == "pdf":
369
+ input_placeholder = '"path/to/your/pdfs/"'
370
+ load_data = ''
371
+ else: # image
372
+ input_placeholder = '"path/to/your/images/"'
373
+ load_data = ''
374
+
375
+ # Generate code based on classification mode
376
+ if classify_mode == "Single Model":
377
+ # Single model mode
378
+ mode_param = f',\n mode="{mode}"' if mode and input_type == "pdf" else ''
379
+ return f'''import catvader
380
+ {load_data}
381
+ # Define categories
382
+ categories = [
383
+ {categories_str}
384
+ ]
385
+
386
+ # Classify data (input type is auto-detected)
387
+ result = catvader.classify(
388
+ input_data={input_placeholder},
389
+ categories=categories,
390
+ api_key="YOUR_API_KEY",
391
+ description="{description}",
392
+ user_model="{model}"{mode_param}
393
+ )
394
+
395
+ # View results
396
+ print(result)
397
+ result.to_csv("classified_results.csv", index=False)
398
+ '''
399
+ else:
400
+ # Multi-model mode (Comparison or Ensemble)
401
+ # Build model tuples with per-model temperature when set
402
+ if ensemble_runs:
403
+ # Ensemble with explicit (model, temp) pairs (supports duplicate models)
404
+ model_lines = []
405
+ for m, temp in ensemble_runs:
406
+ model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})')
407
+ models_str = ",\n ".join(model_lines)
408
+ elif models_list:
409
+ model_lines = []
410
+ for m in models_list:
411
+ temp = model_temperatures.get(m) if model_temperatures else None
412
+ if temp is not None:
413
+ model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})')
414
+ else:
415
+ model_lines.append(f'("{m}", "auto", "YOUR_API_KEY")')
416
+ models_str = ",\n ".join(model_lines)
417
+ else:
418
+ models_str = '("gpt-4o", "auto", "YOUR_API_KEY"),\n ("claude-sonnet-4-5-20250929", "auto", "YOUR_API_KEY")'
419
+
420
+ mode_param = f',\n mode="{mode}"' if mode and input_type == "pdf" else ''
421
+ # Map numeric threshold back to string for cleaner code
422
+ threshold_str = "majority" if consensus_threshold == 0.5 else "two-thirds" if consensus_threshold == 0.67 else "unanimous"
423
+ consensus_param = f',\n consensus_threshold="{threshold_str}"' if classify_mode == "Ensemble" else ''
424
+
425
+ return f'''import catvader
426
+ {load_data}
427
+ # Define categories
428
+ categories = [
429
+ {categories_str}
430
+ ]
431
+
432
+ # Define models for {"ensemble voting" if classify_mode == "Ensemble" else "comparison"}
433
+ models = [
434
+ {models_str}
435
+ ]
436
+
437
+ # Classify with multiple models
438
+ result = catvader.classify(
439
+ input_data={input_placeholder},
440
+ categories=categories,
441
+ models=models,
442
+ description="{description}"{mode_param}{consensus_param}
443
+ )
444
+
445
+ # View results
446
+ print(result)
447
+ result.to_csv("classified_results.csv", index=False)
448
+ '''
449
+
450
+
451
+ def generate_methodology_report_pdf(categories, model, column_name, num_rows, model_source, filename, success_rate,
452
+ result_df=None, processing_time=None, prompt_template=None,
453
+ data_quality=None, catvader_version=None, python_version=None,
454
+ task_type="assign", extracted_categories_df=None, max_categories=None,
455
+ input_type="text", description=None, classify_mode="Single Model",
456
+ models_list=None, code=None, consensus_threshold=None):
457
+ """Generate a PDF methodology report."""
458
+ from reportlab.lib.pagesizes import letter
459
+ from reportlab.lib import colors
460
+ from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
461
+ from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak
462
+
463
+ pdf_file = tempfile.NamedTemporaryFile(mode='wb', suffix='_methodology_report.pdf', delete=False)
464
+ doc = SimpleDocTemplate(pdf_file.name, pagesize=letter)
465
+ styles = getSampleStyleSheet()
466
+
467
+ title_style = ParagraphStyle('Title', parent=styles['Heading1'], fontSize=18, spaceAfter=20)
468
+ heading_style = ParagraphStyle('Heading', parent=styles['Heading2'], fontSize=14, spaceAfter=10, spaceBefore=15)
469
+ normal_style = styles['Normal']
470
+ code_style = ParagraphStyle('Code', parent=styles['Normal'], fontName='Courier', fontSize=9, leftIndent=20, spaceAfter=3)
471
+
472
+ story = []
473
+
474
+ if task_type == "extract_and_assign":
475
+ report_title = "CatVader Extraction & Classification Report"
476
+ else:
477
+ report_title = "CatVader Classification Report"
478
+
479
+ story.append(Paragraph(report_title, title_style))
480
+ story.append(Paragraph(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", normal_style))
481
+ story.append(Spacer(1, 15))
482
+
483
+ story.append(Paragraph("About This Report", heading_style))
484
+
485
+ if task_type == "extract_and_assign":
486
+ about_text = """This methodology report documents the automated category extraction and classification process. \
487
+ CatVader first discovers categories from your data using LLMs, then classifies each item into those categories."""
488
+ else:
489
+ about_text = """This methodology report documents the classification process for reproducibility and transparency. \
490
+ CatVader restricts the prompt to a standard template that is impartial to the researcher's inclinations, ensuring \
491
+ consistent and reproducible results."""
492
+
493
+ story.append(Paragraph(about_text, normal_style))
494
+ story.append(Spacer(1, 15))
495
+
496
+ if categories:
497
+ story.append(Paragraph("Category Mapping", heading_style))
498
+
499
+ if classify_mode in ("Ensemble", "Model Comparison") and result_df is not None:
500
+ # Multi-model: show per-model columns and consensus columns
501
+ story.append(Paragraph("Each model produces its own binary columns. "
502
+ "Consensus columns show the majority vote result.", normal_style))
503
+ story.append(Spacer(1, 8))
504
+
505
+ # Detect ALL distinct model suffixes directly from the DataFrame
506
+ # (handles same-model-different-temperature cases correctly)
507
+ all_suffixes = _find_all_model_suffixes(result_df)
508
+
509
+ category_data = [["Column Name", "Category Description"]]
510
+ for i, cat in enumerate(categories, 1):
511
+ # Per-model columns (each suffix is a unique model/temperature)
512
+ for suffix in all_suffixes:
513
+ category_data.append([f"category_{i}_{suffix}", f"{cat} ({suffix})"])
514
+ # Consensus + agreement columns
515
+ category_data.append([f"category_{i}_consensus", f"{cat} (consensus)"])
516
+ category_data.append([f"category_{i}_agreement", f"{cat} (agreement score)"])
517
+
518
+ cat_table = Table(category_data, colWidths=[200, 250])
519
+ else:
520
+ # Single model: simple mapping
521
+ story.append(Paragraph("Each category column contains binary values: 1 = present, 0 = not present", normal_style))
522
+ story.append(Spacer(1, 8))
523
+
524
+ category_data = [["Column Name", "Category Description"]]
525
+ for i, cat in enumerate(categories, 1):
526
+ category_data.append([f"category_{i}", cat])
527
+
528
+ cat_table = Table(category_data, colWidths=[120, 330])
529
+
530
+ cat_table.setStyle(TableStyle([
531
+ ('BACKGROUND', (0, 0), (-1, 0), colors.grey),
532
+ ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
533
+ ('GRID', (0, 0), (-1, -1), 1, colors.black),
534
+ ('PADDING', (0, 0), (-1, -1), 6),
535
+ ('BACKGROUND', (0, 1), (0, -1), colors.lightgrey),
536
+ ('FONTSIZE', (0, 0), (-1, -1), 9),
537
+ ]))
538
+ story.append(cat_table)
539
+ story.append(Spacer(1, 15))
540
+
541
+ story.append(Spacer(1, 30))
542
+ story.append(Paragraph("Citation", heading_style))
543
+ story.append(Paragraph("If you use CatVader in your research, please cite:", normal_style))
544
+ story.append(Spacer(1, 5))
545
+ story.append(Paragraph("Soria, C. (2025). CatVader: A Python package for LLM-based social media classification. DOI: 10.5281/zenodo.15532316", normal_style))
546
+
547
+ # Summary section
548
+ story.append(PageBreak())
549
+ story.append(Paragraph("Classification Summary", title_style))
550
+ story.append(Spacer(1, 15))
551
+
552
+ summary_data = [
553
+ ["Source File", filename],
554
+ ["Source Column", column_name],
555
+ ["Classification Mode", classify_mode],
556
+ ["Model(s) Used", model],
557
+ ["Model Source", model_source],
558
+ ["Rows Classified", str(num_rows)],
559
+ ["Number of Categories", str(len(categories)) if categories else "0"],
560
+ ["Success Rate", f"{success_rate:.2f}%"],
561
+ ]
562
+ # Add consensus threshold for ensemble mode
563
+ if classify_mode == "Ensemble" and consensus_threshold is not None:
564
+ threshold_labels = {0.5: "Majority (50%+)", 0.67: "Two-Thirds (67%+)", 1.0: "Unanimous (100%)"}
565
+ threshold_label = threshold_labels.get(consensus_threshold, f"Custom ({consensus_threshold:.0%})")
566
+ summary_data.append(["Consensus Threshold", threshold_label])
567
+
568
+ summary_table = Table(summary_data, colWidths=[150, 300])
569
+ summary_table.setStyle(TableStyle([
570
+ ('BACKGROUND', (0, 0), (0, -1), colors.lightgrey),
571
+ ('GRID', (0, 0), (-1, -1), 1, colors.black),
572
+ ('PADDING', (0, 0), (-1, -1), 6),
573
+ ('FONTSIZE', (0, 0), (-1, -1), 9),
574
+ ]))
575
+ story.append(summary_table)
576
+ story.append(Spacer(1, 15))
577
+
578
+ # Agreement scores table for ensemble mode
579
+ if classify_mode == "Ensemble" and result_df is not None and categories:
580
+ agreement_cols = [f"category_{i}_agreement" for i in range(1, len(categories) + 1)]
581
+ has_agreement = all(col in result_df.columns for col in agreement_cols)
582
+ if has_agreement:
583
+ story.append(Paragraph("Ensemble Agreement Scores", heading_style))
584
+ story.append(Paragraph(
585
+ "Agreement shows what proportion of models agreed on each category. "
586
+ "Higher scores indicate stronger consensus.", normal_style))
587
+ story.append(Spacer(1, 8))
588
+
589
+ agree_data = [["Category", "Mean Agreement", "Min Agreement"]]
590
+ for i, cat in enumerate(categories, 1):
591
+ col = f"category_{i}_agreement"
592
+ mean_val = result_df[col].mean()
593
+ min_val = result_df[col].min()
594
+ agree_data.append([cat, f"{mean_val:.1%}", f"{min_val:.1%}"])
595
+
596
+ agree_table = Table(agree_data, colWidths=[200, 125, 125])
597
+ agree_table.setStyle(TableStyle([
598
+ ('BACKGROUND', (0, 0), (-1, 0), colors.grey),
599
+ ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
600
+ ('GRID', (0, 0), (-1, -1), 1, colors.black),
601
+ ('PADDING', (0, 0), (-1, -1), 6),
602
+ ('FONTSIZE', (0, 0), (-1, -1), 9),
603
+ ]))
604
+ story.append(agree_table)
605
+ story.append(Spacer(1, 15))
606
+
607
+ if processing_time is not None:
608
+ story.append(Paragraph("Processing Time", heading_style))
609
+ rows_per_min = (num_rows / processing_time) * 60 if processing_time > 0 else 0
610
+ avg_time = processing_time / num_rows if num_rows > 0 else 0
611
+
612
+ time_data = [
613
+ ["Total Processing Time", f"{processing_time:.1f} seconds"],
614
+ ["Average Time per Response", f"{avg_time:.2f} seconds"],
615
+ ["Processing Rate", f"{rows_per_min:.1f} rows/minute"],
616
+ ]
617
+ time_table = Table(time_data, colWidths=[180, 270])
618
+ time_table.setStyle(TableStyle([
619
+ ('BACKGROUND', (0, 0), (0, -1), colors.lightgrey),
620
+ ('GRID', (0, 0), (-1, -1), 1, colors.black),
621
+ ('PADDING', (0, 0), (-1, -1), 6),
622
+ ('FONTSIZE', (0, 0), (-1, -1), 9),
623
+ ]))
624
+ story.append(time_table)
625
+
626
+ story.append(Spacer(1, 15))
627
+ story.append(Paragraph("Version Information", heading_style))
628
+ version_data = [
629
+ ["CatVader Version", catvader_version or "unknown"],
630
+ ["Python Version", python_version or "unknown"],
631
+ ["Timestamp", datetime.now().strftime('%Y-%m-%d %H:%M:%S')],
632
+ ]
633
+ version_table = Table(version_data, colWidths=[180, 270])
634
+ version_table.setStyle(TableStyle([
635
+ ('BACKGROUND', (0, 0), (0, -1), colors.lightgrey),
636
+ ('GRID', (0, 0), (-1, -1), 1, colors.black),
637
+ ('PADDING', (0, 0), (-1, -1), 6),
638
+ ('FONTSIZE', (0, 0), (-1, -1), 9),
639
+ ]))
640
+ story.append(version_table)
641
+
642
+ # Reproducibility Code section
643
+ if code:
644
+ story.append(PageBreak())
645
+ story.append(Paragraph("Reproducibility Code", title_style))
646
+ story.append(Paragraph("Use this Python code to reproduce the classification with the CatVader package:", normal_style))
647
+ story.append(Spacer(1, 10))
648
+
649
+ # Split code into lines and add as code-formatted paragraphs
650
+ for line in code.strip().split('\n'):
651
+ # Escape special characters for reportlab
652
+ escaped_line = line.replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;')
653
+ if escaped_line.strip():
654
+ story.append(Paragraph(escaped_line, code_style))
655
+ else:
656
+ story.append(Spacer(1, 6))
657
+
658
+ # Visualizations section
659
+ if result_df is not None and categories:
660
+ from reportlab.platypus import Image
661
+ import io
662
+
663
+ # Distribution chart (new page)
664
+ story.append(PageBreak())
665
+ story.append(Paragraph("Category Distribution", title_style))
666
+ try:
667
+ fig1 = create_distribution_chart(result_df, categories, classify_mode, models_list)
668
+ img_buffer1 = io.BytesIO()
669
+ fig1.savefig(img_buffer1, format='png', dpi=150, bbox_inches='tight')
670
+ img_buffer1.seek(0)
671
+ plt.close(fig1)
672
+
673
+ # Save to temp file for reportlab
674
+ img_temp1 = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
675
+ img_temp1.write(img_buffer1.read())
676
+ img_temp1.close()
677
+
678
+ img1 = Image(img_temp1.name, width=450, height=250)
679
+ story.append(img1)
680
+ story.append(Spacer(1, 10))
681
+ story.append(Paragraph("Note: Categories are not mutually exclusive—each item can belong to multiple categories.", normal_style))
682
+ except Exception as e:
683
+ story.append(Paragraph(f"Could not generate distribution chart: {str(e)}", normal_style))
684
+
685
+ # Classification matrix (new page)
686
+ story.append(PageBreak())
687
+ story.append(Paragraph("Classification Matrix", title_style))
688
+ try:
689
+ fig2 = create_classification_heatmap(result_df, categories, classify_mode, models_list)
690
+ img_buffer2 = io.BytesIO()
691
+ fig2.savefig(img_buffer2, format='png', dpi=150, bbox_inches='tight')
692
+ img_buffer2.seek(0)
693
+ plt.close(fig2)
694
+
695
+ # Save to temp file for reportlab
696
+ img_temp2 = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
697
+ img_temp2.write(img_buffer2.read())
698
+ img_temp2.close()
699
+
700
+ img2 = Image(img_temp2.name, width=450, height=300)
701
+ story.append(img2)
702
+ story.append(Spacer(1, 10))
703
+ story.append(Paragraph("Orange = category present, Black = not present. Each row represents one response.", normal_style))
704
+ except Exception as e:
705
+ story.append(Paragraph(f"Could not generate classification matrix: {str(e)}", normal_style))
706
+
707
+ doc.build(story)
708
+ return pdf_file.name
709
+
710
+
711
+ def run_auto_extract(input_type, input_data, description, max_categories_val,
712
+ model_tier, model, api_key_input, mode=None, progress_callback=None):
713
+ """Extract categories from data."""
714
+ if not CATVADER_AVAILABLE:
715
+ return None, "catvader package not available"
716
+
717
+ actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
718
+ if not actual_api_key:
719
+ return None, f"{provider} API key not configured"
720
+
721
+ model_source = get_model_source(model)
722
+
723
+ try:
724
+ if isinstance(input_data, list):
725
+ num_items = len(input_data)
726
+ else:
727
+ num_items = 1
728
+
729
+ if input_type == "image":
730
+ divisions = min(3, max(1, num_items // 5))
731
+ categories_per_chunk = 12
732
+ else:
733
+ divisions = max(1, num_items // 15)
734
+ divisions = min(divisions, 5)
735
+ chunk_size = num_items // max(1, divisions)
736
+ categories_per_chunk = min(10, chunk_size - 1)
737
+
738
+ extract_kwargs = {
739
+ 'input_data': input_data,
740
+ 'api_key': actual_api_key,
741
+ 'input_type': input_type,
742
+ 'description': description,
743
+ 'user_model': model,
744
+ 'model_source': model_source,
745
+ 'divisions': divisions,
746
+ 'categories_per_chunk': categories_per_chunk,
747
+ 'max_categories': int(max_categories_val)
748
+ }
749
+ if mode:
750
+ extract_kwargs['mode'] = mode
751
+
752
+ extract_result = catvader.extract(**extract_kwargs)
753
+ categories = extract_result.get('top_categories', [])
754
+
755
+ if not categories:
756
+ return None, "No categories were extracted"
757
+
758
+ return categories, f"Extracted {len(categories)} categories successfully!"
759
+
760
+ except Exception as e:
761
+ return None, f"Error: {str(e)}"
762
+
763
+
764
+ def run_classify_data(input_type, input_data, description, categories,
765
+ model_tier, model, api_key_input, mode=None,
766
+ original_filename="data", column_name="text",
767
+ progress_callback=None):
768
+ """Classify data with user-provided categories."""
769
+ if not CATVADER_AVAILABLE:
770
+ return None, None, None, None, "catvader package not available"
771
+
772
+ if not categories:
773
+ return None, None, None, None, "Please enter at least one category"
774
+
775
+ actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
776
+ if not actual_api_key:
777
+ return None, None, None, None, f"{provider} API key not configured"
778
+
779
+ model_source = get_model_source(model)
780
+
781
+ try:
782
+ start_time = time.time()
783
+
784
+ classify_kwargs = {
785
+ 'input_data': input_data,
786
+ 'categories': categories,
787
+ 'models': [(model, model_source, actual_api_key)],
788
+ 'description': description,
789
+ }
790
+ if mode:
791
+ classify_kwargs['mode'] = mode
792
+
793
+ result = catvader.classify(**classify_kwargs)
794
+
795
+ processing_time = time.time() - start_time
796
+ num_items = len(result)
797
+
798
+ # Save CSV
799
+ with tempfile.NamedTemporaryFile(mode='w', suffix='_classified.csv', delete=False) as f:
800
+ result.to_csv(f.name, index=False)
801
+ csv_path = f.name
802
+
803
+ # Calculate success rate
804
+ if 'processing_status' in result.columns:
805
+ success_count = (result['processing_status'] == 'success').sum()
806
+ success_rate = (success_count / len(result)) * 100
807
+ else:
808
+ success_rate = 100.0
809
+
810
+ # Get version info
811
+ try:
812
+ catvader_version = catvader.__version__
813
+ except AttributeError:
814
+ catvader_version = "unknown"
815
+ python_version = sys.version.split()[0]
816
+
817
+ # Generate methodology report
818
+ report_pdf_path = generate_methodology_report_pdf(
819
+ categories=categories,
820
+ model=model,
821
+ column_name=column_name,
822
+ num_rows=num_items,
823
+ model_source=model_source,
824
+ filename=original_filename,
825
+ success_rate=success_rate,
826
+ result_df=result,
827
+ processing_time=processing_time,
828
+ catvader_version=catvader_version,
829
+ python_version=python_version,
830
+ task_type="assign",
831
+ input_type=input_type,
832
+ description=description
833
+ )
834
+
835
+ # Generate reproducibility code
836
+ code = generate_classify_code(input_type, description, categories, model, model_source, mode)
837
+
838
+ return result, csv_path, report_pdf_path, code, f"Classified {num_items} items in {processing_time:.1f}s"
839
+
840
+ except Exception as e:
841
+ return None, None, None, None, f"Error: {str(e)}"
842
+
843
+
844
+ def sanitize_model_name(model: str) -> str:
845
+ """Convert model name to column-safe suffix (matches catvader logic)."""
846
+ import re
847
+ sanitized = re.sub(r'[^a-zA-Z0-9]', '_', model)
848
+ sanitized = re.sub(r'_+', '_', sanitized)
849
+ sanitized = sanitized.strip('_').lower()
850
+ return sanitized[:40]
851
+
852
+
853
+ def _find_model_column_suffix(result_df, model_name):
854
+ """Find the actual column suffix used for a model in the DataFrame.
855
+
856
+ catvader appends a creativity suffix (e.g. _tauto, _t50) to ensemble column
857
+ names, so we can't just use sanitize_model_name(). This function looks at
858
+ the real DataFrame columns to discover the full suffix.
859
+ """
860
+ sanitized = sanitize_model_name(model_name)
861
+ prefix = f"category_1_{sanitized}"
862
+ for col in result_df.columns:
863
+ if col.startswith(prefix):
864
+ # Return everything after "category_1_"
865
+ return col[len("category_1_"):]
866
+ # Fallback: return just the sanitized name
867
+ return sanitized
868
+
869
+
870
+ def _find_all_model_suffixes(result_df):
871
+ """Discover all distinct per-model column suffixes from the DataFrame.
872
+
873
+ Looks at category_1_* columns (excluding _consensus and _agreement)
874
+ to find every unique model suffix. Works even when the same model
875
+ appears multiple times with different temperature suffixes.
876
+
877
+ Returns:
878
+ List of suffix strings, e.g.
879
+ ['claude_haiku_4_5_20251001_t0', 'claude_haiku_4_5_20251001_t25', ...]
880
+ """
881
+ import re
882
+ suffixes = []
883
+ for col in result_df.columns:
884
+ m = re.match(r'^category_1_(.+)$', col)
885
+ if m:
886
+ suffix = m.group(1)
887
+ if suffix not in ('consensus', 'agreement'):
888
+ suffixes.append(suffix)
889
+ return suffixes
890
+
891
+
892
+ def create_classification_heatmap(result_df, categories, classify_mode="Single Model", models_list=None):
893
+ """Create a binary heatmap showing classification for each row.
894
+
895
+ Args:
896
+ result_df: DataFrame with classification results
897
+ categories: List of category names
898
+ classify_mode: "Single Model", "Model Comparison", or "Ensemble"
899
+ models_list: List of model names (for multi-model modes)
900
+ """
901
+ import numpy as np
902
+
903
+ total_rows = len(result_df)
904
+ if total_rows == 0:
905
+ fig, ax = plt.subplots(figsize=(10, 4))
906
+ ax.text(0.5, 0.5, 'No data to display', ha='center', va='center', fontsize=14)
907
+ ax.axis('off')
908
+ return fig
909
+
910
+ # Build the binary matrix based on classify_mode
911
+ if classify_mode == "Ensemble":
912
+ # Use consensus columns
913
+ col_names = [f"category_{i}_consensus" for i in range(1, len(categories) + 1)]
914
+ elif classify_mode == "Model Comparison" and models_list:
915
+ # Use first model's columns (detect actual suffix from DataFrame)
916
+ suffix = _find_model_column_suffix(result_df, models_list[0])
917
+ col_names = [f"category_{i}_{suffix}" for i in range(1, len(categories) + 1)]
918
+ else:
919
+ # Single model
920
+ col_names = [f"category_{i}" for i in range(1, len(categories) + 1)]
921
+
922
+ # Extract the binary matrix
923
+ matrix_data = []
924
+ for col in col_names:
925
+ if col in result_df.columns:
926
+ matrix_data.append(result_df[col].astype(int).values)
927
+ else:
928
+ matrix_data.append(np.zeros(total_rows, dtype=int))
929
+
930
+ matrix = np.array(matrix_data).T # Rows = responses, Cols = categories
931
+
932
+ # Create figure with appropriate sizing
933
+ fig_height = max(4, min(20, total_rows * 0.15))
934
+ fig_width = max(8, len(categories) * 0.8)
935
+ fig, ax = plt.subplots(figsize=(fig_width, fig_height))
936
+
937
+ # Create custom colormap: black (0) and orange (1) - CatVader theme
938
+ from matplotlib.colors import ListedColormap
939
+ cmap = ListedColormap(['#1a1a1a', '#E8A33C'])
940
+
941
+ # Plot heatmap
942
+ im = ax.imshow(matrix, aspect='auto', cmap=cmap, vmin=0, vmax=1)
943
+
944
+ # Set labels - remove y-axis numbers for cleaner look
945
+ ax.set_xticks(range(len(categories)))
946
+ ax.set_xticklabels(categories, rotation=45, ha='right', fontsize=9)
947
+ ax.set_xlabel('Categories', fontsize=11)
948
+ ax.set_ylabel(f'Responses (n={total_rows})', fontsize=11)
949
+ ax.set_yticks([]) # Remove y-axis tick marks
950
+
951
+ title = 'Classification Matrix'
952
+ if classify_mode == "Ensemble":
953
+ title += ' (Ensemble Consensus)'
954
+ elif classify_mode == "Model Comparison":
955
+ title += f' ({models_list[0].split("/")[-1].split(":")[0][:20]})'
956
+ ax.set_title(title, fontsize=14, fontweight='bold')
957
+
958
+ # Add legend
959
+ from matplotlib.patches import Patch
960
+ legend_elements = [
961
+ Patch(facecolor='#1a1a1a', edgecolor='white', label='Not Present'),
962
+ Patch(facecolor='#E8A33C', edgecolor='white', label='Present')
963
+ ]
964
+ ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.15, 1))
965
+
966
+ plt.tight_layout()
967
+ return fig
968
+
969
+
970
+ def create_distribution_chart(result_df, categories, classify_mode="Single Model", models_list=None):
971
+ """Create a bar chart showing category distribution.
972
+
973
+ Args:
974
+ result_df: DataFrame with classification results
975
+ categories: List of category names
976
+ classify_mode: "Single Model", "Model Comparison", or "Ensemble"
977
+ models_list: List of model names (for multi-model modes)
978
+ """
979
+ import numpy as np
980
+
981
+ total_rows = len(result_df)
982
+ if total_rows == 0:
983
+ fig, ax = plt.subplots(figsize=(10, 4))
984
+ ax.text(0.5, 0.5, 'No data to display', ha='center', va='center', fontsize=14)
985
+ ax.axis('off')
986
+ return fig
987
+
988
+ # Define colors for different models
989
+ model_colors = ['#2563eb', '#dc2626', '#16a34a', '#ca8a04', '#9333ea', '#0891b2', '#be185d', '#65a30d']
990
+
991
+ if classify_mode == "Single Model":
992
+ # Single model: use category_1, category_2, etc.
993
+ fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8)))
994
+
995
+ dist_data = []
996
+ for i, cat in enumerate(categories, 1):
997
+ col_name = f"category_{i}"
998
+ if col_name in result_df.columns:
999
+ count = int(result_df[col_name].sum())
1000
+ pct = (count / total_rows) * 100
1001
+ dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
1002
+
1003
+ categories_list = [d["Category"] for d in dist_data][::-1]
1004
+ percentages = [d["Percentage"] for d in dist_data][::-1]
1005
+
1006
+ bars = ax.barh(categories_list, percentages, color='#2563eb')
1007
+ ax.set_xlim(0, 100)
1008
+ ax.set_xlabel('Percentage (%)', fontsize=11)
1009
+ ax.set_title('Category Distribution (%)', fontsize=14, fontweight='bold')
1010
+
1011
+ for bar, pct in zip(bars, percentages):
1012
+ ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
1013
+ f'{pct:.1f}%', va='center', fontsize=10)
1014
+
1015
+ elif classify_mode == "Ensemble":
1016
+ # Ensemble: use category_1_consensus, category_2_consensus, etc.
1017
+ fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8)))
1018
+
1019
+ dist_data = []
1020
+ for i, cat in enumerate(categories, 1):
1021
+ col_name = f"category_{i}_consensus"
1022
+ if col_name in result_df.columns:
1023
+ count = int(result_df[col_name].sum())
1024
+ pct = (count / total_rows) * 100
1025
+ dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
1026
+
1027
+ categories_list = [d["Category"] for d in dist_data][::-1]
1028
+ percentages = [d["Percentage"] for d in dist_data][::-1]
1029
+
1030
+ bars = ax.barh(categories_list, percentages, color='#16a34a')
1031
+ ax.set_xlim(0, 100)
1032
+ ax.set_xlabel('Percentage (%)', fontsize=11)
1033
+ ax.set_title('Ensemble Consensus Distribution (%)', fontsize=14, fontweight='bold')
1034
+
1035
+ for bar, pct in zip(bars, percentages):
1036
+ ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
1037
+ f'{pct:.1f}%', va='center', fontsize=10)
1038
+
1039
+ else: # Model Comparison
1040
+ # Model Comparison: grouped bars for each model
1041
+ if not models_list:
1042
+ models_list = []
1043
+
1044
+ # Detect actual column suffixes from the DataFrame
1045
+ model_suffixes = [_find_model_column_suffix(result_df, m) for m in models_list]
1046
+ n_models = len(model_suffixes)
1047
+ n_categories = len(categories)
1048
+
1049
+ fig, ax = plt.subplots(figsize=(12, max(5, n_categories * 1.2)))
1050
+
1051
+ # Gather data for each model
1052
+ bar_height = 0.8 / n_models
1053
+ y_positions = np.arange(n_categories)
1054
+
1055
+ for model_idx, (model_name, suffix) in enumerate(zip(models_list, model_suffixes)):
1056
+ model_pcts = []
1057
+ for i in range(1, n_categories + 1):
1058
+ col_name = f"category_{i}_{suffix}"
1059
+ if col_name in result_df.columns:
1060
+ count = int(result_df[col_name].sum())
1061
+ pct = (count / total_rows) * 100
1062
+ else:
1063
+ pct = 0
1064
+ model_pcts.append(pct)
1065
+
1066
+ # Reverse for horizontal bar chart
1067
+ model_pcts = model_pcts[::-1]
1068
+ offset = (model_idx - n_models / 2 + 0.5) * bar_height
1069
+ color = model_colors[model_idx % len(model_colors)]
1070
+
1071
+ # Use shorter display name
1072
+ display_name = model_name.split('/')[-1].split(':')[0][:20]
1073
+ bars = ax.barh(y_positions + offset, model_pcts, bar_height * 0.9,
1074
+ label=display_name, color=color, alpha=0.85)
1075
+
1076
+ ax.set_yticks(y_positions)
1077
+ ax.set_yticklabels(categories[::-1])
1078
+ ax.set_xlim(0, 100)
1079
+ ax.set_xlabel('Percentage (%)', fontsize=11)
1080
+ ax.set_title('Category Distribution by Model (%)', fontsize=14, fontweight='bold')
1081
+ ax.legend(loc='lower right', fontsize=9)
1082
+
1083
+ plt.tight_layout()
1084
+ return fig
1085
+
1086
+
1087
+ # Page config
1088
+ st.set_page_config(
1089
+ page_title="CatVader - Social Media Classifier",
1090
+ page_icon="🐱",
1091
+ layout="wide"
1092
+ )
1093
+
1094
+ # Custom CSS for enhanced styling
1095
+ st.markdown("""
1096
+ <style>
1097
+ /* Import Garamond font and apply globally */
1098
+ @import url('https://fonts.googleapis.com/css2?family=EB+Garamond:wght@400;500;600;700&display=swap');
1099
+
1100
+ *:not([class*="icon"]):not([data-testid="stIconMaterial"]):not(svg):not(path) {
1101
+ font-family: 'EB Garamond', Garamond, Georgia, serif !important;
1102
+ font-size: 17px !important;
1103
+ }
1104
+
1105
+ /* Preserve Streamlit icon fonts */
1106
+ [data-testid="stIconMaterial"], .material-icons, .material-symbols-rounded {
1107
+ font-family: 'Material Symbols Rounded', 'Material Icons' !important;
1108
+ font-size: 24px !important;
1109
+ }
1110
+
1111
+ /* Main container styling */
1112
+ .main .block-container {
1113
+ padding-top: 2rem;
1114
+ padding-bottom: 2rem;
1115
+ }
1116
+
1117
+ /* Headers with gradient accent */
1118
+ h1 {
1119
+ background: linear-gradient(90deg, #E8A33C 0%, #D4872C 100%);
1120
+ -webkit-background-clip: text;
1121
+ -webkit-text-fill-color: transparent;
1122
+ background-clip: text;
1123
+ font-weight: 700;
1124
+ }
1125
+
1126
+ /* Card-like sections */
1127
+ .stExpander {
1128
+ border: 1px solid #E8D5B5;
1129
+ border-radius: 12px;
1130
+ box-shadow: 0 2px 8px rgba(232, 163, 60, 0.08);
1131
+ }
1132
+
1133
+ /* File uploader styling */
1134
+ .stFileUploader {
1135
+ border-radius: 12px;
1136
+ }
1137
+
1138
+ .stFileUploader > div > div {
1139
+ border: 2px dashed #E8A33C;
1140
+ border-radius: 12px;
1141
+ background: linear-gradient(135deg, #FEFCF9 0%, #F5EFE6 100%);
1142
+ }
1143
+
1144
+ /* Button styling */
1145
+ .stButton > button {
1146
+ border-radius: 8px;
1147
+ font-weight: 600;
1148
+ transition: all 0.2s ease;
1149
+ border: 2px solid #E8A33C;
1150
+ background: #FEFCF9;
1151
+ color: #D4872C;
1152
+ }
1153
+
1154
+ /* Tall button for example dataset (matches file uploader height) */
1155
+ .tall-button .stButton > button {
1156
+ min-height: 107px;
1157
+ border-radius: 12px;
1158
+ }
1159
+
1160
+ .stButton > button:hover {
1161
+ transform: translateY(-1px);
1162
+ box-shadow: 0 4px 12px rgba(232, 163, 60, 0.3);
1163
+ background: #F5EFE6;
1164
+ }
1165
+
1166
+ /* Primary button */
1167
+ .stButton > button[kind="primary"] {
1168
+ background: linear-gradient(135deg, #E8A33C 0%, #D4872C 100%);
1169
+ border: none;
1170
+ color: white;
1171
+ }
1172
+
1173
+ /* Success/info messages */
1174
+ .stSuccess {
1175
+ background-color: #E8F5E9;
1176
+ border-left: 4px solid #4CAF50;
1177
+ border-radius: 0 8px 8px 0;
1178
+ }
1179
+
1180
+ .stInfo {
1181
+ background-color: #FFF8E8;
1182
+ border-left: 4px solid #E8A33C;
1183
+ border-radius: 0 8px 8px 0;
1184
+ }
1185
+
1186
+ /* Radio buttons */
1187
+ .stRadio > div {
1188
+ gap: 0.5rem;
1189
+ display: flex;
1190
+ width: 100%;
1191
+ }
1192
+
1193
+ .stRadio > div > label {
1194
+ background: #F5EFE6;
1195
+ padding: 0.5rem 1rem;
1196
+ border-radius: 20px;
1197
+ border: 1px solid transparent;
1198
+ transition: all 0.2s ease;
1199
+ flex: 1;
1200
+ text-align: center;
1201
+ justify-content: center;
1202
+ }
1203
+
1204
+ .stRadio > div > label:hover {
1205
+ border-color: #E8A33C;
1206
+ }
1207
+
1208
+ /* Text inputs */
1209
+ .stTextInput > div > div > input {
1210
+ border-radius: 8px;
1211
+ border: 1px solid #E8D5B5;
1212
+ }
1213
+
1214
+ .stTextInput > div > div > input:focus {
1215
+ border-color: #E8A33C;
1216
+ box-shadow: 0 0 0 2px rgba(232, 163, 60, 0.2);
1217
+ }
1218
+
1219
+ /* Select boxes */
1220
+ .stSelectbox > div > div {
1221
+ border-radius: 8px;
1222
+ }
1223
+
1224
+ /* Dataframe styling */
1225
+ .stDataFrame {
1226
+ border-radius: 12px;
1227
+ overflow: hidden;
1228
+ box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
1229
+ }
1230
+
1231
+ /* Progress bar */
1232
+ .stProgress > div > div {
1233
+ background: linear-gradient(90deg, #E8A33C 0%, #D4872C 100%);
1234
+ border-radius: 10px;
1235
+ }
1236
+
1237
+ /* Slider */
1238
+ .stSlider > div > div > div {
1239
+ background: #E8A33C;
1240
+ }
1241
+
1242
+ /* Divider */
1243
+ hr {
1244
+ border: none;
1245
+ height: 1px;
1246
+ background: linear-gradient(90deg, transparent, #E8D5B5, transparent);
1247
+ margin: 1.5rem 0;
1248
+ }
1249
+
1250
+ /* Code blocks */
1251
+ .stCodeBlock {
1252
+ border-radius: 12px;
1253
+ border: 1px solid #E8D5B5;
1254
+ }
1255
+
1256
+ /* Metric cards */
1257
+ .stMetric {
1258
+ background: linear-gradient(135deg, #FEFCF9 0%, #F5EFE6 100%);
1259
+ padding: 1rem;
1260
+ border-radius: 12px;
1261
+ border: 1px solid #E8D5B5;
1262
+ }
1263
+
1264
+ /* Download buttons */
1265
+ .stDownloadButton > button {
1266
+ background: #F5EFE6;
1267
+ border: 1px solid #E8A33C;
1268
+ color: #D4872C;
1269
+ }
1270
+
1271
+ .stDownloadButton > button:hover {
1272
+ background: #E8A33C;
1273
+ color: white;
1274
+ }
1275
+
1276
+ /* Multiselect */
1277
+ .stMultiSelect > div > div {
1278
+ border-radius: 8px;
1279
+ }
1280
+
1281
+ /* Status indicator */
1282
+ .stStatus {
1283
+ border-radius: 12px;
1284
+ }
1285
+
1286
+ /* Column gaps */
1287
+ [data-testid="column"] {
1288
+ padding: 0 0.5rem;
1289
+ }
1290
+
1291
+ /* Logo and title alignment */
1292
+ [data-testid="column"]:first-child img {
1293
+ border-radius: 8px;
1294
+ }
1295
+ </style>
1296
+ """, unsafe_allow_html=True)
1297
+
1298
+ # Initialize session state
1299
+ if 'categories' not in st.session_state:
1300
+ st.session_state.categories = [''] * MAX_CATEGORIES
1301
+ if 'category_count' not in st.session_state:
1302
+ st.session_state.category_count = INITIAL_CATEGORIES
1303
+ if 'task_mode' not in st.session_state:
1304
+ st.session_state.task_mode = None
1305
+ if 'extracted_categories' not in st.session_state:
1306
+ st.session_state.extracted_categories = None
1307
+ if 'results' not in st.session_state:
1308
+ st.session_state.results = None
1309
+ if 'active_tab' not in st.session_state:
1310
+ st.session_state.active_tab = "survey"
1311
+ if 'survey_data' not in st.session_state:
1312
+ st.session_state.survey_data = None
1313
+ if 'pdf_data' not in st.session_state:
1314
+ st.session_state.pdf_data = None
1315
+ if 'image_data' not in st.session_state:
1316
+ st.session_state.image_data = None
1317
+ if 'extraction_params' not in st.session_state:
1318
+ st.session_state.extraction_params = None # Stores params when categories are auto-extracted
1319
+ if 'bluesky_df' not in st.session_state:
1320
+ st.session_state.bluesky_df = None
1321
+
1322
+ # Logo and title - use HTML for better alignment
1323
+ st.markdown("""
1324
+ <div style="display: flex; align-items: center; gap: 20px; margin-bottom: 10px;">
1325
+ <img src="https://huggingface.co/spaces/CatVader/social-media-classifier/resolve/main/logo.png" width="100" style="border-radius: 8px;">
1326
+ <div>
1327
+ <div style="font-size: 2.2rem; font-weight: 700; color: #333; font-family: 'EB Garamond', Garamond, Georgia, serif; line-height: 1.1;">CatVader</div>
1328
+ <div style="font-size: 1.1rem; font-weight: 500; color: #E8A33C; font-family: 'EB Garamond', Garamond, Georgia, serif; margin-bottom: 4px;">NLP for Survey Research</div>
1329
+ <div style="font-size: 1rem; font-weight: 400; color: #666; font-family: 'EB Garamond', Garamond, Georgia, serif;">Research-grade classification of social media posts, PDFs, and images using AI models.</div>
1330
+ <div style="font-size: 0.85rem; font-weight: 400; color: #888; font-family: 'EB Garamond', Garamond, Georgia, serif; margin-top: 4px;">Developed at UC Berkeley</div>
1331
+ </div>
1332
+ </div>
1333
+ """, unsafe_allow_html=True)
1334
+
1335
+ # About section
1336
+ with st.expander("About This App"):
1337
+ st.markdown("""
1338
+ **Privacy Notice:** Your data is sent to third-party LLM APIs for classification. Do not upload sensitive, confidential, or personally identifiable information (PII).
1339
+
1340
+ ---
1341
+
1342
+ **CatVader** is an open-source Python package for classifying and exploring social media data using Large Language Models.
1343
+
1344
+ ### What It Does
1345
+ - **Extract Categories**: Discover themes and categories in your data automatically
1346
+ - **Assign Categories**: Classify data into your predefined categories
1347
+ - **Extract & Assign**: Let CatVader discover categories, then classify all your data
1348
+
1349
+ ### Supported Providers
1350
+ OpenAI (GPT-4o, GPT-4o Mini), Anthropic (Claude), Google (Gemini), Mistral, HuggingFace, xAI (Grok), and Perplexity. Use the free tier or bring your own API key.
1351
+
1352
+ ### Beta Test - We Want Your Feedback!
1353
+ This app is currently in **beta** and **free to use** while CatVader is under active development, made possible by **Bashir Ahmed's generous fellowship support**.
1354
+
1355
+ - Found a bug? Have a feature request? Please open an issue on [GitHub](https://github.com/chrissoria/cat-vader)
1356
+ - Reach out directly: [chrissoria@berkeley.edu](mailto:chrissoria@berkeley.edu)
1357
+
1358
+ ### Acknowledgments
1359
+ - **Bashir Ahmed** for his generous fellowship support that makes this free beta possible
1360
+ - **Claude Fischer** for his thoughtful feedback and collaboration on research that helped inspire this project
1361
+ - **Kevin Collins** from Survey360 for his input
1362
+ - **Fendi Tsim** for sharing it widely
1363
+
1364
+ ### Links
1365
+ - **Website**: [christophersoria.com](https://christophersoria.com)
1366
+ - **PyPI**: [pip install cat-vader](https://pypi.org/project/cat-vader/)
1367
+ - **GitHub**: [github.com/chrissoria/cat-vader](https://github.com/chrissoria/cat-vader)
1368
+
1369
+ ### Citation
1370
+ If you use CatVader in your research, please cite:
1371
+ ```
1372
+ Soria, C. (2025). CatVader: A Python package for LLM-based social media classification. DOI: 10.5281/zenodo.15532316
1373
+ ```
1374
+ """)
1375
+
1376
+ # Main layout
1377
+ col_input, col_output = st.columns([1, 1])
1378
+
1379
+ with col_input:
1380
+ # Input type selector
1381
+ input_type_choice = st.radio(
1382
+ "Input Type",
1383
+ options=["Social Media Posts", "PDF Documents", "Images"],
1384
+ horizontal=True,
1385
+ key="input_type_radio"
1386
+ )
1387
+
1388
+ # Initialize variables
1389
+ input_data = None
1390
+ input_type_selected = "text"
1391
+ description = ""
1392
+ original_filename = "data"
1393
+ pdf_mode = "Image (visual documents)"
1394
+
1395
+ if input_type_choice == "Social Media Posts":
1396
+ input_type_selected = "text"
1397
+
1398
+ data_source = st.radio(
1399
+ "Data Source",
1400
+ options=["Upload CSV/Excel", "Fetch from Bluesky"],
1401
+ horizontal=True,
1402
+ key="data_source_radio"
1403
+ )
1404
+
1405
+ if data_source == "Upload CSV/Excel":
1406
+ st.session_state.bluesky_df = None # Clear any fetched data when switching sources
1407
+ upload_col, example_col = st.columns([3, 1])
1408
+ with upload_col:
1409
+ uploaded_file = st.file_uploader(
1410
+ "Upload Data (CSV or Excel)",
1411
+ type=['csv', 'xlsx', 'xls'],
1412
+ key="survey_file"
1413
+ )
1414
+ with example_col:
1415
+ st.markdown("<div style='height: 27px;'></div>", unsafe_allow_html=True) # Match "Upload Data" label height
1416
+ st.markdown('<div class="tall-button">', unsafe_allow_html=True)
1417
+ if st.button("Try Example Dataset", key="example_btn", use_container_width=True):
1418
+ st.session_state.example_loaded = True
1419
+ st.markdown('</div>', unsafe_allow_html=True)
1420
+
1421
+ columns = []
1422
+ df = None
1423
+ if uploaded_file is not None:
1424
+ try:
1425
+ if uploaded_file.name.endswith('.csv'):
1426
+ df = pd.read_csv(uploaded_file)
1427
+ else:
1428
+ df = pd.read_excel(uploaded_file)
1429
+ columns = df.columns.tolist()
1430
+ st.success(f"Loaded {len(df):,} rows")
1431
+ except Exception as e:
1432
+ st.error(f"Error loading file: {e}")
1433
+ elif hasattr(st.session_state, 'example_loaded') and st.session_state.example_loaded:
1434
+ try:
1435
+ df = pd.read_csv("example_data.csv")
1436
+ columns = df.columns.tolist()
1437
+ st.success(f"Loaded example dataset ({len(df)} rows)")
1438
+ except:
1439
+ pass
1440
+
1441
+ selected_column = st.selectbox(
1442
+ "Column to Process",
1443
+ options=columns if columns else ["Upload a file first"],
1444
+ disabled=not columns,
1445
+ key="survey_column"
1446
+ )
1447
+
1448
+ description = selected_column if columns else ""
1449
+ original_filename = uploaded_file.name if uploaded_file else "example_data.csv"
1450
+
1451
+ if df is not None and columns and selected_column in columns:
1452
+ input_data = df[selected_column].tolist()
1453
+
1454
+ else: # Fetch from Bluesky
1455
+ bsky_handle = st.text_input(
1456
+ "Bluesky Handle",
1457
+ placeholder="e.g. aoc.bsky.social or @aoc.bsky.social",
1458
+ key="bluesky_handle_input"
1459
+ )
1460
+ bsky_num_posts = st.slider(
1461
+ "Number of Posts to Fetch",
1462
+ min_value=10, max_value=250, value=50, step=10,
1463
+ key="bluesky_num_posts"
1464
+ )
1465
+ if st.button("Fetch Posts", key="fetch_bluesky_btn"):
1466
+ handle_clean = bsky_handle.strip().lstrip("@")
1467
+ if not handle_clean:
1468
+ st.error("Please enter a Bluesky handle.")
1469
+ else:
1470
+ with st.spinner(f"Fetching {bsky_num_posts} posts from {handle_clean}..."):
1471
+ try:
1472
+ from catvader._social_media import fetch_bluesky
1473
+ df_bsky = fetch_bluesky(limit=bsky_num_posts, handle=handle_clean)
1474
+ df_bsky = df_bsky[df_bsky["media_type"] != "REPOST_FACADE"].reset_index(drop=True)
1475
+ st.session_state.bluesky_df = df_bsky
1476
+ except Exception as e:
1477
+ st.error(f"Error fetching posts: {e}")
1478
+
1479
+ if st.session_state.bluesky_df is not None:
1480
+ bsky_df = st.session_state.bluesky_df
1481
+ st.success(f"Fetched {len(bsky_df)} posts")
1482
+ st.dataframe(
1483
+ bsky_df[["timestamp", "text", "likes", "replies"]].head(5),
1484
+ use_container_width=True
1485
+ )
1486
+ handle_clean = bsky_handle.strip().lstrip("@") if bsky_handle else "bluesky"
1487
+ input_data = bsky_df["text"].tolist()
1488
+ description = f"Bluesky posts from @{handle_clean}"
1489
+ original_filename = f"bluesky_{handle_clean.replace('.', '_')}"
1490
+
1491
+ elif input_type_choice == "PDF Documents":
1492
+ input_type_selected = "pdf"
1493
+
1494
+ pdf_files = st.file_uploader(
1495
+ "Upload PDF Document(s)",
1496
+ type=['pdf'],
1497
+ accept_multiple_files=True,
1498
+ key="pdf_files"
1499
+ )
1500
+
1501
+ pdf_description = st.text_input(
1502
+ "Document Description",
1503
+ placeholder="e.g., 'research papers', 'interview transcripts'",
1504
+ help="Helps the LLM understand context",
1505
+ key="pdf_desc"
1506
+ )
1507
+
1508
+ pdf_mode = st.radio(
1509
+ "Processing Mode",
1510
+ options=["Image (visual documents)", "Text (text-heavy)", "Both (comprehensive)"],
1511
+ key="pdf_mode"
1512
+ )
1513
+
1514
+ if pdf_files:
1515
+ input_data = []
1516
+ pdf_name_map = {} # Map temp paths to original filenames
1517
+ for f in pdf_files:
1518
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
1519
+ tmp.write(f.read())
1520
+ input_data.append(tmp.name)
1521
+ pdf_name_map[tmp.name] = f.name.replace('.pdf', '') # Store original name without extension
1522
+ st.session_state.pdf_name_map = pdf_name_map
1523
+ description = pdf_description or "document"
1524
+ original_filename = "pdf_files"
1525
+ st.success(f"Uploaded {len(pdf_files)} PDF file(s)")
1526
+
1527
+ else: # Images
1528
+ input_type_selected = "image"
1529
+
1530
+ image_files = st.file_uploader(
1531
+ "Upload Images",
1532
+ type=['png', 'jpg', 'jpeg', 'gif', 'webp'],
1533
+ accept_multiple_files=True,
1534
+ key="image_files"
1535
+ )
1536
+
1537
+ image_description = st.text_input(
1538
+ "Image Description",
1539
+ placeholder="e.g., 'product photos', 'social media posts'",
1540
+ help="Helps the LLM understand context",
1541
+ key="image_desc"
1542
+ )
1543
+
1544
+ if image_files:
1545
+ input_data = []
1546
+ for f in image_files:
1547
+ suffix = '.' + f.name.split('.')[-1]
1548
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
1549
+ tmp.write(f.read())
1550
+ input_data.append(tmp.name)
1551
+ description = image_description or "images"
1552
+ original_filename = "image_files"
1553
+ st.success(f"Uploaded {len(image_files)} image file(s)")
1554
+
1555
+ st.markdown("---")
1556
+
1557
+ # Task selection
1558
+ st.markdown("### What would you like to do?")
1559
+ col_btn1, col_btn2 = st.columns(2)
1560
+ with col_btn1:
1561
+ manual_mode = st.button("Enter Categories Manually", use_container_width=True)
1562
+ with col_btn2:
1563
+ auto_mode = st.button("Auto-extract Categories", use_container_width=True)
1564
+
1565
+ if manual_mode:
1566
+ st.session_state.task_mode = "manual"
1567
+ if auto_mode:
1568
+ st.session_state.task_mode = "auto_extract"
1569
+
1570
+ # Auto-extract settings
1571
+ if st.session_state.task_mode == "auto_extract":
1572
+ st.markdown("### Auto-extract Categories")
1573
+ st.markdown("We'll analyze your data to discover the main categories.")
1574
+
1575
+ max_categories = st.slider(
1576
+ "Number of Categories to Extract",
1577
+ min_value=3,
1578
+ max_value=25,
1579
+ value=12,
1580
+ help="How many categories should be identified in your data"
1581
+ )
1582
+
1583
+ specificity = st.selectbox(
1584
+ "How specific should categories be?",
1585
+ options=["Broad", "Moderate", "Narrow"],
1586
+ index=0,
1587
+ help="Broad = general themes, Moderate = balanced detail, Narrow = highly specific categories"
1588
+ )
1589
+
1590
+ focus = st.text_input(
1591
+ "What should categories be focused around? (optional)",
1592
+ placeholder="e.g., 'decisions to move', 'emotional responses', 'financial factors'",
1593
+ help="Guide the model to prioritize extracting categories related to this focus"
1594
+ )
1595
+
1596
+ # Model selection for extraction
1597
+ st.markdown("### Model Selection")
1598
+ model_tier = st.radio(
1599
+ "Model Tier",
1600
+ options=["Free Models", "Bring Your Own Key"],
1601
+ key="extract_model_tier"
1602
+ )
1603
+
1604
+ if model_tier == "Free Models":
1605
+ model_display = st.selectbox("Model", options=FREE_MODEL_DISPLAY_NAMES, key="extract_model")
1606
+ model = FREE_MODELS_MAP[model_display] # Convert to actual model name
1607
+ api_key = ""
1608
+ else:
1609
+ model = st.selectbox("Model", options=PAID_MODEL_CHOICES, key="extract_model_paid")
1610
+ api_key = st.text_input("API Key", type="password", key="extract_api_key")
1611
+
1612
+ if st.button("Extract Categories", type="primary"):
1613
+ if input_data is None:
1614
+ st.error("Please upload data first")
1615
+ else:
1616
+ mode = None
1617
+ if input_type_selected == "pdf":
1618
+ mode_mapping = {
1619
+ "Image (visual documents)": "image",
1620
+ "Text (text-heavy)": "text",
1621
+ "Both (comprehensive)": "both"
1622
+ }
1623
+ mode = mode_mapping.get(pdf_mode, "image")
1624
+
1625
+ actual_api_key, provider = get_api_key(model, model_tier, api_key)
1626
+ if not actual_api_key:
1627
+ st.error(f"{provider} API key not configured")
1628
+ else:
1629
+ model_source = get_model_source(model)
1630
+
1631
+ # Calculate estimated time based on input size
1632
+ num_items = len(input_data) if isinstance(input_data, list) else 1
1633
+ if input_type_selected == "pdf":
1634
+ # PDFs take longer - estimate ~5s per page
1635
+ total_pages = sum(count_pdf_pages(p) for p in (input_data if isinstance(input_data, list) else [input_data]))
1636
+ est_seconds = total_pages * 5
1637
+ elif input_type_selected == "image":
1638
+ # Images ~4s each
1639
+ est_seconds = num_items * 4
1640
+ else:
1641
+ # Text ~2s per item, but batched
1642
+ est_seconds = max(10, num_items * 0.5)
1643
+
1644
+ # Progress tracking UI
1645
+ progress_bar = st.progress(0)
1646
+ status_text = st.empty()
1647
+ start_time = time.time()
1648
+
1649
+ # Progress callback for extraction
1650
+ def extract_progress_callback(current_step, total_steps, step_label):
1651
+ progress = current_step / total_steps if total_steps > 0 else 0
1652
+ progress_bar.progress(min(progress, 1.0))
1653
+
1654
+ elapsed = time.time() - start_time
1655
+ if current_step > 0:
1656
+ avg_time = elapsed / current_step
1657
+ eta_seconds = avg_time * (total_steps - current_step)
1658
+ eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
1659
+ else:
1660
+ eta_str = ""
1661
+
1662
+ status_text.text(f"Extracting categories: {step_label} ({progress*100:.0f}%){eta_str}")
1663
+
1664
+ extract_kwargs = {
1665
+ 'input_data': input_data,
1666
+ 'api_key': actual_api_key,
1667
+ 'input_type': input_type_selected,
1668
+ 'description': description,
1669
+ 'user_model': model,
1670
+ 'model_source': model_source,
1671
+ 'max_categories': int(max_categories),
1672
+ 'specificity': specificity.lower(),
1673
+ 'progress_callback': extract_progress_callback,
1674
+ }
1675
+ if mode:
1676
+ extract_kwargs['mode'] = mode
1677
+ if focus and focus.strip():
1678
+ extract_kwargs['focus'] = focus.strip()
1679
+
1680
+ try:
1681
+ extract_result = catvader.extract(**extract_kwargs)
1682
+ categories = extract_result.get('top_categories', [])
1683
+
1684
+ processing_time = time.time() - start_time
1685
+ progress_bar.progress(1.0)
1686
+ status_text.text(f"Completed in {processing_time:.1f}s")
1687
+
1688
+ if categories:
1689
+ st.success(f"Extracted {len(categories)} categories in {processing_time:.1f}s")
1690
+ st.session_state.extracted_categories = categories
1691
+ # Store extraction params for code generation
1692
+ st.session_state.extraction_params = {
1693
+ 'model': model,
1694
+ 'model_source': model_source,
1695
+ 'max_categories': int(max_categories),
1696
+ 'input_type': input_type_selected,
1697
+ 'description': description,
1698
+ 'mode': mode,
1699
+ }
1700
+ st.session_state.task_mode = "manual"
1701
+ st.rerun()
1702
+ else:
1703
+ st.error("No categories were extracted from the data")
1704
+ except Exception as e:
1705
+ st.error(f"Error: {str(e)}")
1706
+
1707
+ # Category inputs (shown for manual mode or after extraction)
1708
+ if st.session_state.task_mode == "manual":
1709
+ st.markdown("### Categories")
1710
+ st.markdown("Enter your classification categories below.")
1711
+
1712
+ # Pre-fill with extracted categories if available
1713
+ if st.session_state.extracted_categories:
1714
+ for i, cat in enumerate(st.session_state.extracted_categories[:MAX_CATEGORIES]):
1715
+ st.session_state.categories[i] = cat
1716
+ st.session_state.category_count = min(len(st.session_state.extracted_categories), MAX_CATEGORIES)
1717
+ st.session_state.extracted_categories = None # Clear after use
1718
+
1719
+ placeholder_examples = [
1720
+ "e.g., Positive sentiment",
1721
+ "e.g., Negative sentiment",
1722
+ "e.g., Product feedback",
1723
+ "e.g., Service complaint",
1724
+ "e.g., Feature request",
1725
+ "e.g., Custom category"
1726
+ ]
1727
+
1728
+ categories_entered = []
1729
+ for i in range(st.session_state.category_count):
1730
+ placeholder = placeholder_examples[i] if i < len(placeholder_examples) else "e.g., Custom category"
1731
+ cat_value = st.text_input(
1732
+ f"Category {i+1}",
1733
+ value=st.session_state.categories[i],
1734
+ placeholder=placeholder,
1735
+ key=f"cat_{i}"
1736
+ )
1737
+ st.session_state.categories[i] = cat_value
1738
+ if cat_value.strip():
1739
+ categories_entered.append(cat_value.strip())
1740
+
1741
+ if st.session_state.category_count < MAX_CATEGORIES:
1742
+ if st.button("+ Add More"):
1743
+ st.session_state.category_count += 1
1744
+ st.rerun()
1745
+
1746
+ st.markdown("### Model Selection")
1747
+
1748
+ # Classification mode selector
1749
+ classify_mode = st.radio(
1750
+ "Classification Mode",
1751
+ options=["Single Model", "Model Comparison", "Ensemble"],
1752
+ horizontal=True,
1753
+ key="classify_mode",
1754
+ help="Single: one model. Comparison: see results from multiple models side-by-side. Ensemble: multiple models vote for consensus."
1755
+ )
1756
+
1757
+ model_tier = st.radio(
1758
+ "Model Tier",
1759
+ options=["Free Models", "Bring Your Own Key"],
1760
+ key="classify_model_tier"
1761
+ )
1762
+
1763
+ # Multi-model mode uses multiselect
1764
+ is_multi_model = classify_mode in ["Model Comparison", "Ensemble"]
1765
+ min_models = 3 if classify_mode == "Ensemble" else 2
1766
+
1767
+ # Track per-run temperatures: list of (model_name, temperature) for ensemble,
1768
+ # or dict {model_name: temperature} for model comparison
1769
+ model_temperatures = {}
1770
+ # ensemble_runs stores list of (model_name, temperature) allowing duplicate models
1771
+ ensemble_runs = []
1772
+
1773
+ if classify_mode == "Ensemble":
1774
+ # Ensemble mode: dynamic rows allowing same model multiple times with different temps
1775
+ if "ensemble_num_runs" not in st.session_state:
1776
+ st.session_state.ensemble_num_runs = 3
1777
+
1778
+ if model_tier == "Free Models":
1779
+ model_options = FREE_MODEL_DISPLAY_NAMES
1780
+ is_free = True
1781
+ else:
1782
+ model_options = PAID_MODEL_CHOICES
1783
+ is_free = False
1784
+
1785
+ st.markdown(f"**Model Runs** (select {min_models}+ runs)")
1786
+ for i in range(st.session_state.ensemble_num_runs):
1787
+ cols = st.columns([3, 1, 0.5])
1788
+ with cols[0]:
1789
+ default_idx = 0 if i < len(model_options) else i % len(model_options)
1790
+ selected = st.selectbox(
1791
+ f"Run {i+1}", options=model_options,
1792
+ index=default_idx, key=f"ensemble_model_{i}",
1793
+ label_visibility="collapsed"
1794
+ )
1795
+ with cols[1]:
1796
+ temp = st.number_input(
1797
+ "Temp", min_value=0.0, max_value=2.0, value=round(i * 0.25, 2),
1798
+ step=0.25, key=f"ensemble_temp_{i}", label_visibility="collapsed"
1799
+ )
1800
+ with cols[2]:
1801
+ if st.session_state.ensemble_num_runs > 3:
1802
+ if st.button("✕", key=f"ensemble_remove_{i}"):
1803
+ st.session_state.ensemble_num_runs -= 1
1804
+ st.rerun()
1805
+
1806
+ model_name = FREE_MODELS_MAP[selected] if is_free else selected
1807
+ ensemble_runs.append((model_name, temp))
1808
+
1809
+ if st.button("Add Run", key="add_ensemble_run"):
1810
+ st.session_state.ensemble_num_runs += 1
1811
+ st.rerun()
1812
+
1813
+ models_list = [r[0] for r in ensemble_runs]
1814
+ model_temperatures = {f"{r[0]}__run{i}": r[1] for i, r in enumerate(ensemble_runs)}
1815
+ api_key = "" if model_tier == "Free Models" else st.text_input("API Key", type="password", key="classify_api_key")
1816
+
1817
+ elif is_multi_model:
1818
+ # Model Comparison mode: multiselect (each model unique) + temperature row
1819
+ if model_tier == "Free Models":
1820
+ default_models = FREE_MODEL_DISPLAY_NAMES[:min_models] if len(FREE_MODEL_DISPLAY_NAMES) >= min_models else FREE_MODEL_DISPLAY_NAMES
1821
+ model_displays = st.multiselect(
1822
+ f"Models (select {min_models}+)",
1823
+ options=FREE_MODEL_DISPLAY_NAMES,
1824
+ default=default_models,
1825
+ key="classify_models_multi"
1826
+ )
1827
+ models_list = [FREE_MODELS_MAP[d] for d in model_displays]
1828
+ api_key = ""
1829
+ else:
1830
+ default_models = PAID_MODEL_CHOICES[:min_models] if len(PAID_MODEL_CHOICES) >= min_models else PAID_MODEL_CHOICES
1831
+ models_list = st.multiselect(
1832
+ f"Models (select {min_models}+)",
1833
+ options=PAID_MODEL_CHOICES,
1834
+ default=default_models,
1835
+ key="classify_models_multi_paid"
1836
+ )
1837
+ api_key = st.text_input("API Key", type="password", key="classify_api_key")
1838
+
1839
+ if models_list:
1840
+ st.markdown("**Model Temperature**")
1841
+ temp_cols = st.columns(len(models_list))
1842
+ for idx, (col, m) in enumerate(zip(temp_cols, models_list)):
1843
+ short_name = m.split('/')[-1].split(':')[0][:20]
1844
+ model_temperatures[m] = col.number_input(
1845
+ short_name,
1846
+ min_value=0.0,
1847
+ max_value=2.0,
1848
+ value=0.0,
1849
+ step=0.25,
1850
+ key=f"temp_{idx}",
1851
+ help=f"Temperature for {m} (0 = deterministic, higher = more creative)"
1852
+ )
1853
+ else:
1854
+ # Single model mode
1855
+ if model_tier == "Free Models":
1856
+ model_display = st.selectbox("Model", options=FREE_MODEL_DISPLAY_NAMES, key="classify_model")
1857
+ model = FREE_MODELS_MAP[model_display] # Convert to actual model name
1858
+ models_list = [model]
1859
+ api_key = ""
1860
+ else:
1861
+ model = st.selectbox("Model", options=PAID_MODEL_CHOICES, key="classify_model_paid")
1862
+ models_list = [model]
1863
+ api_key = st.text_input("API Key", type="password", key="classify_api_key")
1864
+
1865
+ # Ensemble-specific options
1866
+ consensus_threshold = 0.5 # Default
1867
+ if classify_mode == "Ensemble":
1868
+ consensus_options = {
1869
+ "Majority (50%+)": 0.5,
1870
+ "Two-Thirds (67%+)": 0.67,
1871
+ "Unanimous (100%)": 1.0,
1872
+ }
1873
+ consensus_choice = st.radio(
1874
+ "Consensus Rule",
1875
+ options=list(consensus_options.keys()),
1876
+ horizontal=True,
1877
+ key="consensus_choice",
1878
+ help="How many models must agree for a category to be marked present"
1879
+ )
1880
+ consensus_threshold = consensus_options[consensus_choice]
1881
+
1882
+ if st.button("Categorize Data", type="primary", use_container_width=True):
1883
+ if input_data is None:
1884
+ st.error("Please upload data first")
1885
+ elif not categories_entered:
1886
+ st.error("Please enter at least one category")
1887
+ elif classify_mode == "Model Comparison" and len(models_list) < 2:
1888
+ st.error("Please select at least 2 models for comparison mode")
1889
+ elif classify_mode == "Ensemble" and len(models_list) < 3:
1890
+ st.error("Please select at least 3 models for ensemble mode (needed for majority voting)")
1891
+ else:
1892
+ # Set up progress tracking
1893
+ mode = None
1894
+ if input_type_selected == "pdf":
1895
+ mode_mapping = {
1896
+ "Image (visual documents)": "image",
1897
+ "Text (text-heavy)": "text",
1898
+ "Both (comprehensive)": "both"
1899
+ }
1900
+ mode = mode_mapping.get(pdf_mode, "image")
1901
+
1902
+ # Build models tuples list
1903
+ # Uses 4-tuple (model, source, api_key, options) when per-model temperatures are set
1904
+ models_tuples = []
1905
+ api_key_error = None
1906
+ if ensemble_runs:
1907
+ # Ensemble mode: use ensemble_runs (model, temp) pairs directly
1908
+ for m, temp in ensemble_runs:
1909
+ actual_key, provider = get_api_key(m, model_tier, api_key)
1910
+ if not actual_key:
1911
+ api_key_error = f"{provider} API key not configured for {m}"
1912
+ break
1913
+ m_source = get_model_source(m)
1914
+ models_tuples.append((m, m_source, actual_key, {"creativity": temp}))
1915
+ else:
1916
+ for m in models_list:
1917
+ actual_key, provider = get_api_key(m, model_tier, api_key)
1918
+ if not actual_key:
1919
+ api_key_error = f"{provider} API key not configured for {m}"
1920
+ break
1921
+ m_source = get_model_source(m)
1922
+ temp = model_temperatures.get(m)
1923
+ if temp is not None and is_multi_model:
1924
+ models_tuples.append((m, m_source, actual_key, {"creativity": temp}))
1925
+ else:
1926
+ models_tuples.append((m, m_source, actual_key))
1927
+
1928
+ if api_key_error:
1929
+ st.error(api_key_error)
1930
+ else:
1931
+ items_list = input_data if isinstance(input_data, list) else [input_data]
1932
+
1933
+ # Progress UI
1934
+ progress_bar = st.progress(0)
1935
+ status_text = st.empty()
1936
+ start_time = time.time()
1937
+
1938
+ # For PDFs, use progress callback
1939
+ if input_type_selected == "pdf":
1940
+ # Progress callback for PDF page-by-page updates
1941
+ def pdf_progress_callback(current_idx, total_pages, page_label):
1942
+ progress = current_idx / total_pages if total_pages > 0 else 0
1943
+ progress_bar.progress(min(progress, 1.0))
1944
+
1945
+ elapsed = time.time() - start_time
1946
+ if current_idx > 0:
1947
+ avg_time = elapsed / current_idx
1948
+ eta_seconds = avg_time * (total_pages - current_idx)
1949
+ eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
1950
+ else:
1951
+ eta_str = ""
1952
+
1953
+ status_text.text(f"Processing page {current_idx+1} of {total_pages} ({page_label}) ({progress*100:.0f}%){eta_str}")
1954
+
1955
+ try:
1956
+ # Build kwargs for classify
1957
+ classify_kwargs = {
1958
+ "input_data": items_list,
1959
+ "categories": categories_entered,
1960
+ "models": models_tuples,
1961
+ "description": description,
1962
+ "mode": mode,
1963
+ "progress_callback": pdf_progress_callback,
1964
+ }
1965
+ # Add consensus_threshold for ensemble mode
1966
+ if classify_mode == "Ensemble":
1967
+ classify_kwargs["consensus_threshold"] = consensus_threshold
1968
+
1969
+ result_df = catvader.classify(**classify_kwargs)
1970
+
1971
+ processing_time = time.time() - start_time
1972
+ total_items = len(result_df)
1973
+ progress_bar.progress(1.0)
1974
+ status_text.text(f"Completed {total_items} pages in {processing_time:.1f}s")
1975
+
1976
+ # Replace temp paths with original filenames in pdf_input column
1977
+ if 'pdf_input' in result_df.columns:
1978
+ pdf_name_map = st.session_state.get('pdf_name_map', {})
1979
+ def replace_temp_path(val):
1980
+ if pd.isna(val):
1981
+ return val
1982
+ val_str = str(val)
1983
+ for temp_path, orig_name in pdf_name_map.items():
1984
+ # Check if the temp path's filename (without extension) is in the value
1985
+ temp_name = os.path.basename(temp_path).replace('.pdf', '')
1986
+ if temp_name in val_str:
1987
+ return val_str.replace(temp_name, orig_name)
1988
+ return val_str
1989
+ result_df['pdf_input'] = result_df['pdf_input'].apply(replace_temp_path)
1990
+
1991
+ all_results = [result_df]
1992
+
1993
+ except Exception as e:
1994
+ st.error(f"Error: {str(e)}")
1995
+ all_results = []
1996
+
1997
+ else:
1998
+ # Non-PDF processing (text, images) - process all at once
1999
+ total_items = len(items_list)
2000
+
2001
+ # Progress callback for item-by-item updates
2002
+ def item_progress_callback(current_idx, total, item_label):
2003
+ progress = current_idx / total if total > 0 else 0
2004
+ progress_bar.progress(min(progress, 1.0))
2005
+
2006
+ elapsed = time.time() - start_time
2007
+ if current_idx > 0:
2008
+ avg_time = elapsed / current_idx
2009
+ eta_seconds = avg_time * (total - current_idx)
2010
+ eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
2011
+ else:
2012
+ eta_str = ""
2013
+
2014
+ status_text.text(f"Processing item {current_idx+1} of {total} ({progress*100:.0f}%){eta_str}")
2015
+
2016
+ try:
2017
+ # Build kwargs for classify
2018
+ classify_kwargs = {
2019
+ "input_data": items_list,
2020
+ "categories": categories_entered,
2021
+ "models": models_tuples,
2022
+ "description": description,
2023
+ "progress_callback": item_progress_callback,
2024
+ }
2025
+ # Add consensus_threshold for ensemble mode
2026
+ if classify_mode == "Ensemble":
2027
+ classify_kwargs["consensus_threshold"] = consensus_threshold
2028
+
2029
+ result_df = catvader.classify(**classify_kwargs)
2030
+ all_results = [result_df]
2031
+
2032
+ processing_time = time.time() - start_time
2033
+ progress_bar.progress(1.0)
2034
+ status_text.text(f"Completed {total_items} items in {processing_time:.1f}s")
2035
+
2036
+ except Exception as e:
2037
+ st.error(f"Error: {str(e)}")
2038
+ all_results = []
2039
+ processing_time = time.time() - start_time
2040
+
2041
+ if all_results:
2042
+ # Combine results
2043
+ result_df = pd.concat(all_results, ignore_index=True)
2044
+
2045
+ # Merge Bluesky engagement columns if available
2046
+ if st.session_state.get("bluesky_df") is not None:
2047
+ bsky_eng = st.session_state.bluesky_df.reset_index(drop=True)
2048
+ if len(bsky_eng) == len(result_df):
2049
+ for col in ["post_id", "timestamp", "likes", "replies", "reposts",
2050
+ "media_type", "image_url", "post_length",
2051
+ "contains_url", "contains_image", "is_repost"]:
2052
+ if col in bsky_eng.columns:
2053
+ result_df[col] = bsky_eng[col].values
2054
+
2055
+ # Save CSV
2056
+ with tempfile.NamedTemporaryFile(mode='w', suffix='_classified.csv', delete=False) as f:
2057
+ result_df.to_csv(f.name, index=False)
2058
+ csv_path = f.name
2059
+
2060
+ # Calculate success rate
2061
+ if 'processing_status' in result_df.columns:
2062
+ success_count = (result_df['processing_status'] == 'success').sum()
2063
+ success_rate = (success_count / len(result_df)) * 100
2064
+ else:
2065
+ success_rate = 100.0
2066
+
2067
+ # Get version info
2068
+ try:
2069
+ catvader_version = catvader.__version__
2070
+ except AttributeError:
2071
+ catvader_version = "unknown"
2072
+ python_version = sys.version.split()[0]
2073
+
2074
+ # For reports: create model string (single or list)
2075
+ if len(models_list) == 1:
2076
+ report_model = models_list[0]
2077
+ report_model_source = models_tuples[0][1]
2078
+ else:
2079
+ report_model = ", ".join(models_list)
2080
+ report_model_source = f"{classify_mode} ({len(models_list)} models)"
2081
+
2082
+ # Generate code first so we can include it in the PDF
2083
+ # If categories were auto-extracted, include both extract and classify code
2084
+ if st.session_state.extraction_params:
2085
+ classify_params = {
2086
+ 'model': report_model,
2087
+ 'description': description,
2088
+ 'mode': mode,
2089
+ 'classify_mode': classify_mode,
2090
+ 'models_list': models_list,
2091
+ 'consensus_threshold': consensus_threshold,
2092
+ 'model_temperatures': model_temperatures,
2093
+ 'ensemble_runs': ensemble_runs if ensemble_runs else None,
2094
+ }
2095
+ code = generate_full_code(st.session_state.extraction_params, classify_params)
2096
+ else:
2097
+ code = generate_classify_code(
2098
+ input_type_selected, description, categories_entered,
2099
+ report_model, report_model_source, mode,
2100
+ classify_mode=classify_mode, models_list=models_list,
2101
+ consensus_threshold=consensus_threshold,
2102
+ model_temperatures=model_temperatures,
2103
+ ensemble_runs=ensemble_runs if ensemble_runs else None,
2104
+ )
2105
+
2106
+ # Generate methodology report with code included
2107
+ pdf_path = generate_methodology_report_pdf(
2108
+ categories=categories_entered,
2109
+ model=report_model,
2110
+ column_name=description,
2111
+ num_rows=len(result_df),
2112
+ model_source=report_model_source,
2113
+ filename=original_filename,
2114
+ success_rate=success_rate,
2115
+ result_df=result_df,
2116
+ processing_time=processing_time,
2117
+ catvader_version=catvader_version,
2118
+ python_version=python_version,
2119
+ task_type="assign",
2120
+ input_type=input_type_selected,
2121
+ description=description,
2122
+ classify_mode=classify_mode,
2123
+ models_list=models_list,
2124
+ code=code,
2125
+ consensus_threshold=consensus_threshold if classify_mode == "Ensemble" else None,
2126
+ )
2127
+
2128
+ st.session_state.results = {
2129
+ 'df': result_df,
2130
+ 'csv_path': csv_path,
2131
+ 'pdf_path': pdf_path,
2132
+ 'code': code,
2133
+ 'status': f"Classified {len(result_df)} items in {processing_time:.1f}s",
2134
+ 'categories': categories_entered,
2135
+ 'classify_mode': classify_mode,
2136
+ 'models_list': models_list,
2137
+ 'model_temperatures': model_temperatures,
2138
+ 'ensemble_runs': ensemble_runs if ensemble_runs else None,
2139
+ }
2140
+ st.success(f"Classified {len(result_df)} items in {processing_time:.1f}s")
2141
+ st.rerun()
2142
+ else:
2143
+ st.error("No items were successfully classified")
2144
+
2145
+ with col_output:
2146
+ st.markdown("### Results")
2147
+
2148
+ if st.session_state.results:
2149
+ results = st.session_state.results
2150
+
2151
+ # Visualization selector
2152
+ viz_type = st.selectbox(
2153
+ "Visualization",
2154
+ options=["Category Distribution", "Classification Matrix"],
2155
+ key="viz_type",
2156
+ help="Distribution shows category percentages. Matrix shows each response's classifications."
2157
+ )
2158
+
2159
+ if viz_type == "Category Distribution":
2160
+ fig = create_distribution_chart(
2161
+ results['df'],
2162
+ results['categories'],
2163
+ classify_mode=results.get('classify_mode', 'Single Model'),
2164
+ models_list=results.get('models_list', [])
2165
+ )
2166
+ st.pyplot(fig)
2167
+ st.caption("Note: Categories are not mutually exclusive—each item can belong to multiple categories.")
2168
+ else:
2169
+ fig = create_classification_heatmap(
2170
+ results['df'],
2171
+ results['categories'],
2172
+ classify_mode=results.get('classify_mode', 'Single Model'),
2173
+ models_list=results.get('models_list', [])
2174
+ )
2175
+ st.pyplot(fig)
2176
+ st.caption("Green = category present, Black = not present. Each row is one response.")
2177
+
2178
+ # Results dataframe (hide technical columns from display)
2179
+ display_df = results['df'].copy()
2180
+ cols_to_hide = ['model_response', 'json', 'raw_response', 'raw_json']
2181
+ display_df = display_df.drop(columns=[c for c in cols_to_hide if c in display_df.columns])
2182
+ st.dataframe(display_df, use_container_width=True)
2183
+
2184
+ # Downloads
2185
+ col_dl1, col_dl2, col_dl3 = st.columns(3)
2186
+ with col_dl1:
2187
+ with open(results['csv_path'], 'rb') as f:
2188
+ st.download_button(
2189
+ "Download CSV",
2190
+ data=f,
2191
+ file_name="classified_results.csv",
2192
+ mime="text/csv"
2193
+ )
2194
+ with col_dl2:
2195
+ with open(results['pdf_path'], 'rb') as f:
2196
+ st.download_button(
2197
+ "Download Report",
2198
+ data=f,
2199
+ file_name="methodology_report.pdf",
2200
+ mime="application/pdf"
2201
+ )
2202
+ with col_dl3:
2203
+ # Generate both plots and save to a single PDF
2204
+ import io
2205
+ from matplotlib.backends.backend_pdf import PdfPages
2206
+
2207
+ plot_buffer = io.BytesIO()
2208
+ with PdfPages(plot_buffer) as pdf:
2209
+ # Distribution chart
2210
+ fig1 = create_distribution_chart(
2211
+ results['df'],
2212
+ results['categories'],
2213
+ classify_mode=results.get('classify_mode', 'Single Model'),
2214
+ models_list=results.get('models_list', [])
2215
+ )
2216
+ pdf.savefig(fig1, bbox_inches='tight')
2217
+ plt.close(fig1)
2218
+
2219
+ # Classification matrix
2220
+ fig2 = create_classification_heatmap(
2221
+ results['df'],
2222
+ results['categories'],
2223
+ classify_mode=results.get('classify_mode', 'Single Model'),
2224
+ models_list=results.get('models_list', [])
2225
+ )
2226
+ pdf.savefig(fig2, bbox_inches='tight')
2227
+ plt.close(fig2)
2228
+
2229
+ plot_buffer.seek(0)
2230
+ st.download_button(
2231
+ "Download Plots",
2232
+ data=plot_buffer,
2233
+ file_name="classification_plots.pdf",
2234
+ mime="application/pdf"
2235
+ )
2236
+
2237
+ # Code
2238
+ with st.expander("See the Code"):
2239
+ st.code(results['code'], language='python')
2240
+ else:
2241
+ st.info("Upload data, select categories, and click 'Categorize Data' to see results here.")
2242
+
2243
+ # Bottom buttons
2244
+ col_reset, col_code = st.columns(2)
2245
+ with col_reset:
2246
+ if st.button("Reset", type="secondary", use_container_width=True):
2247
+ st.session_state.categories = [''] * MAX_CATEGORIES
2248
+ st.session_state.category_count = INITIAL_CATEGORIES
2249
+ st.session_state.task_mode = None
2250
+ st.session_state.extracted_categories = None
2251
+ st.session_state.extraction_params = None
2252
+ st.session_state.results = None
2253
+ if hasattr(st.session_state, 'example_loaded'):
2254
+ del st.session_state.example_loaded
2255
+ st.rerun()
2256
+
2257
+ with col_code:
2258
+ if st.button("See in Code", use_container_width=True):
2259
+ st.session_state.show_code_modal = True
2260
+
2261
+ # Code modal/dialog
2262
+ if st.session_state.get('show_code_modal'):
2263
+ st.markdown("---")
2264
+ st.markdown("### Reproducibility Code")
2265
+ st.markdown("Use this code to reproduce the classification with the CatVader Python package:")
2266
+
2267
+ # Use results code if available, otherwise generate from current parameters
2268
+ if st.session_state.results:
2269
+ code_to_show = st.session_state.results['code']
2270
+ else:
2271
+ # Get current categories from session state
2272
+ current_categories = [c for c in st.session_state.categories[:st.session_state.category_count] if c.strip()]
2273
+
2274
+ # Determine current input type and description
2275
+ input_type_map = {"Social Media Posts": "text", "PDF Documents": "pdf", "Images": "image"}
2276
+ current_input_type = input_type_map.get(st.session_state.get('input_type_radio', 'Social Media Posts'), 'text')
2277
+ current_description = st.session_state.get('survey_column', '') or st.session_state.get('pdf_desc', '') or st.session_state.get('image_desc', '') or 'your_data'
2278
+
2279
+ # Get current classification mode and models
2280
+ current_classify_mode = st.session_state.get('classify_mode', 'Single Model')
2281
+ current_model_tier = st.session_state.get('classify_model_tier', 'Free Models')
2282
+
2283
+ if current_classify_mode in ["Model Comparison", "Ensemble"]:
2284
+ # Multi-model mode
2285
+ if current_model_tier == 'Free Models':
2286
+ model_displays = st.session_state.get('classify_models_multi', [])
2287
+ current_models_list = [FREE_MODELS_MAP.get(d, d) for d in model_displays]
2288
+ else:
2289
+ current_models_list = st.session_state.get('classify_models_multi_paid', [])
2290
+ current_model = ", ".join(current_models_list) if current_models_list else "gpt-4o-mini"
2291
+ current_model_source = f"{current_classify_mode} ({len(current_models_list)} models)"
2292
+ else:
2293
+ # Single model mode
2294
+ if current_model_tier == 'Free Models':
2295
+ model_display = st.session_state.get('classify_model', 'GPT-4o Mini')
2296
+ current_model = FREE_MODELS_MAP.get(model_display, 'gpt-4o-mini')
2297
+ else:
2298
+ current_model = st.session_state.get('classify_model_paid', 'gpt-4o-mini')
2299
+ current_models_list = [current_model]
2300
+ current_model_source = get_model_source(current_model)
2301
+
2302
+ # Get consensus threshold for ensemble mode
2303
+ consensus_options = {"Majority (50%+)": 0.5, "Two-Thirds (67%+)": 0.67, "Unanimous (100%)": 1.0}
2304
+ current_consensus = consensus_options.get(st.session_state.get('consensus_choice', 'Majority (50%+)'), 0.5)
2305
+
2306
+ # Get PDF mode if applicable
2307
+ current_mode = None
2308
+ if current_input_type == "pdf":
2309
+ mode_mapping = {
2310
+ "Image (visual documents)": "image",
2311
+ "Text (text-heavy)": "text",
2312
+ "Both (comprehensive)": "both"
2313
+ }
2314
+ current_mode = mode_mapping.get(st.session_state.get('pdf_mode', 'Image (visual documents)'), 'image')
2315
+
2316
+ if current_categories:
2317
+ # Check if categories were auto-extracted
2318
+ if st.session_state.extraction_params:
2319
+ current_temperatures = results.get('model_temperatures', {})
2320
+ classify_params = {
2321
+ 'model': current_model,
2322
+ 'description': current_description,
2323
+ 'mode': current_mode,
2324
+ 'classify_mode': current_classify_mode,
2325
+ 'models_list': current_models_list,
2326
+ 'consensus_threshold': current_consensus,
2327
+ 'model_temperatures': current_temperatures,
2328
+ 'ensemble_runs': results.get('ensemble_runs'),
2329
+ }
2330
+ code_to_show = generate_full_code(st.session_state.extraction_params, classify_params)
2331
+ else:
2332
+ current_temperatures = results.get('model_temperatures', {})
2333
+ code_to_show = generate_classify_code(
2334
+ current_input_type, current_description, current_categories,
2335
+ current_model, current_model_source, current_mode,
2336
+ classify_mode=current_classify_mode, models_list=current_models_list,
2337
+ consensus_threshold=current_consensus,
2338
+ model_temperatures=current_temperatures,
2339
+ ensemble_runs=results.get('ensemble_runs'),
2340
+ )
2341
+ else:
2342
+ code_to_show = '''import catvader
2343
+
2344
+ # Define your categories
2345
+ categories = [
2346
+ "Category 1",
2347
+ "Category 2",
2348
+ # Add more categories...
2349
+ ]
2350
+
2351
+ # Classify your data
2352
+ result = catvader.classify(
2353
+ input_data=df["your_column"].tolist(),
2354
+ categories=categories,
2355
+ api_key="YOUR_API_KEY",
2356
+ description="your_description",
2357
+ user_model="gpt-4o-mini"
2358
+ )
2359
+
2360
+ # View results
2361
+ print(result)
2362
+ result.to_csv("classified_results.csv", index=False)
2363
+ '''
2364
+
2365
+ st.code(code_to_show, language='python')
2366
+ if st.button("Close"):
2367
+ st.session_state.show_code_modal = False
2368
+ st.rerun()