Upload app.py with huggingface_hub
Browse files
app.py
ADDED
|
@@ -0,0 +1,2368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Streamlit app - CatVader Social Media Classifier
|
| 3 |
+
Migrated from Gradio for better mobile support
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import streamlit as st
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import tempfile
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
import sys
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
|
| 15 |
+
# Import catvader
|
| 16 |
+
try:
|
| 17 |
+
import catvader
|
| 18 |
+
CATVADER_AVAILABLE = True
|
| 19 |
+
except ImportError as e:
|
| 20 |
+
print(f"Warning: Could not import catvader: {e}")
|
| 21 |
+
CATVADER_AVAILABLE = False
|
| 22 |
+
|
| 23 |
+
MAX_CATEGORIES = 10
|
| 24 |
+
INITIAL_CATEGORIES = 3
|
| 25 |
+
MAX_FILE_SIZE_MB = 100
|
| 26 |
+
|
| 27 |
+
def count_pdf_pages(pdf_path):
|
| 28 |
+
"""Count the number of pages in a PDF file."""
|
| 29 |
+
try:
|
| 30 |
+
import fitz # PyMuPDF
|
| 31 |
+
doc = fitz.open(pdf_path)
|
| 32 |
+
page_count = len(doc)
|
| 33 |
+
doc.close()
|
| 34 |
+
return page_count
|
| 35 |
+
except Exception:
|
| 36 |
+
return 1 # Default to 1 if can't read
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def extract_text_from_pdfs(pdf_paths):
|
| 40 |
+
"""Extract text from all pages of all PDFs, returning list of page texts."""
|
| 41 |
+
import fitz # PyMuPDF
|
| 42 |
+
all_texts = []
|
| 43 |
+
for pdf_path in pdf_paths:
|
| 44 |
+
try:
|
| 45 |
+
doc = fitz.open(pdf_path)
|
| 46 |
+
for page in doc:
|
| 47 |
+
text = page.get_text().strip()
|
| 48 |
+
if text: # Only add non-empty pages
|
| 49 |
+
all_texts.append(text)
|
| 50 |
+
doc.close()
|
| 51 |
+
except Exception as e:
|
| 52 |
+
print(f"Error extracting text from {pdf_path}: {e}")
|
| 53 |
+
return all_texts
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def extract_pdf_pages(pdf_paths, pdf_name_map, mode="image"):
|
| 57 |
+
"""
|
| 58 |
+
Extract individual pages from PDFs.
|
| 59 |
+
Returns list of (page_data, page_label) tuples.
|
| 60 |
+
For image mode: page_data is path to temp image file
|
| 61 |
+
For text mode: page_data is extracted text
|
| 62 |
+
"""
|
| 63 |
+
import fitz # PyMuPDF
|
| 64 |
+
pages = []
|
| 65 |
+
|
| 66 |
+
for pdf_path in pdf_paths:
|
| 67 |
+
orig_name = pdf_name_map.get(pdf_path, os.path.basename(pdf_path).replace('.pdf', ''))
|
| 68 |
+
try:
|
| 69 |
+
doc = fitz.open(pdf_path)
|
| 70 |
+
for page_num, page in enumerate(doc, 1):
|
| 71 |
+
page_label = f"{orig_name}_p{page_num}"
|
| 72 |
+
|
| 73 |
+
if mode == "text":
|
| 74 |
+
# Extract text
|
| 75 |
+
text = page.get_text().strip()
|
| 76 |
+
if text:
|
| 77 |
+
pages.append((text, page_label, "text"))
|
| 78 |
+
else:
|
| 79 |
+
# Render as image (for image or both mode)
|
| 80 |
+
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x zoom for better quality
|
| 81 |
+
img_path = tempfile.NamedTemporaryFile(delete=False, suffix='.png').name
|
| 82 |
+
pix.save(img_path)
|
| 83 |
+
|
| 84 |
+
if mode == "both":
|
| 85 |
+
text = page.get_text().strip()
|
| 86 |
+
pages.append((img_path, page_label, "image", text))
|
| 87 |
+
else:
|
| 88 |
+
pages.append((img_path, page_label, "image"))
|
| 89 |
+
doc.close()
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"Error extracting pages from {pdf_path}: {e}")
|
| 92 |
+
|
| 93 |
+
return pages
|
| 94 |
+
|
| 95 |
+
# Free models - display name -> actual API model name
|
| 96 |
+
FREE_MODELS_MAP = {
|
| 97 |
+
"GPT-4o Mini": "gpt-4o-mini",
|
| 98 |
+
"Gemini 2.5 Flash": "gemini-2.5-flash",
|
| 99 |
+
"Claude 3 Haiku": "claude-3-haiku-20240307",
|
| 100 |
+
"Llama 3.3 70B": "meta-llama/Llama-3.3-70B-Instruct:groq",
|
| 101 |
+
"Qwen 2.5": "Qwen/Qwen2.5-72B-Instruct",
|
| 102 |
+
"DeepSeek R1": "deepseek-ai/DeepSeek-R1:novita",
|
| 103 |
+
"Mistral Medium": "mistral-medium-2505",
|
| 104 |
+
"Grok 4 Fast": "grok-4-fast-non-reasoning",
|
| 105 |
+
}
|
| 106 |
+
FREE_MODEL_DISPLAY_NAMES = list(FREE_MODELS_MAP.keys())
|
| 107 |
+
FREE_MODEL_CHOICES = list(FREE_MODELS_MAP.values()) # Keep for backward compat
|
| 108 |
+
|
| 109 |
+
# Paid models (user provides their own API key)
|
| 110 |
+
PAID_MODEL_CHOICES = [
|
| 111 |
+
"gemini-2.5-flash",
|
| 112 |
+
"gemini-2.5-pro",
|
| 113 |
+
"gpt-4.1",
|
| 114 |
+
"gpt-4o",
|
| 115 |
+
"gpt-4o-mini",
|
| 116 |
+
"claude-sonnet-4-5-20250929",
|
| 117 |
+
"claude-opus-4-20250514",
|
| 118 |
+
"claude-3-5-haiku-20241022",
|
| 119 |
+
"mistral-large-latest",
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
# Models routed through HuggingFace
|
| 123 |
+
HF_ROUTED_MODELS = [
|
| 124 |
+
"meta-llama/Llama-3.3-70B-Instruct:groq",
|
| 125 |
+
"deepseek-ai/DeepSeek-R1:novita",
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def is_free_model(model, model_tier):
|
| 130 |
+
"""Check if using free tier (Space pays for API)."""
|
| 131 |
+
return model_tier == "Free Models"
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def get_model_source(model):
|
| 135 |
+
"""Auto-detect model source."""
|
| 136 |
+
model_lower = model.lower()
|
| 137 |
+
if "gpt" in model_lower:
|
| 138 |
+
return "openai"
|
| 139 |
+
elif "claude" in model_lower:
|
| 140 |
+
return "anthropic"
|
| 141 |
+
elif "gemini" in model_lower:
|
| 142 |
+
return "google"
|
| 143 |
+
elif "mistral" in model_lower and ":novita" not in model_lower:
|
| 144 |
+
return "mistral"
|
| 145 |
+
elif any(x in model_lower for x in [":novita", ":groq", "qwen", "llama", "deepseek"]):
|
| 146 |
+
return "huggingface"
|
| 147 |
+
elif "sonar" in model_lower:
|
| 148 |
+
return "perplexity"
|
| 149 |
+
elif "grok" in model_lower:
|
| 150 |
+
return "xai"
|
| 151 |
+
return "huggingface"
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def get_api_key(model, model_tier, api_key_input):
|
| 155 |
+
"""Get the appropriate API key based on model and tier."""
|
| 156 |
+
if is_free_model(model, model_tier):
|
| 157 |
+
if model in HF_ROUTED_MODELS:
|
| 158 |
+
return os.environ.get("HF_API_KEY", ""), "HuggingFace"
|
| 159 |
+
elif "gpt" in model.lower():
|
| 160 |
+
return os.environ.get("OPENAI_API_KEY", ""), "OpenAI"
|
| 161 |
+
elif "gemini" in model.lower():
|
| 162 |
+
return os.environ.get("GOOGLE_API_KEY", ""), "Google"
|
| 163 |
+
elif "mistral" in model.lower():
|
| 164 |
+
return os.environ.get("MISTRAL_API_KEY", ""), "Mistral"
|
| 165 |
+
elif "claude" in model.lower():
|
| 166 |
+
return os.environ.get("ANTHROPIC_API_KEY", ""), "Anthropic"
|
| 167 |
+
elif "sonar" in model.lower():
|
| 168 |
+
return os.environ.get("PERPLEXITY_API_KEY", ""), "Perplexity"
|
| 169 |
+
elif "grok" in model.lower():
|
| 170 |
+
return os.environ.get("XAI_API_KEY", ""), "xAI"
|
| 171 |
+
else:
|
| 172 |
+
return os.environ.get("HF_API_KEY", ""), "HuggingFace"
|
| 173 |
+
else:
|
| 174 |
+
if api_key_input and api_key_input.strip():
|
| 175 |
+
return api_key_input.strip(), "User"
|
| 176 |
+
return "", "User"
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def calculate_total_file_size(files):
|
| 180 |
+
"""Calculate total size of uploaded files in MB."""
|
| 181 |
+
if files is None:
|
| 182 |
+
return 0
|
| 183 |
+
if not isinstance(files, list):
|
| 184 |
+
files = [files]
|
| 185 |
+
|
| 186 |
+
total_bytes = 0
|
| 187 |
+
for f in files:
|
| 188 |
+
try:
|
| 189 |
+
if hasattr(f, 'size'):
|
| 190 |
+
total_bytes += f.size
|
| 191 |
+
elif hasattr(f, 'name'):
|
| 192 |
+
total_bytes += os.path.getsize(f.name)
|
| 193 |
+
except (OSError, AttributeError):
|
| 194 |
+
pass
|
| 195 |
+
return total_bytes / (1024 * 1024)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def generate_extract_code(input_type, description, model, model_source, max_categories, mode=None):
|
| 199 |
+
"""Generate Python code for category extraction."""
|
| 200 |
+
if input_type == "text":
|
| 201 |
+
return f'''import catvader
|
| 202 |
+
import pandas as pd
|
| 203 |
+
|
| 204 |
+
# Load your data
|
| 205 |
+
df = pd.read_csv("your_data.csv")
|
| 206 |
+
|
| 207 |
+
# Extract categories from the text column
|
| 208 |
+
result = catvader.extract(
|
| 209 |
+
input_data=df["{description}"].tolist(),
|
| 210 |
+
api_key="YOUR_API_KEY",
|
| 211 |
+
input_type="text",
|
| 212 |
+
description="{description}",
|
| 213 |
+
user_model="{model}",
|
| 214 |
+
model_source="{model_source}",
|
| 215 |
+
max_categories={max_categories}
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# View extracted categories
|
| 219 |
+
print(result["top_categories"])
|
| 220 |
+
print(result["counts_df"])
|
| 221 |
+
'''
|
| 222 |
+
elif input_type == "pdf":
|
| 223 |
+
mode_line = f',\n mode="{mode}"' if mode else ''
|
| 224 |
+
return f'''import catvader
|
| 225 |
+
|
| 226 |
+
# Extract categories from PDF documents
|
| 227 |
+
result = catvader.extract(
|
| 228 |
+
input_data="path/to/your/pdfs/",
|
| 229 |
+
api_key="YOUR_API_KEY",
|
| 230 |
+
input_type="pdf",
|
| 231 |
+
description="{description}"{mode_line},
|
| 232 |
+
user_model="{model}",
|
| 233 |
+
model_source="{model_source}",
|
| 234 |
+
max_categories={max_categories}
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# View extracted categories
|
| 238 |
+
print(result["top_categories"])
|
| 239 |
+
print(result["counts_df"])
|
| 240 |
+
'''
|
| 241 |
+
else: # image
|
| 242 |
+
return f'''import catvader
|
| 243 |
+
|
| 244 |
+
# Extract categories from images
|
| 245 |
+
result = catvader.extract(
|
| 246 |
+
input_data="path/to/your/images/",
|
| 247 |
+
api_key="YOUR_API_KEY",
|
| 248 |
+
input_type="image",
|
| 249 |
+
description="{description}",
|
| 250 |
+
user_model="{model}",
|
| 251 |
+
model_source="{model_source}",
|
| 252 |
+
max_categories={max_categories}
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# View extracted categories
|
| 256 |
+
print(result["top_categories"])
|
| 257 |
+
print(result["counts_df"])
|
| 258 |
+
'''
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def generate_full_code(extraction_params, classify_params):
|
| 262 |
+
"""Generate combined extract + classify code when categories were auto-extracted."""
|
| 263 |
+
ext = extraction_params
|
| 264 |
+
cls = classify_params
|
| 265 |
+
|
| 266 |
+
# Determine input data placeholder
|
| 267 |
+
if ext['input_type'] == "text":
|
| 268 |
+
input_placeholder = 'df["your_column"].tolist()'
|
| 269 |
+
load_data = '''import pandas as pd
|
| 270 |
+
|
| 271 |
+
# Load your data
|
| 272 |
+
df = pd.read_csv("your_data.csv")
|
| 273 |
+
'''
|
| 274 |
+
elif ext['input_type'] == "pdf":
|
| 275 |
+
input_placeholder = '"path/to/your/pdfs/"'
|
| 276 |
+
load_data = ''
|
| 277 |
+
else:
|
| 278 |
+
input_placeholder = '"path/to/your/images/"'
|
| 279 |
+
load_data = ''
|
| 280 |
+
|
| 281 |
+
mode_param = f',\n mode="{ext["mode"]}"' if ext.get('mode') else ''
|
| 282 |
+
|
| 283 |
+
# Build extract code
|
| 284 |
+
extract_code = f'''# Step 1: Extract categories from your data
|
| 285 |
+
extract_result = catvader.extract(
|
| 286 |
+
input_data={input_placeholder},
|
| 287 |
+
api_key="YOUR_API_KEY",
|
| 288 |
+
description="{ext['description']}",
|
| 289 |
+
user_model="{ext['model']}",
|
| 290 |
+
max_categories={ext['max_categories']}{mode_param}
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
categories = extract_result["top_categories"]
|
| 294 |
+
print(f"Extracted {{len(categories)}} categories: {{categories}}")
|
| 295 |
+
'''
|
| 296 |
+
|
| 297 |
+
# Build classify code based on mode
|
| 298 |
+
if cls['classify_mode'] == "Single Model":
|
| 299 |
+
classify_mode_param = f',\n mode="{cls["mode"]}"' if cls.get('mode') and ext['input_type'] == "pdf" else ''
|
| 300 |
+
classify_code = f'''
|
| 301 |
+
# Step 2: Classify data using extracted categories
|
| 302 |
+
result = catvader.classify(
|
| 303 |
+
input_data={input_placeholder},
|
| 304 |
+
categories=categories,
|
| 305 |
+
api_key="YOUR_API_KEY",
|
| 306 |
+
description="{cls['description']}",
|
| 307 |
+
user_model="{cls['model']}"{classify_mode_param}
|
| 308 |
+
)'''
|
| 309 |
+
else:
|
| 310 |
+
# Multi-model mode — include per-model temperatures when set
|
| 311 |
+
ens_runs = cls.get('ensemble_runs')
|
| 312 |
+
model_lines = []
|
| 313 |
+
if ens_runs:
|
| 314 |
+
for m, temp in ens_runs:
|
| 315 |
+
model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})')
|
| 316 |
+
else:
|
| 317 |
+
model_temps = cls.get('model_temperatures', {})
|
| 318 |
+
for m in cls['models_list']:
|
| 319 |
+
temp = model_temps.get(m) if model_temps else None
|
| 320 |
+
if temp is not None:
|
| 321 |
+
model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})')
|
| 322 |
+
else:
|
| 323 |
+
model_lines.append(f'("{m}", "auto", "YOUR_API_KEY")')
|
| 324 |
+
models_str = ",\n ".join(model_lines)
|
| 325 |
+
|
| 326 |
+
classify_mode_param = f',\n mode="{cls["mode"]}"' if cls.get('mode') and ext['input_type'] == "pdf" else ''
|
| 327 |
+
threshold_str = "majority" if cls['consensus_threshold'] == 0.5 else "two-thirds" if cls['consensus_threshold'] == 0.67 else "unanimous"
|
| 328 |
+
consensus_param = f',\n consensus_threshold="{threshold_str}"' if cls['classify_mode'] == "Ensemble" else ''
|
| 329 |
+
|
| 330 |
+
classify_code = f'''
|
| 331 |
+
# Step 2: Classify data using extracted categories with {"ensemble voting" if cls['classify_mode'] == "Ensemble" else "model comparison"}
|
| 332 |
+
models = [
|
| 333 |
+
{models_str}
|
| 334 |
+
]
|
| 335 |
+
|
| 336 |
+
result = catvader.classify(
|
| 337 |
+
input_data={input_placeholder},
|
| 338 |
+
categories=categories,
|
| 339 |
+
models=models,
|
| 340 |
+
description="{cls['description']}"{classify_mode_param}{consensus_param}
|
| 341 |
+
)'''
|
| 342 |
+
|
| 343 |
+
return f'''import catvader
|
| 344 |
+
{load_data}
|
| 345 |
+
{extract_code}
|
| 346 |
+
{classify_code}
|
| 347 |
+
|
| 348 |
+
# View results
|
| 349 |
+
print(result)
|
| 350 |
+
result.to_csv("classified_results.csv", index=False)
|
| 351 |
+
'''
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def generate_classify_code(input_type, description, categories, model, model_source, mode=None,
|
| 355 |
+
classify_mode="Single Model", models_list=None, consensus_threshold=0.5,
|
| 356 |
+
model_temperatures=None, ensemble_runs=None):
|
| 357 |
+
"""Generate Python code for classification."""
|
| 358 |
+
categories_str = ",\n ".join([f'"{cat}"' for cat in categories])
|
| 359 |
+
|
| 360 |
+
# Determine input data placeholder based on type
|
| 361 |
+
if input_type == "text":
|
| 362 |
+
input_placeholder = 'df["your_column"].tolist()'
|
| 363 |
+
load_data = '''import pandas as pd
|
| 364 |
+
|
| 365 |
+
# Load your data
|
| 366 |
+
df = pd.read_csv("your_data.csv")
|
| 367 |
+
'''
|
| 368 |
+
elif input_type == "pdf":
|
| 369 |
+
input_placeholder = '"path/to/your/pdfs/"'
|
| 370 |
+
load_data = ''
|
| 371 |
+
else: # image
|
| 372 |
+
input_placeholder = '"path/to/your/images/"'
|
| 373 |
+
load_data = ''
|
| 374 |
+
|
| 375 |
+
# Generate code based on classification mode
|
| 376 |
+
if classify_mode == "Single Model":
|
| 377 |
+
# Single model mode
|
| 378 |
+
mode_param = f',\n mode="{mode}"' if mode and input_type == "pdf" else ''
|
| 379 |
+
return f'''import catvader
|
| 380 |
+
{load_data}
|
| 381 |
+
# Define categories
|
| 382 |
+
categories = [
|
| 383 |
+
{categories_str}
|
| 384 |
+
]
|
| 385 |
+
|
| 386 |
+
# Classify data (input type is auto-detected)
|
| 387 |
+
result = catvader.classify(
|
| 388 |
+
input_data={input_placeholder},
|
| 389 |
+
categories=categories,
|
| 390 |
+
api_key="YOUR_API_KEY",
|
| 391 |
+
description="{description}",
|
| 392 |
+
user_model="{model}"{mode_param}
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# View results
|
| 396 |
+
print(result)
|
| 397 |
+
result.to_csv("classified_results.csv", index=False)
|
| 398 |
+
'''
|
| 399 |
+
else:
|
| 400 |
+
# Multi-model mode (Comparison or Ensemble)
|
| 401 |
+
# Build model tuples with per-model temperature when set
|
| 402 |
+
if ensemble_runs:
|
| 403 |
+
# Ensemble with explicit (model, temp) pairs (supports duplicate models)
|
| 404 |
+
model_lines = []
|
| 405 |
+
for m, temp in ensemble_runs:
|
| 406 |
+
model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})')
|
| 407 |
+
models_str = ",\n ".join(model_lines)
|
| 408 |
+
elif models_list:
|
| 409 |
+
model_lines = []
|
| 410 |
+
for m in models_list:
|
| 411 |
+
temp = model_temperatures.get(m) if model_temperatures else None
|
| 412 |
+
if temp is not None:
|
| 413 |
+
model_lines.append(f'("{m}", "auto", "YOUR_API_KEY", {{"creativity": {temp}}})')
|
| 414 |
+
else:
|
| 415 |
+
model_lines.append(f'("{m}", "auto", "YOUR_API_KEY")')
|
| 416 |
+
models_str = ",\n ".join(model_lines)
|
| 417 |
+
else:
|
| 418 |
+
models_str = '("gpt-4o", "auto", "YOUR_API_KEY"),\n ("claude-sonnet-4-5-20250929", "auto", "YOUR_API_KEY")'
|
| 419 |
+
|
| 420 |
+
mode_param = f',\n mode="{mode}"' if mode and input_type == "pdf" else ''
|
| 421 |
+
# Map numeric threshold back to string for cleaner code
|
| 422 |
+
threshold_str = "majority" if consensus_threshold == 0.5 else "two-thirds" if consensus_threshold == 0.67 else "unanimous"
|
| 423 |
+
consensus_param = f',\n consensus_threshold="{threshold_str}"' if classify_mode == "Ensemble" else ''
|
| 424 |
+
|
| 425 |
+
return f'''import catvader
|
| 426 |
+
{load_data}
|
| 427 |
+
# Define categories
|
| 428 |
+
categories = [
|
| 429 |
+
{categories_str}
|
| 430 |
+
]
|
| 431 |
+
|
| 432 |
+
# Define models for {"ensemble voting" if classify_mode == "Ensemble" else "comparison"}
|
| 433 |
+
models = [
|
| 434 |
+
{models_str}
|
| 435 |
+
]
|
| 436 |
+
|
| 437 |
+
# Classify with multiple models
|
| 438 |
+
result = catvader.classify(
|
| 439 |
+
input_data={input_placeholder},
|
| 440 |
+
categories=categories,
|
| 441 |
+
models=models,
|
| 442 |
+
description="{description}"{mode_param}{consensus_param}
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# View results
|
| 446 |
+
print(result)
|
| 447 |
+
result.to_csv("classified_results.csv", index=False)
|
| 448 |
+
'''
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def generate_methodology_report_pdf(categories, model, column_name, num_rows, model_source, filename, success_rate,
|
| 452 |
+
result_df=None, processing_time=None, prompt_template=None,
|
| 453 |
+
data_quality=None, catvader_version=None, python_version=None,
|
| 454 |
+
task_type="assign", extracted_categories_df=None, max_categories=None,
|
| 455 |
+
input_type="text", description=None, classify_mode="Single Model",
|
| 456 |
+
models_list=None, code=None, consensus_threshold=None):
|
| 457 |
+
"""Generate a PDF methodology report."""
|
| 458 |
+
from reportlab.lib.pagesizes import letter
|
| 459 |
+
from reportlab.lib import colors
|
| 460 |
+
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
|
| 461 |
+
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Table, TableStyle, PageBreak
|
| 462 |
+
|
| 463 |
+
pdf_file = tempfile.NamedTemporaryFile(mode='wb', suffix='_methodology_report.pdf', delete=False)
|
| 464 |
+
doc = SimpleDocTemplate(pdf_file.name, pagesize=letter)
|
| 465 |
+
styles = getSampleStyleSheet()
|
| 466 |
+
|
| 467 |
+
title_style = ParagraphStyle('Title', parent=styles['Heading1'], fontSize=18, spaceAfter=20)
|
| 468 |
+
heading_style = ParagraphStyle('Heading', parent=styles['Heading2'], fontSize=14, spaceAfter=10, spaceBefore=15)
|
| 469 |
+
normal_style = styles['Normal']
|
| 470 |
+
code_style = ParagraphStyle('Code', parent=styles['Normal'], fontName='Courier', fontSize=9, leftIndent=20, spaceAfter=3)
|
| 471 |
+
|
| 472 |
+
story = []
|
| 473 |
+
|
| 474 |
+
if task_type == "extract_and_assign":
|
| 475 |
+
report_title = "CatVader Extraction & Classification Report"
|
| 476 |
+
else:
|
| 477 |
+
report_title = "CatVader Classification Report"
|
| 478 |
+
|
| 479 |
+
story.append(Paragraph(report_title, title_style))
|
| 480 |
+
story.append(Paragraph(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", normal_style))
|
| 481 |
+
story.append(Spacer(1, 15))
|
| 482 |
+
|
| 483 |
+
story.append(Paragraph("About This Report", heading_style))
|
| 484 |
+
|
| 485 |
+
if task_type == "extract_and_assign":
|
| 486 |
+
about_text = """This methodology report documents the automated category extraction and classification process. \
|
| 487 |
+
CatVader first discovers categories from your data using LLMs, then classifies each item into those categories."""
|
| 488 |
+
else:
|
| 489 |
+
about_text = """This methodology report documents the classification process for reproducibility and transparency. \
|
| 490 |
+
CatVader restricts the prompt to a standard template that is impartial to the researcher's inclinations, ensuring \
|
| 491 |
+
consistent and reproducible results."""
|
| 492 |
+
|
| 493 |
+
story.append(Paragraph(about_text, normal_style))
|
| 494 |
+
story.append(Spacer(1, 15))
|
| 495 |
+
|
| 496 |
+
if categories:
|
| 497 |
+
story.append(Paragraph("Category Mapping", heading_style))
|
| 498 |
+
|
| 499 |
+
if classify_mode in ("Ensemble", "Model Comparison") and result_df is not None:
|
| 500 |
+
# Multi-model: show per-model columns and consensus columns
|
| 501 |
+
story.append(Paragraph("Each model produces its own binary columns. "
|
| 502 |
+
"Consensus columns show the majority vote result.", normal_style))
|
| 503 |
+
story.append(Spacer(1, 8))
|
| 504 |
+
|
| 505 |
+
# Detect ALL distinct model suffixes directly from the DataFrame
|
| 506 |
+
# (handles same-model-different-temperature cases correctly)
|
| 507 |
+
all_suffixes = _find_all_model_suffixes(result_df)
|
| 508 |
+
|
| 509 |
+
category_data = [["Column Name", "Category Description"]]
|
| 510 |
+
for i, cat in enumerate(categories, 1):
|
| 511 |
+
# Per-model columns (each suffix is a unique model/temperature)
|
| 512 |
+
for suffix in all_suffixes:
|
| 513 |
+
category_data.append([f"category_{i}_{suffix}", f"{cat} ({suffix})"])
|
| 514 |
+
# Consensus + agreement columns
|
| 515 |
+
category_data.append([f"category_{i}_consensus", f"{cat} (consensus)"])
|
| 516 |
+
category_data.append([f"category_{i}_agreement", f"{cat} (agreement score)"])
|
| 517 |
+
|
| 518 |
+
cat_table = Table(category_data, colWidths=[200, 250])
|
| 519 |
+
else:
|
| 520 |
+
# Single model: simple mapping
|
| 521 |
+
story.append(Paragraph("Each category column contains binary values: 1 = present, 0 = not present", normal_style))
|
| 522 |
+
story.append(Spacer(1, 8))
|
| 523 |
+
|
| 524 |
+
category_data = [["Column Name", "Category Description"]]
|
| 525 |
+
for i, cat in enumerate(categories, 1):
|
| 526 |
+
category_data.append([f"category_{i}", cat])
|
| 527 |
+
|
| 528 |
+
cat_table = Table(category_data, colWidths=[120, 330])
|
| 529 |
+
|
| 530 |
+
cat_table.setStyle(TableStyle([
|
| 531 |
+
('BACKGROUND', (0, 0), (-1, 0), colors.grey),
|
| 532 |
+
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
| 533 |
+
('GRID', (0, 0), (-1, -1), 1, colors.black),
|
| 534 |
+
('PADDING', (0, 0), (-1, -1), 6),
|
| 535 |
+
('BACKGROUND', (0, 1), (0, -1), colors.lightgrey),
|
| 536 |
+
('FONTSIZE', (0, 0), (-1, -1), 9),
|
| 537 |
+
]))
|
| 538 |
+
story.append(cat_table)
|
| 539 |
+
story.append(Spacer(1, 15))
|
| 540 |
+
|
| 541 |
+
story.append(Spacer(1, 30))
|
| 542 |
+
story.append(Paragraph("Citation", heading_style))
|
| 543 |
+
story.append(Paragraph("If you use CatVader in your research, please cite:", normal_style))
|
| 544 |
+
story.append(Spacer(1, 5))
|
| 545 |
+
story.append(Paragraph("Soria, C. (2025). CatVader: A Python package for LLM-based social media classification. DOI: 10.5281/zenodo.15532316", normal_style))
|
| 546 |
+
|
| 547 |
+
# Summary section
|
| 548 |
+
story.append(PageBreak())
|
| 549 |
+
story.append(Paragraph("Classification Summary", title_style))
|
| 550 |
+
story.append(Spacer(1, 15))
|
| 551 |
+
|
| 552 |
+
summary_data = [
|
| 553 |
+
["Source File", filename],
|
| 554 |
+
["Source Column", column_name],
|
| 555 |
+
["Classification Mode", classify_mode],
|
| 556 |
+
["Model(s) Used", model],
|
| 557 |
+
["Model Source", model_source],
|
| 558 |
+
["Rows Classified", str(num_rows)],
|
| 559 |
+
["Number of Categories", str(len(categories)) if categories else "0"],
|
| 560 |
+
["Success Rate", f"{success_rate:.2f}%"],
|
| 561 |
+
]
|
| 562 |
+
# Add consensus threshold for ensemble mode
|
| 563 |
+
if classify_mode == "Ensemble" and consensus_threshold is not None:
|
| 564 |
+
threshold_labels = {0.5: "Majority (50%+)", 0.67: "Two-Thirds (67%+)", 1.0: "Unanimous (100%)"}
|
| 565 |
+
threshold_label = threshold_labels.get(consensus_threshold, f"Custom ({consensus_threshold:.0%})")
|
| 566 |
+
summary_data.append(["Consensus Threshold", threshold_label])
|
| 567 |
+
|
| 568 |
+
summary_table = Table(summary_data, colWidths=[150, 300])
|
| 569 |
+
summary_table.setStyle(TableStyle([
|
| 570 |
+
('BACKGROUND', (0, 0), (0, -1), colors.lightgrey),
|
| 571 |
+
('GRID', (0, 0), (-1, -1), 1, colors.black),
|
| 572 |
+
('PADDING', (0, 0), (-1, -1), 6),
|
| 573 |
+
('FONTSIZE', (0, 0), (-1, -1), 9),
|
| 574 |
+
]))
|
| 575 |
+
story.append(summary_table)
|
| 576 |
+
story.append(Spacer(1, 15))
|
| 577 |
+
|
| 578 |
+
# Agreement scores table for ensemble mode
|
| 579 |
+
if classify_mode == "Ensemble" and result_df is not None and categories:
|
| 580 |
+
agreement_cols = [f"category_{i}_agreement" for i in range(1, len(categories) + 1)]
|
| 581 |
+
has_agreement = all(col in result_df.columns for col in agreement_cols)
|
| 582 |
+
if has_agreement:
|
| 583 |
+
story.append(Paragraph("Ensemble Agreement Scores", heading_style))
|
| 584 |
+
story.append(Paragraph(
|
| 585 |
+
"Agreement shows what proportion of models agreed on each category. "
|
| 586 |
+
"Higher scores indicate stronger consensus.", normal_style))
|
| 587 |
+
story.append(Spacer(1, 8))
|
| 588 |
+
|
| 589 |
+
agree_data = [["Category", "Mean Agreement", "Min Agreement"]]
|
| 590 |
+
for i, cat in enumerate(categories, 1):
|
| 591 |
+
col = f"category_{i}_agreement"
|
| 592 |
+
mean_val = result_df[col].mean()
|
| 593 |
+
min_val = result_df[col].min()
|
| 594 |
+
agree_data.append([cat, f"{mean_val:.1%}", f"{min_val:.1%}"])
|
| 595 |
+
|
| 596 |
+
agree_table = Table(agree_data, colWidths=[200, 125, 125])
|
| 597 |
+
agree_table.setStyle(TableStyle([
|
| 598 |
+
('BACKGROUND', (0, 0), (-1, 0), colors.grey),
|
| 599 |
+
('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
|
| 600 |
+
('GRID', (0, 0), (-1, -1), 1, colors.black),
|
| 601 |
+
('PADDING', (0, 0), (-1, -1), 6),
|
| 602 |
+
('FONTSIZE', (0, 0), (-1, -1), 9),
|
| 603 |
+
]))
|
| 604 |
+
story.append(agree_table)
|
| 605 |
+
story.append(Spacer(1, 15))
|
| 606 |
+
|
| 607 |
+
if processing_time is not None:
|
| 608 |
+
story.append(Paragraph("Processing Time", heading_style))
|
| 609 |
+
rows_per_min = (num_rows / processing_time) * 60 if processing_time > 0 else 0
|
| 610 |
+
avg_time = processing_time / num_rows if num_rows > 0 else 0
|
| 611 |
+
|
| 612 |
+
time_data = [
|
| 613 |
+
["Total Processing Time", f"{processing_time:.1f} seconds"],
|
| 614 |
+
["Average Time per Response", f"{avg_time:.2f} seconds"],
|
| 615 |
+
["Processing Rate", f"{rows_per_min:.1f} rows/minute"],
|
| 616 |
+
]
|
| 617 |
+
time_table = Table(time_data, colWidths=[180, 270])
|
| 618 |
+
time_table.setStyle(TableStyle([
|
| 619 |
+
('BACKGROUND', (0, 0), (0, -1), colors.lightgrey),
|
| 620 |
+
('GRID', (0, 0), (-1, -1), 1, colors.black),
|
| 621 |
+
('PADDING', (0, 0), (-1, -1), 6),
|
| 622 |
+
('FONTSIZE', (0, 0), (-1, -1), 9),
|
| 623 |
+
]))
|
| 624 |
+
story.append(time_table)
|
| 625 |
+
|
| 626 |
+
story.append(Spacer(1, 15))
|
| 627 |
+
story.append(Paragraph("Version Information", heading_style))
|
| 628 |
+
version_data = [
|
| 629 |
+
["CatVader Version", catvader_version or "unknown"],
|
| 630 |
+
["Python Version", python_version or "unknown"],
|
| 631 |
+
["Timestamp", datetime.now().strftime('%Y-%m-%d %H:%M:%S')],
|
| 632 |
+
]
|
| 633 |
+
version_table = Table(version_data, colWidths=[180, 270])
|
| 634 |
+
version_table.setStyle(TableStyle([
|
| 635 |
+
('BACKGROUND', (0, 0), (0, -1), colors.lightgrey),
|
| 636 |
+
('GRID', (0, 0), (-1, -1), 1, colors.black),
|
| 637 |
+
('PADDING', (0, 0), (-1, -1), 6),
|
| 638 |
+
('FONTSIZE', (0, 0), (-1, -1), 9),
|
| 639 |
+
]))
|
| 640 |
+
story.append(version_table)
|
| 641 |
+
|
| 642 |
+
# Reproducibility Code section
|
| 643 |
+
if code:
|
| 644 |
+
story.append(PageBreak())
|
| 645 |
+
story.append(Paragraph("Reproducibility Code", title_style))
|
| 646 |
+
story.append(Paragraph("Use this Python code to reproduce the classification with the CatVader package:", normal_style))
|
| 647 |
+
story.append(Spacer(1, 10))
|
| 648 |
+
|
| 649 |
+
# Split code into lines and add as code-formatted paragraphs
|
| 650 |
+
for line in code.strip().split('\n'):
|
| 651 |
+
# Escape special characters for reportlab
|
| 652 |
+
escaped_line = line.replace('&', '&').replace('<', '<').replace('>', '>')
|
| 653 |
+
if escaped_line.strip():
|
| 654 |
+
story.append(Paragraph(escaped_line, code_style))
|
| 655 |
+
else:
|
| 656 |
+
story.append(Spacer(1, 6))
|
| 657 |
+
|
| 658 |
+
# Visualizations section
|
| 659 |
+
if result_df is not None and categories:
|
| 660 |
+
from reportlab.platypus import Image
|
| 661 |
+
import io
|
| 662 |
+
|
| 663 |
+
# Distribution chart (new page)
|
| 664 |
+
story.append(PageBreak())
|
| 665 |
+
story.append(Paragraph("Category Distribution", title_style))
|
| 666 |
+
try:
|
| 667 |
+
fig1 = create_distribution_chart(result_df, categories, classify_mode, models_list)
|
| 668 |
+
img_buffer1 = io.BytesIO()
|
| 669 |
+
fig1.savefig(img_buffer1, format='png', dpi=150, bbox_inches='tight')
|
| 670 |
+
img_buffer1.seek(0)
|
| 671 |
+
plt.close(fig1)
|
| 672 |
+
|
| 673 |
+
# Save to temp file for reportlab
|
| 674 |
+
img_temp1 = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
| 675 |
+
img_temp1.write(img_buffer1.read())
|
| 676 |
+
img_temp1.close()
|
| 677 |
+
|
| 678 |
+
img1 = Image(img_temp1.name, width=450, height=250)
|
| 679 |
+
story.append(img1)
|
| 680 |
+
story.append(Spacer(1, 10))
|
| 681 |
+
story.append(Paragraph("Note: Categories are not mutually exclusive—each item can belong to multiple categories.", normal_style))
|
| 682 |
+
except Exception as e:
|
| 683 |
+
story.append(Paragraph(f"Could not generate distribution chart: {str(e)}", normal_style))
|
| 684 |
+
|
| 685 |
+
# Classification matrix (new page)
|
| 686 |
+
story.append(PageBreak())
|
| 687 |
+
story.append(Paragraph("Classification Matrix", title_style))
|
| 688 |
+
try:
|
| 689 |
+
fig2 = create_classification_heatmap(result_df, categories, classify_mode, models_list)
|
| 690 |
+
img_buffer2 = io.BytesIO()
|
| 691 |
+
fig2.savefig(img_buffer2, format='png', dpi=150, bbox_inches='tight')
|
| 692 |
+
img_buffer2.seek(0)
|
| 693 |
+
plt.close(fig2)
|
| 694 |
+
|
| 695 |
+
# Save to temp file for reportlab
|
| 696 |
+
img_temp2 = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
| 697 |
+
img_temp2.write(img_buffer2.read())
|
| 698 |
+
img_temp2.close()
|
| 699 |
+
|
| 700 |
+
img2 = Image(img_temp2.name, width=450, height=300)
|
| 701 |
+
story.append(img2)
|
| 702 |
+
story.append(Spacer(1, 10))
|
| 703 |
+
story.append(Paragraph("Orange = category present, Black = not present. Each row represents one response.", normal_style))
|
| 704 |
+
except Exception as e:
|
| 705 |
+
story.append(Paragraph(f"Could not generate classification matrix: {str(e)}", normal_style))
|
| 706 |
+
|
| 707 |
+
doc.build(story)
|
| 708 |
+
return pdf_file.name
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
def run_auto_extract(input_type, input_data, description, max_categories_val,
|
| 712 |
+
model_tier, model, api_key_input, mode=None, progress_callback=None):
|
| 713 |
+
"""Extract categories from data."""
|
| 714 |
+
if not CATVADER_AVAILABLE:
|
| 715 |
+
return None, "catvader package not available"
|
| 716 |
+
|
| 717 |
+
actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
|
| 718 |
+
if not actual_api_key:
|
| 719 |
+
return None, f"{provider} API key not configured"
|
| 720 |
+
|
| 721 |
+
model_source = get_model_source(model)
|
| 722 |
+
|
| 723 |
+
try:
|
| 724 |
+
if isinstance(input_data, list):
|
| 725 |
+
num_items = len(input_data)
|
| 726 |
+
else:
|
| 727 |
+
num_items = 1
|
| 728 |
+
|
| 729 |
+
if input_type == "image":
|
| 730 |
+
divisions = min(3, max(1, num_items // 5))
|
| 731 |
+
categories_per_chunk = 12
|
| 732 |
+
else:
|
| 733 |
+
divisions = max(1, num_items // 15)
|
| 734 |
+
divisions = min(divisions, 5)
|
| 735 |
+
chunk_size = num_items // max(1, divisions)
|
| 736 |
+
categories_per_chunk = min(10, chunk_size - 1)
|
| 737 |
+
|
| 738 |
+
extract_kwargs = {
|
| 739 |
+
'input_data': input_data,
|
| 740 |
+
'api_key': actual_api_key,
|
| 741 |
+
'input_type': input_type,
|
| 742 |
+
'description': description,
|
| 743 |
+
'user_model': model,
|
| 744 |
+
'model_source': model_source,
|
| 745 |
+
'divisions': divisions,
|
| 746 |
+
'categories_per_chunk': categories_per_chunk,
|
| 747 |
+
'max_categories': int(max_categories_val)
|
| 748 |
+
}
|
| 749 |
+
if mode:
|
| 750 |
+
extract_kwargs['mode'] = mode
|
| 751 |
+
|
| 752 |
+
extract_result = catvader.extract(**extract_kwargs)
|
| 753 |
+
categories = extract_result.get('top_categories', [])
|
| 754 |
+
|
| 755 |
+
if not categories:
|
| 756 |
+
return None, "No categories were extracted"
|
| 757 |
+
|
| 758 |
+
return categories, f"Extracted {len(categories)} categories successfully!"
|
| 759 |
+
|
| 760 |
+
except Exception as e:
|
| 761 |
+
return None, f"Error: {str(e)}"
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
def run_classify_data(input_type, input_data, description, categories,
|
| 765 |
+
model_tier, model, api_key_input, mode=None,
|
| 766 |
+
original_filename="data", column_name="text",
|
| 767 |
+
progress_callback=None):
|
| 768 |
+
"""Classify data with user-provided categories."""
|
| 769 |
+
if not CATVADER_AVAILABLE:
|
| 770 |
+
return None, None, None, None, "catvader package not available"
|
| 771 |
+
|
| 772 |
+
if not categories:
|
| 773 |
+
return None, None, None, None, "Please enter at least one category"
|
| 774 |
+
|
| 775 |
+
actual_api_key, provider = get_api_key(model, model_tier, api_key_input)
|
| 776 |
+
if not actual_api_key:
|
| 777 |
+
return None, None, None, None, f"{provider} API key not configured"
|
| 778 |
+
|
| 779 |
+
model_source = get_model_source(model)
|
| 780 |
+
|
| 781 |
+
try:
|
| 782 |
+
start_time = time.time()
|
| 783 |
+
|
| 784 |
+
classify_kwargs = {
|
| 785 |
+
'input_data': input_data,
|
| 786 |
+
'categories': categories,
|
| 787 |
+
'models': [(model, model_source, actual_api_key)],
|
| 788 |
+
'description': description,
|
| 789 |
+
}
|
| 790 |
+
if mode:
|
| 791 |
+
classify_kwargs['mode'] = mode
|
| 792 |
+
|
| 793 |
+
result = catvader.classify(**classify_kwargs)
|
| 794 |
+
|
| 795 |
+
processing_time = time.time() - start_time
|
| 796 |
+
num_items = len(result)
|
| 797 |
+
|
| 798 |
+
# Save CSV
|
| 799 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='_classified.csv', delete=False) as f:
|
| 800 |
+
result.to_csv(f.name, index=False)
|
| 801 |
+
csv_path = f.name
|
| 802 |
+
|
| 803 |
+
# Calculate success rate
|
| 804 |
+
if 'processing_status' in result.columns:
|
| 805 |
+
success_count = (result['processing_status'] == 'success').sum()
|
| 806 |
+
success_rate = (success_count / len(result)) * 100
|
| 807 |
+
else:
|
| 808 |
+
success_rate = 100.0
|
| 809 |
+
|
| 810 |
+
# Get version info
|
| 811 |
+
try:
|
| 812 |
+
catvader_version = catvader.__version__
|
| 813 |
+
except AttributeError:
|
| 814 |
+
catvader_version = "unknown"
|
| 815 |
+
python_version = sys.version.split()[0]
|
| 816 |
+
|
| 817 |
+
# Generate methodology report
|
| 818 |
+
report_pdf_path = generate_methodology_report_pdf(
|
| 819 |
+
categories=categories,
|
| 820 |
+
model=model,
|
| 821 |
+
column_name=column_name,
|
| 822 |
+
num_rows=num_items,
|
| 823 |
+
model_source=model_source,
|
| 824 |
+
filename=original_filename,
|
| 825 |
+
success_rate=success_rate,
|
| 826 |
+
result_df=result,
|
| 827 |
+
processing_time=processing_time,
|
| 828 |
+
catvader_version=catvader_version,
|
| 829 |
+
python_version=python_version,
|
| 830 |
+
task_type="assign",
|
| 831 |
+
input_type=input_type,
|
| 832 |
+
description=description
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
# Generate reproducibility code
|
| 836 |
+
code = generate_classify_code(input_type, description, categories, model, model_source, mode)
|
| 837 |
+
|
| 838 |
+
return result, csv_path, report_pdf_path, code, f"Classified {num_items} items in {processing_time:.1f}s"
|
| 839 |
+
|
| 840 |
+
except Exception as e:
|
| 841 |
+
return None, None, None, None, f"Error: {str(e)}"
|
| 842 |
+
|
| 843 |
+
|
| 844 |
+
def sanitize_model_name(model: str) -> str:
|
| 845 |
+
"""Convert model name to column-safe suffix (matches catvader logic)."""
|
| 846 |
+
import re
|
| 847 |
+
sanitized = re.sub(r'[^a-zA-Z0-9]', '_', model)
|
| 848 |
+
sanitized = re.sub(r'_+', '_', sanitized)
|
| 849 |
+
sanitized = sanitized.strip('_').lower()
|
| 850 |
+
return sanitized[:40]
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
def _find_model_column_suffix(result_df, model_name):
|
| 854 |
+
"""Find the actual column suffix used for a model in the DataFrame.
|
| 855 |
+
|
| 856 |
+
catvader appends a creativity suffix (e.g. _tauto, _t50) to ensemble column
|
| 857 |
+
names, so we can't just use sanitize_model_name(). This function looks at
|
| 858 |
+
the real DataFrame columns to discover the full suffix.
|
| 859 |
+
"""
|
| 860 |
+
sanitized = sanitize_model_name(model_name)
|
| 861 |
+
prefix = f"category_1_{sanitized}"
|
| 862 |
+
for col in result_df.columns:
|
| 863 |
+
if col.startswith(prefix):
|
| 864 |
+
# Return everything after "category_1_"
|
| 865 |
+
return col[len("category_1_"):]
|
| 866 |
+
# Fallback: return just the sanitized name
|
| 867 |
+
return sanitized
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
def _find_all_model_suffixes(result_df):
|
| 871 |
+
"""Discover all distinct per-model column suffixes from the DataFrame.
|
| 872 |
+
|
| 873 |
+
Looks at category_1_* columns (excluding _consensus and _agreement)
|
| 874 |
+
to find every unique model suffix. Works even when the same model
|
| 875 |
+
appears multiple times with different temperature suffixes.
|
| 876 |
+
|
| 877 |
+
Returns:
|
| 878 |
+
List of suffix strings, e.g.
|
| 879 |
+
['claude_haiku_4_5_20251001_t0', 'claude_haiku_4_5_20251001_t25', ...]
|
| 880 |
+
"""
|
| 881 |
+
import re
|
| 882 |
+
suffixes = []
|
| 883 |
+
for col in result_df.columns:
|
| 884 |
+
m = re.match(r'^category_1_(.+)$', col)
|
| 885 |
+
if m:
|
| 886 |
+
suffix = m.group(1)
|
| 887 |
+
if suffix not in ('consensus', 'agreement'):
|
| 888 |
+
suffixes.append(suffix)
|
| 889 |
+
return suffixes
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
def create_classification_heatmap(result_df, categories, classify_mode="Single Model", models_list=None):
|
| 893 |
+
"""Create a binary heatmap showing classification for each row.
|
| 894 |
+
|
| 895 |
+
Args:
|
| 896 |
+
result_df: DataFrame with classification results
|
| 897 |
+
categories: List of category names
|
| 898 |
+
classify_mode: "Single Model", "Model Comparison", or "Ensemble"
|
| 899 |
+
models_list: List of model names (for multi-model modes)
|
| 900 |
+
"""
|
| 901 |
+
import numpy as np
|
| 902 |
+
|
| 903 |
+
total_rows = len(result_df)
|
| 904 |
+
if total_rows == 0:
|
| 905 |
+
fig, ax = plt.subplots(figsize=(10, 4))
|
| 906 |
+
ax.text(0.5, 0.5, 'No data to display', ha='center', va='center', fontsize=14)
|
| 907 |
+
ax.axis('off')
|
| 908 |
+
return fig
|
| 909 |
+
|
| 910 |
+
# Build the binary matrix based on classify_mode
|
| 911 |
+
if classify_mode == "Ensemble":
|
| 912 |
+
# Use consensus columns
|
| 913 |
+
col_names = [f"category_{i}_consensus" for i in range(1, len(categories) + 1)]
|
| 914 |
+
elif classify_mode == "Model Comparison" and models_list:
|
| 915 |
+
# Use first model's columns (detect actual suffix from DataFrame)
|
| 916 |
+
suffix = _find_model_column_suffix(result_df, models_list[0])
|
| 917 |
+
col_names = [f"category_{i}_{suffix}" for i in range(1, len(categories) + 1)]
|
| 918 |
+
else:
|
| 919 |
+
# Single model
|
| 920 |
+
col_names = [f"category_{i}" for i in range(1, len(categories) + 1)]
|
| 921 |
+
|
| 922 |
+
# Extract the binary matrix
|
| 923 |
+
matrix_data = []
|
| 924 |
+
for col in col_names:
|
| 925 |
+
if col in result_df.columns:
|
| 926 |
+
matrix_data.append(result_df[col].astype(int).values)
|
| 927 |
+
else:
|
| 928 |
+
matrix_data.append(np.zeros(total_rows, dtype=int))
|
| 929 |
+
|
| 930 |
+
matrix = np.array(matrix_data).T # Rows = responses, Cols = categories
|
| 931 |
+
|
| 932 |
+
# Create figure with appropriate sizing
|
| 933 |
+
fig_height = max(4, min(20, total_rows * 0.15))
|
| 934 |
+
fig_width = max(8, len(categories) * 0.8)
|
| 935 |
+
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
|
| 936 |
+
|
| 937 |
+
# Create custom colormap: black (0) and orange (1) - CatVader theme
|
| 938 |
+
from matplotlib.colors import ListedColormap
|
| 939 |
+
cmap = ListedColormap(['#1a1a1a', '#E8A33C'])
|
| 940 |
+
|
| 941 |
+
# Plot heatmap
|
| 942 |
+
im = ax.imshow(matrix, aspect='auto', cmap=cmap, vmin=0, vmax=1)
|
| 943 |
+
|
| 944 |
+
# Set labels - remove y-axis numbers for cleaner look
|
| 945 |
+
ax.set_xticks(range(len(categories)))
|
| 946 |
+
ax.set_xticklabels(categories, rotation=45, ha='right', fontsize=9)
|
| 947 |
+
ax.set_xlabel('Categories', fontsize=11)
|
| 948 |
+
ax.set_ylabel(f'Responses (n={total_rows})', fontsize=11)
|
| 949 |
+
ax.set_yticks([]) # Remove y-axis tick marks
|
| 950 |
+
|
| 951 |
+
title = 'Classification Matrix'
|
| 952 |
+
if classify_mode == "Ensemble":
|
| 953 |
+
title += ' (Ensemble Consensus)'
|
| 954 |
+
elif classify_mode == "Model Comparison":
|
| 955 |
+
title += f' ({models_list[0].split("/")[-1].split(":")[0][:20]})'
|
| 956 |
+
ax.set_title(title, fontsize=14, fontweight='bold')
|
| 957 |
+
|
| 958 |
+
# Add legend
|
| 959 |
+
from matplotlib.patches import Patch
|
| 960 |
+
legend_elements = [
|
| 961 |
+
Patch(facecolor='#1a1a1a', edgecolor='white', label='Not Present'),
|
| 962 |
+
Patch(facecolor='#E8A33C', edgecolor='white', label='Present')
|
| 963 |
+
]
|
| 964 |
+
ax.legend(handles=legend_elements, loc='upper right', bbox_to_anchor=(1.15, 1))
|
| 965 |
+
|
| 966 |
+
plt.tight_layout()
|
| 967 |
+
return fig
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
def create_distribution_chart(result_df, categories, classify_mode="Single Model", models_list=None):
|
| 971 |
+
"""Create a bar chart showing category distribution.
|
| 972 |
+
|
| 973 |
+
Args:
|
| 974 |
+
result_df: DataFrame with classification results
|
| 975 |
+
categories: List of category names
|
| 976 |
+
classify_mode: "Single Model", "Model Comparison", or "Ensemble"
|
| 977 |
+
models_list: List of model names (for multi-model modes)
|
| 978 |
+
"""
|
| 979 |
+
import numpy as np
|
| 980 |
+
|
| 981 |
+
total_rows = len(result_df)
|
| 982 |
+
if total_rows == 0:
|
| 983 |
+
fig, ax = plt.subplots(figsize=(10, 4))
|
| 984 |
+
ax.text(0.5, 0.5, 'No data to display', ha='center', va='center', fontsize=14)
|
| 985 |
+
ax.axis('off')
|
| 986 |
+
return fig
|
| 987 |
+
|
| 988 |
+
# Define colors for different models
|
| 989 |
+
model_colors = ['#2563eb', '#dc2626', '#16a34a', '#ca8a04', '#9333ea', '#0891b2', '#be185d', '#65a30d']
|
| 990 |
+
|
| 991 |
+
if classify_mode == "Single Model":
|
| 992 |
+
# Single model: use category_1, category_2, etc.
|
| 993 |
+
fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8)))
|
| 994 |
+
|
| 995 |
+
dist_data = []
|
| 996 |
+
for i, cat in enumerate(categories, 1):
|
| 997 |
+
col_name = f"category_{i}"
|
| 998 |
+
if col_name in result_df.columns:
|
| 999 |
+
count = int(result_df[col_name].sum())
|
| 1000 |
+
pct = (count / total_rows) * 100
|
| 1001 |
+
dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
|
| 1002 |
+
|
| 1003 |
+
categories_list = [d["Category"] for d in dist_data][::-1]
|
| 1004 |
+
percentages = [d["Percentage"] for d in dist_data][::-1]
|
| 1005 |
+
|
| 1006 |
+
bars = ax.barh(categories_list, percentages, color='#2563eb')
|
| 1007 |
+
ax.set_xlim(0, 100)
|
| 1008 |
+
ax.set_xlabel('Percentage (%)', fontsize=11)
|
| 1009 |
+
ax.set_title('Category Distribution (%)', fontsize=14, fontweight='bold')
|
| 1010 |
+
|
| 1011 |
+
for bar, pct in zip(bars, percentages):
|
| 1012 |
+
ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
|
| 1013 |
+
f'{pct:.1f}%', va='center', fontsize=10)
|
| 1014 |
+
|
| 1015 |
+
elif classify_mode == "Ensemble":
|
| 1016 |
+
# Ensemble: use category_1_consensus, category_2_consensus, etc.
|
| 1017 |
+
fig, ax = plt.subplots(figsize=(10, max(4, len(categories) * 0.8)))
|
| 1018 |
+
|
| 1019 |
+
dist_data = []
|
| 1020 |
+
for i, cat in enumerate(categories, 1):
|
| 1021 |
+
col_name = f"category_{i}_consensus"
|
| 1022 |
+
if col_name in result_df.columns:
|
| 1023 |
+
count = int(result_df[col_name].sum())
|
| 1024 |
+
pct = (count / total_rows) * 100
|
| 1025 |
+
dist_data.append({"Category": cat, "Percentage": round(pct, 1)})
|
| 1026 |
+
|
| 1027 |
+
categories_list = [d["Category"] for d in dist_data][::-1]
|
| 1028 |
+
percentages = [d["Percentage"] for d in dist_data][::-1]
|
| 1029 |
+
|
| 1030 |
+
bars = ax.barh(categories_list, percentages, color='#16a34a')
|
| 1031 |
+
ax.set_xlim(0, 100)
|
| 1032 |
+
ax.set_xlabel('Percentage (%)', fontsize=11)
|
| 1033 |
+
ax.set_title('Ensemble Consensus Distribution (%)', fontsize=14, fontweight='bold')
|
| 1034 |
+
|
| 1035 |
+
for bar, pct in zip(bars, percentages):
|
| 1036 |
+
ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height()/2,
|
| 1037 |
+
f'{pct:.1f}%', va='center', fontsize=10)
|
| 1038 |
+
|
| 1039 |
+
else: # Model Comparison
|
| 1040 |
+
# Model Comparison: grouped bars for each model
|
| 1041 |
+
if not models_list:
|
| 1042 |
+
models_list = []
|
| 1043 |
+
|
| 1044 |
+
# Detect actual column suffixes from the DataFrame
|
| 1045 |
+
model_suffixes = [_find_model_column_suffix(result_df, m) for m in models_list]
|
| 1046 |
+
n_models = len(model_suffixes)
|
| 1047 |
+
n_categories = len(categories)
|
| 1048 |
+
|
| 1049 |
+
fig, ax = plt.subplots(figsize=(12, max(5, n_categories * 1.2)))
|
| 1050 |
+
|
| 1051 |
+
# Gather data for each model
|
| 1052 |
+
bar_height = 0.8 / n_models
|
| 1053 |
+
y_positions = np.arange(n_categories)
|
| 1054 |
+
|
| 1055 |
+
for model_idx, (model_name, suffix) in enumerate(zip(models_list, model_suffixes)):
|
| 1056 |
+
model_pcts = []
|
| 1057 |
+
for i in range(1, n_categories + 1):
|
| 1058 |
+
col_name = f"category_{i}_{suffix}"
|
| 1059 |
+
if col_name in result_df.columns:
|
| 1060 |
+
count = int(result_df[col_name].sum())
|
| 1061 |
+
pct = (count / total_rows) * 100
|
| 1062 |
+
else:
|
| 1063 |
+
pct = 0
|
| 1064 |
+
model_pcts.append(pct)
|
| 1065 |
+
|
| 1066 |
+
# Reverse for horizontal bar chart
|
| 1067 |
+
model_pcts = model_pcts[::-1]
|
| 1068 |
+
offset = (model_idx - n_models / 2 + 0.5) * bar_height
|
| 1069 |
+
color = model_colors[model_idx % len(model_colors)]
|
| 1070 |
+
|
| 1071 |
+
# Use shorter display name
|
| 1072 |
+
display_name = model_name.split('/')[-1].split(':')[0][:20]
|
| 1073 |
+
bars = ax.barh(y_positions + offset, model_pcts, bar_height * 0.9,
|
| 1074 |
+
label=display_name, color=color, alpha=0.85)
|
| 1075 |
+
|
| 1076 |
+
ax.set_yticks(y_positions)
|
| 1077 |
+
ax.set_yticklabels(categories[::-1])
|
| 1078 |
+
ax.set_xlim(0, 100)
|
| 1079 |
+
ax.set_xlabel('Percentage (%)', fontsize=11)
|
| 1080 |
+
ax.set_title('Category Distribution by Model (%)', fontsize=14, fontweight='bold')
|
| 1081 |
+
ax.legend(loc='lower right', fontsize=9)
|
| 1082 |
+
|
| 1083 |
+
plt.tight_layout()
|
| 1084 |
+
return fig
|
| 1085 |
+
|
| 1086 |
+
|
| 1087 |
+
# Page config
|
| 1088 |
+
st.set_page_config(
|
| 1089 |
+
page_title="CatVader - Social Media Classifier",
|
| 1090 |
+
page_icon="🐱",
|
| 1091 |
+
layout="wide"
|
| 1092 |
+
)
|
| 1093 |
+
|
| 1094 |
+
# Custom CSS for enhanced styling
|
| 1095 |
+
st.markdown("""
|
| 1096 |
+
<style>
|
| 1097 |
+
/* Import Garamond font and apply globally */
|
| 1098 |
+
@import url('https://fonts.googleapis.com/css2?family=EB+Garamond:wght@400;500;600;700&display=swap');
|
| 1099 |
+
|
| 1100 |
+
*:not([class*="icon"]):not([data-testid="stIconMaterial"]):not(svg):not(path) {
|
| 1101 |
+
font-family: 'EB Garamond', Garamond, Georgia, serif !important;
|
| 1102 |
+
font-size: 17px !important;
|
| 1103 |
+
}
|
| 1104 |
+
|
| 1105 |
+
/* Preserve Streamlit icon fonts */
|
| 1106 |
+
[data-testid="stIconMaterial"], .material-icons, .material-symbols-rounded {
|
| 1107 |
+
font-family: 'Material Symbols Rounded', 'Material Icons' !important;
|
| 1108 |
+
font-size: 24px !important;
|
| 1109 |
+
}
|
| 1110 |
+
|
| 1111 |
+
/* Main container styling */
|
| 1112 |
+
.main .block-container {
|
| 1113 |
+
padding-top: 2rem;
|
| 1114 |
+
padding-bottom: 2rem;
|
| 1115 |
+
}
|
| 1116 |
+
|
| 1117 |
+
/* Headers with gradient accent */
|
| 1118 |
+
h1 {
|
| 1119 |
+
background: linear-gradient(90deg, #E8A33C 0%, #D4872C 100%);
|
| 1120 |
+
-webkit-background-clip: text;
|
| 1121 |
+
-webkit-text-fill-color: transparent;
|
| 1122 |
+
background-clip: text;
|
| 1123 |
+
font-weight: 700;
|
| 1124 |
+
}
|
| 1125 |
+
|
| 1126 |
+
/* Card-like sections */
|
| 1127 |
+
.stExpander {
|
| 1128 |
+
border: 1px solid #E8D5B5;
|
| 1129 |
+
border-radius: 12px;
|
| 1130 |
+
box-shadow: 0 2px 8px rgba(232, 163, 60, 0.08);
|
| 1131 |
+
}
|
| 1132 |
+
|
| 1133 |
+
/* File uploader styling */
|
| 1134 |
+
.stFileUploader {
|
| 1135 |
+
border-radius: 12px;
|
| 1136 |
+
}
|
| 1137 |
+
|
| 1138 |
+
.stFileUploader > div > div {
|
| 1139 |
+
border: 2px dashed #E8A33C;
|
| 1140 |
+
border-radius: 12px;
|
| 1141 |
+
background: linear-gradient(135deg, #FEFCF9 0%, #F5EFE6 100%);
|
| 1142 |
+
}
|
| 1143 |
+
|
| 1144 |
+
/* Button styling */
|
| 1145 |
+
.stButton > button {
|
| 1146 |
+
border-radius: 8px;
|
| 1147 |
+
font-weight: 600;
|
| 1148 |
+
transition: all 0.2s ease;
|
| 1149 |
+
border: 2px solid #E8A33C;
|
| 1150 |
+
background: #FEFCF9;
|
| 1151 |
+
color: #D4872C;
|
| 1152 |
+
}
|
| 1153 |
+
|
| 1154 |
+
/* Tall button for example dataset (matches file uploader height) */
|
| 1155 |
+
.tall-button .stButton > button {
|
| 1156 |
+
min-height: 107px;
|
| 1157 |
+
border-radius: 12px;
|
| 1158 |
+
}
|
| 1159 |
+
|
| 1160 |
+
.stButton > button:hover {
|
| 1161 |
+
transform: translateY(-1px);
|
| 1162 |
+
box-shadow: 0 4px 12px rgba(232, 163, 60, 0.3);
|
| 1163 |
+
background: #F5EFE6;
|
| 1164 |
+
}
|
| 1165 |
+
|
| 1166 |
+
/* Primary button */
|
| 1167 |
+
.stButton > button[kind="primary"] {
|
| 1168 |
+
background: linear-gradient(135deg, #E8A33C 0%, #D4872C 100%);
|
| 1169 |
+
border: none;
|
| 1170 |
+
color: white;
|
| 1171 |
+
}
|
| 1172 |
+
|
| 1173 |
+
/* Success/info messages */
|
| 1174 |
+
.stSuccess {
|
| 1175 |
+
background-color: #E8F5E9;
|
| 1176 |
+
border-left: 4px solid #4CAF50;
|
| 1177 |
+
border-radius: 0 8px 8px 0;
|
| 1178 |
+
}
|
| 1179 |
+
|
| 1180 |
+
.stInfo {
|
| 1181 |
+
background-color: #FFF8E8;
|
| 1182 |
+
border-left: 4px solid #E8A33C;
|
| 1183 |
+
border-radius: 0 8px 8px 0;
|
| 1184 |
+
}
|
| 1185 |
+
|
| 1186 |
+
/* Radio buttons */
|
| 1187 |
+
.stRadio > div {
|
| 1188 |
+
gap: 0.5rem;
|
| 1189 |
+
display: flex;
|
| 1190 |
+
width: 100%;
|
| 1191 |
+
}
|
| 1192 |
+
|
| 1193 |
+
.stRadio > div > label {
|
| 1194 |
+
background: #F5EFE6;
|
| 1195 |
+
padding: 0.5rem 1rem;
|
| 1196 |
+
border-radius: 20px;
|
| 1197 |
+
border: 1px solid transparent;
|
| 1198 |
+
transition: all 0.2s ease;
|
| 1199 |
+
flex: 1;
|
| 1200 |
+
text-align: center;
|
| 1201 |
+
justify-content: center;
|
| 1202 |
+
}
|
| 1203 |
+
|
| 1204 |
+
.stRadio > div > label:hover {
|
| 1205 |
+
border-color: #E8A33C;
|
| 1206 |
+
}
|
| 1207 |
+
|
| 1208 |
+
/* Text inputs */
|
| 1209 |
+
.stTextInput > div > div > input {
|
| 1210 |
+
border-radius: 8px;
|
| 1211 |
+
border: 1px solid #E8D5B5;
|
| 1212 |
+
}
|
| 1213 |
+
|
| 1214 |
+
.stTextInput > div > div > input:focus {
|
| 1215 |
+
border-color: #E8A33C;
|
| 1216 |
+
box-shadow: 0 0 0 2px rgba(232, 163, 60, 0.2);
|
| 1217 |
+
}
|
| 1218 |
+
|
| 1219 |
+
/* Select boxes */
|
| 1220 |
+
.stSelectbox > div > div {
|
| 1221 |
+
border-radius: 8px;
|
| 1222 |
+
}
|
| 1223 |
+
|
| 1224 |
+
/* Dataframe styling */
|
| 1225 |
+
.stDataFrame {
|
| 1226 |
+
border-radius: 12px;
|
| 1227 |
+
overflow: hidden;
|
| 1228 |
+
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.05);
|
| 1229 |
+
}
|
| 1230 |
+
|
| 1231 |
+
/* Progress bar */
|
| 1232 |
+
.stProgress > div > div {
|
| 1233 |
+
background: linear-gradient(90deg, #E8A33C 0%, #D4872C 100%);
|
| 1234 |
+
border-radius: 10px;
|
| 1235 |
+
}
|
| 1236 |
+
|
| 1237 |
+
/* Slider */
|
| 1238 |
+
.stSlider > div > div > div {
|
| 1239 |
+
background: #E8A33C;
|
| 1240 |
+
}
|
| 1241 |
+
|
| 1242 |
+
/* Divider */
|
| 1243 |
+
hr {
|
| 1244 |
+
border: none;
|
| 1245 |
+
height: 1px;
|
| 1246 |
+
background: linear-gradient(90deg, transparent, #E8D5B5, transparent);
|
| 1247 |
+
margin: 1.5rem 0;
|
| 1248 |
+
}
|
| 1249 |
+
|
| 1250 |
+
/* Code blocks */
|
| 1251 |
+
.stCodeBlock {
|
| 1252 |
+
border-radius: 12px;
|
| 1253 |
+
border: 1px solid #E8D5B5;
|
| 1254 |
+
}
|
| 1255 |
+
|
| 1256 |
+
/* Metric cards */
|
| 1257 |
+
.stMetric {
|
| 1258 |
+
background: linear-gradient(135deg, #FEFCF9 0%, #F5EFE6 100%);
|
| 1259 |
+
padding: 1rem;
|
| 1260 |
+
border-radius: 12px;
|
| 1261 |
+
border: 1px solid #E8D5B5;
|
| 1262 |
+
}
|
| 1263 |
+
|
| 1264 |
+
/* Download buttons */
|
| 1265 |
+
.stDownloadButton > button {
|
| 1266 |
+
background: #F5EFE6;
|
| 1267 |
+
border: 1px solid #E8A33C;
|
| 1268 |
+
color: #D4872C;
|
| 1269 |
+
}
|
| 1270 |
+
|
| 1271 |
+
.stDownloadButton > button:hover {
|
| 1272 |
+
background: #E8A33C;
|
| 1273 |
+
color: white;
|
| 1274 |
+
}
|
| 1275 |
+
|
| 1276 |
+
/* Multiselect */
|
| 1277 |
+
.stMultiSelect > div > div {
|
| 1278 |
+
border-radius: 8px;
|
| 1279 |
+
}
|
| 1280 |
+
|
| 1281 |
+
/* Status indicator */
|
| 1282 |
+
.stStatus {
|
| 1283 |
+
border-radius: 12px;
|
| 1284 |
+
}
|
| 1285 |
+
|
| 1286 |
+
/* Column gaps */
|
| 1287 |
+
[data-testid="column"] {
|
| 1288 |
+
padding: 0 0.5rem;
|
| 1289 |
+
}
|
| 1290 |
+
|
| 1291 |
+
/* Logo and title alignment */
|
| 1292 |
+
[data-testid="column"]:first-child img {
|
| 1293 |
+
border-radius: 8px;
|
| 1294 |
+
}
|
| 1295 |
+
</style>
|
| 1296 |
+
""", unsafe_allow_html=True)
|
| 1297 |
+
|
| 1298 |
+
# Initialize session state
|
| 1299 |
+
if 'categories' not in st.session_state:
|
| 1300 |
+
st.session_state.categories = [''] * MAX_CATEGORIES
|
| 1301 |
+
if 'category_count' not in st.session_state:
|
| 1302 |
+
st.session_state.category_count = INITIAL_CATEGORIES
|
| 1303 |
+
if 'task_mode' not in st.session_state:
|
| 1304 |
+
st.session_state.task_mode = None
|
| 1305 |
+
if 'extracted_categories' not in st.session_state:
|
| 1306 |
+
st.session_state.extracted_categories = None
|
| 1307 |
+
if 'results' not in st.session_state:
|
| 1308 |
+
st.session_state.results = None
|
| 1309 |
+
if 'active_tab' not in st.session_state:
|
| 1310 |
+
st.session_state.active_tab = "survey"
|
| 1311 |
+
if 'survey_data' not in st.session_state:
|
| 1312 |
+
st.session_state.survey_data = None
|
| 1313 |
+
if 'pdf_data' not in st.session_state:
|
| 1314 |
+
st.session_state.pdf_data = None
|
| 1315 |
+
if 'image_data' not in st.session_state:
|
| 1316 |
+
st.session_state.image_data = None
|
| 1317 |
+
if 'extraction_params' not in st.session_state:
|
| 1318 |
+
st.session_state.extraction_params = None # Stores params when categories are auto-extracted
|
| 1319 |
+
if 'bluesky_df' not in st.session_state:
|
| 1320 |
+
st.session_state.bluesky_df = None
|
| 1321 |
+
|
| 1322 |
+
# Logo and title - use HTML for better alignment
|
| 1323 |
+
st.markdown("""
|
| 1324 |
+
<div style="display: flex; align-items: center; gap: 20px; margin-bottom: 10px;">
|
| 1325 |
+
<img src="https://huggingface.co/spaces/CatVader/social-media-classifier/resolve/main/logo.png" width="100" style="border-radius: 8px;">
|
| 1326 |
+
<div>
|
| 1327 |
+
<div style="font-size: 2.2rem; font-weight: 700; color: #333; font-family: 'EB Garamond', Garamond, Georgia, serif; line-height: 1.1;">CatVader</div>
|
| 1328 |
+
<div style="font-size: 1.1rem; font-weight: 500; color: #E8A33C; font-family: 'EB Garamond', Garamond, Georgia, serif; margin-bottom: 4px;">NLP for Survey Research</div>
|
| 1329 |
+
<div style="font-size: 1rem; font-weight: 400; color: #666; font-family: 'EB Garamond', Garamond, Georgia, serif;">Research-grade classification of social media posts, PDFs, and images using AI models.</div>
|
| 1330 |
+
<div style="font-size: 0.85rem; font-weight: 400; color: #888; font-family: 'EB Garamond', Garamond, Georgia, serif; margin-top: 4px;">Developed at UC Berkeley</div>
|
| 1331 |
+
</div>
|
| 1332 |
+
</div>
|
| 1333 |
+
""", unsafe_allow_html=True)
|
| 1334 |
+
|
| 1335 |
+
# About section
|
| 1336 |
+
with st.expander("About This App"):
|
| 1337 |
+
st.markdown("""
|
| 1338 |
+
**Privacy Notice:** Your data is sent to third-party LLM APIs for classification. Do not upload sensitive, confidential, or personally identifiable information (PII).
|
| 1339 |
+
|
| 1340 |
+
---
|
| 1341 |
+
|
| 1342 |
+
**CatVader** is an open-source Python package for classifying and exploring social media data using Large Language Models.
|
| 1343 |
+
|
| 1344 |
+
### What It Does
|
| 1345 |
+
- **Extract Categories**: Discover themes and categories in your data automatically
|
| 1346 |
+
- **Assign Categories**: Classify data into your predefined categories
|
| 1347 |
+
- **Extract & Assign**: Let CatVader discover categories, then classify all your data
|
| 1348 |
+
|
| 1349 |
+
### Supported Providers
|
| 1350 |
+
OpenAI (GPT-4o, GPT-4o Mini), Anthropic (Claude), Google (Gemini), Mistral, HuggingFace, xAI (Grok), and Perplexity. Use the free tier or bring your own API key.
|
| 1351 |
+
|
| 1352 |
+
### Beta Test - We Want Your Feedback!
|
| 1353 |
+
This app is currently in **beta** and **free to use** while CatVader is under active development, made possible by **Bashir Ahmed's generous fellowship support**.
|
| 1354 |
+
|
| 1355 |
+
- Found a bug? Have a feature request? Please open an issue on [GitHub](https://github.com/chrissoria/cat-vader)
|
| 1356 |
+
- Reach out directly: [chrissoria@berkeley.edu](mailto:chrissoria@berkeley.edu)
|
| 1357 |
+
|
| 1358 |
+
### Acknowledgments
|
| 1359 |
+
- **Bashir Ahmed** for his generous fellowship support that makes this free beta possible
|
| 1360 |
+
- **Claude Fischer** for his thoughtful feedback and collaboration on research that helped inspire this project
|
| 1361 |
+
- **Kevin Collins** from Survey360 for his input
|
| 1362 |
+
- **Fendi Tsim** for sharing it widely
|
| 1363 |
+
|
| 1364 |
+
### Links
|
| 1365 |
+
- **Website**: [christophersoria.com](https://christophersoria.com)
|
| 1366 |
+
- **PyPI**: [pip install cat-vader](https://pypi.org/project/cat-vader/)
|
| 1367 |
+
- **GitHub**: [github.com/chrissoria/cat-vader](https://github.com/chrissoria/cat-vader)
|
| 1368 |
+
|
| 1369 |
+
### Citation
|
| 1370 |
+
If you use CatVader in your research, please cite:
|
| 1371 |
+
```
|
| 1372 |
+
Soria, C. (2025). CatVader: A Python package for LLM-based social media classification. DOI: 10.5281/zenodo.15532316
|
| 1373 |
+
```
|
| 1374 |
+
""")
|
| 1375 |
+
|
| 1376 |
+
# Main layout
|
| 1377 |
+
col_input, col_output = st.columns([1, 1])
|
| 1378 |
+
|
| 1379 |
+
with col_input:
|
| 1380 |
+
# Input type selector
|
| 1381 |
+
input_type_choice = st.radio(
|
| 1382 |
+
"Input Type",
|
| 1383 |
+
options=["Social Media Posts", "PDF Documents", "Images"],
|
| 1384 |
+
horizontal=True,
|
| 1385 |
+
key="input_type_radio"
|
| 1386 |
+
)
|
| 1387 |
+
|
| 1388 |
+
# Initialize variables
|
| 1389 |
+
input_data = None
|
| 1390 |
+
input_type_selected = "text"
|
| 1391 |
+
description = ""
|
| 1392 |
+
original_filename = "data"
|
| 1393 |
+
pdf_mode = "Image (visual documents)"
|
| 1394 |
+
|
| 1395 |
+
if input_type_choice == "Social Media Posts":
|
| 1396 |
+
input_type_selected = "text"
|
| 1397 |
+
|
| 1398 |
+
data_source = st.radio(
|
| 1399 |
+
"Data Source",
|
| 1400 |
+
options=["Upload CSV/Excel", "Fetch from Bluesky"],
|
| 1401 |
+
horizontal=True,
|
| 1402 |
+
key="data_source_radio"
|
| 1403 |
+
)
|
| 1404 |
+
|
| 1405 |
+
if data_source == "Upload CSV/Excel":
|
| 1406 |
+
st.session_state.bluesky_df = None # Clear any fetched data when switching sources
|
| 1407 |
+
upload_col, example_col = st.columns([3, 1])
|
| 1408 |
+
with upload_col:
|
| 1409 |
+
uploaded_file = st.file_uploader(
|
| 1410 |
+
"Upload Data (CSV or Excel)",
|
| 1411 |
+
type=['csv', 'xlsx', 'xls'],
|
| 1412 |
+
key="survey_file"
|
| 1413 |
+
)
|
| 1414 |
+
with example_col:
|
| 1415 |
+
st.markdown("<div style='height: 27px;'></div>", unsafe_allow_html=True) # Match "Upload Data" label height
|
| 1416 |
+
st.markdown('<div class="tall-button">', unsafe_allow_html=True)
|
| 1417 |
+
if st.button("Try Example Dataset", key="example_btn", use_container_width=True):
|
| 1418 |
+
st.session_state.example_loaded = True
|
| 1419 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
| 1420 |
+
|
| 1421 |
+
columns = []
|
| 1422 |
+
df = None
|
| 1423 |
+
if uploaded_file is not None:
|
| 1424 |
+
try:
|
| 1425 |
+
if uploaded_file.name.endswith('.csv'):
|
| 1426 |
+
df = pd.read_csv(uploaded_file)
|
| 1427 |
+
else:
|
| 1428 |
+
df = pd.read_excel(uploaded_file)
|
| 1429 |
+
columns = df.columns.tolist()
|
| 1430 |
+
st.success(f"Loaded {len(df):,} rows")
|
| 1431 |
+
except Exception as e:
|
| 1432 |
+
st.error(f"Error loading file: {e}")
|
| 1433 |
+
elif hasattr(st.session_state, 'example_loaded') and st.session_state.example_loaded:
|
| 1434 |
+
try:
|
| 1435 |
+
df = pd.read_csv("example_data.csv")
|
| 1436 |
+
columns = df.columns.tolist()
|
| 1437 |
+
st.success(f"Loaded example dataset ({len(df)} rows)")
|
| 1438 |
+
except:
|
| 1439 |
+
pass
|
| 1440 |
+
|
| 1441 |
+
selected_column = st.selectbox(
|
| 1442 |
+
"Column to Process",
|
| 1443 |
+
options=columns if columns else ["Upload a file first"],
|
| 1444 |
+
disabled=not columns,
|
| 1445 |
+
key="survey_column"
|
| 1446 |
+
)
|
| 1447 |
+
|
| 1448 |
+
description = selected_column if columns else ""
|
| 1449 |
+
original_filename = uploaded_file.name if uploaded_file else "example_data.csv"
|
| 1450 |
+
|
| 1451 |
+
if df is not None and columns and selected_column in columns:
|
| 1452 |
+
input_data = df[selected_column].tolist()
|
| 1453 |
+
|
| 1454 |
+
else: # Fetch from Bluesky
|
| 1455 |
+
bsky_handle = st.text_input(
|
| 1456 |
+
"Bluesky Handle",
|
| 1457 |
+
placeholder="e.g. aoc.bsky.social or @aoc.bsky.social",
|
| 1458 |
+
key="bluesky_handle_input"
|
| 1459 |
+
)
|
| 1460 |
+
bsky_num_posts = st.slider(
|
| 1461 |
+
"Number of Posts to Fetch",
|
| 1462 |
+
min_value=10, max_value=250, value=50, step=10,
|
| 1463 |
+
key="bluesky_num_posts"
|
| 1464 |
+
)
|
| 1465 |
+
if st.button("Fetch Posts", key="fetch_bluesky_btn"):
|
| 1466 |
+
handle_clean = bsky_handle.strip().lstrip("@")
|
| 1467 |
+
if not handle_clean:
|
| 1468 |
+
st.error("Please enter a Bluesky handle.")
|
| 1469 |
+
else:
|
| 1470 |
+
with st.spinner(f"Fetching {bsky_num_posts} posts from {handle_clean}..."):
|
| 1471 |
+
try:
|
| 1472 |
+
from catvader._social_media import fetch_bluesky
|
| 1473 |
+
df_bsky = fetch_bluesky(limit=bsky_num_posts, handle=handle_clean)
|
| 1474 |
+
df_bsky = df_bsky[df_bsky["media_type"] != "REPOST_FACADE"].reset_index(drop=True)
|
| 1475 |
+
st.session_state.bluesky_df = df_bsky
|
| 1476 |
+
except Exception as e:
|
| 1477 |
+
st.error(f"Error fetching posts: {e}")
|
| 1478 |
+
|
| 1479 |
+
if st.session_state.bluesky_df is not None:
|
| 1480 |
+
bsky_df = st.session_state.bluesky_df
|
| 1481 |
+
st.success(f"Fetched {len(bsky_df)} posts")
|
| 1482 |
+
st.dataframe(
|
| 1483 |
+
bsky_df[["timestamp", "text", "likes", "replies"]].head(5),
|
| 1484 |
+
use_container_width=True
|
| 1485 |
+
)
|
| 1486 |
+
handle_clean = bsky_handle.strip().lstrip("@") if bsky_handle else "bluesky"
|
| 1487 |
+
input_data = bsky_df["text"].tolist()
|
| 1488 |
+
description = f"Bluesky posts from @{handle_clean}"
|
| 1489 |
+
original_filename = f"bluesky_{handle_clean.replace('.', '_')}"
|
| 1490 |
+
|
| 1491 |
+
elif input_type_choice == "PDF Documents":
|
| 1492 |
+
input_type_selected = "pdf"
|
| 1493 |
+
|
| 1494 |
+
pdf_files = st.file_uploader(
|
| 1495 |
+
"Upload PDF Document(s)",
|
| 1496 |
+
type=['pdf'],
|
| 1497 |
+
accept_multiple_files=True,
|
| 1498 |
+
key="pdf_files"
|
| 1499 |
+
)
|
| 1500 |
+
|
| 1501 |
+
pdf_description = st.text_input(
|
| 1502 |
+
"Document Description",
|
| 1503 |
+
placeholder="e.g., 'research papers', 'interview transcripts'",
|
| 1504 |
+
help="Helps the LLM understand context",
|
| 1505 |
+
key="pdf_desc"
|
| 1506 |
+
)
|
| 1507 |
+
|
| 1508 |
+
pdf_mode = st.radio(
|
| 1509 |
+
"Processing Mode",
|
| 1510 |
+
options=["Image (visual documents)", "Text (text-heavy)", "Both (comprehensive)"],
|
| 1511 |
+
key="pdf_mode"
|
| 1512 |
+
)
|
| 1513 |
+
|
| 1514 |
+
if pdf_files:
|
| 1515 |
+
input_data = []
|
| 1516 |
+
pdf_name_map = {} # Map temp paths to original filenames
|
| 1517 |
+
for f in pdf_files:
|
| 1518 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.pdf') as tmp:
|
| 1519 |
+
tmp.write(f.read())
|
| 1520 |
+
input_data.append(tmp.name)
|
| 1521 |
+
pdf_name_map[tmp.name] = f.name.replace('.pdf', '') # Store original name without extension
|
| 1522 |
+
st.session_state.pdf_name_map = pdf_name_map
|
| 1523 |
+
description = pdf_description or "document"
|
| 1524 |
+
original_filename = "pdf_files"
|
| 1525 |
+
st.success(f"Uploaded {len(pdf_files)} PDF file(s)")
|
| 1526 |
+
|
| 1527 |
+
else: # Images
|
| 1528 |
+
input_type_selected = "image"
|
| 1529 |
+
|
| 1530 |
+
image_files = st.file_uploader(
|
| 1531 |
+
"Upload Images",
|
| 1532 |
+
type=['png', 'jpg', 'jpeg', 'gif', 'webp'],
|
| 1533 |
+
accept_multiple_files=True,
|
| 1534 |
+
key="image_files"
|
| 1535 |
+
)
|
| 1536 |
+
|
| 1537 |
+
image_description = st.text_input(
|
| 1538 |
+
"Image Description",
|
| 1539 |
+
placeholder="e.g., 'product photos', 'social media posts'",
|
| 1540 |
+
help="Helps the LLM understand context",
|
| 1541 |
+
key="image_desc"
|
| 1542 |
+
)
|
| 1543 |
+
|
| 1544 |
+
if image_files:
|
| 1545 |
+
input_data = []
|
| 1546 |
+
for f in image_files:
|
| 1547 |
+
suffix = '.' + f.name.split('.')[-1]
|
| 1548 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
|
| 1549 |
+
tmp.write(f.read())
|
| 1550 |
+
input_data.append(tmp.name)
|
| 1551 |
+
description = image_description or "images"
|
| 1552 |
+
original_filename = "image_files"
|
| 1553 |
+
st.success(f"Uploaded {len(image_files)} image file(s)")
|
| 1554 |
+
|
| 1555 |
+
st.markdown("---")
|
| 1556 |
+
|
| 1557 |
+
# Task selection
|
| 1558 |
+
st.markdown("### What would you like to do?")
|
| 1559 |
+
col_btn1, col_btn2 = st.columns(2)
|
| 1560 |
+
with col_btn1:
|
| 1561 |
+
manual_mode = st.button("Enter Categories Manually", use_container_width=True)
|
| 1562 |
+
with col_btn2:
|
| 1563 |
+
auto_mode = st.button("Auto-extract Categories", use_container_width=True)
|
| 1564 |
+
|
| 1565 |
+
if manual_mode:
|
| 1566 |
+
st.session_state.task_mode = "manual"
|
| 1567 |
+
if auto_mode:
|
| 1568 |
+
st.session_state.task_mode = "auto_extract"
|
| 1569 |
+
|
| 1570 |
+
# Auto-extract settings
|
| 1571 |
+
if st.session_state.task_mode == "auto_extract":
|
| 1572 |
+
st.markdown("### Auto-extract Categories")
|
| 1573 |
+
st.markdown("We'll analyze your data to discover the main categories.")
|
| 1574 |
+
|
| 1575 |
+
max_categories = st.slider(
|
| 1576 |
+
"Number of Categories to Extract",
|
| 1577 |
+
min_value=3,
|
| 1578 |
+
max_value=25,
|
| 1579 |
+
value=12,
|
| 1580 |
+
help="How many categories should be identified in your data"
|
| 1581 |
+
)
|
| 1582 |
+
|
| 1583 |
+
specificity = st.selectbox(
|
| 1584 |
+
"How specific should categories be?",
|
| 1585 |
+
options=["Broad", "Moderate", "Narrow"],
|
| 1586 |
+
index=0,
|
| 1587 |
+
help="Broad = general themes, Moderate = balanced detail, Narrow = highly specific categories"
|
| 1588 |
+
)
|
| 1589 |
+
|
| 1590 |
+
focus = st.text_input(
|
| 1591 |
+
"What should categories be focused around? (optional)",
|
| 1592 |
+
placeholder="e.g., 'decisions to move', 'emotional responses', 'financial factors'",
|
| 1593 |
+
help="Guide the model to prioritize extracting categories related to this focus"
|
| 1594 |
+
)
|
| 1595 |
+
|
| 1596 |
+
# Model selection for extraction
|
| 1597 |
+
st.markdown("### Model Selection")
|
| 1598 |
+
model_tier = st.radio(
|
| 1599 |
+
"Model Tier",
|
| 1600 |
+
options=["Free Models", "Bring Your Own Key"],
|
| 1601 |
+
key="extract_model_tier"
|
| 1602 |
+
)
|
| 1603 |
+
|
| 1604 |
+
if model_tier == "Free Models":
|
| 1605 |
+
model_display = st.selectbox("Model", options=FREE_MODEL_DISPLAY_NAMES, key="extract_model")
|
| 1606 |
+
model = FREE_MODELS_MAP[model_display] # Convert to actual model name
|
| 1607 |
+
api_key = ""
|
| 1608 |
+
else:
|
| 1609 |
+
model = st.selectbox("Model", options=PAID_MODEL_CHOICES, key="extract_model_paid")
|
| 1610 |
+
api_key = st.text_input("API Key", type="password", key="extract_api_key")
|
| 1611 |
+
|
| 1612 |
+
if st.button("Extract Categories", type="primary"):
|
| 1613 |
+
if input_data is None:
|
| 1614 |
+
st.error("Please upload data first")
|
| 1615 |
+
else:
|
| 1616 |
+
mode = None
|
| 1617 |
+
if input_type_selected == "pdf":
|
| 1618 |
+
mode_mapping = {
|
| 1619 |
+
"Image (visual documents)": "image",
|
| 1620 |
+
"Text (text-heavy)": "text",
|
| 1621 |
+
"Both (comprehensive)": "both"
|
| 1622 |
+
}
|
| 1623 |
+
mode = mode_mapping.get(pdf_mode, "image")
|
| 1624 |
+
|
| 1625 |
+
actual_api_key, provider = get_api_key(model, model_tier, api_key)
|
| 1626 |
+
if not actual_api_key:
|
| 1627 |
+
st.error(f"{provider} API key not configured")
|
| 1628 |
+
else:
|
| 1629 |
+
model_source = get_model_source(model)
|
| 1630 |
+
|
| 1631 |
+
# Calculate estimated time based on input size
|
| 1632 |
+
num_items = len(input_data) if isinstance(input_data, list) else 1
|
| 1633 |
+
if input_type_selected == "pdf":
|
| 1634 |
+
# PDFs take longer - estimate ~5s per page
|
| 1635 |
+
total_pages = sum(count_pdf_pages(p) for p in (input_data if isinstance(input_data, list) else [input_data]))
|
| 1636 |
+
est_seconds = total_pages * 5
|
| 1637 |
+
elif input_type_selected == "image":
|
| 1638 |
+
# Images ~4s each
|
| 1639 |
+
est_seconds = num_items * 4
|
| 1640 |
+
else:
|
| 1641 |
+
# Text ~2s per item, but batched
|
| 1642 |
+
est_seconds = max(10, num_items * 0.5)
|
| 1643 |
+
|
| 1644 |
+
# Progress tracking UI
|
| 1645 |
+
progress_bar = st.progress(0)
|
| 1646 |
+
status_text = st.empty()
|
| 1647 |
+
start_time = time.time()
|
| 1648 |
+
|
| 1649 |
+
# Progress callback for extraction
|
| 1650 |
+
def extract_progress_callback(current_step, total_steps, step_label):
|
| 1651 |
+
progress = current_step / total_steps if total_steps > 0 else 0
|
| 1652 |
+
progress_bar.progress(min(progress, 1.0))
|
| 1653 |
+
|
| 1654 |
+
elapsed = time.time() - start_time
|
| 1655 |
+
if current_step > 0:
|
| 1656 |
+
avg_time = elapsed / current_step
|
| 1657 |
+
eta_seconds = avg_time * (total_steps - current_step)
|
| 1658 |
+
eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
|
| 1659 |
+
else:
|
| 1660 |
+
eta_str = ""
|
| 1661 |
+
|
| 1662 |
+
status_text.text(f"Extracting categories: {step_label} ({progress*100:.0f}%){eta_str}")
|
| 1663 |
+
|
| 1664 |
+
extract_kwargs = {
|
| 1665 |
+
'input_data': input_data,
|
| 1666 |
+
'api_key': actual_api_key,
|
| 1667 |
+
'input_type': input_type_selected,
|
| 1668 |
+
'description': description,
|
| 1669 |
+
'user_model': model,
|
| 1670 |
+
'model_source': model_source,
|
| 1671 |
+
'max_categories': int(max_categories),
|
| 1672 |
+
'specificity': specificity.lower(),
|
| 1673 |
+
'progress_callback': extract_progress_callback,
|
| 1674 |
+
}
|
| 1675 |
+
if mode:
|
| 1676 |
+
extract_kwargs['mode'] = mode
|
| 1677 |
+
if focus and focus.strip():
|
| 1678 |
+
extract_kwargs['focus'] = focus.strip()
|
| 1679 |
+
|
| 1680 |
+
try:
|
| 1681 |
+
extract_result = catvader.extract(**extract_kwargs)
|
| 1682 |
+
categories = extract_result.get('top_categories', [])
|
| 1683 |
+
|
| 1684 |
+
processing_time = time.time() - start_time
|
| 1685 |
+
progress_bar.progress(1.0)
|
| 1686 |
+
status_text.text(f"Completed in {processing_time:.1f}s")
|
| 1687 |
+
|
| 1688 |
+
if categories:
|
| 1689 |
+
st.success(f"Extracted {len(categories)} categories in {processing_time:.1f}s")
|
| 1690 |
+
st.session_state.extracted_categories = categories
|
| 1691 |
+
# Store extraction params for code generation
|
| 1692 |
+
st.session_state.extraction_params = {
|
| 1693 |
+
'model': model,
|
| 1694 |
+
'model_source': model_source,
|
| 1695 |
+
'max_categories': int(max_categories),
|
| 1696 |
+
'input_type': input_type_selected,
|
| 1697 |
+
'description': description,
|
| 1698 |
+
'mode': mode,
|
| 1699 |
+
}
|
| 1700 |
+
st.session_state.task_mode = "manual"
|
| 1701 |
+
st.rerun()
|
| 1702 |
+
else:
|
| 1703 |
+
st.error("No categories were extracted from the data")
|
| 1704 |
+
except Exception as e:
|
| 1705 |
+
st.error(f"Error: {str(e)}")
|
| 1706 |
+
|
| 1707 |
+
# Category inputs (shown for manual mode or after extraction)
|
| 1708 |
+
if st.session_state.task_mode == "manual":
|
| 1709 |
+
st.markdown("### Categories")
|
| 1710 |
+
st.markdown("Enter your classification categories below.")
|
| 1711 |
+
|
| 1712 |
+
# Pre-fill with extracted categories if available
|
| 1713 |
+
if st.session_state.extracted_categories:
|
| 1714 |
+
for i, cat in enumerate(st.session_state.extracted_categories[:MAX_CATEGORIES]):
|
| 1715 |
+
st.session_state.categories[i] = cat
|
| 1716 |
+
st.session_state.category_count = min(len(st.session_state.extracted_categories), MAX_CATEGORIES)
|
| 1717 |
+
st.session_state.extracted_categories = None # Clear after use
|
| 1718 |
+
|
| 1719 |
+
placeholder_examples = [
|
| 1720 |
+
"e.g., Positive sentiment",
|
| 1721 |
+
"e.g., Negative sentiment",
|
| 1722 |
+
"e.g., Product feedback",
|
| 1723 |
+
"e.g., Service complaint",
|
| 1724 |
+
"e.g., Feature request",
|
| 1725 |
+
"e.g., Custom category"
|
| 1726 |
+
]
|
| 1727 |
+
|
| 1728 |
+
categories_entered = []
|
| 1729 |
+
for i in range(st.session_state.category_count):
|
| 1730 |
+
placeholder = placeholder_examples[i] if i < len(placeholder_examples) else "e.g., Custom category"
|
| 1731 |
+
cat_value = st.text_input(
|
| 1732 |
+
f"Category {i+1}",
|
| 1733 |
+
value=st.session_state.categories[i],
|
| 1734 |
+
placeholder=placeholder,
|
| 1735 |
+
key=f"cat_{i}"
|
| 1736 |
+
)
|
| 1737 |
+
st.session_state.categories[i] = cat_value
|
| 1738 |
+
if cat_value.strip():
|
| 1739 |
+
categories_entered.append(cat_value.strip())
|
| 1740 |
+
|
| 1741 |
+
if st.session_state.category_count < MAX_CATEGORIES:
|
| 1742 |
+
if st.button("+ Add More"):
|
| 1743 |
+
st.session_state.category_count += 1
|
| 1744 |
+
st.rerun()
|
| 1745 |
+
|
| 1746 |
+
st.markdown("### Model Selection")
|
| 1747 |
+
|
| 1748 |
+
# Classification mode selector
|
| 1749 |
+
classify_mode = st.radio(
|
| 1750 |
+
"Classification Mode",
|
| 1751 |
+
options=["Single Model", "Model Comparison", "Ensemble"],
|
| 1752 |
+
horizontal=True,
|
| 1753 |
+
key="classify_mode",
|
| 1754 |
+
help="Single: one model. Comparison: see results from multiple models side-by-side. Ensemble: multiple models vote for consensus."
|
| 1755 |
+
)
|
| 1756 |
+
|
| 1757 |
+
model_tier = st.radio(
|
| 1758 |
+
"Model Tier",
|
| 1759 |
+
options=["Free Models", "Bring Your Own Key"],
|
| 1760 |
+
key="classify_model_tier"
|
| 1761 |
+
)
|
| 1762 |
+
|
| 1763 |
+
# Multi-model mode uses multiselect
|
| 1764 |
+
is_multi_model = classify_mode in ["Model Comparison", "Ensemble"]
|
| 1765 |
+
min_models = 3 if classify_mode == "Ensemble" else 2
|
| 1766 |
+
|
| 1767 |
+
# Track per-run temperatures: list of (model_name, temperature) for ensemble,
|
| 1768 |
+
# or dict {model_name: temperature} for model comparison
|
| 1769 |
+
model_temperatures = {}
|
| 1770 |
+
# ensemble_runs stores list of (model_name, temperature) allowing duplicate models
|
| 1771 |
+
ensemble_runs = []
|
| 1772 |
+
|
| 1773 |
+
if classify_mode == "Ensemble":
|
| 1774 |
+
# Ensemble mode: dynamic rows allowing same model multiple times with different temps
|
| 1775 |
+
if "ensemble_num_runs" not in st.session_state:
|
| 1776 |
+
st.session_state.ensemble_num_runs = 3
|
| 1777 |
+
|
| 1778 |
+
if model_tier == "Free Models":
|
| 1779 |
+
model_options = FREE_MODEL_DISPLAY_NAMES
|
| 1780 |
+
is_free = True
|
| 1781 |
+
else:
|
| 1782 |
+
model_options = PAID_MODEL_CHOICES
|
| 1783 |
+
is_free = False
|
| 1784 |
+
|
| 1785 |
+
st.markdown(f"**Model Runs** (select {min_models}+ runs)")
|
| 1786 |
+
for i in range(st.session_state.ensemble_num_runs):
|
| 1787 |
+
cols = st.columns([3, 1, 0.5])
|
| 1788 |
+
with cols[0]:
|
| 1789 |
+
default_idx = 0 if i < len(model_options) else i % len(model_options)
|
| 1790 |
+
selected = st.selectbox(
|
| 1791 |
+
f"Run {i+1}", options=model_options,
|
| 1792 |
+
index=default_idx, key=f"ensemble_model_{i}",
|
| 1793 |
+
label_visibility="collapsed"
|
| 1794 |
+
)
|
| 1795 |
+
with cols[1]:
|
| 1796 |
+
temp = st.number_input(
|
| 1797 |
+
"Temp", min_value=0.0, max_value=2.0, value=round(i * 0.25, 2),
|
| 1798 |
+
step=0.25, key=f"ensemble_temp_{i}", label_visibility="collapsed"
|
| 1799 |
+
)
|
| 1800 |
+
with cols[2]:
|
| 1801 |
+
if st.session_state.ensemble_num_runs > 3:
|
| 1802 |
+
if st.button("✕", key=f"ensemble_remove_{i}"):
|
| 1803 |
+
st.session_state.ensemble_num_runs -= 1
|
| 1804 |
+
st.rerun()
|
| 1805 |
+
|
| 1806 |
+
model_name = FREE_MODELS_MAP[selected] if is_free else selected
|
| 1807 |
+
ensemble_runs.append((model_name, temp))
|
| 1808 |
+
|
| 1809 |
+
if st.button("Add Run", key="add_ensemble_run"):
|
| 1810 |
+
st.session_state.ensemble_num_runs += 1
|
| 1811 |
+
st.rerun()
|
| 1812 |
+
|
| 1813 |
+
models_list = [r[0] for r in ensemble_runs]
|
| 1814 |
+
model_temperatures = {f"{r[0]}__run{i}": r[1] for i, r in enumerate(ensemble_runs)}
|
| 1815 |
+
api_key = "" if model_tier == "Free Models" else st.text_input("API Key", type="password", key="classify_api_key")
|
| 1816 |
+
|
| 1817 |
+
elif is_multi_model:
|
| 1818 |
+
# Model Comparison mode: multiselect (each model unique) + temperature row
|
| 1819 |
+
if model_tier == "Free Models":
|
| 1820 |
+
default_models = FREE_MODEL_DISPLAY_NAMES[:min_models] if len(FREE_MODEL_DISPLAY_NAMES) >= min_models else FREE_MODEL_DISPLAY_NAMES
|
| 1821 |
+
model_displays = st.multiselect(
|
| 1822 |
+
f"Models (select {min_models}+)",
|
| 1823 |
+
options=FREE_MODEL_DISPLAY_NAMES,
|
| 1824 |
+
default=default_models,
|
| 1825 |
+
key="classify_models_multi"
|
| 1826 |
+
)
|
| 1827 |
+
models_list = [FREE_MODELS_MAP[d] for d in model_displays]
|
| 1828 |
+
api_key = ""
|
| 1829 |
+
else:
|
| 1830 |
+
default_models = PAID_MODEL_CHOICES[:min_models] if len(PAID_MODEL_CHOICES) >= min_models else PAID_MODEL_CHOICES
|
| 1831 |
+
models_list = st.multiselect(
|
| 1832 |
+
f"Models (select {min_models}+)",
|
| 1833 |
+
options=PAID_MODEL_CHOICES,
|
| 1834 |
+
default=default_models,
|
| 1835 |
+
key="classify_models_multi_paid"
|
| 1836 |
+
)
|
| 1837 |
+
api_key = st.text_input("API Key", type="password", key="classify_api_key")
|
| 1838 |
+
|
| 1839 |
+
if models_list:
|
| 1840 |
+
st.markdown("**Model Temperature**")
|
| 1841 |
+
temp_cols = st.columns(len(models_list))
|
| 1842 |
+
for idx, (col, m) in enumerate(zip(temp_cols, models_list)):
|
| 1843 |
+
short_name = m.split('/')[-1].split(':')[0][:20]
|
| 1844 |
+
model_temperatures[m] = col.number_input(
|
| 1845 |
+
short_name,
|
| 1846 |
+
min_value=0.0,
|
| 1847 |
+
max_value=2.0,
|
| 1848 |
+
value=0.0,
|
| 1849 |
+
step=0.25,
|
| 1850 |
+
key=f"temp_{idx}",
|
| 1851 |
+
help=f"Temperature for {m} (0 = deterministic, higher = more creative)"
|
| 1852 |
+
)
|
| 1853 |
+
else:
|
| 1854 |
+
# Single model mode
|
| 1855 |
+
if model_tier == "Free Models":
|
| 1856 |
+
model_display = st.selectbox("Model", options=FREE_MODEL_DISPLAY_NAMES, key="classify_model")
|
| 1857 |
+
model = FREE_MODELS_MAP[model_display] # Convert to actual model name
|
| 1858 |
+
models_list = [model]
|
| 1859 |
+
api_key = ""
|
| 1860 |
+
else:
|
| 1861 |
+
model = st.selectbox("Model", options=PAID_MODEL_CHOICES, key="classify_model_paid")
|
| 1862 |
+
models_list = [model]
|
| 1863 |
+
api_key = st.text_input("API Key", type="password", key="classify_api_key")
|
| 1864 |
+
|
| 1865 |
+
# Ensemble-specific options
|
| 1866 |
+
consensus_threshold = 0.5 # Default
|
| 1867 |
+
if classify_mode == "Ensemble":
|
| 1868 |
+
consensus_options = {
|
| 1869 |
+
"Majority (50%+)": 0.5,
|
| 1870 |
+
"Two-Thirds (67%+)": 0.67,
|
| 1871 |
+
"Unanimous (100%)": 1.0,
|
| 1872 |
+
}
|
| 1873 |
+
consensus_choice = st.radio(
|
| 1874 |
+
"Consensus Rule",
|
| 1875 |
+
options=list(consensus_options.keys()),
|
| 1876 |
+
horizontal=True,
|
| 1877 |
+
key="consensus_choice",
|
| 1878 |
+
help="How many models must agree for a category to be marked present"
|
| 1879 |
+
)
|
| 1880 |
+
consensus_threshold = consensus_options[consensus_choice]
|
| 1881 |
+
|
| 1882 |
+
if st.button("Categorize Data", type="primary", use_container_width=True):
|
| 1883 |
+
if input_data is None:
|
| 1884 |
+
st.error("Please upload data first")
|
| 1885 |
+
elif not categories_entered:
|
| 1886 |
+
st.error("Please enter at least one category")
|
| 1887 |
+
elif classify_mode == "Model Comparison" and len(models_list) < 2:
|
| 1888 |
+
st.error("Please select at least 2 models for comparison mode")
|
| 1889 |
+
elif classify_mode == "Ensemble" and len(models_list) < 3:
|
| 1890 |
+
st.error("Please select at least 3 models for ensemble mode (needed for majority voting)")
|
| 1891 |
+
else:
|
| 1892 |
+
# Set up progress tracking
|
| 1893 |
+
mode = None
|
| 1894 |
+
if input_type_selected == "pdf":
|
| 1895 |
+
mode_mapping = {
|
| 1896 |
+
"Image (visual documents)": "image",
|
| 1897 |
+
"Text (text-heavy)": "text",
|
| 1898 |
+
"Both (comprehensive)": "both"
|
| 1899 |
+
}
|
| 1900 |
+
mode = mode_mapping.get(pdf_mode, "image")
|
| 1901 |
+
|
| 1902 |
+
# Build models tuples list
|
| 1903 |
+
# Uses 4-tuple (model, source, api_key, options) when per-model temperatures are set
|
| 1904 |
+
models_tuples = []
|
| 1905 |
+
api_key_error = None
|
| 1906 |
+
if ensemble_runs:
|
| 1907 |
+
# Ensemble mode: use ensemble_runs (model, temp) pairs directly
|
| 1908 |
+
for m, temp in ensemble_runs:
|
| 1909 |
+
actual_key, provider = get_api_key(m, model_tier, api_key)
|
| 1910 |
+
if not actual_key:
|
| 1911 |
+
api_key_error = f"{provider} API key not configured for {m}"
|
| 1912 |
+
break
|
| 1913 |
+
m_source = get_model_source(m)
|
| 1914 |
+
models_tuples.append((m, m_source, actual_key, {"creativity": temp}))
|
| 1915 |
+
else:
|
| 1916 |
+
for m in models_list:
|
| 1917 |
+
actual_key, provider = get_api_key(m, model_tier, api_key)
|
| 1918 |
+
if not actual_key:
|
| 1919 |
+
api_key_error = f"{provider} API key not configured for {m}"
|
| 1920 |
+
break
|
| 1921 |
+
m_source = get_model_source(m)
|
| 1922 |
+
temp = model_temperatures.get(m)
|
| 1923 |
+
if temp is not None and is_multi_model:
|
| 1924 |
+
models_tuples.append((m, m_source, actual_key, {"creativity": temp}))
|
| 1925 |
+
else:
|
| 1926 |
+
models_tuples.append((m, m_source, actual_key))
|
| 1927 |
+
|
| 1928 |
+
if api_key_error:
|
| 1929 |
+
st.error(api_key_error)
|
| 1930 |
+
else:
|
| 1931 |
+
items_list = input_data if isinstance(input_data, list) else [input_data]
|
| 1932 |
+
|
| 1933 |
+
# Progress UI
|
| 1934 |
+
progress_bar = st.progress(0)
|
| 1935 |
+
status_text = st.empty()
|
| 1936 |
+
start_time = time.time()
|
| 1937 |
+
|
| 1938 |
+
# For PDFs, use progress callback
|
| 1939 |
+
if input_type_selected == "pdf":
|
| 1940 |
+
# Progress callback for PDF page-by-page updates
|
| 1941 |
+
def pdf_progress_callback(current_idx, total_pages, page_label):
|
| 1942 |
+
progress = current_idx / total_pages if total_pages > 0 else 0
|
| 1943 |
+
progress_bar.progress(min(progress, 1.0))
|
| 1944 |
+
|
| 1945 |
+
elapsed = time.time() - start_time
|
| 1946 |
+
if current_idx > 0:
|
| 1947 |
+
avg_time = elapsed / current_idx
|
| 1948 |
+
eta_seconds = avg_time * (total_pages - current_idx)
|
| 1949 |
+
eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
|
| 1950 |
+
else:
|
| 1951 |
+
eta_str = ""
|
| 1952 |
+
|
| 1953 |
+
status_text.text(f"Processing page {current_idx+1} of {total_pages} ({page_label}) ({progress*100:.0f}%){eta_str}")
|
| 1954 |
+
|
| 1955 |
+
try:
|
| 1956 |
+
# Build kwargs for classify
|
| 1957 |
+
classify_kwargs = {
|
| 1958 |
+
"input_data": items_list,
|
| 1959 |
+
"categories": categories_entered,
|
| 1960 |
+
"models": models_tuples,
|
| 1961 |
+
"description": description,
|
| 1962 |
+
"mode": mode,
|
| 1963 |
+
"progress_callback": pdf_progress_callback,
|
| 1964 |
+
}
|
| 1965 |
+
# Add consensus_threshold for ensemble mode
|
| 1966 |
+
if classify_mode == "Ensemble":
|
| 1967 |
+
classify_kwargs["consensus_threshold"] = consensus_threshold
|
| 1968 |
+
|
| 1969 |
+
result_df = catvader.classify(**classify_kwargs)
|
| 1970 |
+
|
| 1971 |
+
processing_time = time.time() - start_time
|
| 1972 |
+
total_items = len(result_df)
|
| 1973 |
+
progress_bar.progress(1.0)
|
| 1974 |
+
status_text.text(f"Completed {total_items} pages in {processing_time:.1f}s")
|
| 1975 |
+
|
| 1976 |
+
# Replace temp paths with original filenames in pdf_input column
|
| 1977 |
+
if 'pdf_input' in result_df.columns:
|
| 1978 |
+
pdf_name_map = st.session_state.get('pdf_name_map', {})
|
| 1979 |
+
def replace_temp_path(val):
|
| 1980 |
+
if pd.isna(val):
|
| 1981 |
+
return val
|
| 1982 |
+
val_str = str(val)
|
| 1983 |
+
for temp_path, orig_name in pdf_name_map.items():
|
| 1984 |
+
# Check if the temp path's filename (without extension) is in the value
|
| 1985 |
+
temp_name = os.path.basename(temp_path).replace('.pdf', '')
|
| 1986 |
+
if temp_name in val_str:
|
| 1987 |
+
return val_str.replace(temp_name, orig_name)
|
| 1988 |
+
return val_str
|
| 1989 |
+
result_df['pdf_input'] = result_df['pdf_input'].apply(replace_temp_path)
|
| 1990 |
+
|
| 1991 |
+
all_results = [result_df]
|
| 1992 |
+
|
| 1993 |
+
except Exception as e:
|
| 1994 |
+
st.error(f"Error: {str(e)}")
|
| 1995 |
+
all_results = []
|
| 1996 |
+
|
| 1997 |
+
else:
|
| 1998 |
+
# Non-PDF processing (text, images) - process all at once
|
| 1999 |
+
total_items = len(items_list)
|
| 2000 |
+
|
| 2001 |
+
# Progress callback for item-by-item updates
|
| 2002 |
+
def item_progress_callback(current_idx, total, item_label):
|
| 2003 |
+
progress = current_idx / total if total > 0 else 0
|
| 2004 |
+
progress_bar.progress(min(progress, 1.0))
|
| 2005 |
+
|
| 2006 |
+
elapsed = time.time() - start_time
|
| 2007 |
+
if current_idx > 0:
|
| 2008 |
+
avg_time = elapsed / current_idx
|
| 2009 |
+
eta_seconds = avg_time * (total - current_idx)
|
| 2010 |
+
eta_str = f" | ETA: {eta_seconds:.0f}s" if eta_seconds < 60 else f" | ETA: {eta_seconds/60:.1f}m"
|
| 2011 |
+
else:
|
| 2012 |
+
eta_str = ""
|
| 2013 |
+
|
| 2014 |
+
status_text.text(f"Processing item {current_idx+1} of {total} ({progress*100:.0f}%){eta_str}")
|
| 2015 |
+
|
| 2016 |
+
try:
|
| 2017 |
+
# Build kwargs for classify
|
| 2018 |
+
classify_kwargs = {
|
| 2019 |
+
"input_data": items_list,
|
| 2020 |
+
"categories": categories_entered,
|
| 2021 |
+
"models": models_tuples,
|
| 2022 |
+
"description": description,
|
| 2023 |
+
"progress_callback": item_progress_callback,
|
| 2024 |
+
}
|
| 2025 |
+
# Add consensus_threshold for ensemble mode
|
| 2026 |
+
if classify_mode == "Ensemble":
|
| 2027 |
+
classify_kwargs["consensus_threshold"] = consensus_threshold
|
| 2028 |
+
|
| 2029 |
+
result_df = catvader.classify(**classify_kwargs)
|
| 2030 |
+
all_results = [result_df]
|
| 2031 |
+
|
| 2032 |
+
processing_time = time.time() - start_time
|
| 2033 |
+
progress_bar.progress(1.0)
|
| 2034 |
+
status_text.text(f"Completed {total_items} items in {processing_time:.1f}s")
|
| 2035 |
+
|
| 2036 |
+
except Exception as e:
|
| 2037 |
+
st.error(f"Error: {str(e)}")
|
| 2038 |
+
all_results = []
|
| 2039 |
+
processing_time = time.time() - start_time
|
| 2040 |
+
|
| 2041 |
+
if all_results:
|
| 2042 |
+
# Combine results
|
| 2043 |
+
result_df = pd.concat(all_results, ignore_index=True)
|
| 2044 |
+
|
| 2045 |
+
# Merge Bluesky engagement columns if available
|
| 2046 |
+
if st.session_state.get("bluesky_df") is not None:
|
| 2047 |
+
bsky_eng = st.session_state.bluesky_df.reset_index(drop=True)
|
| 2048 |
+
if len(bsky_eng) == len(result_df):
|
| 2049 |
+
for col in ["post_id", "timestamp", "likes", "replies", "reposts",
|
| 2050 |
+
"media_type", "image_url", "post_length",
|
| 2051 |
+
"contains_url", "contains_image", "is_repost"]:
|
| 2052 |
+
if col in bsky_eng.columns:
|
| 2053 |
+
result_df[col] = bsky_eng[col].values
|
| 2054 |
+
|
| 2055 |
+
# Save CSV
|
| 2056 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='_classified.csv', delete=False) as f:
|
| 2057 |
+
result_df.to_csv(f.name, index=False)
|
| 2058 |
+
csv_path = f.name
|
| 2059 |
+
|
| 2060 |
+
# Calculate success rate
|
| 2061 |
+
if 'processing_status' in result_df.columns:
|
| 2062 |
+
success_count = (result_df['processing_status'] == 'success').sum()
|
| 2063 |
+
success_rate = (success_count / len(result_df)) * 100
|
| 2064 |
+
else:
|
| 2065 |
+
success_rate = 100.0
|
| 2066 |
+
|
| 2067 |
+
# Get version info
|
| 2068 |
+
try:
|
| 2069 |
+
catvader_version = catvader.__version__
|
| 2070 |
+
except AttributeError:
|
| 2071 |
+
catvader_version = "unknown"
|
| 2072 |
+
python_version = sys.version.split()[0]
|
| 2073 |
+
|
| 2074 |
+
# For reports: create model string (single or list)
|
| 2075 |
+
if len(models_list) == 1:
|
| 2076 |
+
report_model = models_list[0]
|
| 2077 |
+
report_model_source = models_tuples[0][1]
|
| 2078 |
+
else:
|
| 2079 |
+
report_model = ", ".join(models_list)
|
| 2080 |
+
report_model_source = f"{classify_mode} ({len(models_list)} models)"
|
| 2081 |
+
|
| 2082 |
+
# Generate code first so we can include it in the PDF
|
| 2083 |
+
# If categories were auto-extracted, include both extract and classify code
|
| 2084 |
+
if st.session_state.extraction_params:
|
| 2085 |
+
classify_params = {
|
| 2086 |
+
'model': report_model,
|
| 2087 |
+
'description': description,
|
| 2088 |
+
'mode': mode,
|
| 2089 |
+
'classify_mode': classify_mode,
|
| 2090 |
+
'models_list': models_list,
|
| 2091 |
+
'consensus_threshold': consensus_threshold,
|
| 2092 |
+
'model_temperatures': model_temperatures,
|
| 2093 |
+
'ensemble_runs': ensemble_runs if ensemble_runs else None,
|
| 2094 |
+
}
|
| 2095 |
+
code = generate_full_code(st.session_state.extraction_params, classify_params)
|
| 2096 |
+
else:
|
| 2097 |
+
code = generate_classify_code(
|
| 2098 |
+
input_type_selected, description, categories_entered,
|
| 2099 |
+
report_model, report_model_source, mode,
|
| 2100 |
+
classify_mode=classify_mode, models_list=models_list,
|
| 2101 |
+
consensus_threshold=consensus_threshold,
|
| 2102 |
+
model_temperatures=model_temperatures,
|
| 2103 |
+
ensemble_runs=ensemble_runs if ensemble_runs else None,
|
| 2104 |
+
)
|
| 2105 |
+
|
| 2106 |
+
# Generate methodology report with code included
|
| 2107 |
+
pdf_path = generate_methodology_report_pdf(
|
| 2108 |
+
categories=categories_entered,
|
| 2109 |
+
model=report_model,
|
| 2110 |
+
column_name=description,
|
| 2111 |
+
num_rows=len(result_df),
|
| 2112 |
+
model_source=report_model_source,
|
| 2113 |
+
filename=original_filename,
|
| 2114 |
+
success_rate=success_rate,
|
| 2115 |
+
result_df=result_df,
|
| 2116 |
+
processing_time=processing_time,
|
| 2117 |
+
catvader_version=catvader_version,
|
| 2118 |
+
python_version=python_version,
|
| 2119 |
+
task_type="assign",
|
| 2120 |
+
input_type=input_type_selected,
|
| 2121 |
+
description=description,
|
| 2122 |
+
classify_mode=classify_mode,
|
| 2123 |
+
models_list=models_list,
|
| 2124 |
+
code=code,
|
| 2125 |
+
consensus_threshold=consensus_threshold if classify_mode == "Ensemble" else None,
|
| 2126 |
+
)
|
| 2127 |
+
|
| 2128 |
+
st.session_state.results = {
|
| 2129 |
+
'df': result_df,
|
| 2130 |
+
'csv_path': csv_path,
|
| 2131 |
+
'pdf_path': pdf_path,
|
| 2132 |
+
'code': code,
|
| 2133 |
+
'status': f"Classified {len(result_df)} items in {processing_time:.1f}s",
|
| 2134 |
+
'categories': categories_entered,
|
| 2135 |
+
'classify_mode': classify_mode,
|
| 2136 |
+
'models_list': models_list,
|
| 2137 |
+
'model_temperatures': model_temperatures,
|
| 2138 |
+
'ensemble_runs': ensemble_runs if ensemble_runs else None,
|
| 2139 |
+
}
|
| 2140 |
+
st.success(f"Classified {len(result_df)} items in {processing_time:.1f}s")
|
| 2141 |
+
st.rerun()
|
| 2142 |
+
else:
|
| 2143 |
+
st.error("No items were successfully classified")
|
| 2144 |
+
|
| 2145 |
+
with col_output:
|
| 2146 |
+
st.markdown("### Results")
|
| 2147 |
+
|
| 2148 |
+
if st.session_state.results:
|
| 2149 |
+
results = st.session_state.results
|
| 2150 |
+
|
| 2151 |
+
# Visualization selector
|
| 2152 |
+
viz_type = st.selectbox(
|
| 2153 |
+
"Visualization",
|
| 2154 |
+
options=["Category Distribution", "Classification Matrix"],
|
| 2155 |
+
key="viz_type",
|
| 2156 |
+
help="Distribution shows category percentages. Matrix shows each response's classifications."
|
| 2157 |
+
)
|
| 2158 |
+
|
| 2159 |
+
if viz_type == "Category Distribution":
|
| 2160 |
+
fig = create_distribution_chart(
|
| 2161 |
+
results['df'],
|
| 2162 |
+
results['categories'],
|
| 2163 |
+
classify_mode=results.get('classify_mode', 'Single Model'),
|
| 2164 |
+
models_list=results.get('models_list', [])
|
| 2165 |
+
)
|
| 2166 |
+
st.pyplot(fig)
|
| 2167 |
+
st.caption("Note: Categories are not mutually exclusive—each item can belong to multiple categories.")
|
| 2168 |
+
else:
|
| 2169 |
+
fig = create_classification_heatmap(
|
| 2170 |
+
results['df'],
|
| 2171 |
+
results['categories'],
|
| 2172 |
+
classify_mode=results.get('classify_mode', 'Single Model'),
|
| 2173 |
+
models_list=results.get('models_list', [])
|
| 2174 |
+
)
|
| 2175 |
+
st.pyplot(fig)
|
| 2176 |
+
st.caption("Green = category present, Black = not present. Each row is one response.")
|
| 2177 |
+
|
| 2178 |
+
# Results dataframe (hide technical columns from display)
|
| 2179 |
+
display_df = results['df'].copy()
|
| 2180 |
+
cols_to_hide = ['model_response', 'json', 'raw_response', 'raw_json']
|
| 2181 |
+
display_df = display_df.drop(columns=[c for c in cols_to_hide if c in display_df.columns])
|
| 2182 |
+
st.dataframe(display_df, use_container_width=True)
|
| 2183 |
+
|
| 2184 |
+
# Downloads
|
| 2185 |
+
col_dl1, col_dl2, col_dl3 = st.columns(3)
|
| 2186 |
+
with col_dl1:
|
| 2187 |
+
with open(results['csv_path'], 'rb') as f:
|
| 2188 |
+
st.download_button(
|
| 2189 |
+
"Download CSV",
|
| 2190 |
+
data=f,
|
| 2191 |
+
file_name="classified_results.csv",
|
| 2192 |
+
mime="text/csv"
|
| 2193 |
+
)
|
| 2194 |
+
with col_dl2:
|
| 2195 |
+
with open(results['pdf_path'], 'rb') as f:
|
| 2196 |
+
st.download_button(
|
| 2197 |
+
"Download Report",
|
| 2198 |
+
data=f,
|
| 2199 |
+
file_name="methodology_report.pdf",
|
| 2200 |
+
mime="application/pdf"
|
| 2201 |
+
)
|
| 2202 |
+
with col_dl3:
|
| 2203 |
+
# Generate both plots and save to a single PDF
|
| 2204 |
+
import io
|
| 2205 |
+
from matplotlib.backends.backend_pdf import PdfPages
|
| 2206 |
+
|
| 2207 |
+
plot_buffer = io.BytesIO()
|
| 2208 |
+
with PdfPages(plot_buffer) as pdf:
|
| 2209 |
+
# Distribution chart
|
| 2210 |
+
fig1 = create_distribution_chart(
|
| 2211 |
+
results['df'],
|
| 2212 |
+
results['categories'],
|
| 2213 |
+
classify_mode=results.get('classify_mode', 'Single Model'),
|
| 2214 |
+
models_list=results.get('models_list', [])
|
| 2215 |
+
)
|
| 2216 |
+
pdf.savefig(fig1, bbox_inches='tight')
|
| 2217 |
+
plt.close(fig1)
|
| 2218 |
+
|
| 2219 |
+
# Classification matrix
|
| 2220 |
+
fig2 = create_classification_heatmap(
|
| 2221 |
+
results['df'],
|
| 2222 |
+
results['categories'],
|
| 2223 |
+
classify_mode=results.get('classify_mode', 'Single Model'),
|
| 2224 |
+
models_list=results.get('models_list', [])
|
| 2225 |
+
)
|
| 2226 |
+
pdf.savefig(fig2, bbox_inches='tight')
|
| 2227 |
+
plt.close(fig2)
|
| 2228 |
+
|
| 2229 |
+
plot_buffer.seek(0)
|
| 2230 |
+
st.download_button(
|
| 2231 |
+
"Download Plots",
|
| 2232 |
+
data=plot_buffer,
|
| 2233 |
+
file_name="classification_plots.pdf",
|
| 2234 |
+
mime="application/pdf"
|
| 2235 |
+
)
|
| 2236 |
+
|
| 2237 |
+
# Code
|
| 2238 |
+
with st.expander("See the Code"):
|
| 2239 |
+
st.code(results['code'], language='python')
|
| 2240 |
+
else:
|
| 2241 |
+
st.info("Upload data, select categories, and click 'Categorize Data' to see results here.")
|
| 2242 |
+
|
| 2243 |
+
# Bottom buttons
|
| 2244 |
+
col_reset, col_code = st.columns(2)
|
| 2245 |
+
with col_reset:
|
| 2246 |
+
if st.button("Reset", type="secondary", use_container_width=True):
|
| 2247 |
+
st.session_state.categories = [''] * MAX_CATEGORIES
|
| 2248 |
+
st.session_state.category_count = INITIAL_CATEGORIES
|
| 2249 |
+
st.session_state.task_mode = None
|
| 2250 |
+
st.session_state.extracted_categories = None
|
| 2251 |
+
st.session_state.extraction_params = None
|
| 2252 |
+
st.session_state.results = None
|
| 2253 |
+
if hasattr(st.session_state, 'example_loaded'):
|
| 2254 |
+
del st.session_state.example_loaded
|
| 2255 |
+
st.rerun()
|
| 2256 |
+
|
| 2257 |
+
with col_code:
|
| 2258 |
+
if st.button("See in Code", use_container_width=True):
|
| 2259 |
+
st.session_state.show_code_modal = True
|
| 2260 |
+
|
| 2261 |
+
# Code modal/dialog
|
| 2262 |
+
if st.session_state.get('show_code_modal'):
|
| 2263 |
+
st.markdown("---")
|
| 2264 |
+
st.markdown("### Reproducibility Code")
|
| 2265 |
+
st.markdown("Use this code to reproduce the classification with the CatVader Python package:")
|
| 2266 |
+
|
| 2267 |
+
# Use results code if available, otherwise generate from current parameters
|
| 2268 |
+
if st.session_state.results:
|
| 2269 |
+
code_to_show = st.session_state.results['code']
|
| 2270 |
+
else:
|
| 2271 |
+
# Get current categories from session state
|
| 2272 |
+
current_categories = [c for c in st.session_state.categories[:st.session_state.category_count] if c.strip()]
|
| 2273 |
+
|
| 2274 |
+
# Determine current input type and description
|
| 2275 |
+
input_type_map = {"Social Media Posts": "text", "PDF Documents": "pdf", "Images": "image"}
|
| 2276 |
+
current_input_type = input_type_map.get(st.session_state.get('input_type_radio', 'Social Media Posts'), 'text')
|
| 2277 |
+
current_description = st.session_state.get('survey_column', '') or st.session_state.get('pdf_desc', '') or st.session_state.get('image_desc', '') or 'your_data'
|
| 2278 |
+
|
| 2279 |
+
# Get current classification mode and models
|
| 2280 |
+
current_classify_mode = st.session_state.get('classify_mode', 'Single Model')
|
| 2281 |
+
current_model_tier = st.session_state.get('classify_model_tier', 'Free Models')
|
| 2282 |
+
|
| 2283 |
+
if current_classify_mode in ["Model Comparison", "Ensemble"]:
|
| 2284 |
+
# Multi-model mode
|
| 2285 |
+
if current_model_tier == 'Free Models':
|
| 2286 |
+
model_displays = st.session_state.get('classify_models_multi', [])
|
| 2287 |
+
current_models_list = [FREE_MODELS_MAP.get(d, d) for d in model_displays]
|
| 2288 |
+
else:
|
| 2289 |
+
current_models_list = st.session_state.get('classify_models_multi_paid', [])
|
| 2290 |
+
current_model = ", ".join(current_models_list) if current_models_list else "gpt-4o-mini"
|
| 2291 |
+
current_model_source = f"{current_classify_mode} ({len(current_models_list)} models)"
|
| 2292 |
+
else:
|
| 2293 |
+
# Single model mode
|
| 2294 |
+
if current_model_tier == 'Free Models':
|
| 2295 |
+
model_display = st.session_state.get('classify_model', 'GPT-4o Mini')
|
| 2296 |
+
current_model = FREE_MODELS_MAP.get(model_display, 'gpt-4o-mini')
|
| 2297 |
+
else:
|
| 2298 |
+
current_model = st.session_state.get('classify_model_paid', 'gpt-4o-mini')
|
| 2299 |
+
current_models_list = [current_model]
|
| 2300 |
+
current_model_source = get_model_source(current_model)
|
| 2301 |
+
|
| 2302 |
+
# Get consensus threshold for ensemble mode
|
| 2303 |
+
consensus_options = {"Majority (50%+)": 0.5, "Two-Thirds (67%+)": 0.67, "Unanimous (100%)": 1.0}
|
| 2304 |
+
current_consensus = consensus_options.get(st.session_state.get('consensus_choice', 'Majority (50%+)'), 0.5)
|
| 2305 |
+
|
| 2306 |
+
# Get PDF mode if applicable
|
| 2307 |
+
current_mode = None
|
| 2308 |
+
if current_input_type == "pdf":
|
| 2309 |
+
mode_mapping = {
|
| 2310 |
+
"Image (visual documents)": "image",
|
| 2311 |
+
"Text (text-heavy)": "text",
|
| 2312 |
+
"Both (comprehensive)": "both"
|
| 2313 |
+
}
|
| 2314 |
+
current_mode = mode_mapping.get(st.session_state.get('pdf_mode', 'Image (visual documents)'), 'image')
|
| 2315 |
+
|
| 2316 |
+
if current_categories:
|
| 2317 |
+
# Check if categories were auto-extracted
|
| 2318 |
+
if st.session_state.extraction_params:
|
| 2319 |
+
current_temperatures = results.get('model_temperatures', {})
|
| 2320 |
+
classify_params = {
|
| 2321 |
+
'model': current_model,
|
| 2322 |
+
'description': current_description,
|
| 2323 |
+
'mode': current_mode,
|
| 2324 |
+
'classify_mode': current_classify_mode,
|
| 2325 |
+
'models_list': current_models_list,
|
| 2326 |
+
'consensus_threshold': current_consensus,
|
| 2327 |
+
'model_temperatures': current_temperatures,
|
| 2328 |
+
'ensemble_runs': results.get('ensemble_runs'),
|
| 2329 |
+
}
|
| 2330 |
+
code_to_show = generate_full_code(st.session_state.extraction_params, classify_params)
|
| 2331 |
+
else:
|
| 2332 |
+
current_temperatures = results.get('model_temperatures', {})
|
| 2333 |
+
code_to_show = generate_classify_code(
|
| 2334 |
+
current_input_type, current_description, current_categories,
|
| 2335 |
+
current_model, current_model_source, current_mode,
|
| 2336 |
+
classify_mode=current_classify_mode, models_list=current_models_list,
|
| 2337 |
+
consensus_threshold=current_consensus,
|
| 2338 |
+
model_temperatures=current_temperatures,
|
| 2339 |
+
ensemble_runs=results.get('ensemble_runs'),
|
| 2340 |
+
)
|
| 2341 |
+
else:
|
| 2342 |
+
code_to_show = '''import catvader
|
| 2343 |
+
|
| 2344 |
+
# Define your categories
|
| 2345 |
+
categories = [
|
| 2346 |
+
"Category 1",
|
| 2347 |
+
"Category 2",
|
| 2348 |
+
# Add more categories...
|
| 2349 |
+
]
|
| 2350 |
+
|
| 2351 |
+
# Classify your data
|
| 2352 |
+
result = catvader.classify(
|
| 2353 |
+
input_data=df["your_column"].tolist(),
|
| 2354 |
+
categories=categories,
|
| 2355 |
+
api_key="YOUR_API_KEY",
|
| 2356 |
+
description="your_description",
|
| 2357 |
+
user_model="gpt-4o-mini"
|
| 2358 |
+
)
|
| 2359 |
+
|
| 2360 |
+
# View results
|
| 2361 |
+
print(result)
|
| 2362 |
+
result.to_csv("classified_results.csv", index=False)
|
| 2363 |
+
'''
|
| 2364 |
+
|
| 2365 |
+
st.code(code_to_show, language='python')
|
| 2366 |
+
if st.button("Close"):
|
| 2367 |
+
st.session_state.show_code_modal = False
|
| 2368 |
+
st.rerun()
|