Spaces:
Runtime error
Runtime error
Upload 30 files
Browse files- .dockerignore +84 -0
- .gitignore +12 -0
- .python-version +1 -0
- Dockerfile +39 -20
- README.md +712 -19
- app/app.py +916 -916
- app/run_demo.py +38 -38
- dist/.gitignore +1 -0
- dist/qualivec-0.1.0-py3-none-any.whl +0 -0
- dist/qualivec-0.1.0.tar.gz +3 -0
- src/qualivec/__pycache__/embedding.cpython-312.pyc +0 -0
- uv.lock +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Docker
|
| 2 |
+
.dockerignore
|
| 3 |
+
Dockerfile*
|
| 4 |
+
docker-compose*.yml
|
| 5 |
+
|
| 6 |
+
# Git
|
| 7 |
+
.git/
|
| 8 |
+
.gitignore
|
| 9 |
+
|
| 10 |
+
# Python Virtual Environment
|
| 11 |
+
.venv/
|
| 12 |
+
venv/
|
| 13 |
+
env/
|
| 14 |
+
ENV/
|
| 15 |
+
env.bak/
|
| 16 |
+
venv.bak/
|
| 17 |
+
|
| 18 |
+
# Python cache
|
| 19 |
+
__pycache__/
|
| 20 |
+
*.pyc
|
| 21 |
+
*.py[cod]
|
| 22 |
+
*$py.class
|
| 23 |
+
*.so
|
| 24 |
+
|
| 25 |
+
# Build artifacts
|
| 26 |
+
dist/
|
| 27 |
+
build/
|
| 28 |
+
develop-eggs/
|
| 29 |
+
downloads/
|
| 30 |
+
eggs/
|
| 31 |
+
.eggs/
|
| 32 |
+
lib/
|
| 33 |
+
lib64/
|
| 34 |
+
parts/
|
| 35 |
+
sdist/
|
| 36 |
+
var/
|
| 37 |
+
wheels/
|
| 38 |
+
share/python-wheels/
|
| 39 |
+
*.egg-info/
|
| 40 |
+
.installed.cfg
|
| 41 |
+
*.egg
|
| 42 |
+
MANIFEST
|
| 43 |
+
|
| 44 |
+
# IDE
|
| 45 |
+
.vscode/
|
| 46 |
+
.idea/
|
| 47 |
+
*.swp
|
| 48 |
+
*.swo
|
| 49 |
+
*~
|
| 50 |
+
|
| 51 |
+
# OS
|
| 52 |
+
.DS_Store
|
| 53 |
+
.DS_Store?
|
| 54 |
+
._*
|
| 55 |
+
.Spotlight-V100
|
| 56 |
+
.Trashes
|
| 57 |
+
ehthumbs.db
|
| 58 |
+
Thumbs.db
|
| 59 |
+
|
| 60 |
+
# Documentation (keep README.md)
|
| 61 |
+
docs/
|
| 62 |
+
*.md
|
| 63 |
+
!README.md
|
| 64 |
+
|
| 65 |
+
# Tests
|
| 66 |
+
tests/
|
| 67 |
+
test_*/
|
| 68 |
+
*_test.py
|
| 69 |
+
**/test_*.py
|
| 70 |
+
|
| 71 |
+
# Data files (you may want to adjust these based on your needs)
|
| 72 |
+
*.csv
|
| 73 |
+
*.json
|
| 74 |
+
*.pkl
|
| 75 |
+
*.parquet
|
| 76 |
+
|
| 77 |
+
# Logs
|
| 78 |
+
*.log
|
| 79 |
+
logs/
|
| 80 |
+
|
| 81 |
+
# Temporary files
|
| 82 |
+
tmp/
|
| 83 |
+
temp/
|
| 84 |
+
.tmp/
|
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
*.pdf
|
| 12 |
+
*.csv
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12.10
|
Dockerfile
CHANGED
|
@@ -1,21 +1,40 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
ENTRYPOINT ["python", "app/run_demo.py"]
|
|
|
|
| 1 |
+
# Dockerfile for QualiVec Streamlit Demo
|
| 2 |
+
|
| 3 |
+
# 1. Base Image
|
| 4 |
+
FROM python:3.12-slim
|
| 5 |
+
|
| 6 |
+
# 2. Set the working directory
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
# 3. Install system dependencies
|
| 10 |
+
RUN apt-get update && apt-get install -y \
|
| 11 |
+
build-essential \
|
| 12 |
+
curl \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# 4. Install uv - the fast Python package manager
|
| 16 |
+
RUN pip install --no-cache-dir uv
|
| 17 |
+
|
| 18 |
+
# 5. Copy dependency definition files and README (required for package build)
|
| 19 |
+
COPY pyproject.toml uv.lock README.md ./
|
| 20 |
+
|
| 21 |
+
# 6. Copy source code (needed for package installation)
|
| 22 |
+
COPY src/ ./src/
|
| 23 |
+
|
| 24 |
+
# 7. Install Python dependencies using uv
|
| 25 |
+
# 'uv pip install .' reads pyproject.toml and installs the project dependencies
|
| 26 |
+
RUN uv pip install --system --no-cache-dir .
|
| 27 |
+
|
| 28 |
+
# 8. Copy the rest of the application source code
|
| 29 |
+
# Make sure you have a .dockerignore file to exclude .venv
|
| 30 |
+
COPY . .
|
| 31 |
+
|
| 32 |
+
# 9. Expose the port Streamlit runs on
|
| 33 |
+
EXPOSE 8501
|
| 34 |
+
|
| 35 |
+
# 10. Add a health check to verify the app is running
|
| 36 |
+
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 37 |
+
|
| 38 |
+
# 11. Define the entry point to use the run_demo.py script via uv
|
| 39 |
+
# ENTRYPOINT ["uv", "run", "app/run_demo.py"]
|
| 40 |
ENTRYPOINT ["python", "app/run_demo.py"]
|
README.md
CHANGED
|
@@ -1,19 +1,712 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
---
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# QualiVec
|
| 2 |
+
|
| 3 |
+
**QualiVec** is a Python library for scalable qualitative content analysis powered by Large Language Model (LLM) embeddings. It bridges qualitative content analysis with machine learning by leveraging the semantic understanding capabilities of Large Language Models. Instead of relying on simple keyword matching or manually coding large datasets, QualiVec uses embedding vectors to capture semantic meaning and perform classification based on similarity to reference vectors.
|
| 4 |
+
|
| 5 |
+
Key features:
|
| 6 |
+
- LLM-based embedding generation
|
| 7 |
+
- Semantic similarity assessment using cosine similarity
|
| 8 |
+
- Deductive and inductive coding support
|
| 9 |
+
- Reference vector creation from labeled corpora
|
| 10 |
+
- Corpus-driven clustering for robust semantic anchor construction
|
| 11 |
+
- Supports large-scale document classification
|
| 12 |
+
- Domain-agnostic and model-flexible design
|
| 13 |
+
- Human-level performance in multi-domain content analysis
|
| 14 |
+
- Bootstrap evaluation with confidence intervals
|
| 15 |
+
- Threshold optimization for classification performance
|
| 16 |
+
|
| 17 |
+
## π» Installation
|
| 18 |
+
|
| 19 |
+
```bash
|
| 20 |
+
pip install qualivec
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
For development installation:
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
git clone https://github.com/AkhilVaidya91/QualiVec.git
|
| 27 |
+
cd qualivec
|
| 28 |
+
pip install -e .
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## π₯οΈ Interactive Demo
|
| 32 |
+
|
| 33 |
+
QualiVec includes a comprehensive Streamlit web application that provides an interactive demonstration of the library's capabilities. The demo allows users to upload their own data and experience the full workflow of qualitative content analysis using LLM embeddings.
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
### Demo Features
|
| 37 |
+
|
| 38 |
+
- **Interactive Data Upload**: Upload your own CSV files for reference and labeled data
|
| 39 |
+
- **Model Configuration**: Choose from different pre-trained embedding models
|
| 40 |
+
- **Threshold Optimization**: Automatically find the optimal similarity threshold
|
| 41 |
+
- **Real-time Classification**: See classification results as they happen
|
| 42 |
+
- **Comprehensive Evaluation**: View detailed performance metrics and visualizations
|
| 43 |
+
- **Bootstrap Analysis**: Get confidence intervals for robust evaluation
|
| 44 |
+
- **Download Results**: Export classification results and metrics
|
| 45 |
+
|
| 46 |
+
### Getting Started with Demo
|
| 47 |
+
|
| 48 |
+
1. **Install Dependencies**:
|
| 49 |
+
```bash
|
| 50 |
+
pip install -e .
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
2. **Run the Demo**:
|
| 54 |
+
```bash
|
| 55 |
+
cd app
|
| 56 |
+
uv run run_demo.py
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
3. **Access the Demo**:
|
| 60 |
+
Open your browser and navigate to `http://localhost:8501`
|
| 61 |
+
|
| 62 |
+
### Demo Walkthrough
|
| 63 |
+
|
| 64 |
+
#### 1. Data Upload Page
|
| 65 |
+
Upload your reference and labeled data files. The demo validates file formats and shows data statistics.
|
| 66 |
+
|
| 67 |
+

|
| 68 |
+
|
| 69 |
+
#### 2. Configuration Page
|
| 70 |
+
Configure embedding models and optimization parameters. Choose from multiple pre-trained models and set classification thresholds.
|
| 71 |
+
|
| 72 |
+

|
| 73 |
+
|
| 74 |
+
#### 3. Classification Page
|
| 75 |
+
Run the classification process with real-time progress updates. View optimization results and threshold analysis.
|
| 76 |
+
|
| 77 |
+

|
| 78 |
+
|
| 79 |
+
#### 4. Results Page
|
| 80 |
+
Examine detailed evaluation metrics, confusion matrices, bootstrap confidence intervals, and sample predictions.
|
| 81 |
+
|
| 82 |
+

|
| 83 |
+
|
| 84 |
+
### Data Format Requirements
|
| 85 |
+
|
| 86 |
+
#### Reference Data (CSV)
|
| 87 |
+
Your reference data should contain:
|
| 88 |
+
- `tag`: The class/category label
|
| 89 |
+
- `sentence`: The example text for that category
|
| 90 |
+
|
| 91 |
+
Example:
|
| 92 |
+
|
| 93 |
+
| tag | sentence |
|
| 94 |
+
|----------|---------------------------------|
|
| 95 |
+
| Positive | This is absolutely fantastic! |
|
| 96 |
+
| Negative | This is terrible and disappointing |
|
| 97 |
+
| Neutral | This is okay I guess |
|
| 98 |
+
|
| 99 |
+
#### Labeled Data (CSV)
|
| 100 |
+
Your labeled data should contain:
|
| 101 |
+
- `sentence`: The text to be classified
|
| 102 |
+
- `Label`: The true class/category (for evaluation)
|
| 103 |
+
|
| 104 |
+
Example:
|
| 105 |
+
|
| 106 |
+
| sentence | Label |
|
| 107 |
+
|------------------------------------|----------|
|
| 108 |
+
| I love this product so much! | Positive |
|
| 109 |
+
| Not very good quality | Negative |
|
| 110 |
+
| Average product nothing special | Neutral |
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
## π Quick Start
|
| 114 |
+
|
| 115 |
+
Here's a simple example to classify text data using reference vectors:
|
| 116 |
+
|
| 117 |
+
```python
|
| 118 |
+
from qualivec.classification import Classifier
|
| 119 |
+
|
| 120 |
+
# Initialize classifier
|
| 121 |
+
classifier = Classifier(verbose=True)
|
| 122 |
+
|
| 123 |
+
# Load models
|
| 124 |
+
classifier.load_models(model_name="sentence-transformers/all-MiniLM-L6-v2", threshold=0.7)
|
| 125 |
+
|
| 126 |
+
# Prepare reference vectors
|
| 127 |
+
reference_data = classifier.prepare_reference_vectors(
|
| 128 |
+
reference_path="path/to/reference_vectors.csv",
|
| 129 |
+
class_column="class",
|
| 130 |
+
node_column="matching_node"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Classify corpus
|
| 134 |
+
results_df = classifier.classify(
|
| 135 |
+
corpus_path="path/to/corpus.csv",
|
| 136 |
+
reference_data=reference_data,
|
| 137 |
+
sentence_column="sentence",
|
| 138 |
+
output_path="path/to/results.csv"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Display distribution of classifications
|
| 142 |
+
print(results_df["predicted_class"].value_counts())
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+

|
| 146 |
+
|
| 147 |
+
## π§© Core Concepts
|
| 148 |
+
| Concept | Description |
|
| 149 |
+
|----------------------|--------------------------------------------------------------------------------------------------|
|
| 150 |
+
| **Reference Vectors**| Semantic anchors that define each class or category, curated as representative example texts. |
|
| 151 |
+
| **Similarity Threshold** | Determines how similar a text must be to a reference vector to be classified as that class; higher values are more restrictive. |
|
| 152 |
+
| **Embedding** | Numerical vector representations of text that capture semantic meaning; similar texts have similar embeddings. |
|
| 153 |
+
| **Semantic Matching**| Uses cosine similarity between embeddings to assess how close texts are to reference vectors. |
|
| 154 |
+
| **Bootstrap Evaluation** | Statistical method for estimating uncertainty in evaluation metrics by resampling with replacement. |
|
| 155 |
+
|
| 156 |
+
## π§° Components
|
| 157 |
+
|
| 158 |
+
### Data Loading and Preparation
|
| 159 |
+
|
| 160 |
+
The `DataLoader` class handles loading and validation of data:
|
| 161 |
+
|
| 162 |
+
```python
|
| 163 |
+
from qualivec.data import DataLoader
|
| 164 |
+
|
| 165 |
+
# Initialize data loader
|
| 166 |
+
data_loader = DataLoader(verbose=True)
|
| 167 |
+
|
| 168 |
+
# Load corpus
|
| 169 |
+
corpus_df = data_loader.load_corpus(
|
| 170 |
+
filepath="path/to/corpus.csv",
|
| 171 |
+
sentence_column="sentence"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Load reference vectors
|
| 175 |
+
reference_df = data_loader.load_reference_vectors(
|
| 176 |
+
filepath="path/to/reference_vectors.csv",
|
| 177 |
+
class_column="class",
|
| 178 |
+
node_column="matching_node"
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Load labeled data for evaluation
|
| 182 |
+
labeled_df = data_loader.load_labeled_data(
|
| 183 |
+
filepath="path/to/labeled_data.csv",
|
| 184 |
+
label_column="label"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Save results
|
| 188 |
+
data_loader.save_dataframe(df=results_df, filepath="path/to/output.csv")
|
| 189 |
+
```
|
| 190 |
+
|
| 191 |
+
### Embedding Generation
|
| 192 |
+
|
| 193 |
+
The `EmbeddingModel` class generates embeddings from text:
|
| 194 |
+
|
| 195 |
+
```python
|
| 196 |
+
from qualivec.embedding import EmbeddingModel
|
| 197 |
+
|
| 198 |
+
# Initialize embedding model
|
| 199 |
+
model = EmbeddingModel(
|
| 200 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 201 |
+
device=None, # Auto-selects CPU or GPU
|
| 202 |
+
cache_dir=None,
|
| 203 |
+
verbose=True
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Generate embeddings for a list of texts
|
| 207 |
+
texts = ["This is a sample text", "Another example text"]
|
| 208 |
+
embeddings = model.embed_texts(texts, batch_size=32)
|
| 209 |
+
|
| 210 |
+
# Generate embeddings from a DataFrame column
|
| 211 |
+
embeddings = model.embed_dataframe(df, text_column="sentence", batch_size=32)
|
| 212 |
+
|
| 213 |
+
# Generate embeddings for reference vectors
|
| 214 |
+
reference_data = model.embed_reference_vectors(
|
| 215 |
+
df=reference_df,
|
| 216 |
+
class_column="class",
|
| 217 |
+
node_column="matching_node",
|
| 218 |
+
batch_size=32
|
| 219 |
+
)
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
### Semantic Matching
|
| 223 |
+
|
| 224 |
+
The `SemanticMatcher` class performs semantic matching using cosine similarity:
|
| 225 |
+
|
| 226 |
+
```python
|
| 227 |
+
from qualivec.matching import SemanticMatcher
|
| 228 |
+
|
| 229 |
+
# Initialize matcher with similarity threshold
|
| 230 |
+
matcher = SemanticMatcher(threshold=0.7, verbose=True)
|
| 231 |
+
|
| 232 |
+
# Match query embeddings against reference vectors
|
| 233 |
+
match_results = matcher.match(
|
| 234 |
+
query_embeddings=query_embeddings,
|
| 235 |
+
reference_data=reference_data,
|
| 236 |
+
return_similarities=False
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Classify an entire corpus
|
| 240 |
+
classified_df = matcher.classify_corpus(
|
| 241 |
+
corpus_embeddings=corpus_embeddings,
|
| 242 |
+
reference_data=reference_data,
|
| 243 |
+
corpus_df=corpus_df
|
| 244 |
+
)
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
### Classification
|
| 248 |
+
|
| 249 |
+
The `Classifier` class combines embedding and matching for end-to-end classification:
|
| 250 |
+
|
| 251 |
+
```python
|
| 252 |
+
from qualivec.classification import Classifier
|
| 253 |
+
|
| 254 |
+
# Initialize classifier
|
| 255 |
+
classifier = Classifier(verbose=True)
|
| 256 |
+
|
| 257 |
+
# Load models
|
| 258 |
+
classifier.load_models(
|
| 259 |
+
model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 260 |
+
threshold=0.7
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Prepare reference vectors
|
| 264 |
+
reference_data = classifier.prepare_reference_vectors(
|
| 265 |
+
reference_path="path/to/reference_vectors.csv",
|
| 266 |
+
class_column="class",
|
| 267 |
+
node_column="matching_node"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# Classify corpus
|
| 271 |
+
results_df = classifier.classify(
|
| 272 |
+
corpus_path="path/to/corpus.csv",
|
| 273 |
+
reference_data=reference_data,
|
| 274 |
+
sentence_column="sentence",
|
| 275 |
+
output_path="path/to/results.csv"
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
# Evaluate classification performance
|
| 279 |
+
eval_results = classifier.evaluate_classification(
|
| 280 |
+
labeled_path="path/to/labeled_data.csv",
|
| 281 |
+
reference_data=reference_data,
|
| 282 |
+
sentence_column="sentence",
|
| 283 |
+
label_column="label",
|
| 284 |
+
optimize_threshold=False
|
| 285 |
+
)
|
| 286 |
+
```
|
| 287 |
+
|
| 288 |
+
### Evaluation
|
| 289 |
+
|
| 290 |
+
The `Evaluator` class evaluates classification performance:
|
| 291 |
+
|
| 292 |
+
```python
|
| 293 |
+
from qualivec.evaluation import Evaluator
|
| 294 |
+
|
| 295 |
+
# Initialize evaluator
|
| 296 |
+
evaluator = Evaluator(verbose=True)
|
| 297 |
+
|
| 298 |
+
# Simple evaluation
|
| 299 |
+
results = evaluator.evaluate(
|
| 300 |
+
true_labels=true_labels,
|
| 301 |
+
predicted_labels=predicted_labels,
|
| 302 |
+
class_names=class_names
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
# Bootstrap evaluation with confidence intervals
|
| 306 |
+
bootstrap_results = evaluator.bootstrap_evaluate(
|
| 307 |
+
true_labels=true_labels,
|
| 308 |
+
predicted_labels=predicted_labels,
|
| 309 |
+
n_iterations=1000,
|
| 310 |
+
confidence_levels=[0.9, 0.95, 0.99],
|
| 311 |
+
random_seed=42
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# Plot confusion matrix
|
| 315 |
+
evaluator.plot_confusion_matrix(
|
| 316 |
+
confusion_matrix=results['confusion_matrix'],
|
| 317 |
+
class_names=results['confusion_matrix_labels']
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# Plot bootstrap distributions
|
| 321 |
+
evaluator.plot_bootstrap_distributions(bootstrap_results)
|
| 322 |
+
```
|
| 323 |
+
|
| 324 |
+

|
| 325 |
+
|
| 326 |
+
### Threshold Optimization
|
| 327 |
+
|
| 328 |
+
The `ThresholdOptimizer` class finds the optimal similarity threshold:
|
| 329 |
+
|
| 330 |
+
```python
|
| 331 |
+
from qualivec.optimization import ThresholdOptimizer
|
| 332 |
+
|
| 333 |
+
# Initialize optimizer
|
| 334 |
+
optimizer = ThresholdOptimizer(verbose=True)
|
| 335 |
+
|
| 336 |
+
# Optimize threshold
|
| 337 |
+
optimization_results = optimizer.optimize(
|
| 338 |
+
query_embeddings=query_embeddings,
|
| 339 |
+
reference_data=reference_data,
|
| 340 |
+
true_labels=true_labels,
|
| 341 |
+
start=0.5,
|
| 342 |
+
end=0.9,
|
| 343 |
+
step=0.01,
|
| 344 |
+
metric="f1_macro",
|
| 345 |
+
bootstrap=True,
|
| 346 |
+
n_bootstrap=100,
|
| 347 |
+
confidence_level=0.95
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# Plot optimization results
|
| 351 |
+
optimizer.plot_optimization_results(
|
| 352 |
+
results=optimization_results,
|
| 353 |
+
metrics=["accuracy", "precision_macro", "recall_macro", "f1_macro"]
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Plot class distribution at different thresholds
|
| 357 |
+
optimizer.plot_class_distribution(
|
| 358 |
+
results=optimization_results,
|
| 359 |
+
top_n=10
|
| 360 |
+
)
|
| 361 |
+
```
|
| 362 |
+
|
| 363 |
+
### Sampling
|
| 364 |
+
|
| 365 |
+
The `Sampler` class helps create samples for manual coding:
|
| 366 |
+
|
| 367 |
+
```python
|
| 368 |
+
from qualivec.sampling import Sampler
|
| 369 |
+
|
| 370 |
+
# Initialize sampler
|
| 371 |
+
sampler = Sampler(verbose=True)
|
| 372 |
+
|
| 373 |
+
# Random sampling
|
| 374 |
+
random_sample = sampler.sample(
|
| 375 |
+
df=corpus_df,
|
| 376 |
+
sampling_type="random",
|
| 377 |
+
sample_size=0.1, # 10% of corpus
|
| 378 |
+
seed=42,
|
| 379 |
+
label_column="Label"
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# Stratified sampling
|
| 383 |
+
stratified_sample = sampler.sample(
|
| 384 |
+
df=corpus_df,
|
| 385 |
+
sampling_type="stratified",
|
| 386 |
+
sample_size=0.1,
|
| 387 |
+
stratify_column="category",
|
| 388 |
+
seed=42,
|
| 389 |
+
label_column="Label"
|
| 390 |
+
)
|
| 391 |
+
```
|
| 392 |
+
|
| 393 |
+
## π Usage Examples
|
| 394 |
+
|
| 395 |
+
### Preparing Reference Vectors
|
| 396 |
+
|
| 397 |
+
Reference vectors are the foundation of classification in QualiVec. Here's how to prepare them:
|
| 398 |
+
|
| 399 |
+
```python
|
| 400 |
+
# Step 1: Sample data for manual coding
|
| 401 |
+
from qualivec.sampling import Sampler
|
| 402 |
+
|
| 403 |
+
sampler = Sampler(verbose=True)
|
| 404 |
+
sample_df = sampler.sample(
|
| 405 |
+
df=corpus_df,
|
| 406 |
+
sampling_type="stratified",
|
| 407 |
+
sample_size=0.05, # 5% of corpus
|
| 408 |
+
stratify_column="document_type"
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
# Step 2: Save sample for manual coding
|
| 412 |
+
sample_df.to_csv("sample_for_coding.csv", index=False)
|
| 413 |
+
|
| 414 |
+
# Step 3: After manual coding, load the coded data
|
| 415 |
+
from qualivec.data import DataLoader
|
| 416 |
+
|
| 417 |
+
data_loader = DataLoader(verbose=True)
|
| 418 |
+
coded_df = data_loader.load_labeled_data(
|
| 419 |
+
filepath="coded_sample.csv",
|
| 420 |
+
label_column="coded_class"
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# Step 4: Generate embeddings for reference vectors
|
| 424 |
+
from qualivec.embedding import EmbeddingModel
|
| 425 |
+
|
| 426 |
+
model = EmbeddingModel(verbose=True)
|
| 427 |
+
reference_data = model.embed_reference_vectors(
|
| 428 |
+
df=coded_df,
|
| 429 |
+
class_column="coded_class",
|
| 430 |
+
node_column="sentence"
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Step 5: Save reference data for future use
|
| 434 |
+
import pickle
|
| 435 |
+
with open("reference_data.pkl", "wb") as f:
|
| 436 |
+
pickle.dump(reference_data, f)
|
| 437 |
+
```
|
| 438 |
+
|
| 439 |
+
### Classifying New Data
|
| 440 |
+
|
| 441 |
+
Once reference vectors are prepared, you can classify new data:
|
| 442 |
+
|
| 443 |
+
```python
|
| 444 |
+
# Load reference data
|
| 445 |
+
import pickle
|
| 446 |
+
with open("reference_data.pkl", "rb") as f:
|
| 447 |
+
reference_data = pickle.load(f)
|
| 448 |
+
|
| 449 |
+
# Initialize classifier
|
| 450 |
+
from qualivec.classification import Classifier
|
| 451 |
+
|
| 452 |
+
classifier = Classifier(verbose=True)
|
| 453 |
+
classifier.load_models(threshold=0.7)
|
| 454 |
+
|
| 455 |
+
# Classify corpus
|
| 456 |
+
results_df = classifier.classify(
|
| 457 |
+
corpus_path="new_corpus.csv",
|
| 458 |
+
reference_data=reference_data,
|
| 459 |
+
sentence_column="sentence",
|
| 460 |
+
output_path="classified_corpus.csv"
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# Analyze results
|
| 464 |
+
import pandas as pd
|
| 465 |
+
import matplotlib.pyplot as plt
|
| 466 |
+
|
| 467 |
+
# Distribution of classes
|
| 468 |
+
plt.figure(figsize=(10, 6))
|
| 469 |
+
results_df["predicted_class"].value_counts().plot(kind="bar")
|
| 470 |
+
plt.title("Distribution of Predicted Classes")
|
| 471 |
+
plt.tight_layout()
|
| 472 |
+
plt.show()
|
| 473 |
+
|
| 474 |
+
# Average similarity by class
|
| 475 |
+
results_df.groupby("predicted_class")["similarity_score"].mean().sort_values().plot(kind="barh")
|
| 476 |
+
plt.title("Average Similarity Score by Class")
|
| 477 |
+
plt.tight_layout()
|
| 478 |
+
plt.show()
|
| 479 |
+
```
|
| 480 |
+
|
| 481 |
+
### Evaluating Classification Performance
|
| 482 |
+
|
| 483 |
+
To assess how well your classification is performing:
|
| 484 |
+
|
| 485 |
+
```python
|
| 486 |
+
# Load labeled data
|
| 487 |
+
from qualivec.data import DataLoader
|
| 488 |
+
|
| 489 |
+
data_loader = DataLoader(verbose=True)
|
| 490 |
+
labeled_df = data_loader.load_labeled_data(
|
| 491 |
+
filepath="labeled_test_set.csv",
|
| 492 |
+
label_column="true_label"
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# Generate embeddings
|
| 496 |
+
from qualivec.embedding import EmbeddingModel
|
| 497 |
+
|
| 498 |
+
model = EmbeddingModel(verbose=True)
|
| 499 |
+
labeled_embeddings = model.embed_dataframe(
|
| 500 |
+
df=labeled_df,
|
| 501 |
+
text_column="sentence"
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
# Initialize evaluator
|
| 505 |
+
from qualivec.evaluation import Evaluator
|
| 506 |
+
from qualivec.matching import SemanticMatcher
|
| 507 |
+
|
| 508 |
+
matcher = SemanticMatcher(threshold=0.7, verbose=True)
|
| 509 |
+
match_results = matcher.match(labeled_embeddings, reference_data)
|
| 510 |
+
predicted_labels = match_results["predicted_class"].tolist()
|
| 511 |
+
true_labels = labeled_df["true_label"].tolist()
|
| 512 |
+
|
| 513 |
+
evaluator = Evaluator(verbose=True)
|
| 514 |
+
|
| 515 |
+
# Simple evaluation
|
| 516 |
+
eval_results = evaluator.evaluate(
|
| 517 |
+
true_labels=true_labels,
|
| 518 |
+
predicted_labels=predicted_labels
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# Bootstrap evaluation
|
| 522 |
+
bootstrap_results = evaluator.bootstrap_evaluate(
|
| 523 |
+
true_labels=true_labels,
|
| 524 |
+
predicted_labels=predicted_labels,
|
| 525 |
+
n_iterations=1000
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
# Plot confusion matrix
|
| 529 |
+
evaluator.plot_confusion_matrix(
|
| 530 |
+
confusion_matrix=eval_results['confusion_matrix'],
|
| 531 |
+
class_names=eval_results['confusion_matrix_labels']
|
| 532 |
+
)
|
| 533 |
+
|
| 534 |
+
# Plot bootstrap distributions
|
| 535 |
+
evaluator.plot_bootstrap_distributions(bootstrap_results)
|
| 536 |
+
```
|
| 537 |
+
|
| 538 |
+
### Optimizing Similarity Thresholds
|
| 539 |
+
|
| 540 |
+
To find the optimal similarity threshold for your classification:
|
| 541 |
+
|
| 542 |
+
```python
|
| 543 |
+
# Initialize optimizer
|
| 544 |
+
from qualivec.optimization import ThresholdOptimizer
|
| 545 |
+
|
| 546 |
+
optimizer = ThresholdOptimizer(verbose=True)
|
| 547 |
+
|
| 548 |
+
# Optimize threshold
|
| 549 |
+
optimization_results = optimizer.optimize(
|
| 550 |
+
query_embeddings=labeled_embeddings,
|
| 551 |
+
reference_data=reference_data,
|
| 552 |
+
true_labels=true_labels,
|
| 553 |
+
start=0.5,
|
| 554 |
+
end=0.9,
|
| 555 |
+
step=0.01,
|
| 556 |
+
metric="f1_macro"
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
# Plot optimization results
|
| 560 |
+
optimizer.plot_optimization_results(
|
| 561 |
+
results=optimization_results,
|
| 562 |
+
metrics=["accuracy", "f1_macro"]
|
| 563 |
+
)
|
| 564 |
+
|
| 565 |
+
# Plot class distribution
|
| 566 |
+
optimizer.plot_class_distribution(
|
| 567 |
+
results=optimization_results,
|
| 568 |
+
top_n=5
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
# Use the optimal threshold
|
| 572 |
+
optimal_threshold = optimization_results["optimal_threshold"]
|
| 573 |
+
print(f"Optimal threshold: {optimal_threshold}")
|
| 574 |
+
|
| 575 |
+
# Create a new matcher with the optimal threshold
|
| 576 |
+
matcher = SemanticMatcher(threshold=optimal_threshold, verbose=True)
|
| 577 |
+
```
|
| 578 |
+
|
| 579 |
+
### Sampling for Manual Coding
|
| 580 |
+
|
| 581 |
+
To create samples for manual coding or validation:
|
| 582 |
+
|
| 583 |
+
```python
|
| 584 |
+
from qualivec.sampling import Sampler
|
| 585 |
+
|
| 586 |
+
sampler = Sampler(verbose=True)
|
| 587 |
+
|
| 588 |
+
# Random sampling
|
| 589 |
+
random_sample = sampler.sample(
|
| 590 |
+
df=corpus_df,
|
| 591 |
+
sampling_type="random",
|
| 592 |
+
sample_size=100, # 100 documents
|
| 593 |
+
seed=42
|
| 594 |
+
)
|
| 595 |
+
|
| 596 |
+
# Stratified sampling by predicted class
|
| 597 |
+
stratified_sample = sampler.sample(
|
| 598 |
+
df=results_df,
|
| 599 |
+
sampling_type="stratified",
|
| 600 |
+
sample_size=0.1, # 10% of corpus
|
| 601 |
+
stratify_column="predicted_class",
|
| 602 |
+
seed=42
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
# Save samples for manual coding
|
| 606 |
+
random_sample.to_csv("random_sample_for_coding.csv", index=False)
|
| 607 |
+
stratified_sample.to_csv("stratified_sample_for_coding.csv", index=False)
|
| 608 |
+
```
|
| 609 |
+
|
| 610 |
+
## π API Reference
|
| 611 |
+
|
| 612 |
+
### DataLoader
|
| 613 |
+
|
| 614 |
+
```python
|
| 615 |
+
class DataLoader:
|
| 616 |
+
def __init__(self, verbose=True)
|
| 617 |
+
def load_corpus(self, filepath, sentence_column="sentence")
|
| 618 |
+
def load_reference_vectors(self, filepath, class_column="class", node_column="matching_node")
|
| 619 |
+
def load_labeled_data(self, filepath, label_column="label")
|
| 620 |
+
def save_dataframe(self, df, filepath)
|
| 621 |
+
def validate_labels(self, labeled_df, reference_df, label_column="label", class_column="class")
|
| 622 |
+
```
|
| 623 |
+
|
| 624 |
+
### Sampler
|
| 625 |
+
|
| 626 |
+
```python
|
| 627 |
+
class Sampler:
|
| 628 |
+
def __init__(self, verbose=True)
|
| 629 |
+
def sample(self, df, sampling_type="random", sample_size=0.1, stratify_column=None,
|
| 630 |
+
seed=None, label_column="Label")
|
| 631 |
+
```
|
| 632 |
+
|
| 633 |
+
### EmbeddingModel
|
| 634 |
+
|
| 635 |
+
```python
|
| 636 |
+
class EmbeddingModel:
|
| 637 |
+
def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2",
|
| 638 |
+
device=None, cache_dir=None, verbose=True)
|
| 639 |
+
def embed_texts(self, texts, batch_size=32)
|
| 640 |
+
def embed_dataframe(self, df, text_column, batch_size=32)
|
| 641 |
+
def embed_reference_vectors(self, df, class_column="class",
|
| 642 |
+
node_column="matching_node", batch_size=32)
|
| 643 |
+
```
|
| 644 |
+
|
| 645 |
+
### SemanticMatcher
|
| 646 |
+
|
| 647 |
+
```python
|
| 648 |
+
class SemanticMatcher:
|
| 649 |
+
def __init__(self, threshold=0.7, verbose=True)
|
| 650 |
+
def match(self, query_embeddings, reference_data, return_similarities=False)
|
| 651 |
+
def classify_corpus(self, corpus_embeddings, reference_data, corpus_df)
|
| 652 |
+
```
|
| 653 |
+
|
| 654 |
+
### Evaluator
|
| 655 |
+
|
| 656 |
+
```python
|
| 657 |
+
class Evaluator:
|
| 658 |
+
def __init__(self, verbose=True)
|
| 659 |
+
def evaluate(self, true_labels, predicted_labels, class_names=None)
|
| 660 |
+
def bootstrap_evaluate(self, true_labels, predicted_labels, n_iterations=1000,
|
| 661 |
+
confidence_levels=[0.9, 0.95, 0.99], random_seed=None)
|
| 662 |
+
def plot_confusion_matrix(self, confusion_matrix, class_names,
|
| 663 |
+
figsize=(10, 8), title="Confusion Matrix")
|
| 664 |
+
def plot_bootstrap_distributions(self, bootstrap_results, figsize=(12, 8))
|
| 665 |
+
```
|
| 666 |
+
|
| 667 |
+
### ThresholdOptimizer
|
| 668 |
+
|
| 669 |
+
```python
|
| 670 |
+
class ThresholdOptimizer:
|
| 671 |
+
def __init__(self, verbose=True)
|
| 672 |
+
def optimize(self, query_embeddings, reference_data, true_labels,
|
| 673 |
+
start=0.0, end=1.0, step=0.01, metric="f1_macro",
|
| 674 |
+
bootstrap=True, n_bootstrap=100, confidence_level=0.95, random_seed=None)
|
| 675 |
+
def plot_optimization_results(self, results, metrics=None, figsize=(12, 6))
|
| 676 |
+
def plot_class_distribution(self, results, top_n=10, figsize=(12, 8))
|
| 677 |
+
```
|
| 678 |
+
|
| 679 |
+
### Classifier
|
| 680 |
+
|
| 681 |
+
```python
|
| 682 |
+
class Classifier:
|
| 683 |
+
def __init__(self, embedding_model=None, matcher=None, verbose=True)
|
| 684 |
+
def load_models(self, model_name="sentence-transformers/all-MiniLM-L6-v2", threshold=0.7)
|
| 685 |
+
def prepare_reference_vectors(self, reference_path, class_column="class",
|
| 686 |
+
node_column="matching_node")
|
| 687 |
+
def classify(self, corpus_path, reference_data, sentence_column="sentence",
|
| 688 |
+
output_path=None)
|
| 689 |
+
def evaluate_classification(self, labeled_path, reference_data,
|
| 690 |
+
sentence_column="sentence", label_column="label",
|
| 691 |
+
optimize_threshold=False, start=0.5, end=0.9, step=0.01)
|
| 692 |
+
```
|
| 693 |
+
|
| 694 |
+
## π‘ Best Practices
|
| 695 |
+
|
| 696 |
+
1. **Reference Vector Quality**: The quality of your reference vectors greatly impacts classification performance. Ensure they are representative and distinct.
|
| 697 |
+
|
| 698 |
+
2. **Model Selection**: Larger models generally provide better semantic understanding but are slower. For simple tasks, smaller models like MiniLM may be sufficient.
|
| 699 |
+
|
| 700 |
+
3. **Threshold Tuning**: Always optimize the similarity threshold for your specific dataset and task.
|
| 701 |
+
|
| 702 |
+
4. **Evaluation**: Use bootstrap evaluation to get confidence intervals around your metrics, especially for smaller datasets.
|
| 703 |
+
|
| 704 |
+
5. **Class Imbalance**: Be aware of class imbalance in your data. Consider using stratified sampling for creating evaluation sets.
|
| 705 |
+
|
| 706 |
+
6. **Preprocessing**: Clean and preprocess your text data before embedding for best results.
|
| 707 |
+
|
| 708 |
+
7. **Out-of-Domain Detection**: Use the "Other" class (when similarity is below threshold) to identify texts that might need new reference vectors.
|
| 709 |
+
|
| 710 |
+
## π License
|
| 711 |
+
|
| 712 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
app/app.py
CHANGED
|
@@ -1,916 +1,916 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
import pandas as pd
|
| 3 |
-
import numpy as np
|
| 4 |
-
import matplotlib.pyplot as plt
|
| 5 |
-
import seaborn as sns
|
| 6 |
-
import tempfile
|
| 7 |
-
import os
|
| 8 |
-
import sys
|
| 9 |
-
from io import StringIO
|
| 10 |
-
import plotly.express as px
|
| 11 |
-
import plotly.graph_objects as go
|
| 12 |
-
from plotly.subplots import make_subplots
|
| 13 |
-
|
| 14 |
-
# Add the parent directory to sys.path to import the module
|
| 15 |
-
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 16 |
-
|
| 17 |
-
from .data import DataLoader
|
| 18 |
-
from .embedding import EmbeddingModel
|
| 19 |
-
from .matching import SemanticMatcher
|
| 20 |
-
from .classification import Classifier
|
| 21 |
-
from .evaluation import Evaluator
|
| 22 |
-
from .optimization import ThresholdOptimizer
|
| 23 |
-
|
| 24 |
-
# Set page config
|
| 25 |
-
st.set_page_config(
|
| 26 |
-
page_title="QualiVec Demo",
|
| 27 |
-
page_icon="π",
|
| 28 |
-
layout="wide",
|
| 29 |
-
initial_sidebar_state="expanded"
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
# Custom CSS for better styling
|
| 33 |
-
st.markdown("""
|
| 34 |
-
<style>
|
| 35 |
-
.main-header {
|
| 36 |
-
font-size: 2.5rem;
|
| 37 |
-
font-weight: bold;
|
| 38 |
-
color: #2E4057;
|
| 39 |
-
text-align: center;
|
| 40 |
-
margin-bottom: 2rem;
|
| 41 |
-
}
|
| 42 |
-
.section-header {
|
| 43 |
-
font-size: 1.5rem;
|
| 44 |
-
font-weight: bold;
|
| 45 |
-
color: #048A81;
|
| 46 |
-
margin-top: 2rem;
|
| 47 |
-
margin-bottom: 1rem;
|
| 48 |
-
}
|
| 49 |
-
.metric-card {
|
| 50 |
-
background-color: #f0f2f6;
|
| 51 |
-
padding: 1rem;
|
| 52 |
-
border-radius: 0.5rem;
|
| 53 |
-
margin: 0.5rem 0;
|
| 54 |
-
}
|
| 55 |
-
.success-message {
|
| 56 |
-
background-color: #d4edda;
|
| 57 |
-
color: #155724;
|
| 58 |
-
padding: 1rem;
|
| 59 |
-
border-radius: 0.5rem;
|
| 60 |
-
margin: 1rem 0;
|
| 61 |
-
}
|
| 62 |
-
.warning-message {
|
| 63 |
-
background-color: #fff3cd;
|
| 64 |
-
color: #856404;
|
| 65 |
-
padding: 1rem;
|
| 66 |
-
border-radius: 0.5rem;
|
| 67 |
-
margin: 1rem 0;
|
| 68 |
-
}
|
| 69 |
-
</style>
|
| 70 |
-
""", unsafe_allow_html=True)
|
| 71 |
-
|
| 72 |
-
def main():
|
| 73 |
-
st.markdown('<div class="main-header">π QualiVec Demo</div>', unsafe_allow_html=True)
|
| 74 |
-
st.markdown("""
|
| 75 |
-
<div style="text-align: center; margin-bottom: 2rem;">
|
| 76 |
-
<p style="font-size: 1.2rem; color: #666;">
|
| 77 |
-
Qualitative Content Analysis with LLM Embeddings
|
| 78 |
-
</p>
|
| 79 |
-
</div>
|
| 80 |
-
""", unsafe_allow_html=True)
|
| 81 |
-
|
| 82 |
-
# Sidebar for navigation
|
| 83 |
-
st.sidebar.title("Navigation")
|
| 84 |
-
page = st.sidebar.selectbox(
|
| 85 |
-
"Choose a page",
|
| 86 |
-
["π Home", "π Data Upload", "π§ Configuration", "π― Classification", "π Results"]
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
# Initialize session state
|
| 90 |
-
if 'classifier' not in st.session_state:
|
| 91 |
-
st.session_state.classifier = None
|
| 92 |
-
if 'reference_data' not in st.session_state:
|
| 93 |
-
st.session_state.reference_data = None
|
| 94 |
-
if 'labeled_data' not in st.session_state:
|
| 95 |
-
st.session_state.labeled_data = None
|
| 96 |
-
if 'optimization_results' not in st.session_state:
|
| 97 |
-
st.session_state.optimization_results = None
|
| 98 |
-
if 'evaluation_results' not in st.session_state:
|
| 99 |
-
st.session_state.evaluation_results = None
|
| 100 |
-
|
| 101 |
-
# Route to different pages
|
| 102 |
-
if page == "π Home":
|
| 103 |
-
show_home_page()
|
| 104 |
-
elif page == "π Data Upload":
|
| 105 |
-
show_data_upload_page()
|
| 106 |
-
elif page == "π§ Configuration":
|
| 107 |
-
show_configuration_page()
|
| 108 |
-
elif page == "π― Classification":
|
| 109 |
-
show_classification_page()
|
| 110 |
-
elif page == "π Results":
|
| 111 |
-
show_results_page()
|
| 112 |
-
|
| 113 |
-
def show_home_page():
|
| 114 |
-
st.markdown('<div class="section-header">Welcome to QualiVec</div>', unsafe_allow_html=True)
|
| 115 |
-
|
| 116 |
-
col1, col2, col3 = st.columns([1, 2, 1])
|
| 117 |
-
|
| 118 |
-
with col2:
|
| 119 |
-
st.markdown("""
|
| 120 |
-
### What is QualiVec?
|
| 121 |
-
|
| 122 |
-
QualiVec is a Python library that uses Large Language Model (LLM) embeddings for qualitative content analysis. It helps researchers and analysts classify text data by comparing it against reference examples.
|
| 123 |
-
|
| 124 |
-
### Key Features:
|
| 125 |
-
- **Semantic Matching**: Uses advanced embedding models to find semantic similarity
|
| 126 |
-
- **Threshold Optimization**: Automatically finds the best similarity threshold
|
| 127 |
-
- **Comprehensive Evaluation**: Provides detailed metrics and visualizations
|
| 128 |
-
- **Bootstrap Analysis**: Confidence intervals for robust evaluation
|
| 129 |
-
|
| 130 |
-
### How It Works:
|
| 131 |
-
1. **Upload Data**: Provide reference examples and data to classify
|
| 132 |
-
2. **Configure**: Set up embedding models and parameters
|
| 133 |
-
3. **Optimize**: Find the best threshold for classification
|
| 134 |
-
4. **Classify**: Apply the model to your data
|
| 135 |
-
5. **Evaluate**: Get detailed performance metrics
|
| 136 |
-
|
| 137 |
-
### Getting Started:
|
| 138 |
-
Use the sidebar to navigate through the demo. Start with **Data Upload** to begin your analysis.
|
| 139 |
-
""")
|
| 140 |
-
|
| 141 |
-
# Add sample data info
|
| 142 |
-
st.markdown('<div class="section-header">Sample Data Format</div>', unsafe_allow_html=True)
|
| 143 |
-
|
| 144 |
-
col1, col2 = st.columns(2)
|
| 145 |
-
|
| 146 |
-
with col1:
|
| 147 |
-
st.markdown("**Reference Data Format:**")
|
| 148 |
-
sample_ref = pd.DataFrame({
|
| 149 |
-
'tag': ['Positive', 'Negative', 'Neutral'],
|
| 150 |
-
'sentence': ['This is great!', 'This is terrible', 'This is okay']
|
| 151 |
-
})
|
| 152 |
-
st.dataframe(sample_ref, use_container_width=True)
|
| 153 |
-
|
| 154 |
-
with col2:
|
| 155 |
-
st.markdown("**Labeled Data Format:**")
|
| 156 |
-
sample_labeled = pd.DataFrame({
|
| 157 |
-
'sentence': ['I love this product', 'Not very good', 'Average quality'],
|
| 158 |
-
'Label': ['Positive', 'Negative', 'Neutral']
|
| 159 |
-
})
|
| 160 |
-
st.dataframe(sample_labeled, use_container_width=True)
|
| 161 |
-
|
| 162 |
-
def show_data_upload_page():
|
| 163 |
-
st.markdown('<div class="section-header">Data Upload</div>', unsafe_allow_html=True)
|
| 164 |
-
|
| 165 |
-
col1, col2 = st.columns(2)
|
| 166 |
-
|
| 167 |
-
with col1:
|
| 168 |
-
st.markdown("### Reference Data")
|
| 169 |
-
st.markdown("Upload a CSV file containing reference examples with columns: `tag` (class) and `sentence` (example text)")
|
| 170 |
-
|
| 171 |
-
reference_file = st.file_uploader(
|
| 172 |
-
"Choose reference data file",
|
| 173 |
-
type=['csv'],
|
| 174 |
-
key='reference_file'
|
| 175 |
-
)
|
| 176 |
-
|
| 177 |
-
if reference_file is not None:
|
| 178 |
-
try:
|
| 179 |
-
reference_df = pd.read_csv(reference_file)
|
| 180 |
-
st.success("Reference data loaded successfully!")
|
| 181 |
-
st.dataframe(reference_df.head(), use_container_width=True)
|
| 182 |
-
|
| 183 |
-
# Validate columns
|
| 184 |
-
required_cols = ['tag', 'sentence']
|
| 185 |
-
missing_cols = [col for col in required_cols if col not in reference_df.columns]
|
| 186 |
-
|
| 187 |
-
if missing_cols:
|
| 188 |
-
st.error(f"Missing required columns: {missing_cols}")
|
| 189 |
-
else:
|
| 190 |
-
# Prepare reference data
|
| 191 |
-
reference_df = reference_df.rename(columns={
|
| 192 |
-
'tag': 'class',
|
| 193 |
-
'sentence': 'matching_node'
|
| 194 |
-
})
|
| 195 |
-
st.session_state.reference_data = reference_df
|
| 196 |
-
|
| 197 |
-
# Show statistics
|
| 198 |
-
st.markdown("**Data Statistics:**")
|
| 199 |
-
st.write(f"- Total examples: {len(reference_df)}")
|
| 200 |
-
st.write(f"- Unique classes: {reference_df['class'].nunique()}")
|
| 201 |
-
st.write(f"- Class distribution:")
|
| 202 |
-
st.write(reference_df['class'].value_counts())
|
| 203 |
-
|
| 204 |
-
except Exception as e:
|
| 205 |
-
st.error(f"Error loading reference data: {str(e)}")
|
| 206 |
-
|
| 207 |
-
with col2:
|
| 208 |
-
st.markdown("### Labeled Data")
|
| 209 |
-
st.markdown("Upload a CSV file containing data to classify with columns: `sentence` (text) and `Label` (true class)")
|
| 210 |
-
|
| 211 |
-
labeled_file = st.file_uploader(
|
| 212 |
-
"Choose labeled data file",
|
| 213 |
-
type=['csv'],
|
| 214 |
-
key='labeled_file'
|
| 215 |
-
)
|
| 216 |
-
|
| 217 |
-
if labeled_file is not None:
|
| 218 |
-
try:
|
| 219 |
-
labeled_df = pd.read_csv(labeled_file)
|
| 220 |
-
st.success("Labeled data loaded successfully!")
|
| 221 |
-
st.dataframe(labeled_df.head(), use_container_width=True)
|
| 222 |
-
|
| 223 |
-
# Validate columns
|
| 224 |
-
required_cols = ['sentence', 'Label']
|
| 225 |
-
missing_cols = [col for col in required_cols if col not in labeled_df.columns]
|
| 226 |
-
|
| 227 |
-
if missing_cols:
|
| 228 |
-
st.error(f"Missing required columns: {missing_cols}")
|
| 229 |
-
else:
|
| 230 |
-
# Prepare labeled data
|
| 231 |
-
labeled_df = labeled_df.rename(columns={'Label': 'label'})
|
| 232 |
-
labeled_df['label'] = labeled_df['label'].replace('0', 'Other')
|
| 233 |
-
st.session_state.labeled_data = labeled_df
|
| 234 |
-
|
| 235 |
-
# Show statistics
|
| 236 |
-
st.markdown("**Data Statistics:**")
|
| 237 |
-
st.write(f"- Total samples: {len(labeled_df)}")
|
| 238 |
-
st.write(f"- Unique labels: {labeled_df['label'].nunique()}")
|
| 239 |
-
st.write(f"- Label distribution:")
|
| 240 |
-
st.write(labeled_df['label'].value_counts())
|
| 241 |
-
|
| 242 |
-
except Exception as e:
|
| 243 |
-
st.error(f"Error loading labeled data: {str(e)}")
|
| 244 |
-
|
| 245 |
-
# Show data compatibility check
|
| 246 |
-
if st.session_state.reference_data is not None and st.session_state.labeled_data is not None:
|
| 247 |
-
st.markdown('<div class="section-header">Data Compatibility Check</div>', unsafe_allow_html=True)
|
| 248 |
-
|
| 249 |
-
ref_classes = set(st.session_state.reference_data['class'].unique())
|
| 250 |
-
labeled_classes = set(st.session_state.labeled_data['label'].unique())
|
| 251 |
-
|
| 252 |
-
# Check for unknown classes
|
| 253 |
-
unknown_classes = labeled_classes - ref_classes
|
| 254 |
-
|
| 255 |
-
if unknown_classes:
|
| 256 |
-
st.warning(f"Warning: Labels in labeled data not found in reference data: {unknown_classes}")
|
| 257 |
-
else:
|
| 258 |
-
st.success("β
Data compatibility check passed!")
|
| 259 |
-
|
| 260 |
-
# Show class overlap
|
| 261 |
-
st.markdown("**Class Overlap Analysis:**")
|
| 262 |
-
col1, col2, col3 = st.columns(3)
|
| 263 |
-
|
| 264 |
-
with col1:
|
| 265 |
-
st.metric("Reference Classes", len(ref_classes))
|
| 266 |
-
with col2:
|
| 267 |
-
st.metric("Labeled Classes", len(labeled_classes))
|
| 268 |
-
with col3:
|
| 269 |
-
st.metric("Common Classes", len(ref_classes.intersection(labeled_classes)))
|
| 270 |
-
|
| 271 |
-
def show_configuration_page():
|
| 272 |
-
st.markdown('<div class="section-header">Model Configuration</div>', unsafe_allow_html=True)
|
| 273 |
-
|
| 274 |
-
# Check if data is loaded
|
| 275 |
-
if st.session_state.reference_data is None or st.session_state.labeled_data is None:
|
| 276 |
-
st.warning("Please upload both reference and labeled data first.")
|
| 277 |
-
return
|
| 278 |
-
|
| 279 |
-
col1, col2 = st.columns(2)
|
| 280 |
-
|
| 281 |
-
with col1:
|
| 282 |
-
st.markdown("### Embedding Model")
|
| 283 |
-
|
| 284 |
-
# Model type selection
|
| 285 |
-
model_type = st.selectbox(
|
| 286 |
-
"Choose model type",
|
| 287 |
-
["HuggingFace", "Gemini"],
|
| 288 |
-
help="Select the type of embedding model to use"
|
| 289 |
-
)
|
| 290 |
-
|
| 291 |
-
# Model selection based on type
|
| 292 |
-
if model_type == "HuggingFace":
|
| 293 |
-
model_options = [
|
| 294 |
-
"sentence-transformers/all-MiniLM-L6-v2",
|
| 295 |
-
"sentence-transformers/all-mpnet-base-v2",
|
| 296 |
-
"sentence-transformers/distilbert-base-nli-mean-tokens"
|
| 297 |
-
]
|
| 298 |
-
|
| 299 |
-
selected_model = st.selectbox(
|
| 300 |
-
"Choose HuggingFace model",
|
| 301 |
-
model_options,
|
| 302 |
-
help="Select the pre-trained HuggingFace model for generating embeddings"
|
| 303 |
-
)
|
| 304 |
-
else: # Gemini
|
| 305 |
-
gemini_models = [
|
| 306 |
-
"gemini-embedding-001",
|
| 307 |
-
"text-embedding-004"
|
| 308 |
-
]
|
| 309 |
-
|
| 310 |
-
selected_model = st.selectbox(
|
| 311 |
-
"Choose Gemini model",
|
| 312 |
-
gemini_models,
|
| 313 |
-
help="Select the Gemini embedding model for generating embeddings"
|
| 314 |
-
)
|
| 315 |
-
|
| 316 |
-
# Calculate total texts to process
|
| 317 |
-
total_texts = 0
|
| 318 |
-
if st.session_state.reference_data is not None:
|
| 319 |
-
total_texts += len(st.session_state.reference_data)
|
| 320 |
-
if st.session_state.labeled_data is not None:
|
| 321 |
-
total_texts += len(st.session_state.labeled_data)
|
| 322 |
-
|
| 323 |
-
st.warning(
|
| 324 |
-
f"β οΈ **Gemini API Rate Limits (Free Tier)**\\n\\n"
|
| 325 |
-
f"- 1,500 requests per day\\n"
|
| 326 |
-
f"- Each batch of 100 texts = 1 request\\n"
|
| 327 |
-
f"- Your current dataset: ~{total_texts} texts\\n"
|
| 328 |
-
f"- Estimated requests needed: ~{(total_texts // 100) + 1}\\n\\n"
|
| 329 |
-
f"If you exceed quota, consider:\\n"
|
| 330 |
-
f"1. Using a smaller dataset\\n"
|
| 331 |
-
f"2. Switching to HuggingFace models (no limits)\\n"
|
| 332 |
-
f"3. Upgrading to a paid API plan"
|
| 333 |
-
)
|
| 334 |
-
|
| 335 |
-
st.info("π‘ Note: Using Gemini embeddings requires GOOGLE_API_KEY environment variable to be set.")
|
| 336 |
-
|
| 337 |
-
st.markdown("### Initial Threshold")
|
| 338 |
-
initial_threshold = st.slider(
|
| 339 |
-
"Initial similarity threshold",
|
| 340 |
-
min_value=0.0,
|
| 341 |
-
max_value=1.0,
|
| 342 |
-
value=0.7,
|
| 343 |
-
step=0.05,
|
| 344 |
-
help="Cosine similarity threshold for classification"
|
| 345 |
-
)
|
| 346 |
-
|
| 347 |
-
with col2:
|
| 348 |
-
st.markdown("### Optimization Parameters")
|
| 349 |
-
|
| 350 |
-
optimize_threshold = st.checkbox(
|
| 351 |
-
"Enable threshold optimization",
|
| 352 |
-
value=True,
|
| 353 |
-
help="Automatically find the best threshold"
|
| 354 |
-
)
|
| 355 |
-
|
| 356 |
-
if optimize_threshold:
|
| 357 |
-
col2_1, col2_2 = st.columns(2)
|
| 358 |
-
|
| 359 |
-
with col2_1:
|
| 360 |
-
start_threshold = st.slider(
|
| 361 |
-
"Start threshold",
|
| 362 |
-
min_value=0.0,
|
| 363 |
-
max_value=1.0,
|
| 364 |
-
value=0.5,
|
| 365 |
-
step=0.05
|
| 366 |
-
)
|
| 367 |
-
|
| 368 |
-
end_threshold = st.slider(
|
| 369 |
-
"End threshold",
|
| 370 |
-
min_value=0.0,
|
| 371 |
-
max_value=1.0,
|
| 372 |
-
value=0.9,
|
| 373 |
-
step=0.05
|
| 374 |
-
)
|
| 375 |
-
|
| 376 |
-
with col2_2:
|
| 377 |
-
step_size = st.slider(
|
| 378 |
-
"Step size",
|
| 379 |
-
min_value=0.005,
|
| 380 |
-
max_value=0.05,
|
| 381 |
-
value=0.01,
|
| 382 |
-
step=0.005
|
| 383 |
-
)
|
| 384 |
-
|
| 385 |
-
optimization_metric = st.selectbox(
|
| 386 |
-
"Optimization metric",
|
| 387 |
-
["f1_macro", "accuracy", "precision_macro", "recall_macro"]
|
| 388 |
-
)
|
| 389 |
-
|
| 390 |
-
# Load models button
|
| 391 |
-
if st.button("Initialize Models", type="primary"):
|
| 392 |
-
with st.spinner("Loading models... This may take a few minutes."):
|
| 393 |
-
try:
|
| 394 |
-
# Initialize classifier
|
| 395 |
-
classifier = Classifier(verbose=False)
|
| 396 |
-
|
| 397 |
-
# Determine model type parameter
|
| 398 |
-
model_type_param = "gemini" if model_type == "Gemini" else "huggingface"
|
| 399 |
-
|
| 400 |
-
classifier.load_models(
|
| 401 |
-
model_name=selected_model,
|
| 402 |
-
model_type=model_type_param,
|
| 403 |
-
threshold=initial_threshold
|
| 404 |
-
)
|
| 405 |
-
|
| 406 |
-
# Prepare reference vectors
|
| 407 |
-
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp_ref:
|
| 408 |
-
tmp_ref_path = tmp_ref.name
|
| 409 |
-
st.session_state.reference_data.to_csv(tmp_ref_path, index=False)
|
| 410 |
-
|
| 411 |
-
try:
|
| 412 |
-
reference_data = classifier.prepare_reference_vectors(
|
| 413 |
-
reference_path=tmp_ref_path,
|
| 414 |
-
class_column='class',
|
| 415 |
-
node_column='matching_node'
|
| 416 |
-
)
|
| 417 |
-
finally:
|
| 418 |
-
# Ensure file is deleted even if an error occurs
|
| 419 |
-
try:
|
| 420 |
-
os.unlink(tmp_ref_path)
|
| 421 |
-
except (OSError, PermissionError):
|
| 422 |
-
pass # File might already be deleted or locked
|
| 423 |
-
|
| 424 |
-
st.session_state.classifier = classifier
|
| 425 |
-
st.session_state.reference_vectors = reference_data
|
| 426 |
-
st.session_state.config = {
|
| 427 |
-
'model_type': model_type,
|
| 428 |
-
'model_name': selected_model,
|
| 429 |
-
'initial_threshold': initial_threshold,
|
| 430 |
-
'optimize_threshold': optimize_threshold,
|
| 431 |
-
'start_threshold': start_threshold if optimize_threshold else None,
|
| 432 |
-
'end_threshold': end_threshold if optimize_threshold else None,
|
| 433 |
-
'step_size': step_size if optimize_threshold else None,
|
| 434 |
-
'optimization_metric': optimization_metric if optimize_threshold else None
|
| 435 |
-
}
|
| 436 |
-
|
| 437 |
-
st.success("β
Models initialized successfully!")
|
| 438 |
-
|
| 439 |
-
except Exception as e:
|
| 440 |
-
st.error(f"Error initializing models: {str(e)}")
|
| 441 |
-
|
| 442 |
-
# Show current configuration
|
| 443 |
-
if st.session_state.classifier is not None:
|
| 444 |
-
st.markdown('<div class="section-header">Current Configuration</div>', unsafe_allow_html=True)
|
| 445 |
-
|
| 446 |
-
config = st.session_state.config
|
| 447 |
-
|
| 448 |
-
col1, col2, col3 = st.columns(3)
|
| 449 |
-
|
| 450 |
-
with col1:
|
| 451 |
-
st.markdown("**Model Settings:**")
|
| 452 |
-
st.write(f"- Model type: {config['model_type']}")
|
| 453 |
-
st.write(f"- Model: {config['model_name']}")
|
| 454 |
-
st.write(f"- Initial threshold: {config['initial_threshold']}")
|
| 455 |
-
|
| 456 |
-
with col2:
|
| 457 |
-
st.markdown("**Optimization:**")
|
| 458 |
-
st.write(f"- Enabled: {config['optimize_threshold']}")
|
| 459 |
-
if config['optimize_threshold']:
|
| 460 |
-
st.write(f"- Range: {config['start_threshold']:.2f} - {config['end_threshold']:.2f}")
|
| 461 |
-
st.write(f"- Step: {config['step_size']:.3f}")
|
| 462 |
-
|
| 463 |
-
with col3:
|
| 464 |
-
st.markdown("**Data:**")
|
| 465 |
-
st.write(f"- Reference examples: {len(st.session_state.reference_data)}")
|
| 466 |
-
st.write(f"- Labeled samples: {len(st.session_state.labeled_data)}")
|
| 467 |
-
|
| 468 |
-
def show_classification_page():
|
| 469 |
-
st.markdown('<div class="section-header">Classification & Optimization</div>', unsafe_allow_html=True)
|
| 470 |
-
|
| 471 |
-
# Check if models are loaded
|
| 472 |
-
if st.session_state.classifier is None:
|
| 473 |
-
st.warning("Please configure and initialize models first.")
|
| 474 |
-
return
|
| 475 |
-
|
| 476 |
-
# Run classification
|
| 477 |
-
if st.button("Run Classification", type="primary"):
|
| 478 |
-
with st.spinner("Running classification and optimization..."):
|
| 479 |
-
try:
|
| 480 |
-
# Save labeled data to temporary file
|
| 481 |
-
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp_labeled:
|
| 482 |
-
tmp_labeled_path = tmp_labeled.name
|
| 483 |
-
st.session_state.labeled_data.to_csv(tmp_labeled_path, index=False)
|
| 484 |
-
|
| 485 |
-
try:
|
| 486 |
-
# Run optimization if enabled
|
| 487 |
-
if st.session_state.config['optimize_threshold']:
|
| 488 |
-
optimization_results = st.session_state.classifier.evaluate_classification(
|
| 489 |
-
labeled_path=tmp_labeled_path,
|
| 490 |
-
reference_data=st.session_state.reference_vectors,
|
| 491 |
-
sentence_column='sentence',
|
| 492 |
-
label_column='label',
|
| 493 |
-
optimize_threshold=True,
|
| 494 |
-
start=st.session_state.config['start_threshold'],
|
| 495 |
-
end=st.session_state.config['end_threshold'],
|
| 496 |
-
step=st.session_state.config['step_size']
|
| 497 |
-
)
|
| 498 |
-
|
| 499 |
-
st.session_state.optimization_results = optimization_results
|
| 500 |
-
optimal_threshold = optimization_results["optimal_threshold"]
|
| 501 |
-
|
| 502 |
-
# Update classifier with optimal threshold
|
| 503 |
-
st.session_state.classifier.matcher = SemanticMatcher(
|
| 504 |
-
threshold=optimal_threshold,
|
| 505 |
-
verbose=False
|
| 506 |
-
)
|
| 507 |
-
|
| 508 |
-
st.success(f"β
Optimization completed! Optimal threshold: {optimal_threshold:.4f}")
|
| 509 |
-
|
| 510 |
-
else:
|
| 511 |
-
optimal_threshold = st.session_state.config['initial_threshold']
|
| 512 |
-
|
| 513 |
-
# Run evaluation
|
| 514 |
-
embedding_model = st.session_state.classifier.embedding_model
|
| 515 |
-
data_loader = DataLoader(verbose=False)
|
| 516 |
-
full_df = data_loader.load_labeled_data(tmp_labeled_path, label_column='label')
|
| 517 |
-
|
| 518 |
-
# Generate embeddings
|
| 519 |
-
full_embeddings = embedding_model.embed_dataframe(full_df, text_column='sentence')
|
| 520 |
-
|
| 521 |
-
# Classify
|
| 522 |
-
match_results = st.session_state.classifier.matcher.match(
|
| 523 |
-
full_embeddings,
|
| 524 |
-
st.session_state.reference_vectors
|
| 525 |
-
)
|
| 526 |
-
predicted_labels = match_results["predicted_class"].tolist()
|
| 527 |
-
true_labels = full_df['label'].tolist()
|
| 528 |
-
|
| 529 |
-
# Evaluate
|
| 530 |
-
evaluator = Evaluator(verbose=False)
|
| 531 |
-
eval_results = evaluator.evaluate(
|
| 532 |
-
true_labels=true_labels,
|
| 533 |
-
predicted_labels=predicted_labels,
|
| 534 |
-
class_names=list(set(true_labels) | set(predicted_labels))
|
| 535 |
-
)
|
| 536 |
-
|
| 537 |
-
# Bootstrap evaluation
|
| 538 |
-
bootstrap_results = evaluator.bootstrap_evaluate(
|
| 539 |
-
true_labels=true_labels,
|
| 540 |
-
predicted_labels=predicted_labels,
|
| 541 |
-
n_iterations=100
|
| 542 |
-
)
|
| 543 |
-
|
| 544 |
-
st.session_state.evaluation_results = eval_results
|
| 545 |
-
st.session_state.bootstrap_results = bootstrap_results
|
| 546 |
-
st.session_state.predictions = {
|
| 547 |
-
'true_labels': true_labels,
|
| 548 |
-
'predicted_labels': predicted_labels,
|
| 549 |
-
'match_results': match_results,
|
| 550 |
-
'full_df': full_df
|
| 551 |
-
}
|
| 552 |
-
|
| 553 |
-
finally:
|
| 554 |
-
# Ensure temporary file is deleted
|
| 555 |
-
try:
|
| 556 |
-
os.unlink(tmp_labeled_path)
|
| 557 |
-
except (OSError, PermissionError):
|
| 558 |
-
pass # File might already be deleted or locked
|
| 559 |
-
|
| 560 |
-
st.success("β
Classification completed successfully!")
|
| 561 |
-
|
| 562 |
-
except Exception as e:
|
| 563 |
-
st.error(f"Error during classification: {str(e)}")
|
| 564 |
-
|
| 565 |
-
# Show optimization results if available
|
| 566 |
-
if st.session_state.optimization_results is not None:
|
| 567 |
-
st.markdown('<div class="section-header">Optimization Results</div>', unsafe_allow_html=True)
|
| 568 |
-
|
| 569 |
-
results = st.session_state.optimization_results
|
| 570 |
-
|
| 571 |
-
col1, col2, col3, col4 = st.columns(4)
|
| 572 |
-
|
| 573 |
-
with col1:
|
| 574 |
-
st.metric(
|
| 575 |
-
"Optimal Threshold",
|
| 576 |
-
f"{results['optimal_threshold']:.4f}"
|
| 577 |
-
)
|
| 578 |
-
|
| 579 |
-
with col2:
|
| 580 |
-
st.metric(
|
| 581 |
-
"Accuracy",
|
| 582 |
-
f"{results['optimal_metrics']['accuracy']:.4f}"
|
| 583 |
-
)
|
| 584 |
-
|
| 585 |
-
with col3:
|
| 586 |
-
st.metric(
|
| 587 |
-
"F1 Score",
|
| 588 |
-
f"{results['optimal_metrics']['f1_macro']:.4f}"
|
| 589 |
-
)
|
| 590 |
-
|
| 591 |
-
with col4:
|
| 592 |
-
st.metric(
|
| 593 |
-
"Precision",
|
| 594 |
-
f"{results['optimal_metrics']['precision_macro']:.4f}"
|
| 595 |
-
)
|
| 596 |
-
|
| 597 |
-
# Plot optimization curve
|
| 598 |
-
st.markdown("### Optimization Curve")
|
| 599 |
-
|
| 600 |
-
opt_results = results["results_by_threshold"]
|
| 601 |
-
|
| 602 |
-
fig = make_subplots(
|
| 603 |
-
rows=2, cols=2,
|
| 604 |
-
subplot_titles=('Accuracy', 'F1 Score', 'Precision', 'Recall'),
|
| 605 |
-
vertical_spacing=0.1
|
| 606 |
-
)
|
| 607 |
-
|
| 608 |
-
thresholds = opt_results["thresholds"]
|
| 609 |
-
|
| 610 |
-
# Add traces
|
| 611 |
-
fig.add_trace(
|
| 612 |
-
go.Scatter(x=thresholds, y=opt_results["accuracy"], name="Accuracy"),
|
| 613 |
-
row=1, col=1
|
| 614 |
-
)
|
| 615 |
-
fig.add_trace(
|
| 616 |
-
go.Scatter(x=thresholds, y=opt_results["f1_macro"], name="F1 Score"),
|
| 617 |
-
row=1, col=2
|
| 618 |
-
)
|
| 619 |
-
fig.add_trace(
|
| 620 |
-
go.Scatter(x=thresholds, y=opt_results["precision_macro"], name="Precision"),
|
| 621 |
-
row=2, col=1
|
| 622 |
-
)
|
| 623 |
-
fig.add_trace(
|
| 624 |
-
go.Scatter(x=thresholds, y=opt_results["recall_macro"], name="Recall"),
|
| 625 |
-
row=2, col=2
|
| 626 |
-
)
|
| 627 |
-
|
| 628 |
-
# Add optimal threshold line to each subplot using shapes
|
| 629 |
-
optimal_thresh = results['optimal_threshold']
|
| 630 |
-
|
| 631 |
-
# Add vertical line as shapes to each subplot
|
| 632 |
-
shapes = []
|
| 633 |
-
for row in range(1, 3):
|
| 634 |
-
for col in range(1, 3):
|
| 635 |
-
# Calculate the subplot domain
|
| 636 |
-
xaxis = f'x{(row-1)*2 + col}' if (row-1)*2 + col > 1 else 'x'
|
| 637 |
-
shapes.append(
|
| 638 |
-
dict(
|
| 639 |
-
type="line",
|
| 640 |
-
x0=optimal_thresh, x1=optimal_thresh,
|
| 641 |
-
y0=0, y1=1,
|
| 642 |
-
yref=f"y{(row-1)*2 + col} domain" if (row-1)*2 + col > 1 else "y domain",
|
| 643 |
-
xref=xaxis,
|
| 644 |
-
line=dict(color="red", width=2, dash="dash")
|
| 645 |
-
)
|
| 646 |
-
)
|
| 647 |
-
|
| 648 |
-
fig.update_layout(shapes=shapes)
|
| 649 |
-
|
| 650 |
-
fig.update_layout(
|
| 651 |
-
title="Threshold Optimization Results",
|
| 652 |
-
showlegend=False,
|
| 653 |
-
height=600
|
| 654 |
-
)
|
| 655 |
-
|
| 656 |
-
st.plotly_chart(fig, use_container_width=True)
|
| 657 |
-
|
| 658 |
-
def show_results_page():
|
| 659 |
-
st.markdown('<div class="section-header">Results & Evaluation</div>', unsafe_allow_html=True)
|
| 660 |
-
|
| 661 |
-
# Check if evaluation results are available
|
| 662 |
-
if st.session_state.evaluation_results is None:
|
| 663 |
-
st.warning("Please run classification first to see results.")
|
| 664 |
-
return
|
| 665 |
-
|
| 666 |
-
eval_results = st.session_state.evaluation_results
|
| 667 |
-
|
| 668 |
-
# Performance metrics
|
| 669 |
-
st.markdown("### Performance Metrics")
|
| 670 |
-
|
| 671 |
-
col1, col2, col3, col4 = st.columns(4)
|
| 672 |
-
|
| 673 |
-
with col1:
|
| 674 |
-
st.metric(
|
| 675 |
-
"Overall Accuracy",
|
| 676 |
-
f"{eval_results['accuracy']:.4f}"
|
| 677 |
-
)
|
| 678 |
-
|
| 679 |
-
with col2:
|
| 680 |
-
st.metric(
|
| 681 |
-
"Macro F1 Score",
|
| 682 |
-
f"{eval_results['f1_macro']:.4f}"
|
| 683 |
-
)
|
| 684 |
-
|
| 685 |
-
with col3:
|
| 686 |
-
st.metric(
|
| 687 |
-
"Macro Precision",
|
| 688 |
-
f"{eval_results['precision_macro']:.4f}"
|
| 689 |
-
)
|
| 690 |
-
|
| 691 |
-
with col4:
|
| 692 |
-
st.metric(
|
| 693 |
-
"Macro Recall",
|
| 694 |
-
f"{eval_results['recall_macro']:.4f}"
|
| 695 |
-
)
|
| 696 |
-
|
| 697 |
-
# Class-wise metrics
|
| 698 |
-
st.markdown("### Class-wise Performance")
|
| 699 |
-
|
| 700 |
-
class_metrics_df = pd.DataFrame({
|
| 701 |
-
'Class': list(eval_results['class_metrics']['precision'].keys()),
|
| 702 |
-
'Precision': list(eval_results['class_metrics']['precision'].values()),
|
| 703 |
-
'Recall': list(eval_results['class_metrics']['recall'].values()),
|
| 704 |
-
'F1-Score': list(eval_results['class_metrics']['f1'].values()),
|
| 705 |
-
'Support': list(eval_results['class_metrics']['support'].values())
|
| 706 |
-
})
|
| 707 |
-
|
| 708 |
-
st.dataframe(class_metrics_df, use_container_width=True)
|
| 709 |
-
|
| 710 |
-
# Confusion Matrix
|
| 711 |
-
st.markdown("### Confusion Matrix")
|
| 712 |
-
|
| 713 |
-
cm = eval_results['confusion_matrix']
|
| 714 |
-
class_names = eval_results['confusion_matrix_labels']
|
| 715 |
-
|
| 716 |
-
fig = px.imshow(
|
| 717 |
-
cm,
|
| 718 |
-
labels=dict(x="Predicted", y="True", color="Count"),
|
| 719 |
-
x=class_names,
|
| 720 |
-
y=class_names,
|
| 721 |
-
color_continuous_scale='Blues',
|
| 722 |
-
text_auto=True,
|
| 723 |
-
title="Confusion Matrix"
|
| 724 |
-
)
|
| 725 |
-
|
| 726 |
-
fig.update_layout(
|
| 727 |
-
width=600,
|
| 728 |
-
height=600
|
| 729 |
-
)
|
| 730 |
-
|
| 731 |
-
st.plotly_chart(fig, use_container_width=True)
|
| 732 |
-
|
| 733 |
-
# Bootstrap Results
|
| 734 |
-
if st.session_state.bootstrap_results is not None:
|
| 735 |
-
st.markdown("### Bootstrap Confidence Intervals")
|
| 736 |
-
|
| 737 |
-
bootstrap_results = st.session_state.bootstrap_results
|
| 738 |
-
|
| 739 |
-
# Debug: show available keys
|
| 740 |
-
if 'confidence_intervals' in bootstrap_results:
|
| 741 |
-
metrics = ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
|
| 742 |
-
|
| 743 |
-
for metric in metrics:
|
| 744 |
-
if metric in bootstrap_results['confidence_intervals']:
|
| 745 |
-
ci_data = bootstrap_results['confidence_intervals'][metric]
|
| 746 |
-
st.markdown(f"**{metric.replace('_', ' ').title()}:**")
|
| 747 |
-
|
| 748 |
-
col1, col2, col3 = st.columns(3)
|
| 749 |
-
|
| 750 |
-
# Check available confidence levels
|
| 751 |
-
available_levels = list(ci_data.keys())
|
| 752 |
-
|
| 753 |
-
with col1:
|
| 754 |
-
if '0.95' in ci_data:
|
| 755 |
-
ci_95 = ci_data['0.95']
|
| 756 |
-
if isinstance(ci_95, dict):
|
| 757 |
-
st.write(f"95% CI: [{ci_95['lower']:.4f}, {ci_95['upper']:.4f}]")
|
| 758 |
-
elif isinstance(ci_95, (list, tuple)) and len(ci_95) >= 2:
|
| 759 |
-
st.write(f"95% CI: [{ci_95[0]:.4f}, {ci_95[1]:.4f}]")
|
| 760 |
-
else:
|
| 761 |
-
st.write("95% CI: Format not recognized")
|
| 762 |
-
elif 0.95 in ci_data:
|
| 763 |
-
ci_95 = ci_data[0.95]
|
| 764 |
-
if isinstance(ci_95, dict):
|
| 765 |
-
st.write(f"95% CI: [{ci_95['lower']:.4f}, {ci_95['upper']:.4f}]")
|
| 766 |
-
elif isinstance(ci_95, (list, tuple)) and len(ci_95) >= 2:
|
| 767 |
-
st.write(f"95% CI: [{ci_95[0]:.4f}, {ci_95[1]:.4f}]")
|
| 768 |
-
else:
|
| 769 |
-
st.write("95% CI: Format not recognized")
|
| 770 |
-
else:
|
| 771 |
-
st.write("95% CI: Not available")
|
| 772 |
-
|
| 773 |
-
with col2:
|
| 774 |
-
if '0.99' in ci_data:
|
| 775 |
-
ci_99 = ci_data['0.99']
|
| 776 |
-
if isinstance(ci_99, dict):
|
| 777 |
-
st.write(f"99% CI: [{ci_99['lower']:.4f}, {ci_99['upper']:.4f}]")
|
| 778 |
-
elif isinstance(ci_99, (list, tuple)) and len(ci_99) >= 2:
|
| 779 |
-
st.write(f"99% CI: [{ci_99[0]:.4f}, {ci_99[1]:.4f}]")
|
| 780 |
-
else:
|
| 781 |
-
st.write("99% CI: Format not recognized")
|
| 782 |
-
elif 0.99 in ci_data:
|
| 783 |
-
ci_99 = ci_data[0.99]
|
| 784 |
-
if isinstance(ci_99, dict):
|
| 785 |
-
st.write(f"99% CI: [{ci_99['lower']:.4f}, {ci_99['upper']:.4f}]")
|
| 786 |
-
elif isinstance(ci_99, (list, tuple)) and len(ci_99) >= 2:
|
| 787 |
-
st.write(f"99% CI: [{ci_99[0]:.4f}, {ci_99[1]:.4f}]")
|
| 788 |
-
else:
|
| 789 |
-
st.write("99% CI: Format not recognized")
|
| 790 |
-
else:
|
| 791 |
-
st.write("99% CI: Not available")
|
| 792 |
-
|
| 793 |
-
with col3:
|
| 794 |
-
if 'point_estimates' in bootstrap_results and metric in bootstrap_results['point_estimates']:
|
| 795 |
-
st.write(f"Point Estimate: {bootstrap_results['point_estimates'][metric]:.4f}")
|
| 796 |
-
else:
|
| 797 |
-
st.write("Point Estimate: Not available")
|
| 798 |
-
else:
|
| 799 |
-
st.info("Bootstrap confidence intervals not available.")
|
| 800 |
-
|
| 801 |
-
# Bootstrap Distribution Plot
|
| 802 |
-
st.markdown("### Bootstrap Distributions")
|
| 803 |
-
|
| 804 |
-
if 'bootstrap_distribution' in bootstrap_results:
|
| 805 |
-
fig = make_subplots(
|
| 806 |
-
rows=2, cols=2,
|
| 807 |
-
subplot_titles=('Accuracy', 'F1 Score', 'Precision', 'Recall')
|
| 808 |
-
)
|
| 809 |
-
|
| 810 |
-
distributions = bootstrap_results['bootstrap_distribution']
|
| 811 |
-
|
| 812 |
-
if 'accuracy' in distributions:
|
| 813 |
-
fig.add_trace(
|
| 814 |
-
go.Histogram(x=distributions['accuracy'], name="Accuracy", nbinsx=30),
|
| 815 |
-
row=1, col=1
|
| 816 |
-
)
|
| 817 |
-
if 'f1_macro' in distributions:
|
| 818 |
-
fig.add_trace(
|
| 819 |
-
go.Histogram(x=distributions['f1_macro'], name="F1 Score", nbinsx=30),
|
| 820 |
-
row=1, col=2
|
| 821 |
-
)
|
| 822 |
-
if 'precision_macro' in distributions:
|
| 823 |
-
fig.add_trace(
|
| 824 |
-
go.Histogram(x=distributions['precision_macro'], name="Precision", nbinsx=30),
|
| 825 |
-
row=2, col=1
|
| 826 |
-
)
|
| 827 |
-
if 'recall_macro' in distributions:
|
| 828 |
-
fig.add_trace(
|
| 829 |
-
go.Histogram(x=distributions['recall_macro'], name="Recall", nbinsx=30),
|
| 830 |
-
row=2, col=2
|
| 831 |
-
)
|
| 832 |
-
|
| 833 |
-
fig.update_layout(
|
| 834 |
-
title="Bootstrap Distributions",
|
| 835 |
-
showlegend=False,
|
| 836 |
-
height=600
|
| 837 |
-
)
|
| 838 |
-
|
| 839 |
-
st.plotly_chart(fig, use_container_width=True)
|
| 840 |
-
else:
|
| 841 |
-
st.info("Bootstrap distributions not available.")
|
| 842 |
-
|
| 843 |
-
# Sample predictions
|
| 844 |
-
if 'predictions' in st.session_state:
|
| 845 |
-
st.markdown("### Sample Predictions")
|
| 846 |
-
|
| 847 |
-
predictions = st.session_state.predictions
|
| 848 |
-
sample_df = predictions['full_df'].copy()
|
| 849 |
-
sample_df['predicted_class'] = predictions['predicted_labels']
|
| 850 |
-
sample_df['true_class'] = predictions['true_labels']
|
| 851 |
-
sample_df['similarity_score'] = predictions['match_results']['similarity_score']
|
| 852 |
-
sample_df['correct'] = sample_df['predicted_class'] == sample_df['true_class']
|
| 853 |
-
|
| 854 |
-
# Filter options
|
| 855 |
-
col1, col2 = st.columns(2)
|
| 856 |
-
|
| 857 |
-
with col1:
|
| 858 |
-
show_correct = st.checkbox("Show correct predictions", value=True)
|
| 859 |
-
|
| 860 |
-
with col2:
|
| 861 |
-
show_incorrect = st.checkbox("Show incorrect predictions", value=True)
|
| 862 |
-
|
| 863 |
-
# Filter data
|
| 864 |
-
if show_correct and show_incorrect:
|
| 865 |
-
filtered_df = sample_df
|
| 866 |
-
elif show_correct:
|
| 867 |
-
filtered_df = sample_df[sample_df['correct'] == True]
|
| 868 |
-
elif show_incorrect:
|
| 869 |
-
filtered_df = sample_df[sample_df['correct'] == False]
|
| 870 |
-
else:
|
| 871 |
-
filtered_df = pd.DataFrame()
|
| 872 |
-
|
| 873 |
-
if not filtered_df.empty:
|
| 874 |
-
# Sample random rows
|
| 875 |
-
n_samples = min(20, len(filtered_df))
|
| 876 |
-
sample_rows = filtered_df.sample(n=n_samples) if len(filtered_df) > n_samples else filtered_df
|
| 877 |
-
|
| 878 |
-
display_df = sample_rows[['sentence', 'true_class', 'predicted_class', 'similarity_score', 'correct']].reset_index(drop=True)
|
| 879 |
-
|
| 880 |
-
st.dataframe(display_df, use_container_width=True)
|
| 881 |
-
else:
|
| 882 |
-
st.info("No predictions to show with current filters.")
|
| 883 |
-
|
| 884 |
-
# Download results
|
| 885 |
-
st.markdown("### Download Results")
|
| 886 |
-
|
| 887 |
-
col1, col2 = st.columns(2)
|
| 888 |
-
|
| 889 |
-
with col1:
|
| 890 |
-
# Download class-wise metrics
|
| 891 |
-
csv_metrics = class_metrics_df.to_csv(index=False)
|
| 892 |
-
st.download_button(
|
| 893 |
-
label="Download Class Metrics",
|
| 894 |
-
data=csv_metrics,
|
| 895 |
-
file_name="class_metrics.csv",
|
| 896 |
-
mime="text/csv"
|
| 897 |
-
)
|
| 898 |
-
|
| 899 |
-
with col2:
|
| 900 |
-
# Download predictions
|
| 901 |
-
if 'predictions' in st.session_state:
|
| 902 |
-
predictions = st.session_state.predictions
|
| 903 |
-
results_df = predictions['full_df'].copy()
|
| 904 |
-
results_df['predicted_class'] = predictions['predicted_labels']
|
| 905 |
-
results_df['similarity_score'] = predictions['match_results']['similarity_score']
|
| 906 |
-
|
| 907 |
-
csv_results = results_df.to_csv(index=False)
|
| 908 |
-
st.download_button(
|
| 909 |
-
label="Download Predictions",
|
| 910 |
-
data=csv_results,
|
| 911 |
-
file_name="predictions.csv",
|
| 912 |
-
mime="text/csv"
|
| 913 |
-
)
|
| 914 |
-
|
| 915 |
-
if __name__ == "__main__":
|
| 916 |
-
main()
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import seaborn as sns
|
| 6 |
+
import tempfile
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
from io import StringIO
|
| 10 |
+
import plotly.express as px
|
| 11 |
+
import plotly.graph_objects as go
|
| 12 |
+
from plotly.subplots import make_subplots
|
| 13 |
+
|
| 14 |
+
# Add the parent directory to sys.path to import the module
|
| 15 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
| 16 |
+
|
| 17 |
+
from src.qualivec.data import DataLoader
|
| 18 |
+
from src.qualivec.embedding import EmbeddingModel
|
| 19 |
+
from src.qualivec.matching import SemanticMatcher
|
| 20 |
+
from src.qualivec.classification import Classifier
|
| 21 |
+
from src.qualivec.evaluation import Evaluator
|
| 22 |
+
from src.qualivec.optimization import ThresholdOptimizer
|
| 23 |
+
|
| 24 |
+
# Set page config
|
| 25 |
+
st.set_page_config(
|
| 26 |
+
page_title="QualiVec Demo",
|
| 27 |
+
page_icon="π",
|
| 28 |
+
layout="wide",
|
| 29 |
+
initial_sidebar_state="expanded"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Custom CSS for better styling
|
| 33 |
+
st.markdown("""
|
| 34 |
+
<style>
|
| 35 |
+
.main-header {
|
| 36 |
+
font-size: 2.5rem;
|
| 37 |
+
font-weight: bold;
|
| 38 |
+
color: #2E4057;
|
| 39 |
+
text-align: center;
|
| 40 |
+
margin-bottom: 2rem;
|
| 41 |
+
}
|
| 42 |
+
.section-header {
|
| 43 |
+
font-size: 1.5rem;
|
| 44 |
+
font-weight: bold;
|
| 45 |
+
color: #048A81;
|
| 46 |
+
margin-top: 2rem;
|
| 47 |
+
margin-bottom: 1rem;
|
| 48 |
+
}
|
| 49 |
+
.metric-card {
|
| 50 |
+
background-color: #f0f2f6;
|
| 51 |
+
padding: 1rem;
|
| 52 |
+
border-radius: 0.5rem;
|
| 53 |
+
margin: 0.5rem 0;
|
| 54 |
+
}
|
| 55 |
+
.success-message {
|
| 56 |
+
background-color: #d4edda;
|
| 57 |
+
color: #155724;
|
| 58 |
+
padding: 1rem;
|
| 59 |
+
border-radius: 0.5rem;
|
| 60 |
+
margin: 1rem 0;
|
| 61 |
+
}
|
| 62 |
+
.warning-message {
|
| 63 |
+
background-color: #fff3cd;
|
| 64 |
+
color: #856404;
|
| 65 |
+
padding: 1rem;
|
| 66 |
+
border-radius: 0.5rem;
|
| 67 |
+
margin: 1rem 0;
|
| 68 |
+
}
|
| 69 |
+
</style>
|
| 70 |
+
""", unsafe_allow_html=True)
|
| 71 |
+
|
| 72 |
+
def main():
|
| 73 |
+
st.markdown('<div class="main-header">π QualiVec Demo</div>', unsafe_allow_html=True)
|
| 74 |
+
st.markdown("""
|
| 75 |
+
<div style="text-align: center; margin-bottom: 2rem;">
|
| 76 |
+
<p style="font-size: 1.2rem; color: #666;">
|
| 77 |
+
Qualitative Content Analysis with LLM Embeddings
|
| 78 |
+
</p>
|
| 79 |
+
</div>
|
| 80 |
+
""", unsafe_allow_html=True)
|
| 81 |
+
|
| 82 |
+
# Sidebar for navigation
|
| 83 |
+
st.sidebar.title("Navigation")
|
| 84 |
+
page = st.sidebar.selectbox(
|
| 85 |
+
"Choose a page",
|
| 86 |
+
["π Home", "π Data Upload", "π§ Configuration", "π― Classification", "π Results"]
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Initialize session state
|
| 90 |
+
if 'classifier' not in st.session_state:
|
| 91 |
+
st.session_state.classifier = None
|
| 92 |
+
if 'reference_data' not in st.session_state:
|
| 93 |
+
st.session_state.reference_data = None
|
| 94 |
+
if 'labeled_data' not in st.session_state:
|
| 95 |
+
st.session_state.labeled_data = None
|
| 96 |
+
if 'optimization_results' not in st.session_state:
|
| 97 |
+
st.session_state.optimization_results = None
|
| 98 |
+
if 'evaluation_results' not in st.session_state:
|
| 99 |
+
st.session_state.evaluation_results = None
|
| 100 |
+
|
| 101 |
+
# Route to different pages
|
| 102 |
+
if page == "π Home":
|
| 103 |
+
show_home_page()
|
| 104 |
+
elif page == "π Data Upload":
|
| 105 |
+
show_data_upload_page()
|
| 106 |
+
elif page == "π§ Configuration":
|
| 107 |
+
show_configuration_page()
|
| 108 |
+
elif page == "π― Classification":
|
| 109 |
+
show_classification_page()
|
| 110 |
+
elif page == "π Results":
|
| 111 |
+
show_results_page()
|
| 112 |
+
|
| 113 |
+
def show_home_page():
|
| 114 |
+
st.markdown('<div class="section-header">Welcome to QualiVec</div>', unsafe_allow_html=True)
|
| 115 |
+
|
| 116 |
+
col1, col2, col3 = st.columns([1, 2, 1])
|
| 117 |
+
|
| 118 |
+
with col2:
|
| 119 |
+
st.markdown("""
|
| 120 |
+
### What is QualiVec?
|
| 121 |
+
|
| 122 |
+
QualiVec is a Python library that uses Large Language Model (LLM) embeddings for qualitative content analysis. It helps researchers and analysts classify text data by comparing it against reference examples.
|
| 123 |
+
|
| 124 |
+
### Key Features:
|
| 125 |
+
- **Semantic Matching**: Uses advanced embedding models to find semantic similarity
|
| 126 |
+
- **Threshold Optimization**: Automatically finds the best similarity threshold
|
| 127 |
+
- **Comprehensive Evaluation**: Provides detailed metrics and visualizations
|
| 128 |
+
- **Bootstrap Analysis**: Confidence intervals for robust evaluation
|
| 129 |
+
|
| 130 |
+
### How It Works:
|
| 131 |
+
1. **Upload Data**: Provide reference examples and data to classify
|
| 132 |
+
2. **Configure**: Set up embedding models and parameters
|
| 133 |
+
3. **Optimize**: Find the best threshold for classification
|
| 134 |
+
4. **Classify**: Apply the model to your data
|
| 135 |
+
5. **Evaluate**: Get detailed performance metrics
|
| 136 |
+
|
| 137 |
+
### Getting Started:
|
| 138 |
+
Use the sidebar to navigate through the demo. Start with **Data Upload** to begin your analysis.
|
| 139 |
+
""")
|
| 140 |
+
|
| 141 |
+
# Add sample data info
|
| 142 |
+
st.markdown('<div class="section-header">Sample Data Format</div>', unsafe_allow_html=True)
|
| 143 |
+
|
| 144 |
+
col1, col2 = st.columns(2)
|
| 145 |
+
|
| 146 |
+
with col1:
|
| 147 |
+
st.markdown("**Reference Data Format:**")
|
| 148 |
+
sample_ref = pd.DataFrame({
|
| 149 |
+
'tag': ['Positive', 'Negative', 'Neutral'],
|
| 150 |
+
'sentence': ['This is great!', 'This is terrible', 'This is okay']
|
| 151 |
+
})
|
| 152 |
+
st.dataframe(sample_ref, use_container_width=True)
|
| 153 |
+
|
| 154 |
+
with col2:
|
| 155 |
+
st.markdown("**Labeled Data Format:**")
|
| 156 |
+
sample_labeled = pd.DataFrame({
|
| 157 |
+
'sentence': ['I love this product', 'Not very good', 'Average quality'],
|
| 158 |
+
'Label': ['Positive', 'Negative', 'Neutral']
|
| 159 |
+
})
|
| 160 |
+
st.dataframe(sample_labeled, use_container_width=True)
|
| 161 |
+
|
| 162 |
+
def show_data_upload_page():
|
| 163 |
+
st.markdown('<div class="section-header">Data Upload</div>', unsafe_allow_html=True)
|
| 164 |
+
|
| 165 |
+
col1, col2 = st.columns(2)
|
| 166 |
+
|
| 167 |
+
with col1:
|
| 168 |
+
st.markdown("### Reference Data")
|
| 169 |
+
st.markdown("Upload a CSV file containing reference examples with columns: `tag` (class) and `sentence` (example text)")
|
| 170 |
+
|
| 171 |
+
reference_file = st.file_uploader(
|
| 172 |
+
"Choose reference data file",
|
| 173 |
+
type=['csv'],
|
| 174 |
+
key='reference_file'
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
if reference_file is not None:
|
| 178 |
+
try:
|
| 179 |
+
reference_df = pd.read_csv(reference_file)
|
| 180 |
+
st.success("Reference data loaded successfully!")
|
| 181 |
+
st.dataframe(reference_df.head(), use_container_width=True)
|
| 182 |
+
|
| 183 |
+
# Validate columns
|
| 184 |
+
required_cols = ['tag', 'sentence']
|
| 185 |
+
missing_cols = [col for col in required_cols if col not in reference_df.columns]
|
| 186 |
+
|
| 187 |
+
if missing_cols:
|
| 188 |
+
st.error(f"Missing required columns: {missing_cols}")
|
| 189 |
+
else:
|
| 190 |
+
# Prepare reference data
|
| 191 |
+
reference_df = reference_df.rename(columns={
|
| 192 |
+
'tag': 'class',
|
| 193 |
+
'sentence': 'matching_node'
|
| 194 |
+
})
|
| 195 |
+
st.session_state.reference_data = reference_df
|
| 196 |
+
|
| 197 |
+
# Show statistics
|
| 198 |
+
st.markdown("**Data Statistics:**")
|
| 199 |
+
st.write(f"- Total examples: {len(reference_df)}")
|
| 200 |
+
st.write(f"- Unique classes: {reference_df['class'].nunique()}")
|
| 201 |
+
st.write(f"- Class distribution:")
|
| 202 |
+
st.write(reference_df['class'].value_counts())
|
| 203 |
+
|
| 204 |
+
except Exception as e:
|
| 205 |
+
st.error(f"Error loading reference data: {str(e)}")
|
| 206 |
+
|
| 207 |
+
with col2:
|
| 208 |
+
st.markdown("### Labeled Data")
|
| 209 |
+
st.markdown("Upload a CSV file containing data to classify with columns: `sentence` (text) and `Label` (true class)")
|
| 210 |
+
|
| 211 |
+
labeled_file = st.file_uploader(
|
| 212 |
+
"Choose labeled data file",
|
| 213 |
+
type=['csv'],
|
| 214 |
+
key='labeled_file'
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
if labeled_file is not None:
|
| 218 |
+
try:
|
| 219 |
+
labeled_df = pd.read_csv(labeled_file)
|
| 220 |
+
st.success("Labeled data loaded successfully!")
|
| 221 |
+
st.dataframe(labeled_df.head(), use_container_width=True)
|
| 222 |
+
|
| 223 |
+
# Validate columns
|
| 224 |
+
required_cols = ['sentence', 'Label']
|
| 225 |
+
missing_cols = [col for col in required_cols if col not in labeled_df.columns]
|
| 226 |
+
|
| 227 |
+
if missing_cols:
|
| 228 |
+
st.error(f"Missing required columns: {missing_cols}")
|
| 229 |
+
else:
|
| 230 |
+
# Prepare labeled data
|
| 231 |
+
labeled_df = labeled_df.rename(columns={'Label': 'label'})
|
| 232 |
+
labeled_df['label'] = labeled_df['label'].replace('0', 'Other')
|
| 233 |
+
st.session_state.labeled_data = labeled_df
|
| 234 |
+
|
| 235 |
+
# Show statistics
|
| 236 |
+
st.markdown("**Data Statistics:**")
|
| 237 |
+
st.write(f"- Total samples: {len(labeled_df)}")
|
| 238 |
+
st.write(f"- Unique labels: {labeled_df['label'].nunique()}")
|
| 239 |
+
st.write(f"- Label distribution:")
|
| 240 |
+
st.write(labeled_df['label'].value_counts())
|
| 241 |
+
|
| 242 |
+
except Exception as e:
|
| 243 |
+
st.error(f"Error loading labeled data: {str(e)}")
|
| 244 |
+
|
| 245 |
+
# Show data compatibility check
|
| 246 |
+
if st.session_state.reference_data is not None and st.session_state.labeled_data is not None:
|
| 247 |
+
st.markdown('<div class="section-header">Data Compatibility Check</div>', unsafe_allow_html=True)
|
| 248 |
+
|
| 249 |
+
ref_classes = set(st.session_state.reference_data['class'].unique())
|
| 250 |
+
labeled_classes = set(st.session_state.labeled_data['label'].unique())
|
| 251 |
+
|
| 252 |
+
# Check for unknown classes
|
| 253 |
+
unknown_classes = labeled_classes - ref_classes
|
| 254 |
+
|
| 255 |
+
if unknown_classes:
|
| 256 |
+
st.warning(f"Warning: Labels in labeled data not found in reference data: {unknown_classes}")
|
| 257 |
+
else:
|
| 258 |
+
st.success("β
Data compatibility check passed!")
|
| 259 |
+
|
| 260 |
+
# Show class overlap
|
| 261 |
+
st.markdown("**Class Overlap Analysis:**")
|
| 262 |
+
col1, col2, col3 = st.columns(3)
|
| 263 |
+
|
| 264 |
+
with col1:
|
| 265 |
+
st.metric("Reference Classes", len(ref_classes))
|
| 266 |
+
with col2:
|
| 267 |
+
st.metric("Labeled Classes", len(labeled_classes))
|
| 268 |
+
with col3:
|
| 269 |
+
st.metric("Common Classes", len(ref_classes.intersection(labeled_classes)))
|
| 270 |
+
|
| 271 |
+
def show_configuration_page():
|
| 272 |
+
st.markdown('<div class="section-header">Model Configuration</div>', unsafe_allow_html=True)
|
| 273 |
+
|
| 274 |
+
# Check if data is loaded
|
| 275 |
+
if st.session_state.reference_data is None or st.session_state.labeled_data is None:
|
| 276 |
+
st.warning("Please upload both reference and labeled data first.")
|
| 277 |
+
return
|
| 278 |
+
|
| 279 |
+
col1, col2 = st.columns(2)
|
| 280 |
+
|
| 281 |
+
with col1:
|
| 282 |
+
st.markdown("### Embedding Model")
|
| 283 |
+
|
| 284 |
+
# Model type selection
|
| 285 |
+
model_type = st.selectbox(
|
| 286 |
+
"Choose model type",
|
| 287 |
+
["HuggingFace", "Gemini"],
|
| 288 |
+
help="Select the type of embedding model to use"
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# Model selection based on type
|
| 292 |
+
if model_type == "HuggingFace":
|
| 293 |
+
model_options = [
|
| 294 |
+
"sentence-transformers/all-MiniLM-L6-v2",
|
| 295 |
+
"sentence-transformers/all-mpnet-base-v2",
|
| 296 |
+
"sentence-transformers/distilbert-base-nli-mean-tokens"
|
| 297 |
+
]
|
| 298 |
+
|
| 299 |
+
selected_model = st.selectbox(
|
| 300 |
+
"Choose HuggingFace model",
|
| 301 |
+
model_options,
|
| 302 |
+
help="Select the pre-trained HuggingFace model for generating embeddings"
|
| 303 |
+
)
|
| 304 |
+
else: # Gemini
|
| 305 |
+
gemini_models = [
|
| 306 |
+
"gemini-embedding-001",
|
| 307 |
+
"text-embedding-004"
|
| 308 |
+
]
|
| 309 |
+
|
| 310 |
+
selected_model = st.selectbox(
|
| 311 |
+
"Choose Gemini model",
|
| 312 |
+
gemini_models,
|
| 313 |
+
help="Select the Gemini embedding model for generating embeddings"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Calculate total texts to process
|
| 317 |
+
total_texts = 0
|
| 318 |
+
if st.session_state.reference_data is not None:
|
| 319 |
+
total_texts += len(st.session_state.reference_data)
|
| 320 |
+
if st.session_state.labeled_data is not None:
|
| 321 |
+
total_texts += len(st.session_state.labeled_data)
|
| 322 |
+
|
| 323 |
+
st.warning(
|
| 324 |
+
f"β οΈ **Gemini API Rate Limits (Free Tier)**\\n\\n"
|
| 325 |
+
f"- 1,500 requests per day\\n"
|
| 326 |
+
f"- Each batch of 100 texts = 1 request\\n"
|
| 327 |
+
f"- Your current dataset: ~{total_texts} texts\\n"
|
| 328 |
+
f"- Estimated requests needed: ~{(total_texts // 100) + 1}\\n\\n"
|
| 329 |
+
f"If you exceed quota, consider:\\n"
|
| 330 |
+
f"1. Using a smaller dataset\\n"
|
| 331 |
+
f"2. Switching to HuggingFace models (no limits)\\n"
|
| 332 |
+
f"3. Upgrading to a paid API plan"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
st.info("π‘ Note: Using Gemini embeddings requires GOOGLE_API_KEY environment variable to be set.")
|
| 336 |
+
|
| 337 |
+
st.markdown("### Initial Threshold")
|
| 338 |
+
initial_threshold = st.slider(
|
| 339 |
+
"Initial similarity threshold",
|
| 340 |
+
min_value=0.0,
|
| 341 |
+
max_value=1.0,
|
| 342 |
+
value=0.7,
|
| 343 |
+
step=0.05,
|
| 344 |
+
help="Cosine similarity threshold for classification"
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
with col2:
|
| 348 |
+
st.markdown("### Optimization Parameters")
|
| 349 |
+
|
| 350 |
+
optimize_threshold = st.checkbox(
|
| 351 |
+
"Enable threshold optimization",
|
| 352 |
+
value=True,
|
| 353 |
+
help="Automatically find the best threshold"
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
if optimize_threshold:
|
| 357 |
+
col2_1, col2_2 = st.columns(2)
|
| 358 |
+
|
| 359 |
+
with col2_1:
|
| 360 |
+
start_threshold = st.slider(
|
| 361 |
+
"Start threshold",
|
| 362 |
+
min_value=0.0,
|
| 363 |
+
max_value=1.0,
|
| 364 |
+
value=0.5,
|
| 365 |
+
step=0.05
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
end_threshold = st.slider(
|
| 369 |
+
"End threshold",
|
| 370 |
+
min_value=0.0,
|
| 371 |
+
max_value=1.0,
|
| 372 |
+
value=0.9,
|
| 373 |
+
step=0.05
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
with col2_2:
|
| 377 |
+
step_size = st.slider(
|
| 378 |
+
"Step size",
|
| 379 |
+
min_value=0.005,
|
| 380 |
+
max_value=0.05,
|
| 381 |
+
value=0.01,
|
| 382 |
+
step=0.005
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
optimization_metric = st.selectbox(
|
| 386 |
+
"Optimization metric",
|
| 387 |
+
["f1_macro", "accuracy", "precision_macro", "recall_macro"]
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# Load models button
|
| 391 |
+
if st.button("Initialize Models", type="primary"):
|
| 392 |
+
with st.spinner("Loading models... This may take a few minutes."):
|
| 393 |
+
try:
|
| 394 |
+
# Initialize classifier
|
| 395 |
+
classifier = Classifier(verbose=False)
|
| 396 |
+
|
| 397 |
+
# Determine model type parameter
|
| 398 |
+
model_type_param = "gemini" if model_type == "Gemini" else "huggingface"
|
| 399 |
+
|
| 400 |
+
classifier.load_models(
|
| 401 |
+
model_name=selected_model,
|
| 402 |
+
model_type=model_type_param,
|
| 403 |
+
threshold=initial_threshold
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Prepare reference vectors
|
| 407 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp_ref:
|
| 408 |
+
tmp_ref_path = tmp_ref.name
|
| 409 |
+
st.session_state.reference_data.to_csv(tmp_ref_path, index=False)
|
| 410 |
+
|
| 411 |
+
try:
|
| 412 |
+
reference_data = classifier.prepare_reference_vectors(
|
| 413 |
+
reference_path=tmp_ref_path,
|
| 414 |
+
class_column='class',
|
| 415 |
+
node_column='matching_node'
|
| 416 |
+
)
|
| 417 |
+
finally:
|
| 418 |
+
# Ensure file is deleted even if an error occurs
|
| 419 |
+
try:
|
| 420 |
+
os.unlink(tmp_ref_path)
|
| 421 |
+
except (OSError, PermissionError):
|
| 422 |
+
pass # File might already be deleted or locked
|
| 423 |
+
|
| 424 |
+
st.session_state.classifier = classifier
|
| 425 |
+
st.session_state.reference_vectors = reference_data
|
| 426 |
+
st.session_state.config = {
|
| 427 |
+
'model_type': model_type,
|
| 428 |
+
'model_name': selected_model,
|
| 429 |
+
'initial_threshold': initial_threshold,
|
| 430 |
+
'optimize_threshold': optimize_threshold,
|
| 431 |
+
'start_threshold': start_threshold if optimize_threshold else None,
|
| 432 |
+
'end_threshold': end_threshold if optimize_threshold else None,
|
| 433 |
+
'step_size': step_size if optimize_threshold else None,
|
| 434 |
+
'optimization_metric': optimization_metric if optimize_threshold else None
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
st.success("β
Models initialized successfully!")
|
| 438 |
+
|
| 439 |
+
except Exception as e:
|
| 440 |
+
st.error(f"Error initializing models: {str(e)}")
|
| 441 |
+
|
| 442 |
+
# Show current configuration
|
| 443 |
+
if st.session_state.classifier is not None:
|
| 444 |
+
st.markdown('<div class="section-header">Current Configuration</div>', unsafe_allow_html=True)
|
| 445 |
+
|
| 446 |
+
config = st.session_state.config
|
| 447 |
+
|
| 448 |
+
col1, col2, col3 = st.columns(3)
|
| 449 |
+
|
| 450 |
+
with col1:
|
| 451 |
+
st.markdown("**Model Settings:**")
|
| 452 |
+
st.write(f"- Model type: {config['model_type']}")
|
| 453 |
+
st.write(f"- Model: {config['model_name']}")
|
| 454 |
+
st.write(f"- Initial threshold: {config['initial_threshold']}")
|
| 455 |
+
|
| 456 |
+
with col2:
|
| 457 |
+
st.markdown("**Optimization:**")
|
| 458 |
+
st.write(f"- Enabled: {config['optimize_threshold']}")
|
| 459 |
+
if config['optimize_threshold']:
|
| 460 |
+
st.write(f"- Range: {config['start_threshold']:.2f} - {config['end_threshold']:.2f}")
|
| 461 |
+
st.write(f"- Step: {config['step_size']:.3f}")
|
| 462 |
+
|
| 463 |
+
with col3:
|
| 464 |
+
st.markdown("**Data:**")
|
| 465 |
+
st.write(f"- Reference examples: {len(st.session_state.reference_data)}")
|
| 466 |
+
st.write(f"- Labeled samples: {len(st.session_state.labeled_data)}")
|
| 467 |
+
|
| 468 |
+
def show_classification_page():
|
| 469 |
+
st.markdown('<div class="section-header">Classification & Optimization</div>', unsafe_allow_html=True)
|
| 470 |
+
|
| 471 |
+
# Check if models are loaded
|
| 472 |
+
if st.session_state.classifier is None:
|
| 473 |
+
st.warning("Please configure and initialize models first.")
|
| 474 |
+
return
|
| 475 |
+
|
| 476 |
+
# Run classification
|
| 477 |
+
if st.button("Run Classification", type="primary"):
|
| 478 |
+
with st.spinner("Running classification and optimization..."):
|
| 479 |
+
try:
|
| 480 |
+
# Save labeled data to temporary file
|
| 481 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp_labeled:
|
| 482 |
+
tmp_labeled_path = tmp_labeled.name
|
| 483 |
+
st.session_state.labeled_data.to_csv(tmp_labeled_path, index=False)
|
| 484 |
+
|
| 485 |
+
try:
|
| 486 |
+
# Run optimization if enabled
|
| 487 |
+
if st.session_state.config['optimize_threshold']:
|
| 488 |
+
optimization_results = st.session_state.classifier.evaluate_classification(
|
| 489 |
+
labeled_path=tmp_labeled_path,
|
| 490 |
+
reference_data=st.session_state.reference_vectors,
|
| 491 |
+
sentence_column='sentence',
|
| 492 |
+
label_column='label',
|
| 493 |
+
optimize_threshold=True,
|
| 494 |
+
start=st.session_state.config['start_threshold'],
|
| 495 |
+
end=st.session_state.config['end_threshold'],
|
| 496 |
+
step=st.session_state.config['step_size']
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
st.session_state.optimization_results = optimization_results
|
| 500 |
+
optimal_threshold = optimization_results["optimal_threshold"]
|
| 501 |
+
|
| 502 |
+
# Update classifier with optimal threshold
|
| 503 |
+
st.session_state.classifier.matcher = SemanticMatcher(
|
| 504 |
+
threshold=optimal_threshold,
|
| 505 |
+
verbose=False
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
st.success(f"β
Optimization completed! Optimal threshold: {optimal_threshold:.4f}")
|
| 509 |
+
|
| 510 |
+
else:
|
| 511 |
+
optimal_threshold = st.session_state.config['initial_threshold']
|
| 512 |
+
|
| 513 |
+
# Run evaluation
|
| 514 |
+
embedding_model = st.session_state.classifier.embedding_model
|
| 515 |
+
data_loader = DataLoader(verbose=False)
|
| 516 |
+
full_df = data_loader.load_labeled_data(tmp_labeled_path, label_column='label')
|
| 517 |
+
|
| 518 |
+
# Generate embeddings
|
| 519 |
+
full_embeddings = embedding_model.embed_dataframe(full_df, text_column='sentence')
|
| 520 |
+
|
| 521 |
+
# Classify
|
| 522 |
+
match_results = st.session_state.classifier.matcher.match(
|
| 523 |
+
full_embeddings,
|
| 524 |
+
st.session_state.reference_vectors
|
| 525 |
+
)
|
| 526 |
+
predicted_labels = match_results["predicted_class"].tolist()
|
| 527 |
+
true_labels = full_df['label'].tolist()
|
| 528 |
+
|
| 529 |
+
# Evaluate
|
| 530 |
+
evaluator = Evaluator(verbose=False)
|
| 531 |
+
eval_results = evaluator.evaluate(
|
| 532 |
+
true_labels=true_labels,
|
| 533 |
+
predicted_labels=predicted_labels,
|
| 534 |
+
class_names=list(set(true_labels) | set(predicted_labels))
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
# Bootstrap evaluation
|
| 538 |
+
bootstrap_results = evaluator.bootstrap_evaluate(
|
| 539 |
+
true_labels=true_labels,
|
| 540 |
+
predicted_labels=predicted_labels,
|
| 541 |
+
n_iterations=100
|
| 542 |
+
)
|
| 543 |
+
|
| 544 |
+
st.session_state.evaluation_results = eval_results
|
| 545 |
+
st.session_state.bootstrap_results = bootstrap_results
|
| 546 |
+
st.session_state.predictions = {
|
| 547 |
+
'true_labels': true_labels,
|
| 548 |
+
'predicted_labels': predicted_labels,
|
| 549 |
+
'match_results': match_results,
|
| 550 |
+
'full_df': full_df
|
| 551 |
+
}
|
| 552 |
+
|
| 553 |
+
finally:
|
| 554 |
+
# Ensure temporary file is deleted
|
| 555 |
+
try:
|
| 556 |
+
os.unlink(tmp_labeled_path)
|
| 557 |
+
except (OSError, PermissionError):
|
| 558 |
+
pass # File might already be deleted or locked
|
| 559 |
+
|
| 560 |
+
st.success("β
Classification completed successfully!")
|
| 561 |
+
|
| 562 |
+
except Exception as e:
|
| 563 |
+
st.error(f"Error during classification: {str(e)}")
|
| 564 |
+
|
| 565 |
+
# Show optimization results if available
|
| 566 |
+
if st.session_state.optimization_results is not None:
|
| 567 |
+
st.markdown('<div class="section-header">Optimization Results</div>', unsafe_allow_html=True)
|
| 568 |
+
|
| 569 |
+
results = st.session_state.optimization_results
|
| 570 |
+
|
| 571 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 572 |
+
|
| 573 |
+
with col1:
|
| 574 |
+
st.metric(
|
| 575 |
+
"Optimal Threshold",
|
| 576 |
+
f"{results['optimal_threshold']:.4f}"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
with col2:
|
| 580 |
+
st.metric(
|
| 581 |
+
"Accuracy",
|
| 582 |
+
f"{results['optimal_metrics']['accuracy']:.4f}"
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
with col3:
|
| 586 |
+
st.metric(
|
| 587 |
+
"F1 Score",
|
| 588 |
+
f"{results['optimal_metrics']['f1_macro']:.4f}"
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
with col4:
|
| 592 |
+
st.metric(
|
| 593 |
+
"Precision",
|
| 594 |
+
f"{results['optimal_metrics']['precision_macro']:.4f}"
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# Plot optimization curve
|
| 598 |
+
st.markdown("### Optimization Curve")
|
| 599 |
+
|
| 600 |
+
opt_results = results["results_by_threshold"]
|
| 601 |
+
|
| 602 |
+
fig = make_subplots(
|
| 603 |
+
rows=2, cols=2,
|
| 604 |
+
subplot_titles=('Accuracy', 'F1 Score', 'Precision', 'Recall'),
|
| 605 |
+
vertical_spacing=0.1
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
thresholds = opt_results["thresholds"]
|
| 609 |
+
|
| 610 |
+
# Add traces
|
| 611 |
+
fig.add_trace(
|
| 612 |
+
go.Scatter(x=thresholds, y=opt_results["accuracy"], name="Accuracy"),
|
| 613 |
+
row=1, col=1
|
| 614 |
+
)
|
| 615 |
+
fig.add_trace(
|
| 616 |
+
go.Scatter(x=thresholds, y=opt_results["f1_macro"], name="F1 Score"),
|
| 617 |
+
row=1, col=2
|
| 618 |
+
)
|
| 619 |
+
fig.add_trace(
|
| 620 |
+
go.Scatter(x=thresholds, y=opt_results["precision_macro"], name="Precision"),
|
| 621 |
+
row=2, col=1
|
| 622 |
+
)
|
| 623 |
+
fig.add_trace(
|
| 624 |
+
go.Scatter(x=thresholds, y=opt_results["recall_macro"], name="Recall"),
|
| 625 |
+
row=2, col=2
|
| 626 |
+
)
|
| 627 |
+
|
| 628 |
+
# Add optimal threshold line to each subplot using shapes
|
| 629 |
+
optimal_thresh = results['optimal_threshold']
|
| 630 |
+
|
| 631 |
+
# Add vertical line as shapes to each subplot
|
| 632 |
+
shapes = []
|
| 633 |
+
for row in range(1, 3):
|
| 634 |
+
for col in range(1, 3):
|
| 635 |
+
# Calculate the subplot domain
|
| 636 |
+
xaxis = f'x{(row-1)*2 + col}' if (row-1)*2 + col > 1 else 'x'
|
| 637 |
+
shapes.append(
|
| 638 |
+
dict(
|
| 639 |
+
type="line",
|
| 640 |
+
x0=optimal_thresh, x1=optimal_thresh,
|
| 641 |
+
y0=0, y1=1,
|
| 642 |
+
yref=f"y{(row-1)*2 + col} domain" if (row-1)*2 + col > 1 else "y domain",
|
| 643 |
+
xref=xaxis,
|
| 644 |
+
line=dict(color="red", width=2, dash="dash")
|
| 645 |
+
)
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
fig.update_layout(shapes=shapes)
|
| 649 |
+
|
| 650 |
+
fig.update_layout(
|
| 651 |
+
title="Threshold Optimization Results",
|
| 652 |
+
showlegend=False,
|
| 653 |
+
height=600
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 657 |
+
|
| 658 |
+
def show_results_page():
|
| 659 |
+
st.markdown('<div class="section-header">Results & Evaluation</div>', unsafe_allow_html=True)
|
| 660 |
+
|
| 661 |
+
# Check if evaluation results are available
|
| 662 |
+
if st.session_state.evaluation_results is None:
|
| 663 |
+
st.warning("Please run classification first to see results.")
|
| 664 |
+
return
|
| 665 |
+
|
| 666 |
+
eval_results = st.session_state.evaluation_results
|
| 667 |
+
|
| 668 |
+
# Performance metrics
|
| 669 |
+
st.markdown("### Performance Metrics")
|
| 670 |
+
|
| 671 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 672 |
+
|
| 673 |
+
with col1:
|
| 674 |
+
st.metric(
|
| 675 |
+
"Overall Accuracy",
|
| 676 |
+
f"{eval_results['accuracy']:.4f}"
|
| 677 |
+
)
|
| 678 |
+
|
| 679 |
+
with col2:
|
| 680 |
+
st.metric(
|
| 681 |
+
"Macro F1 Score",
|
| 682 |
+
f"{eval_results['f1_macro']:.4f}"
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
with col3:
|
| 686 |
+
st.metric(
|
| 687 |
+
"Macro Precision",
|
| 688 |
+
f"{eval_results['precision_macro']:.4f}"
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
with col4:
|
| 692 |
+
st.metric(
|
| 693 |
+
"Macro Recall",
|
| 694 |
+
f"{eval_results['recall_macro']:.4f}"
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
# Class-wise metrics
|
| 698 |
+
st.markdown("### Class-wise Performance")
|
| 699 |
+
|
| 700 |
+
class_metrics_df = pd.DataFrame({
|
| 701 |
+
'Class': list(eval_results['class_metrics']['precision'].keys()),
|
| 702 |
+
'Precision': list(eval_results['class_metrics']['precision'].values()),
|
| 703 |
+
'Recall': list(eval_results['class_metrics']['recall'].values()),
|
| 704 |
+
'F1-Score': list(eval_results['class_metrics']['f1'].values()),
|
| 705 |
+
'Support': list(eval_results['class_metrics']['support'].values())
|
| 706 |
+
})
|
| 707 |
+
|
| 708 |
+
st.dataframe(class_metrics_df, use_container_width=True)
|
| 709 |
+
|
| 710 |
+
# Confusion Matrix
|
| 711 |
+
st.markdown("### Confusion Matrix")
|
| 712 |
+
|
| 713 |
+
cm = eval_results['confusion_matrix']
|
| 714 |
+
class_names = eval_results['confusion_matrix_labels']
|
| 715 |
+
|
| 716 |
+
fig = px.imshow(
|
| 717 |
+
cm,
|
| 718 |
+
labels=dict(x="Predicted", y="True", color="Count"),
|
| 719 |
+
x=class_names,
|
| 720 |
+
y=class_names,
|
| 721 |
+
color_continuous_scale='Blues',
|
| 722 |
+
text_auto=True,
|
| 723 |
+
title="Confusion Matrix"
|
| 724 |
+
)
|
| 725 |
+
|
| 726 |
+
fig.update_layout(
|
| 727 |
+
width=600,
|
| 728 |
+
height=600
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 732 |
+
|
| 733 |
+
# Bootstrap Results
|
| 734 |
+
if st.session_state.bootstrap_results is not None:
|
| 735 |
+
st.markdown("### Bootstrap Confidence Intervals")
|
| 736 |
+
|
| 737 |
+
bootstrap_results = st.session_state.bootstrap_results
|
| 738 |
+
|
| 739 |
+
# Debug: show available keys
|
| 740 |
+
if 'confidence_intervals' in bootstrap_results:
|
| 741 |
+
metrics = ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
|
| 742 |
+
|
| 743 |
+
for metric in metrics:
|
| 744 |
+
if metric in bootstrap_results['confidence_intervals']:
|
| 745 |
+
ci_data = bootstrap_results['confidence_intervals'][metric]
|
| 746 |
+
st.markdown(f"**{metric.replace('_', ' ').title()}:**")
|
| 747 |
+
|
| 748 |
+
col1, col2, col3 = st.columns(3)
|
| 749 |
+
|
| 750 |
+
# Check available confidence levels
|
| 751 |
+
available_levels = list(ci_data.keys())
|
| 752 |
+
|
| 753 |
+
with col1:
|
| 754 |
+
if '0.95' in ci_data:
|
| 755 |
+
ci_95 = ci_data['0.95']
|
| 756 |
+
if isinstance(ci_95, dict):
|
| 757 |
+
st.write(f"95% CI: [{ci_95['lower']:.4f}, {ci_95['upper']:.4f}]")
|
| 758 |
+
elif isinstance(ci_95, (list, tuple)) and len(ci_95) >= 2:
|
| 759 |
+
st.write(f"95% CI: [{ci_95[0]:.4f}, {ci_95[1]:.4f}]")
|
| 760 |
+
else:
|
| 761 |
+
st.write("95% CI: Format not recognized")
|
| 762 |
+
elif 0.95 in ci_data:
|
| 763 |
+
ci_95 = ci_data[0.95]
|
| 764 |
+
if isinstance(ci_95, dict):
|
| 765 |
+
st.write(f"95% CI: [{ci_95['lower']:.4f}, {ci_95['upper']:.4f}]")
|
| 766 |
+
elif isinstance(ci_95, (list, tuple)) and len(ci_95) >= 2:
|
| 767 |
+
st.write(f"95% CI: [{ci_95[0]:.4f}, {ci_95[1]:.4f}]")
|
| 768 |
+
else:
|
| 769 |
+
st.write("95% CI: Format not recognized")
|
| 770 |
+
else:
|
| 771 |
+
st.write("95% CI: Not available")
|
| 772 |
+
|
| 773 |
+
with col2:
|
| 774 |
+
if '0.99' in ci_data:
|
| 775 |
+
ci_99 = ci_data['0.99']
|
| 776 |
+
if isinstance(ci_99, dict):
|
| 777 |
+
st.write(f"99% CI: [{ci_99['lower']:.4f}, {ci_99['upper']:.4f}]")
|
| 778 |
+
elif isinstance(ci_99, (list, tuple)) and len(ci_99) >= 2:
|
| 779 |
+
st.write(f"99% CI: [{ci_99[0]:.4f}, {ci_99[1]:.4f}]")
|
| 780 |
+
else:
|
| 781 |
+
st.write("99% CI: Format not recognized")
|
| 782 |
+
elif 0.99 in ci_data:
|
| 783 |
+
ci_99 = ci_data[0.99]
|
| 784 |
+
if isinstance(ci_99, dict):
|
| 785 |
+
st.write(f"99% CI: [{ci_99['lower']:.4f}, {ci_99['upper']:.4f}]")
|
| 786 |
+
elif isinstance(ci_99, (list, tuple)) and len(ci_99) >= 2:
|
| 787 |
+
st.write(f"99% CI: [{ci_99[0]:.4f}, {ci_99[1]:.4f}]")
|
| 788 |
+
else:
|
| 789 |
+
st.write("99% CI: Format not recognized")
|
| 790 |
+
else:
|
| 791 |
+
st.write("99% CI: Not available")
|
| 792 |
+
|
| 793 |
+
with col3:
|
| 794 |
+
if 'point_estimates' in bootstrap_results and metric in bootstrap_results['point_estimates']:
|
| 795 |
+
st.write(f"Point Estimate: {bootstrap_results['point_estimates'][metric]:.4f}")
|
| 796 |
+
else:
|
| 797 |
+
st.write("Point Estimate: Not available")
|
| 798 |
+
else:
|
| 799 |
+
st.info("Bootstrap confidence intervals not available.")
|
| 800 |
+
|
| 801 |
+
# Bootstrap Distribution Plot
|
| 802 |
+
st.markdown("### Bootstrap Distributions")
|
| 803 |
+
|
| 804 |
+
if 'bootstrap_distribution' in bootstrap_results:
|
| 805 |
+
fig = make_subplots(
|
| 806 |
+
rows=2, cols=2,
|
| 807 |
+
subplot_titles=('Accuracy', 'F1 Score', 'Precision', 'Recall')
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
distributions = bootstrap_results['bootstrap_distribution']
|
| 811 |
+
|
| 812 |
+
if 'accuracy' in distributions:
|
| 813 |
+
fig.add_trace(
|
| 814 |
+
go.Histogram(x=distributions['accuracy'], name="Accuracy", nbinsx=30),
|
| 815 |
+
row=1, col=1
|
| 816 |
+
)
|
| 817 |
+
if 'f1_macro' in distributions:
|
| 818 |
+
fig.add_trace(
|
| 819 |
+
go.Histogram(x=distributions['f1_macro'], name="F1 Score", nbinsx=30),
|
| 820 |
+
row=1, col=2
|
| 821 |
+
)
|
| 822 |
+
if 'precision_macro' in distributions:
|
| 823 |
+
fig.add_trace(
|
| 824 |
+
go.Histogram(x=distributions['precision_macro'], name="Precision", nbinsx=30),
|
| 825 |
+
row=2, col=1
|
| 826 |
+
)
|
| 827 |
+
if 'recall_macro' in distributions:
|
| 828 |
+
fig.add_trace(
|
| 829 |
+
go.Histogram(x=distributions['recall_macro'], name="Recall", nbinsx=30),
|
| 830 |
+
row=2, col=2
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
fig.update_layout(
|
| 834 |
+
title="Bootstrap Distributions",
|
| 835 |
+
showlegend=False,
|
| 836 |
+
height=600
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
st.plotly_chart(fig, use_container_width=True)
|
| 840 |
+
else:
|
| 841 |
+
st.info("Bootstrap distributions not available.")
|
| 842 |
+
|
| 843 |
+
# Sample predictions
|
| 844 |
+
if 'predictions' in st.session_state:
|
| 845 |
+
st.markdown("### Sample Predictions")
|
| 846 |
+
|
| 847 |
+
predictions = st.session_state.predictions
|
| 848 |
+
sample_df = predictions['full_df'].copy()
|
| 849 |
+
sample_df['predicted_class'] = predictions['predicted_labels']
|
| 850 |
+
sample_df['true_class'] = predictions['true_labels']
|
| 851 |
+
sample_df['similarity_score'] = predictions['match_results']['similarity_score']
|
| 852 |
+
sample_df['correct'] = sample_df['predicted_class'] == sample_df['true_class']
|
| 853 |
+
|
| 854 |
+
# Filter options
|
| 855 |
+
col1, col2 = st.columns(2)
|
| 856 |
+
|
| 857 |
+
with col1:
|
| 858 |
+
show_correct = st.checkbox("Show correct predictions", value=True)
|
| 859 |
+
|
| 860 |
+
with col2:
|
| 861 |
+
show_incorrect = st.checkbox("Show incorrect predictions", value=True)
|
| 862 |
+
|
| 863 |
+
# Filter data
|
| 864 |
+
if show_correct and show_incorrect:
|
| 865 |
+
filtered_df = sample_df
|
| 866 |
+
elif show_correct:
|
| 867 |
+
filtered_df = sample_df[sample_df['correct'] == True]
|
| 868 |
+
elif show_incorrect:
|
| 869 |
+
filtered_df = sample_df[sample_df['correct'] == False]
|
| 870 |
+
else:
|
| 871 |
+
filtered_df = pd.DataFrame()
|
| 872 |
+
|
| 873 |
+
if not filtered_df.empty:
|
| 874 |
+
# Sample random rows
|
| 875 |
+
n_samples = min(20, len(filtered_df))
|
| 876 |
+
sample_rows = filtered_df.sample(n=n_samples) if len(filtered_df) > n_samples else filtered_df
|
| 877 |
+
|
| 878 |
+
display_df = sample_rows[['sentence', 'true_class', 'predicted_class', 'similarity_score', 'correct']].reset_index(drop=True)
|
| 879 |
+
|
| 880 |
+
st.dataframe(display_df, use_container_width=True)
|
| 881 |
+
else:
|
| 882 |
+
st.info("No predictions to show with current filters.")
|
| 883 |
+
|
| 884 |
+
# Download results
|
| 885 |
+
st.markdown("### Download Results")
|
| 886 |
+
|
| 887 |
+
col1, col2 = st.columns(2)
|
| 888 |
+
|
| 889 |
+
with col1:
|
| 890 |
+
# Download class-wise metrics
|
| 891 |
+
csv_metrics = class_metrics_df.to_csv(index=False)
|
| 892 |
+
st.download_button(
|
| 893 |
+
label="Download Class Metrics",
|
| 894 |
+
data=csv_metrics,
|
| 895 |
+
file_name="class_metrics.csv",
|
| 896 |
+
mime="text/csv"
|
| 897 |
+
)
|
| 898 |
+
|
| 899 |
+
with col2:
|
| 900 |
+
# Download predictions
|
| 901 |
+
if 'predictions' in st.session_state:
|
| 902 |
+
predictions = st.session_state.predictions
|
| 903 |
+
results_df = predictions['full_df'].copy()
|
| 904 |
+
results_df['predicted_class'] = predictions['predicted_labels']
|
| 905 |
+
results_df['similarity_score'] = predictions['match_results']['similarity_score']
|
| 906 |
+
|
| 907 |
+
csv_results = results_df.to_csv(index=False)
|
| 908 |
+
st.download_button(
|
| 909 |
+
label="Download Predictions",
|
| 910 |
+
data=csv_results,
|
| 911 |
+
file_name="predictions.csv",
|
| 912 |
+
mime="text/csv"
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
if __name__ == "__main__":
|
| 916 |
+
main()
|
app/run_demo.py
CHANGED
|
@@ -1,38 +1,38 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Quick launcher script for the QualiVec Streamlit demo.
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import subprocess
|
| 7 |
-
import sys
|
| 8 |
-
import os
|
| 9 |
-
|
| 10 |
-
def main():
|
| 11 |
-
"""Launch the Streamlit app."""
|
| 12 |
-
|
| 13 |
-
# Get the directory of this script
|
| 14 |
-
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 15 |
-
app_path = os.path.join(script_dir, "app.py")
|
| 16 |
-
|
| 17 |
-
print("π Starting QualiVec Demo...")
|
| 18 |
-
print("π App will be available at: http://localhost:8501")
|
| 19 |
-
print("βΉοΈ Press Ctrl+C to stop the app")
|
| 20 |
-
print("-" * 50)
|
| 21 |
-
|
| 22 |
-
try:
|
| 23 |
-
# Run streamlit
|
| 24 |
-
subprocess.run([
|
| 25 |
-
sys.executable, "-m", "streamlit", "run", app_path,
|
| 26 |
-
"--server.headless", "true",
|
| 27 |
-
"--server.address=0.0.0.0",
|
| 28 |
-
"--server.port=8501",
|
| 29 |
-
"--server.enableCORS", "false",
|
| 30 |
-
"--server.enableXsrfProtection", "false"
|
| 31 |
-
])
|
| 32 |
-
except KeyboardInterrupt:
|
| 33 |
-
print("\nπ App stopped by user")
|
| 34 |
-
except Exception as e:
|
| 35 |
-
print(f"β Error starting app: {e}")
|
| 36 |
-
|
| 37 |
-
if __name__ == "__main__":
|
| 38 |
-
main()
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Quick launcher script for the QualiVec Streamlit demo.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import subprocess
|
| 7 |
+
import sys
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
def main():
|
| 11 |
+
"""Launch the Streamlit app."""
|
| 12 |
+
|
| 13 |
+
# Get the directory of this script
|
| 14 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 15 |
+
app_path = os.path.join(script_dir, "app.py")
|
| 16 |
+
|
| 17 |
+
print("π Starting QualiVec Demo...")
|
| 18 |
+
print("π App will be available at: http://localhost:8501")
|
| 19 |
+
print("βΉοΈ Press Ctrl+C to stop the app")
|
| 20 |
+
print("-" * 50)
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
# Run streamlit
|
| 24 |
+
subprocess.run([
|
| 25 |
+
sys.executable, "-m", "streamlit", "run", app_path,
|
| 26 |
+
"--server.headless", "true",
|
| 27 |
+
# "--server.address=0.0.0.0",
|
| 28 |
+
"--server.port=8501",
|
| 29 |
+
"--server.enableCORS", "false",
|
| 30 |
+
"--server.enableXsrfProtection", "false"
|
| 31 |
+
])
|
| 32 |
+
except KeyboardInterrupt:
|
| 33 |
+
print("\nπ App stopped by user")
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print(f"β Error starting app: {e}")
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
main()
|
dist/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*
|
dist/qualivec-0.1.0-py3-none-any.whl
ADDED
|
Binary file (19.9 kB). View file
|
|
|
dist/qualivec-0.1.0.tar.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:80b1c1f4ac5470593b6b873c82f620ec544e8f9d8ac2834d23ef81521e65625c
|
| 3 |
+
size 46670
|
src/qualivec/__pycache__/embedding.cpython-312.pyc
CHANGED
|
Binary files a/src/qualivec/__pycache__/embedding.cpython-312.pyc and b/src/qualivec/__pycache__/embedding.cpython-312.pyc differ
|
|
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|