akhil-vaidya commited on
Commit
1b04a15
Β·
verified Β·
1 Parent(s): 4402c16

Upload 30 files

Browse files
.dockerignore ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker
2
+ .dockerignore
3
+ Dockerfile*
4
+ docker-compose*.yml
5
+
6
+ # Git
7
+ .git/
8
+ .gitignore
9
+
10
+ # Python Virtual Environment
11
+ .venv/
12
+ venv/
13
+ env/
14
+ ENV/
15
+ env.bak/
16
+ venv.bak/
17
+
18
+ # Python cache
19
+ __pycache__/
20
+ *.pyc
21
+ *.py[cod]
22
+ *$py.class
23
+ *.so
24
+
25
+ # Build artifacts
26
+ dist/
27
+ build/
28
+ develop-eggs/
29
+ downloads/
30
+ eggs/
31
+ .eggs/
32
+ lib/
33
+ lib64/
34
+ parts/
35
+ sdist/
36
+ var/
37
+ wheels/
38
+ share/python-wheels/
39
+ *.egg-info/
40
+ .installed.cfg
41
+ *.egg
42
+ MANIFEST
43
+
44
+ # IDE
45
+ .vscode/
46
+ .idea/
47
+ *.swp
48
+ *.swo
49
+ *~
50
+
51
+ # OS
52
+ .DS_Store
53
+ .DS_Store?
54
+ ._*
55
+ .Spotlight-V100
56
+ .Trashes
57
+ ehthumbs.db
58
+ Thumbs.db
59
+
60
+ # Documentation (keep README.md)
61
+ docs/
62
+ *.md
63
+ !README.md
64
+
65
+ # Tests
66
+ tests/
67
+ test_*/
68
+ *_test.py
69
+ **/test_*.py
70
+
71
+ # Data files (you may want to adjust these based on your needs)
72
+ *.csv
73
+ *.json
74
+ *.pkl
75
+ *.parquet
76
+
77
+ # Logs
78
+ *.log
79
+ logs/
80
+
81
+ # Temporary files
82
+ tmp/
83
+ temp/
84
+ .tmp/
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+ *.pdf
12
+ *.csv
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12.10
Dockerfile CHANGED
@@ -1,21 +1,40 @@
1
- FROM python:3.13.5-slim
2
-
3
- WORKDIR /app
4
-
5
- RUN apt-get update && apt-get install -y \
6
- build-essential \
7
- curl \
8
- git \
9
- && rm -rf /var/lib/apt/lists/*
10
-
11
- COPY requirements.txt ./
12
- COPY src/ ./src/
13
- COPY app/ ./app/
14
-
15
- RUN pip3 install -r requirements.txt
16
-
17
- EXPOSE 8501
18
-
19
- HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  ENTRYPOINT ["python", "app/run_demo.py"]
 
1
+ # Dockerfile for QualiVec Streamlit Demo
2
+
3
+ # 1. Base Image
4
+ FROM python:3.12-slim
5
+
6
+ # 2. Set the working directory
7
+ WORKDIR /app
8
+
9
+ # 3. Install system dependencies
10
+ RUN apt-get update && apt-get install -y \
11
+ build-essential \
12
+ curl \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # 4. Install uv - the fast Python package manager
16
+ RUN pip install --no-cache-dir uv
17
+
18
+ # 5. Copy dependency definition files and README (required for package build)
19
+ COPY pyproject.toml uv.lock README.md ./
20
+
21
+ # 6. Copy source code (needed for package installation)
22
+ COPY src/ ./src/
23
+
24
+ # 7. Install Python dependencies using uv
25
+ # 'uv pip install .' reads pyproject.toml and installs the project dependencies
26
+ RUN uv pip install --system --no-cache-dir .
27
+
28
+ # 8. Copy the rest of the application source code
29
+ # Make sure you have a .dockerignore file to exclude .venv
30
+ COPY . .
31
+
32
+ # 9. Expose the port Streamlit runs on
33
+ EXPOSE 8501
34
+
35
+ # 10. Add a health check to verify the app is running
36
+ HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
37
+
38
+ # 11. Define the entry point to use the run_demo.py script via uv
39
+ # ENTRYPOINT ["uv", "run", "app/run_demo.py"]
40
  ENTRYPOINT ["python", "app/run_demo.py"]
README.md CHANGED
@@ -1,19 +1,712 @@
1
- ---
2
- title: Quailvec
3
- emoji: πŸš€
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: Streamlit template space
12
- ---
13
-
14
- # Welcome to Streamlit!
15
-
16
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
17
-
18
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
19
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # QualiVec
2
+
3
+ **QualiVec** is a Python library for scalable qualitative content analysis powered by Large Language Model (LLM) embeddings. It bridges qualitative content analysis with machine learning by leveraging the semantic understanding capabilities of Large Language Models. Instead of relying on simple keyword matching or manually coding large datasets, QualiVec uses embedding vectors to capture semantic meaning and perform classification based on similarity to reference vectors.
4
+
5
+ Key features:
6
+ - LLM-based embedding generation
7
+ - Semantic similarity assessment using cosine similarity
8
+ - Deductive and inductive coding support
9
+ - Reference vector creation from labeled corpora
10
+ - Corpus-driven clustering for robust semantic anchor construction
11
+ - Supports large-scale document classification
12
+ - Domain-agnostic and model-flexible design
13
+ - Human-level performance in multi-domain content analysis
14
+ - Bootstrap evaluation with confidence intervals
15
+ - Threshold optimization for classification performance
16
+
17
+ ## πŸ’» Installation
18
+
19
+ ```bash
20
+ pip install qualivec
21
+ ```
22
+
23
+ For development installation:
24
+
25
+ ```bash
26
+ git clone https://github.com/AkhilVaidya91/QualiVec.git
27
+ cd qualivec
28
+ pip install -e .
29
+ ```
30
+
31
+ ## πŸ–₯️ Interactive Demo
32
+
33
+ QualiVec includes a comprehensive Streamlit web application that provides an interactive demonstration of the library's capabilities. The demo allows users to upload their own data and experience the full workflow of qualitative content analysis using LLM embeddings.
34
+
35
+
36
+ ### Demo Features
37
+
38
+ - **Interactive Data Upload**: Upload your own CSV files for reference and labeled data
39
+ - **Model Configuration**: Choose from different pre-trained embedding models
40
+ - **Threshold Optimization**: Automatically find the optimal similarity threshold
41
+ - **Real-time Classification**: See classification results as they happen
42
+ - **Comprehensive Evaluation**: View detailed performance metrics and visualizations
43
+ - **Bootstrap Analysis**: Get confidence intervals for robust evaluation
44
+ - **Download Results**: Export classification results and metrics
45
+
46
+ ### Getting Started with Demo
47
+
48
+ 1. **Install Dependencies**:
49
+ ```bash
50
+ pip install -e .
51
+ ```
52
+
53
+ 2. **Run the Demo**:
54
+ ```bash
55
+ cd app
56
+ uv run run_demo.py
57
+ ```
58
+
59
+ 3. **Access the Demo**:
60
+ Open your browser and navigate to `http://localhost:8501`
61
+
62
+ ### Demo Walkthrough
63
+
64
+ #### 1. Data Upload Page
65
+ Upload your reference and labeled data files. The demo validates file formats and shows data statistics.
66
+
67
+ ![Data Upload Interface](assets/data_upload.png)
68
+
69
+ #### 2. Configuration Page
70
+ Configure embedding models and optimization parameters. Choose from multiple pre-trained models and set classification thresholds.
71
+
72
+ ![Configuration Interface](assets/config.png)
73
+
74
+ #### 3. Classification Page
75
+ Run the classification process with real-time progress updates. View optimization results and threshold analysis.
76
+
77
+ ![Classification Process](assets/optim.png)
78
+
79
+ #### 4. Results Page
80
+ Examine detailed evaluation metrics, confusion matrices, bootstrap confidence intervals, and sample predictions.
81
+
82
+ ![Results Dashboard](assets/bootstrap.png)
83
+
84
+ ### Data Format Requirements
85
+
86
+ #### Reference Data (CSV)
87
+ Your reference data should contain:
88
+ - `tag`: The class/category label
89
+ - `sentence`: The example text for that category
90
+
91
+ Example:
92
+
93
+ | tag | sentence |
94
+ |----------|---------------------------------|
95
+ | Positive | This is absolutely fantastic! |
96
+ | Negative | This is terrible and disappointing |
97
+ | Neutral | This is okay I guess |
98
+
99
+ #### Labeled Data (CSV)
100
+ Your labeled data should contain:
101
+ - `sentence`: The text to be classified
102
+ - `Label`: The true class/category (for evaluation)
103
+
104
+ Example:
105
+
106
+ | sentence | Label |
107
+ |------------------------------------|----------|
108
+ | I love this product so much! | Positive |
109
+ | Not very good quality | Negative |
110
+ | Average product nothing special | Neutral |
111
+
112
+
113
+ ## πŸš€ Quick Start
114
+
115
+ Here's a simple example to classify text data using reference vectors:
116
+
117
+ ```python
118
+ from qualivec.classification import Classifier
119
+
120
+ # Initialize classifier
121
+ classifier = Classifier(verbose=True)
122
+
123
+ # Load models
124
+ classifier.load_models(model_name="sentence-transformers/all-MiniLM-L6-v2", threshold=0.7)
125
+
126
+ # Prepare reference vectors
127
+ reference_data = classifier.prepare_reference_vectors(
128
+ reference_path="path/to/reference_vectors.csv",
129
+ class_column="class",
130
+ node_column="matching_node"
131
+ )
132
+
133
+ # Classify corpus
134
+ results_df = classifier.classify(
135
+ corpus_path="path/to/corpus.csv",
136
+ reference_data=reference_data,
137
+ sentence_column="sentence",
138
+ output_path="path/to/results.csv"
139
+ )
140
+
141
+ # Display distribution of classifications
142
+ print(results_df["predicted_class"].value_counts())
143
+ ```
144
+
145
+ ![QualiVec Classification Results](assets/distributions.png)
146
+
147
+ ## 🧩 Core Concepts
148
+ | Concept | Description |
149
+ |----------------------|--------------------------------------------------------------------------------------------------|
150
+ | **Reference Vectors**| Semantic anchors that define each class or category, curated as representative example texts. |
151
+ | **Similarity Threshold** | Determines how similar a text must be to a reference vector to be classified as that class; higher values are more restrictive. |
152
+ | **Embedding** | Numerical vector representations of text that capture semantic meaning; similar texts have similar embeddings. |
153
+ | **Semantic Matching**| Uses cosine similarity between embeddings to assess how close texts are to reference vectors. |
154
+ | **Bootstrap Evaluation** | Statistical method for estimating uncertainty in evaluation metrics by resampling with replacement. |
155
+
156
+ ## 🧰 Components
157
+
158
+ ### Data Loading and Preparation
159
+
160
+ The `DataLoader` class handles loading and validation of data:
161
+
162
+ ```python
163
+ from qualivec.data import DataLoader
164
+
165
+ # Initialize data loader
166
+ data_loader = DataLoader(verbose=True)
167
+
168
+ # Load corpus
169
+ corpus_df = data_loader.load_corpus(
170
+ filepath="path/to/corpus.csv",
171
+ sentence_column="sentence"
172
+ )
173
+
174
+ # Load reference vectors
175
+ reference_df = data_loader.load_reference_vectors(
176
+ filepath="path/to/reference_vectors.csv",
177
+ class_column="class",
178
+ node_column="matching_node"
179
+ )
180
+
181
+ # Load labeled data for evaluation
182
+ labeled_df = data_loader.load_labeled_data(
183
+ filepath="path/to/labeled_data.csv",
184
+ label_column="label"
185
+ )
186
+
187
+ # Save results
188
+ data_loader.save_dataframe(df=results_df, filepath="path/to/output.csv")
189
+ ```
190
+
191
+ ### Embedding Generation
192
+
193
+ The `EmbeddingModel` class generates embeddings from text:
194
+
195
+ ```python
196
+ from qualivec.embedding import EmbeddingModel
197
+
198
+ # Initialize embedding model
199
+ model = EmbeddingModel(
200
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
201
+ device=None, # Auto-selects CPU or GPU
202
+ cache_dir=None,
203
+ verbose=True
204
+ )
205
+
206
+ # Generate embeddings for a list of texts
207
+ texts = ["This is a sample text", "Another example text"]
208
+ embeddings = model.embed_texts(texts, batch_size=32)
209
+
210
+ # Generate embeddings from a DataFrame column
211
+ embeddings = model.embed_dataframe(df, text_column="sentence", batch_size=32)
212
+
213
+ # Generate embeddings for reference vectors
214
+ reference_data = model.embed_reference_vectors(
215
+ df=reference_df,
216
+ class_column="class",
217
+ node_column="matching_node",
218
+ batch_size=32
219
+ )
220
+ ```
221
+
222
+ ### Semantic Matching
223
+
224
+ The `SemanticMatcher` class performs semantic matching using cosine similarity:
225
+
226
+ ```python
227
+ from qualivec.matching import SemanticMatcher
228
+
229
+ # Initialize matcher with similarity threshold
230
+ matcher = SemanticMatcher(threshold=0.7, verbose=True)
231
+
232
+ # Match query embeddings against reference vectors
233
+ match_results = matcher.match(
234
+ query_embeddings=query_embeddings,
235
+ reference_data=reference_data,
236
+ return_similarities=False
237
+ )
238
+
239
+ # Classify an entire corpus
240
+ classified_df = matcher.classify_corpus(
241
+ corpus_embeddings=corpus_embeddings,
242
+ reference_data=reference_data,
243
+ corpus_df=corpus_df
244
+ )
245
+ ```
246
+
247
+ ### Classification
248
+
249
+ The `Classifier` class combines embedding and matching for end-to-end classification:
250
+
251
+ ```python
252
+ from qualivec.classification import Classifier
253
+
254
+ # Initialize classifier
255
+ classifier = Classifier(verbose=True)
256
+
257
+ # Load models
258
+ classifier.load_models(
259
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
260
+ threshold=0.7
261
+ )
262
+
263
+ # Prepare reference vectors
264
+ reference_data = classifier.prepare_reference_vectors(
265
+ reference_path="path/to/reference_vectors.csv",
266
+ class_column="class",
267
+ node_column="matching_node"
268
+ )
269
+
270
+ # Classify corpus
271
+ results_df = classifier.classify(
272
+ corpus_path="path/to/corpus.csv",
273
+ reference_data=reference_data,
274
+ sentence_column="sentence",
275
+ output_path="path/to/results.csv"
276
+ )
277
+
278
+ # Evaluate classification performance
279
+ eval_results = classifier.evaluate_classification(
280
+ labeled_path="path/to/labeled_data.csv",
281
+ reference_data=reference_data,
282
+ sentence_column="sentence",
283
+ label_column="label",
284
+ optimize_threshold=False
285
+ )
286
+ ```
287
+
288
+ ### Evaluation
289
+
290
+ The `Evaluator` class evaluates classification performance:
291
+
292
+ ```python
293
+ from qualivec.evaluation import Evaluator
294
+
295
+ # Initialize evaluator
296
+ evaluator = Evaluator(verbose=True)
297
+
298
+ # Simple evaluation
299
+ results = evaluator.evaluate(
300
+ true_labels=true_labels,
301
+ predicted_labels=predicted_labels,
302
+ class_names=class_names
303
+ )
304
+
305
+ # Bootstrap evaluation with confidence intervals
306
+ bootstrap_results = evaluator.bootstrap_evaluate(
307
+ true_labels=true_labels,
308
+ predicted_labels=predicted_labels,
309
+ n_iterations=1000,
310
+ confidence_levels=[0.9, 0.95, 0.99],
311
+ random_seed=42
312
+ )
313
+
314
+ # Plot confusion matrix
315
+ evaluator.plot_confusion_matrix(
316
+ confusion_matrix=results['confusion_matrix'],
317
+ class_names=results['confusion_matrix_labels']
318
+ )
319
+
320
+ # Plot bootstrap distributions
321
+ evaluator.plot_bootstrap_distributions(bootstrap_results)
322
+ ```
323
+
324
+ ![QualiVec Confusion Matrix](assets/confusion_matrix.png)
325
+
326
+ ### Threshold Optimization
327
+
328
+ The `ThresholdOptimizer` class finds the optimal similarity threshold:
329
+
330
+ ```python
331
+ from qualivec.optimization import ThresholdOptimizer
332
+
333
+ # Initialize optimizer
334
+ optimizer = ThresholdOptimizer(verbose=True)
335
+
336
+ # Optimize threshold
337
+ optimization_results = optimizer.optimize(
338
+ query_embeddings=query_embeddings,
339
+ reference_data=reference_data,
340
+ true_labels=true_labels,
341
+ start=0.5,
342
+ end=0.9,
343
+ step=0.01,
344
+ metric="f1_macro",
345
+ bootstrap=True,
346
+ n_bootstrap=100,
347
+ confidence_level=0.95
348
+ )
349
+
350
+ # Plot optimization results
351
+ optimizer.plot_optimization_results(
352
+ results=optimization_results,
353
+ metrics=["accuracy", "precision_macro", "recall_macro", "f1_macro"]
354
+ )
355
+
356
+ # Plot class distribution at different thresholds
357
+ optimizer.plot_class_distribution(
358
+ results=optimization_results,
359
+ top_n=10
360
+ )
361
+ ```
362
+
363
+ ### Sampling
364
+
365
+ The `Sampler` class helps create samples for manual coding:
366
+
367
+ ```python
368
+ from qualivec.sampling import Sampler
369
+
370
+ # Initialize sampler
371
+ sampler = Sampler(verbose=True)
372
+
373
+ # Random sampling
374
+ random_sample = sampler.sample(
375
+ df=corpus_df,
376
+ sampling_type="random",
377
+ sample_size=0.1, # 10% of corpus
378
+ seed=42,
379
+ label_column="Label"
380
+ )
381
+
382
+ # Stratified sampling
383
+ stratified_sample = sampler.sample(
384
+ df=corpus_df,
385
+ sampling_type="stratified",
386
+ sample_size=0.1,
387
+ stratify_column="category",
388
+ seed=42,
389
+ label_column="Label"
390
+ )
391
+ ```
392
+
393
+ ## πŸ“š Usage Examples
394
+
395
+ ### Preparing Reference Vectors
396
+
397
+ Reference vectors are the foundation of classification in QualiVec. Here's how to prepare them:
398
+
399
+ ```python
400
+ # Step 1: Sample data for manual coding
401
+ from qualivec.sampling import Sampler
402
+
403
+ sampler = Sampler(verbose=True)
404
+ sample_df = sampler.sample(
405
+ df=corpus_df,
406
+ sampling_type="stratified",
407
+ sample_size=0.05, # 5% of corpus
408
+ stratify_column="document_type"
409
+ )
410
+
411
+ # Step 2: Save sample for manual coding
412
+ sample_df.to_csv("sample_for_coding.csv", index=False)
413
+
414
+ # Step 3: After manual coding, load the coded data
415
+ from qualivec.data import DataLoader
416
+
417
+ data_loader = DataLoader(verbose=True)
418
+ coded_df = data_loader.load_labeled_data(
419
+ filepath="coded_sample.csv",
420
+ label_column="coded_class"
421
+ )
422
+
423
+ # Step 4: Generate embeddings for reference vectors
424
+ from qualivec.embedding import EmbeddingModel
425
+
426
+ model = EmbeddingModel(verbose=True)
427
+ reference_data = model.embed_reference_vectors(
428
+ df=coded_df,
429
+ class_column="coded_class",
430
+ node_column="sentence"
431
+ )
432
+
433
+ # Step 5: Save reference data for future use
434
+ import pickle
435
+ with open("reference_data.pkl", "wb") as f:
436
+ pickle.dump(reference_data, f)
437
+ ```
438
+
439
+ ### Classifying New Data
440
+
441
+ Once reference vectors are prepared, you can classify new data:
442
+
443
+ ```python
444
+ # Load reference data
445
+ import pickle
446
+ with open("reference_data.pkl", "rb") as f:
447
+ reference_data = pickle.load(f)
448
+
449
+ # Initialize classifier
450
+ from qualivec.classification import Classifier
451
+
452
+ classifier = Classifier(verbose=True)
453
+ classifier.load_models(threshold=0.7)
454
+
455
+ # Classify corpus
456
+ results_df = classifier.classify(
457
+ corpus_path="new_corpus.csv",
458
+ reference_data=reference_data,
459
+ sentence_column="sentence",
460
+ output_path="classified_corpus.csv"
461
+ )
462
+
463
+ # Analyze results
464
+ import pandas as pd
465
+ import matplotlib.pyplot as plt
466
+
467
+ # Distribution of classes
468
+ plt.figure(figsize=(10, 6))
469
+ results_df["predicted_class"].value_counts().plot(kind="bar")
470
+ plt.title("Distribution of Predicted Classes")
471
+ plt.tight_layout()
472
+ plt.show()
473
+
474
+ # Average similarity by class
475
+ results_df.groupby("predicted_class")["similarity_score"].mean().sort_values().plot(kind="barh")
476
+ plt.title("Average Similarity Score by Class")
477
+ plt.tight_layout()
478
+ plt.show()
479
+ ```
480
+
481
+ ### Evaluating Classification Performance
482
+
483
+ To assess how well your classification is performing:
484
+
485
+ ```python
486
+ # Load labeled data
487
+ from qualivec.data import DataLoader
488
+
489
+ data_loader = DataLoader(verbose=True)
490
+ labeled_df = data_loader.load_labeled_data(
491
+ filepath="labeled_test_set.csv",
492
+ label_column="true_label"
493
+ )
494
+
495
+ # Generate embeddings
496
+ from qualivec.embedding import EmbeddingModel
497
+
498
+ model = EmbeddingModel(verbose=True)
499
+ labeled_embeddings = model.embed_dataframe(
500
+ df=labeled_df,
501
+ text_column="sentence"
502
+ )
503
+
504
+ # Initialize evaluator
505
+ from qualivec.evaluation import Evaluator
506
+ from qualivec.matching import SemanticMatcher
507
+
508
+ matcher = SemanticMatcher(threshold=0.7, verbose=True)
509
+ match_results = matcher.match(labeled_embeddings, reference_data)
510
+ predicted_labels = match_results["predicted_class"].tolist()
511
+ true_labels = labeled_df["true_label"].tolist()
512
+
513
+ evaluator = Evaluator(verbose=True)
514
+
515
+ # Simple evaluation
516
+ eval_results = evaluator.evaluate(
517
+ true_labels=true_labels,
518
+ predicted_labels=predicted_labels
519
+ )
520
+
521
+ # Bootstrap evaluation
522
+ bootstrap_results = evaluator.bootstrap_evaluate(
523
+ true_labels=true_labels,
524
+ predicted_labels=predicted_labels,
525
+ n_iterations=1000
526
+ )
527
+
528
+ # Plot confusion matrix
529
+ evaluator.plot_confusion_matrix(
530
+ confusion_matrix=eval_results['confusion_matrix'],
531
+ class_names=eval_results['confusion_matrix_labels']
532
+ )
533
+
534
+ # Plot bootstrap distributions
535
+ evaluator.plot_bootstrap_distributions(bootstrap_results)
536
+ ```
537
+
538
+ ### Optimizing Similarity Thresholds
539
+
540
+ To find the optimal similarity threshold for your classification:
541
+
542
+ ```python
543
+ # Initialize optimizer
544
+ from qualivec.optimization import ThresholdOptimizer
545
+
546
+ optimizer = ThresholdOptimizer(verbose=True)
547
+
548
+ # Optimize threshold
549
+ optimization_results = optimizer.optimize(
550
+ query_embeddings=labeled_embeddings,
551
+ reference_data=reference_data,
552
+ true_labels=true_labels,
553
+ start=0.5,
554
+ end=0.9,
555
+ step=0.01,
556
+ metric="f1_macro"
557
+ )
558
+
559
+ # Plot optimization results
560
+ optimizer.plot_optimization_results(
561
+ results=optimization_results,
562
+ metrics=["accuracy", "f1_macro"]
563
+ )
564
+
565
+ # Plot class distribution
566
+ optimizer.plot_class_distribution(
567
+ results=optimization_results,
568
+ top_n=5
569
+ )
570
+
571
+ # Use the optimal threshold
572
+ optimal_threshold = optimization_results["optimal_threshold"]
573
+ print(f"Optimal threshold: {optimal_threshold}")
574
+
575
+ # Create a new matcher with the optimal threshold
576
+ matcher = SemanticMatcher(threshold=optimal_threshold, verbose=True)
577
+ ```
578
+
579
+ ### Sampling for Manual Coding
580
+
581
+ To create samples for manual coding or validation:
582
+
583
+ ```python
584
+ from qualivec.sampling import Sampler
585
+
586
+ sampler = Sampler(verbose=True)
587
+
588
+ # Random sampling
589
+ random_sample = sampler.sample(
590
+ df=corpus_df,
591
+ sampling_type="random",
592
+ sample_size=100, # 100 documents
593
+ seed=42
594
+ )
595
+
596
+ # Stratified sampling by predicted class
597
+ stratified_sample = sampler.sample(
598
+ df=results_df,
599
+ sampling_type="stratified",
600
+ sample_size=0.1, # 10% of corpus
601
+ stratify_column="predicted_class",
602
+ seed=42
603
+ )
604
+
605
+ # Save samples for manual coding
606
+ random_sample.to_csv("random_sample_for_coding.csv", index=False)
607
+ stratified_sample.to_csv("stratified_sample_for_coding.csv", index=False)
608
+ ```
609
+
610
+ ## πŸ“– API Reference
611
+
612
+ ### DataLoader
613
+
614
+ ```python
615
+ class DataLoader:
616
+ def __init__(self, verbose=True)
617
+ def load_corpus(self, filepath, sentence_column="sentence")
618
+ def load_reference_vectors(self, filepath, class_column="class", node_column="matching_node")
619
+ def load_labeled_data(self, filepath, label_column="label")
620
+ def save_dataframe(self, df, filepath)
621
+ def validate_labels(self, labeled_df, reference_df, label_column="label", class_column="class")
622
+ ```
623
+
624
+ ### Sampler
625
+
626
+ ```python
627
+ class Sampler:
628
+ def __init__(self, verbose=True)
629
+ def sample(self, df, sampling_type="random", sample_size=0.1, stratify_column=None,
630
+ seed=None, label_column="Label")
631
+ ```
632
+
633
+ ### EmbeddingModel
634
+
635
+ ```python
636
+ class EmbeddingModel:
637
+ def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2",
638
+ device=None, cache_dir=None, verbose=True)
639
+ def embed_texts(self, texts, batch_size=32)
640
+ def embed_dataframe(self, df, text_column, batch_size=32)
641
+ def embed_reference_vectors(self, df, class_column="class",
642
+ node_column="matching_node", batch_size=32)
643
+ ```
644
+
645
+ ### SemanticMatcher
646
+
647
+ ```python
648
+ class SemanticMatcher:
649
+ def __init__(self, threshold=0.7, verbose=True)
650
+ def match(self, query_embeddings, reference_data, return_similarities=False)
651
+ def classify_corpus(self, corpus_embeddings, reference_data, corpus_df)
652
+ ```
653
+
654
+ ### Evaluator
655
+
656
+ ```python
657
+ class Evaluator:
658
+ def __init__(self, verbose=True)
659
+ def evaluate(self, true_labels, predicted_labels, class_names=None)
660
+ def bootstrap_evaluate(self, true_labels, predicted_labels, n_iterations=1000,
661
+ confidence_levels=[0.9, 0.95, 0.99], random_seed=None)
662
+ def plot_confusion_matrix(self, confusion_matrix, class_names,
663
+ figsize=(10, 8), title="Confusion Matrix")
664
+ def plot_bootstrap_distributions(self, bootstrap_results, figsize=(12, 8))
665
+ ```
666
+
667
+ ### ThresholdOptimizer
668
+
669
+ ```python
670
+ class ThresholdOptimizer:
671
+ def __init__(self, verbose=True)
672
+ def optimize(self, query_embeddings, reference_data, true_labels,
673
+ start=0.0, end=1.0, step=0.01, metric="f1_macro",
674
+ bootstrap=True, n_bootstrap=100, confidence_level=0.95, random_seed=None)
675
+ def plot_optimization_results(self, results, metrics=None, figsize=(12, 6))
676
+ def plot_class_distribution(self, results, top_n=10, figsize=(12, 8))
677
+ ```
678
+
679
+ ### Classifier
680
+
681
+ ```python
682
+ class Classifier:
683
+ def __init__(self, embedding_model=None, matcher=None, verbose=True)
684
+ def load_models(self, model_name="sentence-transformers/all-MiniLM-L6-v2", threshold=0.7)
685
+ def prepare_reference_vectors(self, reference_path, class_column="class",
686
+ node_column="matching_node")
687
+ def classify(self, corpus_path, reference_data, sentence_column="sentence",
688
+ output_path=None)
689
+ def evaluate_classification(self, labeled_path, reference_data,
690
+ sentence_column="sentence", label_column="label",
691
+ optimize_threshold=False, start=0.5, end=0.9, step=0.01)
692
+ ```
693
+
694
+ ## πŸ’‘ Best Practices
695
+
696
+ 1. **Reference Vector Quality**: The quality of your reference vectors greatly impacts classification performance. Ensure they are representative and distinct.
697
+
698
+ 2. **Model Selection**: Larger models generally provide better semantic understanding but are slower. For simple tasks, smaller models like MiniLM may be sufficient.
699
+
700
+ 3. **Threshold Tuning**: Always optimize the similarity threshold for your specific dataset and task.
701
+
702
+ 4. **Evaluation**: Use bootstrap evaluation to get confidence intervals around your metrics, especially for smaller datasets.
703
+
704
+ 5. **Class Imbalance**: Be aware of class imbalance in your data. Consider using stratified sampling for creating evaluation sets.
705
+
706
+ 6. **Preprocessing**: Clean and preprocess your text data before embedding for best results.
707
+
708
+ 7. **Out-of-Domain Detection**: Use the "Other" class (when similarity is below threshold) to identify texts that might need new reference vectors.
709
+
710
+ ## πŸ“„ License
711
+
712
+ This project is licensed under the MIT License - see the LICENSE file for details.
app/app.py CHANGED
@@ -1,916 +1,916 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import numpy as np
4
- import matplotlib.pyplot as plt
5
- import seaborn as sns
6
- import tempfile
7
- import os
8
- import sys
9
- from io import StringIO
10
- import plotly.express as px
11
- import plotly.graph_objects as go
12
- from plotly.subplots import make_subplots
13
-
14
- # Add the parent directory to sys.path to import the module
15
- sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
16
-
17
- from .data import DataLoader
18
- from .embedding import EmbeddingModel
19
- from .matching import SemanticMatcher
20
- from .classification import Classifier
21
- from .evaluation import Evaluator
22
- from .optimization import ThresholdOptimizer
23
-
24
- # Set page config
25
- st.set_page_config(
26
- page_title="QualiVec Demo",
27
- page_icon="πŸ”",
28
- layout="wide",
29
- initial_sidebar_state="expanded"
30
- )
31
-
32
- # Custom CSS for better styling
33
- st.markdown("""
34
- <style>
35
- .main-header {
36
- font-size: 2.5rem;
37
- font-weight: bold;
38
- color: #2E4057;
39
- text-align: center;
40
- margin-bottom: 2rem;
41
- }
42
- .section-header {
43
- font-size: 1.5rem;
44
- font-weight: bold;
45
- color: #048A81;
46
- margin-top: 2rem;
47
- margin-bottom: 1rem;
48
- }
49
- .metric-card {
50
- background-color: #f0f2f6;
51
- padding: 1rem;
52
- border-radius: 0.5rem;
53
- margin: 0.5rem 0;
54
- }
55
- .success-message {
56
- background-color: #d4edda;
57
- color: #155724;
58
- padding: 1rem;
59
- border-radius: 0.5rem;
60
- margin: 1rem 0;
61
- }
62
- .warning-message {
63
- background-color: #fff3cd;
64
- color: #856404;
65
- padding: 1rem;
66
- border-radius: 0.5rem;
67
- margin: 1rem 0;
68
- }
69
- </style>
70
- """, unsafe_allow_html=True)
71
-
72
- def main():
73
- st.markdown('<div class="main-header">πŸ” QualiVec Demo</div>', unsafe_allow_html=True)
74
- st.markdown("""
75
- <div style="text-align: center; margin-bottom: 2rem;">
76
- <p style="font-size: 1.2rem; color: #666;">
77
- Qualitative Content Analysis with LLM Embeddings
78
- </p>
79
- </div>
80
- """, unsafe_allow_html=True)
81
-
82
- # Sidebar for navigation
83
- st.sidebar.title("Navigation")
84
- page = st.sidebar.selectbox(
85
- "Choose a page",
86
- ["🏠 Home", "πŸ“Š Data Upload", "πŸ”§ Configuration", "🎯 Classification", "πŸ“ˆ Results"]
87
- )
88
-
89
- # Initialize session state
90
- if 'classifier' not in st.session_state:
91
- st.session_state.classifier = None
92
- if 'reference_data' not in st.session_state:
93
- st.session_state.reference_data = None
94
- if 'labeled_data' not in st.session_state:
95
- st.session_state.labeled_data = None
96
- if 'optimization_results' not in st.session_state:
97
- st.session_state.optimization_results = None
98
- if 'evaluation_results' not in st.session_state:
99
- st.session_state.evaluation_results = None
100
-
101
- # Route to different pages
102
- if page == "🏠 Home":
103
- show_home_page()
104
- elif page == "πŸ“Š Data Upload":
105
- show_data_upload_page()
106
- elif page == "πŸ”§ Configuration":
107
- show_configuration_page()
108
- elif page == "🎯 Classification":
109
- show_classification_page()
110
- elif page == "πŸ“ˆ Results":
111
- show_results_page()
112
-
113
- def show_home_page():
114
- st.markdown('<div class="section-header">Welcome to QualiVec</div>', unsafe_allow_html=True)
115
-
116
- col1, col2, col3 = st.columns([1, 2, 1])
117
-
118
- with col2:
119
- st.markdown("""
120
- ### What is QualiVec?
121
-
122
- QualiVec is a Python library that uses Large Language Model (LLM) embeddings for qualitative content analysis. It helps researchers and analysts classify text data by comparing it against reference examples.
123
-
124
- ### Key Features:
125
- - **Semantic Matching**: Uses advanced embedding models to find semantic similarity
126
- - **Threshold Optimization**: Automatically finds the best similarity threshold
127
- - **Comprehensive Evaluation**: Provides detailed metrics and visualizations
128
- - **Bootstrap Analysis**: Confidence intervals for robust evaluation
129
-
130
- ### How It Works:
131
- 1. **Upload Data**: Provide reference examples and data to classify
132
- 2. **Configure**: Set up embedding models and parameters
133
- 3. **Optimize**: Find the best threshold for classification
134
- 4. **Classify**: Apply the model to your data
135
- 5. **Evaluate**: Get detailed performance metrics
136
-
137
- ### Getting Started:
138
- Use the sidebar to navigate through the demo. Start with **Data Upload** to begin your analysis.
139
- """)
140
-
141
- # Add sample data info
142
- st.markdown('<div class="section-header">Sample Data Format</div>', unsafe_allow_html=True)
143
-
144
- col1, col2 = st.columns(2)
145
-
146
- with col1:
147
- st.markdown("**Reference Data Format:**")
148
- sample_ref = pd.DataFrame({
149
- 'tag': ['Positive', 'Negative', 'Neutral'],
150
- 'sentence': ['This is great!', 'This is terrible', 'This is okay']
151
- })
152
- st.dataframe(sample_ref, use_container_width=True)
153
-
154
- with col2:
155
- st.markdown("**Labeled Data Format:**")
156
- sample_labeled = pd.DataFrame({
157
- 'sentence': ['I love this product', 'Not very good', 'Average quality'],
158
- 'Label': ['Positive', 'Negative', 'Neutral']
159
- })
160
- st.dataframe(sample_labeled, use_container_width=True)
161
-
162
- def show_data_upload_page():
163
- st.markdown('<div class="section-header">Data Upload</div>', unsafe_allow_html=True)
164
-
165
- col1, col2 = st.columns(2)
166
-
167
- with col1:
168
- st.markdown("### Reference Data")
169
- st.markdown("Upload a CSV file containing reference examples with columns: `tag` (class) and `sentence` (example text)")
170
-
171
- reference_file = st.file_uploader(
172
- "Choose reference data file",
173
- type=['csv'],
174
- key='reference_file'
175
- )
176
-
177
- if reference_file is not None:
178
- try:
179
- reference_df = pd.read_csv(reference_file)
180
- st.success("Reference data loaded successfully!")
181
- st.dataframe(reference_df.head(), use_container_width=True)
182
-
183
- # Validate columns
184
- required_cols = ['tag', 'sentence']
185
- missing_cols = [col for col in required_cols if col not in reference_df.columns]
186
-
187
- if missing_cols:
188
- st.error(f"Missing required columns: {missing_cols}")
189
- else:
190
- # Prepare reference data
191
- reference_df = reference_df.rename(columns={
192
- 'tag': 'class',
193
- 'sentence': 'matching_node'
194
- })
195
- st.session_state.reference_data = reference_df
196
-
197
- # Show statistics
198
- st.markdown("**Data Statistics:**")
199
- st.write(f"- Total examples: {len(reference_df)}")
200
- st.write(f"- Unique classes: {reference_df['class'].nunique()}")
201
- st.write(f"- Class distribution:")
202
- st.write(reference_df['class'].value_counts())
203
-
204
- except Exception as e:
205
- st.error(f"Error loading reference data: {str(e)}")
206
-
207
- with col2:
208
- st.markdown("### Labeled Data")
209
- st.markdown("Upload a CSV file containing data to classify with columns: `sentence` (text) and `Label` (true class)")
210
-
211
- labeled_file = st.file_uploader(
212
- "Choose labeled data file",
213
- type=['csv'],
214
- key='labeled_file'
215
- )
216
-
217
- if labeled_file is not None:
218
- try:
219
- labeled_df = pd.read_csv(labeled_file)
220
- st.success("Labeled data loaded successfully!")
221
- st.dataframe(labeled_df.head(), use_container_width=True)
222
-
223
- # Validate columns
224
- required_cols = ['sentence', 'Label']
225
- missing_cols = [col for col in required_cols if col not in labeled_df.columns]
226
-
227
- if missing_cols:
228
- st.error(f"Missing required columns: {missing_cols}")
229
- else:
230
- # Prepare labeled data
231
- labeled_df = labeled_df.rename(columns={'Label': 'label'})
232
- labeled_df['label'] = labeled_df['label'].replace('0', 'Other')
233
- st.session_state.labeled_data = labeled_df
234
-
235
- # Show statistics
236
- st.markdown("**Data Statistics:**")
237
- st.write(f"- Total samples: {len(labeled_df)}")
238
- st.write(f"- Unique labels: {labeled_df['label'].nunique()}")
239
- st.write(f"- Label distribution:")
240
- st.write(labeled_df['label'].value_counts())
241
-
242
- except Exception as e:
243
- st.error(f"Error loading labeled data: {str(e)}")
244
-
245
- # Show data compatibility check
246
- if st.session_state.reference_data is not None and st.session_state.labeled_data is not None:
247
- st.markdown('<div class="section-header">Data Compatibility Check</div>', unsafe_allow_html=True)
248
-
249
- ref_classes = set(st.session_state.reference_data['class'].unique())
250
- labeled_classes = set(st.session_state.labeled_data['label'].unique())
251
-
252
- # Check for unknown classes
253
- unknown_classes = labeled_classes - ref_classes
254
-
255
- if unknown_classes:
256
- st.warning(f"Warning: Labels in labeled data not found in reference data: {unknown_classes}")
257
- else:
258
- st.success("βœ… Data compatibility check passed!")
259
-
260
- # Show class overlap
261
- st.markdown("**Class Overlap Analysis:**")
262
- col1, col2, col3 = st.columns(3)
263
-
264
- with col1:
265
- st.metric("Reference Classes", len(ref_classes))
266
- with col2:
267
- st.metric("Labeled Classes", len(labeled_classes))
268
- with col3:
269
- st.metric("Common Classes", len(ref_classes.intersection(labeled_classes)))
270
-
271
- def show_configuration_page():
272
- st.markdown('<div class="section-header">Model Configuration</div>', unsafe_allow_html=True)
273
-
274
- # Check if data is loaded
275
- if st.session_state.reference_data is None or st.session_state.labeled_data is None:
276
- st.warning("Please upload both reference and labeled data first.")
277
- return
278
-
279
- col1, col2 = st.columns(2)
280
-
281
- with col1:
282
- st.markdown("### Embedding Model")
283
-
284
- # Model type selection
285
- model_type = st.selectbox(
286
- "Choose model type",
287
- ["HuggingFace", "Gemini"],
288
- help="Select the type of embedding model to use"
289
- )
290
-
291
- # Model selection based on type
292
- if model_type == "HuggingFace":
293
- model_options = [
294
- "sentence-transformers/all-MiniLM-L6-v2",
295
- "sentence-transformers/all-mpnet-base-v2",
296
- "sentence-transformers/distilbert-base-nli-mean-tokens"
297
- ]
298
-
299
- selected_model = st.selectbox(
300
- "Choose HuggingFace model",
301
- model_options,
302
- help="Select the pre-trained HuggingFace model for generating embeddings"
303
- )
304
- else: # Gemini
305
- gemini_models = [
306
- "gemini-embedding-001",
307
- "text-embedding-004"
308
- ]
309
-
310
- selected_model = st.selectbox(
311
- "Choose Gemini model",
312
- gemini_models,
313
- help="Select the Gemini embedding model for generating embeddings"
314
- )
315
-
316
- # Calculate total texts to process
317
- total_texts = 0
318
- if st.session_state.reference_data is not None:
319
- total_texts += len(st.session_state.reference_data)
320
- if st.session_state.labeled_data is not None:
321
- total_texts += len(st.session_state.labeled_data)
322
-
323
- st.warning(
324
- f"⚠️ **Gemini API Rate Limits (Free Tier)**\\n\\n"
325
- f"- 1,500 requests per day\\n"
326
- f"- Each batch of 100 texts = 1 request\\n"
327
- f"- Your current dataset: ~{total_texts} texts\\n"
328
- f"- Estimated requests needed: ~{(total_texts // 100) + 1}\\n\\n"
329
- f"If you exceed quota, consider:\\n"
330
- f"1. Using a smaller dataset\\n"
331
- f"2. Switching to HuggingFace models (no limits)\\n"
332
- f"3. Upgrading to a paid API plan"
333
- )
334
-
335
- st.info("πŸ’‘ Note: Using Gemini embeddings requires GOOGLE_API_KEY environment variable to be set.")
336
-
337
- st.markdown("### Initial Threshold")
338
- initial_threshold = st.slider(
339
- "Initial similarity threshold",
340
- min_value=0.0,
341
- max_value=1.0,
342
- value=0.7,
343
- step=0.05,
344
- help="Cosine similarity threshold for classification"
345
- )
346
-
347
- with col2:
348
- st.markdown("### Optimization Parameters")
349
-
350
- optimize_threshold = st.checkbox(
351
- "Enable threshold optimization",
352
- value=True,
353
- help="Automatically find the best threshold"
354
- )
355
-
356
- if optimize_threshold:
357
- col2_1, col2_2 = st.columns(2)
358
-
359
- with col2_1:
360
- start_threshold = st.slider(
361
- "Start threshold",
362
- min_value=0.0,
363
- max_value=1.0,
364
- value=0.5,
365
- step=0.05
366
- )
367
-
368
- end_threshold = st.slider(
369
- "End threshold",
370
- min_value=0.0,
371
- max_value=1.0,
372
- value=0.9,
373
- step=0.05
374
- )
375
-
376
- with col2_2:
377
- step_size = st.slider(
378
- "Step size",
379
- min_value=0.005,
380
- max_value=0.05,
381
- value=0.01,
382
- step=0.005
383
- )
384
-
385
- optimization_metric = st.selectbox(
386
- "Optimization metric",
387
- ["f1_macro", "accuracy", "precision_macro", "recall_macro"]
388
- )
389
-
390
- # Load models button
391
- if st.button("Initialize Models", type="primary"):
392
- with st.spinner("Loading models... This may take a few minutes."):
393
- try:
394
- # Initialize classifier
395
- classifier = Classifier(verbose=False)
396
-
397
- # Determine model type parameter
398
- model_type_param = "gemini" if model_type == "Gemini" else "huggingface"
399
-
400
- classifier.load_models(
401
- model_name=selected_model,
402
- model_type=model_type_param,
403
- threshold=initial_threshold
404
- )
405
-
406
- # Prepare reference vectors
407
- with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp_ref:
408
- tmp_ref_path = tmp_ref.name
409
- st.session_state.reference_data.to_csv(tmp_ref_path, index=False)
410
-
411
- try:
412
- reference_data = classifier.prepare_reference_vectors(
413
- reference_path=tmp_ref_path,
414
- class_column='class',
415
- node_column='matching_node'
416
- )
417
- finally:
418
- # Ensure file is deleted even if an error occurs
419
- try:
420
- os.unlink(tmp_ref_path)
421
- except (OSError, PermissionError):
422
- pass # File might already be deleted or locked
423
-
424
- st.session_state.classifier = classifier
425
- st.session_state.reference_vectors = reference_data
426
- st.session_state.config = {
427
- 'model_type': model_type,
428
- 'model_name': selected_model,
429
- 'initial_threshold': initial_threshold,
430
- 'optimize_threshold': optimize_threshold,
431
- 'start_threshold': start_threshold if optimize_threshold else None,
432
- 'end_threshold': end_threshold if optimize_threshold else None,
433
- 'step_size': step_size if optimize_threshold else None,
434
- 'optimization_metric': optimization_metric if optimize_threshold else None
435
- }
436
-
437
- st.success("βœ… Models initialized successfully!")
438
-
439
- except Exception as e:
440
- st.error(f"Error initializing models: {str(e)}")
441
-
442
- # Show current configuration
443
- if st.session_state.classifier is not None:
444
- st.markdown('<div class="section-header">Current Configuration</div>', unsafe_allow_html=True)
445
-
446
- config = st.session_state.config
447
-
448
- col1, col2, col3 = st.columns(3)
449
-
450
- with col1:
451
- st.markdown("**Model Settings:**")
452
- st.write(f"- Model type: {config['model_type']}")
453
- st.write(f"- Model: {config['model_name']}")
454
- st.write(f"- Initial threshold: {config['initial_threshold']}")
455
-
456
- with col2:
457
- st.markdown("**Optimization:**")
458
- st.write(f"- Enabled: {config['optimize_threshold']}")
459
- if config['optimize_threshold']:
460
- st.write(f"- Range: {config['start_threshold']:.2f} - {config['end_threshold']:.2f}")
461
- st.write(f"- Step: {config['step_size']:.3f}")
462
-
463
- with col3:
464
- st.markdown("**Data:**")
465
- st.write(f"- Reference examples: {len(st.session_state.reference_data)}")
466
- st.write(f"- Labeled samples: {len(st.session_state.labeled_data)}")
467
-
468
- def show_classification_page():
469
- st.markdown('<div class="section-header">Classification & Optimization</div>', unsafe_allow_html=True)
470
-
471
- # Check if models are loaded
472
- if st.session_state.classifier is None:
473
- st.warning("Please configure and initialize models first.")
474
- return
475
-
476
- # Run classification
477
- if st.button("Run Classification", type="primary"):
478
- with st.spinner("Running classification and optimization..."):
479
- try:
480
- # Save labeled data to temporary file
481
- with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp_labeled:
482
- tmp_labeled_path = tmp_labeled.name
483
- st.session_state.labeled_data.to_csv(tmp_labeled_path, index=False)
484
-
485
- try:
486
- # Run optimization if enabled
487
- if st.session_state.config['optimize_threshold']:
488
- optimization_results = st.session_state.classifier.evaluate_classification(
489
- labeled_path=tmp_labeled_path,
490
- reference_data=st.session_state.reference_vectors,
491
- sentence_column='sentence',
492
- label_column='label',
493
- optimize_threshold=True,
494
- start=st.session_state.config['start_threshold'],
495
- end=st.session_state.config['end_threshold'],
496
- step=st.session_state.config['step_size']
497
- )
498
-
499
- st.session_state.optimization_results = optimization_results
500
- optimal_threshold = optimization_results["optimal_threshold"]
501
-
502
- # Update classifier with optimal threshold
503
- st.session_state.classifier.matcher = SemanticMatcher(
504
- threshold=optimal_threshold,
505
- verbose=False
506
- )
507
-
508
- st.success(f"βœ… Optimization completed! Optimal threshold: {optimal_threshold:.4f}")
509
-
510
- else:
511
- optimal_threshold = st.session_state.config['initial_threshold']
512
-
513
- # Run evaluation
514
- embedding_model = st.session_state.classifier.embedding_model
515
- data_loader = DataLoader(verbose=False)
516
- full_df = data_loader.load_labeled_data(tmp_labeled_path, label_column='label')
517
-
518
- # Generate embeddings
519
- full_embeddings = embedding_model.embed_dataframe(full_df, text_column='sentence')
520
-
521
- # Classify
522
- match_results = st.session_state.classifier.matcher.match(
523
- full_embeddings,
524
- st.session_state.reference_vectors
525
- )
526
- predicted_labels = match_results["predicted_class"].tolist()
527
- true_labels = full_df['label'].tolist()
528
-
529
- # Evaluate
530
- evaluator = Evaluator(verbose=False)
531
- eval_results = evaluator.evaluate(
532
- true_labels=true_labels,
533
- predicted_labels=predicted_labels,
534
- class_names=list(set(true_labels) | set(predicted_labels))
535
- )
536
-
537
- # Bootstrap evaluation
538
- bootstrap_results = evaluator.bootstrap_evaluate(
539
- true_labels=true_labels,
540
- predicted_labels=predicted_labels,
541
- n_iterations=100
542
- )
543
-
544
- st.session_state.evaluation_results = eval_results
545
- st.session_state.bootstrap_results = bootstrap_results
546
- st.session_state.predictions = {
547
- 'true_labels': true_labels,
548
- 'predicted_labels': predicted_labels,
549
- 'match_results': match_results,
550
- 'full_df': full_df
551
- }
552
-
553
- finally:
554
- # Ensure temporary file is deleted
555
- try:
556
- os.unlink(tmp_labeled_path)
557
- except (OSError, PermissionError):
558
- pass # File might already be deleted or locked
559
-
560
- st.success("βœ… Classification completed successfully!")
561
-
562
- except Exception as e:
563
- st.error(f"Error during classification: {str(e)}")
564
-
565
- # Show optimization results if available
566
- if st.session_state.optimization_results is not None:
567
- st.markdown('<div class="section-header">Optimization Results</div>', unsafe_allow_html=True)
568
-
569
- results = st.session_state.optimization_results
570
-
571
- col1, col2, col3, col4 = st.columns(4)
572
-
573
- with col1:
574
- st.metric(
575
- "Optimal Threshold",
576
- f"{results['optimal_threshold']:.4f}"
577
- )
578
-
579
- with col2:
580
- st.metric(
581
- "Accuracy",
582
- f"{results['optimal_metrics']['accuracy']:.4f}"
583
- )
584
-
585
- with col3:
586
- st.metric(
587
- "F1 Score",
588
- f"{results['optimal_metrics']['f1_macro']:.4f}"
589
- )
590
-
591
- with col4:
592
- st.metric(
593
- "Precision",
594
- f"{results['optimal_metrics']['precision_macro']:.4f}"
595
- )
596
-
597
- # Plot optimization curve
598
- st.markdown("### Optimization Curve")
599
-
600
- opt_results = results["results_by_threshold"]
601
-
602
- fig = make_subplots(
603
- rows=2, cols=2,
604
- subplot_titles=('Accuracy', 'F1 Score', 'Precision', 'Recall'),
605
- vertical_spacing=0.1
606
- )
607
-
608
- thresholds = opt_results["thresholds"]
609
-
610
- # Add traces
611
- fig.add_trace(
612
- go.Scatter(x=thresholds, y=opt_results["accuracy"], name="Accuracy"),
613
- row=1, col=1
614
- )
615
- fig.add_trace(
616
- go.Scatter(x=thresholds, y=opt_results["f1_macro"], name="F1 Score"),
617
- row=1, col=2
618
- )
619
- fig.add_trace(
620
- go.Scatter(x=thresholds, y=opt_results["precision_macro"], name="Precision"),
621
- row=2, col=1
622
- )
623
- fig.add_trace(
624
- go.Scatter(x=thresholds, y=opt_results["recall_macro"], name="Recall"),
625
- row=2, col=2
626
- )
627
-
628
- # Add optimal threshold line to each subplot using shapes
629
- optimal_thresh = results['optimal_threshold']
630
-
631
- # Add vertical line as shapes to each subplot
632
- shapes = []
633
- for row in range(1, 3):
634
- for col in range(1, 3):
635
- # Calculate the subplot domain
636
- xaxis = f'x{(row-1)*2 + col}' if (row-1)*2 + col > 1 else 'x'
637
- shapes.append(
638
- dict(
639
- type="line",
640
- x0=optimal_thresh, x1=optimal_thresh,
641
- y0=0, y1=1,
642
- yref=f"y{(row-1)*2 + col} domain" if (row-1)*2 + col > 1 else "y domain",
643
- xref=xaxis,
644
- line=dict(color="red", width=2, dash="dash")
645
- )
646
- )
647
-
648
- fig.update_layout(shapes=shapes)
649
-
650
- fig.update_layout(
651
- title="Threshold Optimization Results",
652
- showlegend=False,
653
- height=600
654
- )
655
-
656
- st.plotly_chart(fig, use_container_width=True)
657
-
658
- def show_results_page():
659
- st.markdown('<div class="section-header">Results & Evaluation</div>', unsafe_allow_html=True)
660
-
661
- # Check if evaluation results are available
662
- if st.session_state.evaluation_results is None:
663
- st.warning("Please run classification first to see results.")
664
- return
665
-
666
- eval_results = st.session_state.evaluation_results
667
-
668
- # Performance metrics
669
- st.markdown("### Performance Metrics")
670
-
671
- col1, col2, col3, col4 = st.columns(4)
672
-
673
- with col1:
674
- st.metric(
675
- "Overall Accuracy",
676
- f"{eval_results['accuracy']:.4f}"
677
- )
678
-
679
- with col2:
680
- st.metric(
681
- "Macro F1 Score",
682
- f"{eval_results['f1_macro']:.4f}"
683
- )
684
-
685
- with col3:
686
- st.metric(
687
- "Macro Precision",
688
- f"{eval_results['precision_macro']:.4f}"
689
- )
690
-
691
- with col4:
692
- st.metric(
693
- "Macro Recall",
694
- f"{eval_results['recall_macro']:.4f}"
695
- )
696
-
697
- # Class-wise metrics
698
- st.markdown("### Class-wise Performance")
699
-
700
- class_metrics_df = pd.DataFrame({
701
- 'Class': list(eval_results['class_metrics']['precision'].keys()),
702
- 'Precision': list(eval_results['class_metrics']['precision'].values()),
703
- 'Recall': list(eval_results['class_metrics']['recall'].values()),
704
- 'F1-Score': list(eval_results['class_metrics']['f1'].values()),
705
- 'Support': list(eval_results['class_metrics']['support'].values())
706
- })
707
-
708
- st.dataframe(class_metrics_df, use_container_width=True)
709
-
710
- # Confusion Matrix
711
- st.markdown("### Confusion Matrix")
712
-
713
- cm = eval_results['confusion_matrix']
714
- class_names = eval_results['confusion_matrix_labels']
715
-
716
- fig = px.imshow(
717
- cm,
718
- labels=dict(x="Predicted", y="True", color="Count"),
719
- x=class_names,
720
- y=class_names,
721
- color_continuous_scale='Blues',
722
- text_auto=True,
723
- title="Confusion Matrix"
724
- )
725
-
726
- fig.update_layout(
727
- width=600,
728
- height=600
729
- )
730
-
731
- st.plotly_chart(fig, use_container_width=True)
732
-
733
- # Bootstrap Results
734
- if st.session_state.bootstrap_results is not None:
735
- st.markdown("### Bootstrap Confidence Intervals")
736
-
737
- bootstrap_results = st.session_state.bootstrap_results
738
-
739
- # Debug: show available keys
740
- if 'confidence_intervals' in bootstrap_results:
741
- metrics = ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
742
-
743
- for metric in metrics:
744
- if metric in bootstrap_results['confidence_intervals']:
745
- ci_data = bootstrap_results['confidence_intervals'][metric]
746
- st.markdown(f"**{metric.replace('_', ' ').title()}:**")
747
-
748
- col1, col2, col3 = st.columns(3)
749
-
750
- # Check available confidence levels
751
- available_levels = list(ci_data.keys())
752
-
753
- with col1:
754
- if '0.95' in ci_data:
755
- ci_95 = ci_data['0.95']
756
- if isinstance(ci_95, dict):
757
- st.write(f"95% CI: [{ci_95['lower']:.4f}, {ci_95['upper']:.4f}]")
758
- elif isinstance(ci_95, (list, tuple)) and len(ci_95) >= 2:
759
- st.write(f"95% CI: [{ci_95[0]:.4f}, {ci_95[1]:.4f}]")
760
- else:
761
- st.write("95% CI: Format not recognized")
762
- elif 0.95 in ci_data:
763
- ci_95 = ci_data[0.95]
764
- if isinstance(ci_95, dict):
765
- st.write(f"95% CI: [{ci_95['lower']:.4f}, {ci_95['upper']:.4f}]")
766
- elif isinstance(ci_95, (list, tuple)) and len(ci_95) >= 2:
767
- st.write(f"95% CI: [{ci_95[0]:.4f}, {ci_95[1]:.4f}]")
768
- else:
769
- st.write("95% CI: Format not recognized")
770
- else:
771
- st.write("95% CI: Not available")
772
-
773
- with col2:
774
- if '0.99' in ci_data:
775
- ci_99 = ci_data['0.99']
776
- if isinstance(ci_99, dict):
777
- st.write(f"99% CI: [{ci_99['lower']:.4f}, {ci_99['upper']:.4f}]")
778
- elif isinstance(ci_99, (list, tuple)) and len(ci_99) >= 2:
779
- st.write(f"99% CI: [{ci_99[0]:.4f}, {ci_99[1]:.4f}]")
780
- else:
781
- st.write("99% CI: Format not recognized")
782
- elif 0.99 in ci_data:
783
- ci_99 = ci_data[0.99]
784
- if isinstance(ci_99, dict):
785
- st.write(f"99% CI: [{ci_99['lower']:.4f}, {ci_99['upper']:.4f}]")
786
- elif isinstance(ci_99, (list, tuple)) and len(ci_99) >= 2:
787
- st.write(f"99% CI: [{ci_99[0]:.4f}, {ci_99[1]:.4f}]")
788
- else:
789
- st.write("99% CI: Format not recognized")
790
- else:
791
- st.write("99% CI: Not available")
792
-
793
- with col3:
794
- if 'point_estimates' in bootstrap_results and metric in bootstrap_results['point_estimates']:
795
- st.write(f"Point Estimate: {bootstrap_results['point_estimates'][metric]:.4f}")
796
- else:
797
- st.write("Point Estimate: Not available")
798
- else:
799
- st.info("Bootstrap confidence intervals not available.")
800
-
801
- # Bootstrap Distribution Plot
802
- st.markdown("### Bootstrap Distributions")
803
-
804
- if 'bootstrap_distribution' in bootstrap_results:
805
- fig = make_subplots(
806
- rows=2, cols=2,
807
- subplot_titles=('Accuracy', 'F1 Score', 'Precision', 'Recall')
808
- )
809
-
810
- distributions = bootstrap_results['bootstrap_distribution']
811
-
812
- if 'accuracy' in distributions:
813
- fig.add_trace(
814
- go.Histogram(x=distributions['accuracy'], name="Accuracy", nbinsx=30),
815
- row=1, col=1
816
- )
817
- if 'f1_macro' in distributions:
818
- fig.add_trace(
819
- go.Histogram(x=distributions['f1_macro'], name="F1 Score", nbinsx=30),
820
- row=1, col=2
821
- )
822
- if 'precision_macro' in distributions:
823
- fig.add_trace(
824
- go.Histogram(x=distributions['precision_macro'], name="Precision", nbinsx=30),
825
- row=2, col=1
826
- )
827
- if 'recall_macro' in distributions:
828
- fig.add_trace(
829
- go.Histogram(x=distributions['recall_macro'], name="Recall", nbinsx=30),
830
- row=2, col=2
831
- )
832
-
833
- fig.update_layout(
834
- title="Bootstrap Distributions",
835
- showlegend=False,
836
- height=600
837
- )
838
-
839
- st.plotly_chart(fig, use_container_width=True)
840
- else:
841
- st.info("Bootstrap distributions not available.")
842
-
843
- # Sample predictions
844
- if 'predictions' in st.session_state:
845
- st.markdown("### Sample Predictions")
846
-
847
- predictions = st.session_state.predictions
848
- sample_df = predictions['full_df'].copy()
849
- sample_df['predicted_class'] = predictions['predicted_labels']
850
- sample_df['true_class'] = predictions['true_labels']
851
- sample_df['similarity_score'] = predictions['match_results']['similarity_score']
852
- sample_df['correct'] = sample_df['predicted_class'] == sample_df['true_class']
853
-
854
- # Filter options
855
- col1, col2 = st.columns(2)
856
-
857
- with col1:
858
- show_correct = st.checkbox("Show correct predictions", value=True)
859
-
860
- with col2:
861
- show_incorrect = st.checkbox("Show incorrect predictions", value=True)
862
-
863
- # Filter data
864
- if show_correct and show_incorrect:
865
- filtered_df = sample_df
866
- elif show_correct:
867
- filtered_df = sample_df[sample_df['correct'] == True]
868
- elif show_incorrect:
869
- filtered_df = sample_df[sample_df['correct'] == False]
870
- else:
871
- filtered_df = pd.DataFrame()
872
-
873
- if not filtered_df.empty:
874
- # Sample random rows
875
- n_samples = min(20, len(filtered_df))
876
- sample_rows = filtered_df.sample(n=n_samples) if len(filtered_df) > n_samples else filtered_df
877
-
878
- display_df = sample_rows[['sentence', 'true_class', 'predicted_class', 'similarity_score', 'correct']].reset_index(drop=True)
879
-
880
- st.dataframe(display_df, use_container_width=True)
881
- else:
882
- st.info("No predictions to show with current filters.")
883
-
884
- # Download results
885
- st.markdown("### Download Results")
886
-
887
- col1, col2 = st.columns(2)
888
-
889
- with col1:
890
- # Download class-wise metrics
891
- csv_metrics = class_metrics_df.to_csv(index=False)
892
- st.download_button(
893
- label="Download Class Metrics",
894
- data=csv_metrics,
895
- file_name="class_metrics.csv",
896
- mime="text/csv"
897
- )
898
-
899
- with col2:
900
- # Download predictions
901
- if 'predictions' in st.session_state:
902
- predictions = st.session_state.predictions
903
- results_df = predictions['full_df'].copy()
904
- results_df['predicted_class'] = predictions['predicted_labels']
905
- results_df['similarity_score'] = predictions['match_results']['similarity_score']
906
-
907
- csv_results = results_df.to_csv(index=False)
908
- st.download_button(
909
- label="Download Predictions",
910
- data=csv_results,
911
- file_name="predictions.csv",
912
- mime="text/csv"
913
- )
914
-
915
- if __name__ == "__main__":
916
- main()
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import tempfile
7
+ import os
8
+ import sys
9
+ from io import StringIO
10
+ import plotly.express as px
11
+ import plotly.graph_objects as go
12
+ from plotly.subplots import make_subplots
13
+
14
+ # Add the parent directory to sys.path to import the module
15
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
16
+
17
+ from src.qualivec.data import DataLoader
18
+ from src.qualivec.embedding import EmbeddingModel
19
+ from src.qualivec.matching import SemanticMatcher
20
+ from src.qualivec.classification import Classifier
21
+ from src.qualivec.evaluation import Evaluator
22
+ from src.qualivec.optimization import ThresholdOptimizer
23
+
24
+ # Set page config
25
+ st.set_page_config(
26
+ page_title="QualiVec Demo",
27
+ page_icon="πŸ”",
28
+ layout="wide",
29
+ initial_sidebar_state="expanded"
30
+ )
31
+
32
+ # Custom CSS for better styling
33
+ st.markdown("""
34
+ <style>
35
+ .main-header {
36
+ font-size: 2.5rem;
37
+ font-weight: bold;
38
+ color: #2E4057;
39
+ text-align: center;
40
+ margin-bottom: 2rem;
41
+ }
42
+ .section-header {
43
+ font-size: 1.5rem;
44
+ font-weight: bold;
45
+ color: #048A81;
46
+ margin-top: 2rem;
47
+ margin-bottom: 1rem;
48
+ }
49
+ .metric-card {
50
+ background-color: #f0f2f6;
51
+ padding: 1rem;
52
+ border-radius: 0.5rem;
53
+ margin: 0.5rem 0;
54
+ }
55
+ .success-message {
56
+ background-color: #d4edda;
57
+ color: #155724;
58
+ padding: 1rem;
59
+ border-radius: 0.5rem;
60
+ margin: 1rem 0;
61
+ }
62
+ .warning-message {
63
+ background-color: #fff3cd;
64
+ color: #856404;
65
+ padding: 1rem;
66
+ border-radius: 0.5rem;
67
+ margin: 1rem 0;
68
+ }
69
+ </style>
70
+ """, unsafe_allow_html=True)
71
+
72
+ def main():
73
+ st.markdown('<div class="main-header">πŸ” QualiVec Demo</div>', unsafe_allow_html=True)
74
+ st.markdown("""
75
+ <div style="text-align: center; margin-bottom: 2rem;">
76
+ <p style="font-size: 1.2rem; color: #666;">
77
+ Qualitative Content Analysis with LLM Embeddings
78
+ </p>
79
+ </div>
80
+ """, unsafe_allow_html=True)
81
+
82
+ # Sidebar for navigation
83
+ st.sidebar.title("Navigation")
84
+ page = st.sidebar.selectbox(
85
+ "Choose a page",
86
+ ["🏠 Home", "πŸ“Š Data Upload", "πŸ”§ Configuration", "🎯 Classification", "πŸ“ˆ Results"]
87
+ )
88
+
89
+ # Initialize session state
90
+ if 'classifier' not in st.session_state:
91
+ st.session_state.classifier = None
92
+ if 'reference_data' not in st.session_state:
93
+ st.session_state.reference_data = None
94
+ if 'labeled_data' not in st.session_state:
95
+ st.session_state.labeled_data = None
96
+ if 'optimization_results' not in st.session_state:
97
+ st.session_state.optimization_results = None
98
+ if 'evaluation_results' not in st.session_state:
99
+ st.session_state.evaluation_results = None
100
+
101
+ # Route to different pages
102
+ if page == "🏠 Home":
103
+ show_home_page()
104
+ elif page == "πŸ“Š Data Upload":
105
+ show_data_upload_page()
106
+ elif page == "πŸ”§ Configuration":
107
+ show_configuration_page()
108
+ elif page == "🎯 Classification":
109
+ show_classification_page()
110
+ elif page == "πŸ“ˆ Results":
111
+ show_results_page()
112
+
113
+ def show_home_page():
114
+ st.markdown('<div class="section-header">Welcome to QualiVec</div>', unsafe_allow_html=True)
115
+
116
+ col1, col2, col3 = st.columns([1, 2, 1])
117
+
118
+ with col2:
119
+ st.markdown("""
120
+ ### What is QualiVec?
121
+
122
+ QualiVec is a Python library that uses Large Language Model (LLM) embeddings for qualitative content analysis. It helps researchers and analysts classify text data by comparing it against reference examples.
123
+
124
+ ### Key Features:
125
+ - **Semantic Matching**: Uses advanced embedding models to find semantic similarity
126
+ - **Threshold Optimization**: Automatically finds the best similarity threshold
127
+ - **Comprehensive Evaluation**: Provides detailed metrics and visualizations
128
+ - **Bootstrap Analysis**: Confidence intervals for robust evaluation
129
+
130
+ ### How It Works:
131
+ 1. **Upload Data**: Provide reference examples and data to classify
132
+ 2. **Configure**: Set up embedding models and parameters
133
+ 3. **Optimize**: Find the best threshold for classification
134
+ 4. **Classify**: Apply the model to your data
135
+ 5. **Evaluate**: Get detailed performance metrics
136
+
137
+ ### Getting Started:
138
+ Use the sidebar to navigate through the demo. Start with **Data Upload** to begin your analysis.
139
+ """)
140
+
141
+ # Add sample data info
142
+ st.markdown('<div class="section-header">Sample Data Format</div>', unsafe_allow_html=True)
143
+
144
+ col1, col2 = st.columns(2)
145
+
146
+ with col1:
147
+ st.markdown("**Reference Data Format:**")
148
+ sample_ref = pd.DataFrame({
149
+ 'tag': ['Positive', 'Negative', 'Neutral'],
150
+ 'sentence': ['This is great!', 'This is terrible', 'This is okay']
151
+ })
152
+ st.dataframe(sample_ref, use_container_width=True)
153
+
154
+ with col2:
155
+ st.markdown("**Labeled Data Format:**")
156
+ sample_labeled = pd.DataFrame({
157
+ 'sentence': ['I love this product', 'Not very good', 'Average quality'],
158
+ 'Label': ['Positive', 'Negative', 'Neutral']
159
+ })
160
+ st.dataframe(sample_labeled, use_container_width=True)
161
+
162
+ def show_data_upload_page():
163
+ st.markdown('<div class="section-header">Data Upload</div>', unsafe_allow_html=True)
164
+
165
+ col1, col2 = st.columns(2)
166
+
167
+ with col1:
168
+ st.markdown("### Reference Data")
169
+ st.markdown("Upload a CSV file containing reference examples with columns: `tag` (class) and `sentence` (example text)")
170
+
171
+ reference_file = st.file_uploader(
172
+ "Choose reference data file",
173
+ type=['csv'],
174
+ key='reference_file'
175
+ )
176
+
177
+ if reference_file is not None:
178
+ try:
179
+ reference_df = pd.read_csv(reference_file)
180
+ st.success("Reference data loaded successfully!")
181
+ st.dataframe(reference_df.head(), use_container_width=True)
182
+
183
+ # Validate columns
184
+ required_cols = ['tag', 'sentence']
185
+ missing_cols = [col for col in required_cols if col not in reference_df.columns]
186
+
187
+ if missing_cols:
188
+ st.error(f"Missing required columns: {missing_cols}")
189
+ else:
190
+ # Prepare reference data
191
+ reference_df = reference_df.rename(columns={
192
+ 'tag': 'class',
193
+ 'sentence': 'matching_node'
194
+ })
195
+ st.session_state.reference_data = reference_df
196
+
197
+ # Show statistics
198
+ st.markdown("**Data Statistics:**")
199
+ st.write(f"- Total examples: {len(reference_df)}")
200
+ st.write(f"- Unique classes: {reference_df['class'].nunique()}")
201
+ st.write(f"- Class distribution:")
202
+ st.write(reference_df['class'].value_counts())
203
+
204
+ except Exception as e:
205
+ st.error(f"Error loading reference data: {str(e)}")
206
+
207
+ with col2:
208
+ st.markdown("### Labeled Data")
209
+ st.markdown("Upload a CSV file containing data to classify with columns: `sentence` (text) and `Label` (true class)")
210
+
211
+ labeled_file = st.file_uploader(
212
+ "Choose labeled data file",
213
+ type=['csv'],
214
+ key='labeled_file'
215
+ )
216
+
217
+ if labeled_file is not None:
218
+ try:
219
+ labeled_df = pd.read_csv(labeled_file)
220
+ st.success("Labeled data loaded successfully!")
221
+ st.dataframe(labeled_df.head(), use_container_width=True)
222
+
223
+ # Validate columns
224
+ required_cols = ['sentence', 'Label']
225
+ missing_cols = [col for col in required_cols if col not in labeled_df.columns]
226
+
227
+ if missing_cols:
228
+ st.error(f"Missing required columns: {missing_cols}")
229
+ else:
230
+ # Prepare labeled data
231
+ labeled_df = labeled_df.rename(columns={'Label': 'label'})
232
+ labeled_df['label'] = labeled_df['label'].replace('0', 'Other')
233
+ st.session_state.labeled_data = labeled_df
234
+
235
+ # Show statistics
236
+ st.markdown("**Data Statistics:**")
237
+ st.write(f"- Total samples: {len(labeled_df)}")
238
+ st.write(f"- Unique labels: {labeled_df['label'].nunique()}")
239
+ st.write(f"- Label distribution:")
240
+ st.write(labeled_df['label'].value_counts())
241
+
242
+ except Exception as e:
243
+ st.error(f"Error loading labeled data: {str(e)}")
244
+
245
+ # Show data compatibility check
246
+ if st.session_state.reference_data is not None and st.session_state.labeled_data is not None:
247
+ st.markdown('<div class="section-header">Data Compatibility Check</div>', unsafe_allow_html=True)
248
+
249
+ ref_classes = set(st.session_state.reference_data['class'].unique())
250
+ labeled_classes = set(st.session_state.labeled_data['label'].unique())
251
+
252
+ # Check for unknown classes
253
+ unknown_classes = labeled_classes - ref_classes
254
+
255
+ if unknown_classes:
256
+ st.warning(f"Warning: Labels in labeled data not found in reference data: {unknown_classes}")
257
+ else:
258
+ st.success("βœ… Data compatibility check passed!")
259
+
260
+ # Show class overlap
261
+ st.markdown("**Class Overlap Analysis:**")
262
+ col1, col2, col3 = st.columns(3)
263
+
264
+ with col1:
265
+ st.metric("Reference Classes", len(ref_classes))
266
+ with col2:
267
+ st.metric("Labeled Classes", len(labeled_classes))
268
+ with col3:
269
+ st.metric("Common Classes", len(ref_classes.intersection(labeled_classes)))
270
+
271
+ def show_configuration_page():
272
+ st.markdown('<div class="section-header">Model Configuration</div>', unsafe_allow_html=True)
273
+
274
+ # Check if data is loaded
275
+ if st.session_state.reference_data is None or st.session_state.labeled_data is None:
276
+ st.warning("Please upload both reference and labeled data first.")
277
+ return
278
+
279
+ col1, col2 = st.columns(2)
280
+
281
+ with col1:
282
+ st.markdown("### Embedding Model")
283
+
284
+ # Model type selection
285
+ model_type = st.selectbox(
286
+ "Choose model type",
287
+ ["HuggingFace", "Gemini"],
288
+ help="Select the type of embedding model to use"
289
+ )
290
+
291
+ # Model selection based on type
292
+ if model_type == "HuggingFace":
293
+ model_options = [
294
+ "sentence-transformers/all-MiniLM-L6-v2",
295
+ "sentence-transformers/all-mpnet-base-v2",
296
+ "sentence-transformers/distilbert-base-nli-mean-tokens"
297
+ ]
298
+
299
+ selected_model = st.selectbox(
300
+ "Choose HuggingFace model",
301
+ model_options,
302
+ help="Select the pre-trained HuggingFace model for generating embeddings"
303
+ )
304
+ else: # Gemini
305
+ gemini_models = [
306
+ "gemini-embedding-001",
307
+ "text-embedding-004"
308
+ ]
309
+
310
+ selected_model = st.selectbox(
311
+ "Choose Gemini model",
312
+ gemini_models,
313
+ help="Select the Gemini embedding model for generating embeddings"
314
+ )
315
+
316
+ # Calculate total texts to process
317
+ total_texts = 0
318
+ if st.session_state.reference_data is not None:
319
+ total_texts += len(st.session_state.reference_data)
320
+ if st.session_state.labeled_data is not None:
321
+ total_texts += len(st.session_state.labeled_data)
322
+
323
+ st.warning(
324
+ f"⚠️ **Gemini API Rate Limits (Free Tier)**\\n\\n"
325
+ f"- 1,500 requests per day\\n"
326
+ f"- Each batch of 100 texts = 1 request\\n"
327
+ f"- Your current dataset: ~{total_texts} texts\\n"
328
+ f"- Estimated requests needed: ~{(total_texts // 100) + 1}\\n\\n"
329
+ f"If you exceed quota, consider:\\n"
330
+ f"1. Using a smaller dataset\\n"
331
+ f"2. Switching to HuggingFace models (no limits)\\n"
332
+ f"3. Upgrading to a paid API plan"
333
+ )
334
+
335
+ st.info("πŸ’‘ Note: Using Gemini embeddings requires GOOGLE_API_KEY environment variable to be set.")
336
+
337
+ st.markdown("### Initial Threshold")
338
+ initial_threshold = st.slider(
339
+ "Initial similarity threshold",
340
+ min_value=0.0,
341
+ max_value=1.0,
342
+ value=0.7,
343
+ step=0.05,
344
+ help="Cosine similarity threshold for classification"
345
+ )
346
+
347
+ with col2:
348
+ st.markdown("### Optimization Parameters")
349
+
350
+ optimize_threshold = st.checkbox(
351
+ "Enable threshold optimization",
352
+ value=True,
353
+ help="Automatically find the best threshold"
354
+ )
355
+
356
+ if optimize_threshold:
357
+ col2_1, col2_2 = st.columns(2)
358
+
359
+ with col2_1:
360
+ start_threshold = st.slider(
361
+ "Start threshold",
362
+ min_value=0.0,
363
+ max_value=1.0,
364
+ value=0.5,
365
+ step=0.05
366
+ )
367
+
368
+ end_threshold = st.slider(
369
+ "End threshold",
370
+ min_value=0.0,
371
+ max_value=1.0,
372
+ value=0.9,
373
+ step=0.05
374
+ )
375
+
376
+ with col2_2:
377
+ step_size = st.slider(
378
+ "Step size",
379
+ min_value=0.005,
380
+ max_value=0.05,
381
+ value=0.01,
382
+ step=0.005
383
+ )
384
+
385
+ optimization_metric = st.selectbox(
386
+ "Optimization metric",
387
+ ["f1_macro", "accuracy", "precision_macro", "recall_macro"]
388
+ )
389
+
390
+ # Load models button
391
+ if st.button("Initialize Models", type="primary"):
392
+ with st.spinner("Loading models... This may take a few minutes."):
393
+ try:
394
+ # Initialize classifier
395
+ classifier = Classifier(verbose=False)
396
+
397
+ # Determine model type parameter
398
+ model_type_param = "gemini" if model_type == "Gemini" else "huggingface"
399
+
400
+ classifier.load_models(
401
+ model_name=selected_model,
402
+ model_type=model_type_param,
403
+ threshold=initial_threshold
404
+ )
405
+
406
+ # Prepare reference vectors
407
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp_ref:
408
+ tmp_ref_path = tmp_ref.name
409
+ st.session_state.reference_data.to_csv(tmp_ref_path, index=False)
410
+
411
+ try:
412
+ reference_data = classifier.prepare_reference_vectors(
413
+ reference_path=tmp_ref_path,
414
+ class_column='class',
415
+ node_column='matching_node'
416
+ )
417
+ finally:
418
+ # Ensure file is deleted even if an error occurs
419
+ try:
420
+ os.unlink(tmp_ref_path)
421
+ except (OSError, PermissionError):
422
+ pass # File might already be deleted or locked
423
+
424
+ st.session_state.classifier = classifier
425
+ st.session_state.reference_vectors = reference_data
426
+ st.session_state.config = {
427
+ 'model_type': model_type,
428
+ 'model_name': selected_model,
429
+ 'initial_threshold': initial_threshold,
430
+ 'optimize_threshold': optimize_threshold,
431
+ 'start_threshold': start_threshold if optimize_threshold else None,
432
+ 'end_threshold': end_threshold if optimize_threshold else None,
433
+ 'step_size': step_size if optimize_threshold else None,
434
+ 'optimization_metric': optimization_metric if optimize_threshold else None
435
+ }
436
+
437
+ st.success("βœ… Models initialized successfully!")
438
+
439
+ except Exception as e:
440
+ st.error(f"Error initializing models: {str(e)}")
441
+
442
+ # Show current configuration
443
+ if st.session_state.classifier is not None:
444
+ st.markdown('<div class="section-header">Current Configuration</div>', unsafe_allow_html=True)
445
+
446
+ config = st.session_state.config
447
+
448
+ col1, col2, col3 = st.columns(3)
449
+
450
+ with col1:
451
+ st.markdown("**Model Settings:**")
452
+ st.write(f"- Model type: {config['model_type']}")
453
+ st.write(f"- Model: {config['model_name']}")
454
+ st.write(f"- Initial threshold: {config['initial_threshold']}")
455
+
456
+ with col2:
457
+ st.markdown("**Optimization:**")
458
+ st.write(f"- Enabled: {config['optimize_threshold']}")
459
+ if config['optimize_threshold']:
460
+ st.write(f"- Range: {config['start_threshold']:.2f} - {config['end_threshold']:.2f}")
461
+ st.write(f"- Step: {config['step_size']:.3f}")
462
+
463
+ with col3:
464
+ st.markdown("**Data:**")
465
+ st.write(f"- Reference examples: {len(st.session_state.reference_data)}")
466
+ st.write(f"- Labeled samples: {len(st.session_state.labeled_data)}")
467
+
468
+ def show_classification_page():
469
+ st.markdown('<div class="section-header">Classification & Optimization</div>', unsafe_allow_html=True)
470
+
471
+ # Check if models are loaded
472
+ if st.session_state.classifier is None:
473
+ st.warning("Please configure and initialize models first.")
474
+ return
475
+
476
+ # Run classification
477
+ if st.button("Run Classification", type="primary"):
478
+ with st.spinner("Running classification and optimization..."):
479
+ try:
480
+ # Save labeled data to temporary file
481
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp_labeled:
482
+ tmp_labeled_path = tmp_labeled.name
483
+ st.session_state.labeled_data.to_csv(tmp_labeled_path, index=False)
484
+
485
+ try:
486
+ # Run optimization if enabled
487
+ if st.session_state.config['optimize_threshold']:
488
+ optimization_results = st.session_state.classifier.evaluate_classification(
489
+ labeled_path=tmp_labeled_path,
490
+ reference_data=st.session_state.reference_vectors,
491
+ sentence_column='sentence',
492
+ label_column='label',
493
+ optimize_threshold=True,
494
+ start=st.session_state.config['start_threshold'],
495
+ end=st.session_state.config['end_threshold'],
496
+ step=st.session_state.config['step_size']
497
+ )
498
+
499
+ st.session_state.optimization_results = optimization_results
500
+ optimal_threshold = optimization_results["optimal_threshold"]
501
+
502
+ # Update classifier with optimal threshold
503
+ st.session_state.classifier.matcher = SemanticMatcher(
504
+ threshold=optimal_threshold,
505
+ verbose=False
506
+ )
507
+
508
+ st.success(f"βœ… Optimization completed! Optimal threshold: {optimal_threshold:.4f}")
509
+
510
+ else:
511
+ optimal_threshold = st.session_state.config['initial_threshold']
512
+
513
+ # Run evaluation
514
+ embedding_model = st.session_state.classifier.embedding_model
515
+ data_loader = DataLoader(verbose=False)
516
+ full_df = data_loader.load_labeled_data(tmp_labeled_path, label_column='label')
517
+
518
+ # Generate embeddings
519
+ full_embeddings = embedding_model.embed_dataframe(full_df, text_column='sentence')
520
+
521
+ # Classify
522
+ match_results = st.session_state.classifier.matcher.match(
523
+ full_embeddings,
524
+ st.session_state.reference_vectors
525
+ )
526
+ predicted_labels = match_results["predicted_class"].tolist()
527
+ true_labels = full_df['label'].tolist()
528
+
529
+ # Evaluate
530
+ evaluator = Evaluator(verbose=False)
531
+ eval_results = evaluator.evaluate(
532
+ true_labels=true_labels,
533
+ predicted_labels=predicted_labels,
534
+ class_names=list(set(true_labels) | set(predicted_labels))
535
+ )
536
+
537
+ # Bootstrap evaluation
538
+ bootstrap_results = evaluator.bootstrap_evaluate(
539
+ true_labels=true_labels,
540
+ predicted_labels=predicted_labels,
541
+ n_iterations=100
542
+ )
543
+
544
+ st.session_state.evaluation_results = eval_results
545
+ st.session_state.bootstrap_results = bootstrap_results
546
+ st.session_state.predictions = {
547
+ 'true_labels': true_labels,
548
+ 'predicted_labels': predicted_labels,
549
+ 'match_results': match_results,
550
+ 'full_df': full_df
551
+ }
552
+
553
+ finally:
554
+ # Ensure temporary file is deleted
555
+ try:
556
+ os.unlink(tmp_labeled_path)
557
+ except (OSError, PermissionError):
558
+ pass # File might already be deleted or locked
559
+
560
+ st.success("βœ… Classification completed successfully!")
561
+
562
+ except Exception as e:
563
+ st.error(f"Error during classification: {str(e)}")
564
+
565
+ # Show optimization results if available
566
+ if st.session_state.optimization_results is not None:
567
+ st.markdown('<div class="section-header">Optimization Results</div>', unsafe_allow_html=True)
568
+
569
+ results = st.session_state.optimization_results
570
+
571
+ col1, col2, col3, col4 = st.columns(4)
572
+
573
+ with col1:
574
+ st.metric(
575
+ "Optimal Threshold",
576
+ f"{results['optimal_threshold']:.4f}"
577
+ )
578
+
579
+ with col2:
580
+ st.metric(
581
+ "Accuracy",
582
+ f"{results['optimal_metrics']['accuracy']:.4f}"
583
+ )
584
+
585
+ with col3:
586
+ st.metric(
587
+ "F1 Score",
588
+ f"{results['optimal_metrics']['f1_macro']:.4f}"
589
+ )
590
+
591
+ with col4:
592
+ st.metric(
593
+ "Precision",
594
+ f"{results['optimal_metrics']['precision_macro']:.4f}"
595
+ )
596
+
597
+ # Plot optimization curve
598
+ st.markdown("### Optimization Curve")
599
+
600
+ opt_results = results["results_by_threshold"]
601
+
602
+ fig = make_subplots(
603
+ rows=2, cols=2,
604
+ subplot_titles=('Accuracy', 'F1 Score', 'Precision', 'Recall'),
605
+ vertical_spacing=0.1
606
+ )
607
+
608
+ thresholds = opt_results["thresholds"]
609
+
610
+ # Add traces
611
+ fig.add_trace(
612
+ go.Scatter(x=thresholds, y=opt_results["accuracy"], name="Accuracy"),
613
+ row=1, col=1
614
+ )
615
+ fig.add_trace(
616
+ go.Scatter(x=thresholds, y=opt_results["f1_macro"], name="F1 Score"),
617
+ row=1, col=2
618
+ )
619
+ fig.add_trace(
620
+ go.Scatter(x=thresholds, y=opt_results["precision_macro"], name="Precision"),
621
+ row=2, col=1
622
+ )
623
+ fig.add_trace(
624
+ go.Scatter(x=thresholds, y=opt_results["recall_macro"], name="Recall"),
625
+ row=2, col=2
626
+ )
627
+
628
+ # Add optimal threshold line to each subplot using shapes
629
+ optimal_thresh = results['optimal_threshold']
630
+
631
+ # Add vertical line as shapes to each subplot
632
+ shapes = []
633
+ for row in range(1, 3):
634
+ for col in range(1, 3):
635
+ # Calculate the subplot domain
636
+ xaxis = f'x{(row-1)*2 + col}' if (row-1)*2 + col > 1 else 'x'
637
+ shapes.append(
638
+ dict(
639
+ type="line",
640
+ x0=optimal_thresh, x1=optimal_thresh,
641
+ y0=0, y1=1,
642
+ yref=f"y{(row-1)*2 + col} domain" if (row-1)*2 + col > 1 else "y domain",
643
+ xref=xaxis,
644
+ line=dict(color="red", width=2, dash="dash")
645
+ )
646
+ )
647
+
648
+ fig.update_layout(shapes=shapes)
649
+
650
+ fig.update_layout(
651
+ title="Threshold Optimization Results",
652
+ showlegend=False,
653
+ height=600
654
+ )
655
+
656
+ st.plotly_chart(fig, use_container_width=True)
657
+
658
+ def show_results_page():
659
+ st.markdown('<div class="section-header">Results & Evaluation</div>', unsafe_allow_html=True)
660
+
661
+ # Check if evaluation results are available
662
+ if st.session_state.evaluation_results is None:
663
+ st.warning("Please run classification first to see results.")
664
+ return
665
+
666
+ eval_results = st.session_state.evaluation_results
667
+
668
+ # Performance metrics
669
+ st.markdown("### Performance Metrics")
670
+
671
+ col1, col2, col3, col4 = st.columns(4)
672
+
673
+ with col1:
674
+ st.metric(
675
+ "Overall Accuracy",
676
+ f"{eval_results['accuracy']:.4f}"
677
+ )
678
+
679
+ with col2:
680
+ st.metric(
681
+ "Macro F1 Score",
682
+ f"{eval_results['f1_macro']:.4f}"
683
+ )
684
+
685
+ with col3:
686
+ st.metric(
687
+ "Macro Precision",
688
+ f"{eval_results['precision_macro']:.4f}"
689
+ )
690
+
691
+ with col4:
692
+ st.metric(
693
+ "Macro Recall",
694
+ f"{eval_results['recall_macro']:.4f}"
695
+ )
696
+
697
+ # Class-wise metrics
698
+ st.markdown("### Class-wise Performance")
699
+
700
+ class_metrics_df = pd.DataFrame({
701
+ 'Class': list(eval_results['class_metrics']['precision'].keys()),
702
+ 'Precision': list(eval_results['class_metrics']['precision'].values()),
703
+ 'Recall': list(eval_results['class_metrics']['recall'].values()),
704
+ 'F1-Score': list(eval_results['class_metrics']['f1'].values()),
705
+ 'Support': list(eval_results['class_metrics']['support'].values())
706
+ })
707
+
708
+ st.dataframe(class_metrics_df, use_container_width=True)
709
+
710
+ # Confusion Matrix
711
+ st.markdown("### Confusion Matrix")
712
+
713
+ cm = eval_results['confusion_matrix']
714
+ class_names = eval_results['confusion_matrix_labels']
715
+
716
+ fig = px.imshow(
717
+ cm,
718
+ labels=dict(x="Predicted", y="True", color="Count"),
719
+ x=class_names,
720
+ y=class_names,
721
+ color_continuous_scale='Blues',
722
+ text_auto=True,
723
+ title="Confusion Matrix"
724
+ )
725
+
726
+ fig.update_layout(
727
+ width=600,
728
+ height=600
729
+ )
730
+
731
+ st.plotly_chart(fig, use_container_width=True)
732
+
733
+ # Bootstrap Results
734
+ if st.session_state.bootstrap_results is not None:
735
+ st.markdown("### Bootstrap Confidence Intervals")
736
+
737
+ bootstrap_results = st.session_state.bootstrap_results
738
+
739
+ # Debug: show available keys
740
+ if 'confidence_intervals' in bootstrap_results:
741
+ metrics = ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
742
+
743
+ for metric in metrics:
744
+ if metric in bootstrap_results['confidence_intervals']:
745
+ ci_data = bootstrap_results['confidence_intervals'][metric]
746
+ st.markdown(f"**{metric.replace('_', ' ').title()}:**")
747
+
748
+ col1, col2, col3 = st.columns(3)
749
+
750
+ # Check available confidence levels
751
+ available_levels = list(ci_data.keys())
752
+
753
+ with col1:
754
+ if '0.95' in ci_data:
755
+ ci_95 = ci_data['0.95']
756
+ if isinstance(ci_95, dict):
757
+ st.write(f"95% CI: [{ci_95['lower']:.4f}, {ci_95['upper']:.4f}]")
758
+ elif isinstance(ci_95, (list, tuple)) and len(ci_95) >= 2:
759
+ st.write(f"95% CI: [{ci_95[0]:.4f}, {ci_95[1]:.4f}]")
760
+ else:
761
+ st.write("95% CI: Format not recognized")
762
+ elif 0.95 in ci_data:
763
+ ci_95 = ci_data[0.95]
764
+ if isinstance(ci_95, dict):
765
+ st.write(f"95% CI: [{ci_95['lower']:.4f}, {ci_95['upper']:.4f}]")
766
+ elif isinstance(ci_95, (list, tuple)) and len(ci_95) >= 2:
767
+ st.write(f"95% CI: [{ci_95[0]:.4f}, {ci_95[1]:.4f}]")
768
+ else:
769
+ st.write("95% CI: Format not recognized")
770
+ else:
771
+ st.write("95% CI: Not available")
772
+
773
+ with col2:
774
+ if '0.99' in ci_data:
775
+ ci_99 = ci_data['0.99']
776
+ if isinstance(ci_99, dict):
777
+ st.write(f"99% CI: [{ci_99['lower']:.4f}, {ci_99['upper']:.4f}]")
778
+ elif isinstance(ci_99, (list, tuple)) and len(ci_99) >= 2:
779
+ st.write(f"99% CI: [{ci_99[0]:.4f}, {ci_99[1]:.4f}]")
780
+ else:
781
+ st.write("99% CI: Format not recognized")
782
+ elif 0.99 in ci_data:
783
+ ci_99 = ci_data[0.99]
784
+ if isinstance(ci_99, dict):
785
+ st.write(f"99% CI: [{ci_99['lower']:.4f}, {ci_99['upper']:.4f}]")
786
+ elif isinstance(ci_99, (list, tuple)) and len(ci_99) >= 2:
787
+ st.write(f"99% CI: [{ci_99[0]:.4f}, {ci_99[1]:.4f}]")
788
+ else:
789
+ st.write("99% CI: Format not recognized")
790
+ else:
791
+ st.write("99% CI: Not available")
792
+
793
+ with col3:
794
+ if 'point_estimates' in bootstrap_results and metric in bootstrap_results['point_estimates']:
795
+ st.write(f"Point Estimate: {bootstrap_results['point_estimates'][metric]:.4f}")
796
+ else:
797
+ st.write("Point Estimate: Not available")
798
+ else:
799
+ st.info("Bootstrap confidence intervals not available.")
800
+
801
+ # Bootstrap Distribution Plot
802
+ st.markdown("### Bootstrap Distributions")
803
+
804
+ if 'bootstrap_distribution' in bootstrap_results:
805
+ fig = make_subplots(
806
+ rows=2, cols=2,
807
+ subplot_titles=('Accuracy', 'F1 Score', 'Precision', 'Recall')
808
+ )
809
+
810
+ distributions = bootstrap_results['bootstrap_distribution']
811
+
812
+ if 'accuracy' in distributions:
813
+ fig.add_trace(
814
+ go.Histogram(x=distributions['accuracy'], name="Accuracy", nbinsx=30),
815
+ row=1, col=1
816
+ )
817
+ if 'f1_macro' in distributions:
818
+ fig.add_trace(
819
+ go.Histogram(x=distributions['f1_macro'], name="F1 Score", nbinsx=30),
820
+ row=1, col=2
821
+ )
822
+ if 'precision_macro' in distributions:
823
+ fig.add_trace(
824
+ go.Histogram(x=distributions['precision_macro'], name="Precision", nbinsx=30),
825
+ row=2, col=1
826
+ )
827
+ if 'recall_macro' in distributions:
828
+ fig.add_trace(
829
+ go.Histogram(x=distributions['recall_macro'], name="Recall", nbinsx=30),
830
+ row=2, col=2
831
+ )
832
+
833
+ fig.update_layout(
834
+ title="Bootstrap Distributions",
835
+ showlegend=False,
836
+ height=600
837
+ )
838
+
839
+ st.plotly_chart(fig, use_container_width=True)
840
+ else:
841
+ st.info("Bootstrap distributions not available.")
842
+
843
+ # Sample predictions
844
+ if 'predictions' in st.session_state:
845
+ st.markdown("### Sample Predictions")
846
+
847
+ predictions = st.session_state.predictions
848
+ sample_df = predictions['full_df'].copy()
849
+ sample_df['predicted_class'] = predictions['predicted_labels']
850
+ sample_df['true_class'] = predictions['true_labels']
851
+ sample_df['similarity_score'] = predictions['match_results']['similarity_score']
852
+ sample_df['correct'] = sample_df['predicted_class'] == sample_df['true_class']
853
+
854
+ # Filter options
855
+ col1, col2 = st.columns(2)
856
+
857
+ with col1:
858
+ show_correct = st.checkbox("Show correct predictions", value=True)
859
+
860
+ with col2:
861
+ show_incorrect = st.checkbox("Show incorrect predictions", value=True)
862
+
863
+ # Filter data
864
+ if show_correct and show_incorrect:
865
+ filtered_df = sample_df
866
+ elif show_correct:
867
+ filtered_df = sample_df[sample_df['correct'] == True]
868
+ elif show_incorrect:
869
+ filtered_df = sample_df[sample_df['correct'] == False]
870
+ else:
871
+ filtered_df = pd.DataFrame()
872
+
873
+ if not filtered_df.empty:
874
+ # Sample random rows
875
+ n_samples = min(20, len(filtered_df))
876
+ sample_rows = filtered_df.sample(n=n_samples) if len(filtered_df) > n_samples else filtered_df
877
+
878
+ display_df = sample_rows[['sentence', 'true_class', 'predicted_class', 'similarity_score', 'correct']].reset_index(drop=True)
879
+
880
+ st.dataframe(display_df, use_container_width=True)
881
+ else:
882
+ st.info("No predictions to show with current filters.")
883
+
884
+ # Download results
885
+ st.markdown("### Download Results")
886
+
887
+ col1, col2 = st.columns(2)
888
+
889
+ with col1:
890
+ # Download class-wise metrics
891
+ csv_metrics = class_metrics_df.to_csv(index=False)
892
+ st.download_button(
893
+ label="Download Class Metrics",
894
+ data=csv_metrics,
895
+ file_name="class_metrics.csv",
896
+ mime="text/csv"
897
+ )
898
+
899
+ with col2:
900
+ # Download predictions
901
+ if 'predictions' in st.session_state:
902
+ predictions = st.session_state.predictions
903
+ results_df = predictions['full_df'].copy()
904
+ results_df['predicted_class'] = predictions['predicted_labels']
905
+ results_df['similarity_score'] = predictions['match_results']['similarity_score']
906
+
907
+ csv_results = results_df.to_csv(index=False)
908
+ st.download_button(
909
+ label="Download Predictions",
910
+ data=csv_results,
911
+ file_name="predictions.csv",
912
+ mime="text/csv"
913
+ )
914
+
915
+ if __name__ == "__main__":
916
+ main()
app/run_demo.py CHANGED
@@ -1,38 +1,38 @@
1
- #!/usr/bin/env python3
2
- """
3
- Quick launcher script for the QualiVec Streamlit demo.
4
- """
5
-
6
- import subprocess
7
- import sys
8
- import os
9
-
10
- def main():
11
- """Launch the Streamlit app."""
12
-
13
- # Get the directory of this script
14
- script_dir = os.path.dirname(os.path.abspath(__file__))
15
- app_path = os.path.join(script_dir, "app.py")
16
-
17
- print("πŸš€ Starting QualiVec Demo...")
18
- print("πŸ“ App will be available at: http://localhost:8501")
19
- print("⏹️ Press Ctrl+C to stop the app")
20
- print("-" * 50)
21
-
22
- try:
23
- # Run streamlit
24
- subprocess.run([
25
- sys.executable, "-m", "streamlit", "run", app_path,
26
- "--server.headless", "true",
27
- "--server.address=0.0.0.0",
28
- "--server.port=8501",
29
- "--server.enableCORS", "false",
30
- "--server.enableXsrfProtection", "false"
31
- ])
32
- except KeyboardInterrupt:
33
- print("\nπŸ›‘ App stopped by user")
34
- except Exception as e:
35
- print(f"❌ Error starting app: {e}")
36
-
37
- if __name__ == "__main__":
38
- main()
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick launcher script for the QualiVec Streamlit demo.
4
+ """
5
+
6
+ import subprocess
7
+ import sys
8
+ import os
9
+
10
+ def main():
11
+ """Launch the Streamlit app."""
12
+
13
+ # Get the directory of this script
14
+ script_dir = os.path.dirname(os.path.abspath(__file__))
15
+ app_path = os.path.join(script_dir, "app.py")
16
+
17
+ print("πŸš€ Starting QualiVec Demo...")
18
+ print("πŸ“ App will be available at: http://localhost:8501")
19
+ print("⏹️ Press Ctrl+C to stop the app")
20
+ print("-" * 50)
21
+
22
+ try:
23
+ # Run streamlit
24
+ subprocess.run([
25
+ sys.executable, "-m", "streamlit", "run", app_path,
26
+ "--server.headless", "true",
27
+ # "--server.address=0.0.0.0",
28
+ "--server.port=8501",
29
+ "--server.enableCORS", "false",
30
+ "--server.enableXsrfProtection", "false"
31
+ ])
32
+ except KeyboardInterrupt:
33
+ print("\nπŸ›‘ App stopped by user")
34
+ except Exception as e:
35
+ print(f"❌ Error starting app: {e}")
36
+
37
+ if __name__ == "__main__":
38
+ main()
dist/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *
dist/qualivec-0.1.0-py3-none-any.whl ADDED
Binary file (19.9 kB). View file
 
dist/qualivec-0.1.0.tar.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:80b1c1f4ac5470593b6b873c82f620ec544e8f9d8ac2834d23ef81521e65625c
3
+ size 46670
src/qualivec/__pycache__/embedding.cpython-312.pyc CHANGED
Binary files a/src/qualivec/__pycache__/embedding.cpython-312.pyc and b/src/qualivec/__pycache__/embedding.cpython-312.pyc differ
 
uv.lock ADDED
The diff for this file is too large to render. See raw diff