akhil-vaidya commited on
Commit
59e348f
·
verified ·
1 Parent(s): 5487a6c

Upload 21 files

Browse files
app/README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # QualiVec Streamlit Demo
2
+
3
+ This Streamlit application provides an interactive demonstration of the QualiVec library for qualitative content analysis using LLM embeddings.
4
+
5
+ ## Features
6
+
7
+ - **Interactive Data Upload**: Upload your own CSV files for reference and labeled data
8
+ - **Model Configuration**: Choose from different pre-trained embedding models
9
+ - **Threshold Optimization**: Automatically find the optimal similarity threshold
10
+ - **Real-time Classification**: See classification results as they happen
11
+ - **Comprehensive Evaluation**: View detailed performance metrics and visualizations
12
+ - **Bootstrap Analysis**: Get confidence intervals for robust evaluation
13
+
14
+ ## How to Run
15
+
16
+ ### Option 1: Local Installation
17
+
18
+ 1. **Install Dependencies**:
19
+ ```bash
20
+ pip install -e .
21
+ ```
22
+
23
+ 2. **Run the App**:
24
+ ```bash
25
+ cd app
26
+ uv run run_demo.py
27
+ ```
28
+
29
+ 3. **Access the App**:
30
+ Open your browser and navigate to `http://localhost:8501`
31
+
32
+ ### Option 2: Docker
33
+
34
+ 1. **Build the Docker Image**:
35
+ ```bash
36
+ docker build -t qualivec .
37
+ ```
38
+
39
+ 2. **Run the Docker Container**:
40
+ ```bash
41
+ docker run --rm -p 8501:8501 qualivec
42
+ ```
43
+
44
+ 3. **Access the App**:
45
+ Open your browser and navigate to `http://localhost:8501`
46
+
47
+ > **Note**: The Docker option provides a containerized environment with all dependencies pre-installed, making it easier to run the application without setting up a local Python environment.
48
+
49
+ ## Data Format Requirements
50
+
51
+ ### Reference Data (CSV)
52
+ Your reference data should contain:
53
+ - `tag`: The class/category label
54
+ - `sentence`: The example text for that category
55
+
56
+ Example:
57
+ ```csv
58
+ tag,sentence
59
+ Positive,This is absolutely fantastic!
60
+ Negative,This is terrible and disappointing
61
+ Neutral,This is okay I guess
62
+ ```
63
+
64
+ ### Labeled Data (CSV)
65
+ Your labeled data should contain:
66
+ - `sentence`: The text to be classified
67
+ - `Label`: The true class/category (for evaluation)
68
+
69
+ Example:
70
+ ```csv
71
+ sentence,Label
72
+ I love this product so much!,Positive
73
+ Not very good quality,Negative
74
+ Average product nothing special,Neutral
75
+ ```
76
+
77
+ ## Navigation
78
+
79
+ The app is organized into 5 main sections:
80
+
81
+ 1. **🏠 Home**: Overview and introduction to QualiVec
82
+ 2. **📊 Data Upload**: Upload your reference and labeled data files
83
+ 3. **🔧 Configuration**: Set up embedding models and parameters
84
+ 4. **🎯 Classification**: Run the classification and optimization process
85
+ 5. **📈 Results**: View detailed results and download outputs
app/app.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+ import tempfile
7
+ import os
8
+ import sys
9
+ from io import StringIO
10
+ import plotly.express as px
11
+ import plotly.graph_objects as go
12
+ from plotly.subplots import make_subplots
13
+
14
+ # Add the parent directory to sys.path to import the module
15
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
16
+
17
+ from src.qualivec.data import DataLoader
18
+ from src.qualivec.embedding import EmbeddingModel
19
+ from src.qualivec.matching import SemanticMatcher
20
+ from src.qualivec.classification import Classifier
21
+ from src.qualivec.evaluation import Evaluator
22
+ from src.qualivec.optimization import ThresholdOptimizer
23
+
24
+ # Set page config
25
+ st.set_page_config(
26
+ page_title="QualiVec Demo",
27
+ page_icon="🔍",
28
+ layout="wide",
29
+ initial_sidebar_state="expanded"
30
+ )
31
+
32
+ # Custom CSS for better styling
33
+ st.markdown("""
34
+ <style>
35
+ .main-header {
36
+ font-size: 2.5rem;
37
+ font-weight: bold;
38
+ color: #2E4057;
39
+ text-align: center;
40
+ margin-bottom: 2rem;
41
+ }
42
+ .section-header {
43
+ font-size: 1.5rem;
44
+ font-weight: bold;
45
+ color: #048A81;
46
+ margin-top: 2rem;
47
+ margin-bottom: 1rem;
48
+ }
49
+ .metric-card {
50
+ background-color: #f0f2f6;
51
+ padding: 1rem;
52
+ border-radius: 0.5rem;
53
+ margin: 0.5rem 0;
54
+ }
55
+ .success-message {
56
+ background-color: #d4edda;
57
+ color: #155724;
58
+ padding: 1rem;
59
+ border-radius: 0.5rem;
60
+ margin: 1rem 0;
61
+ }
62
+ .warning-message {
63
+ background-color: #fff3cd;
64
+ color: #856404;
65
+ padding: 1rem;
66
+ border-radius: 0.5rem;
67
+ margin: 1rem 0;
68
+ }
69
+ </style>
70
+ """, unsafe_allow_html=True)
71
+
72
+ def main():
73
+ st.markdown('<div class="main-header">🔍 QualiVec Demo</div>', unsafe_allow_html=True)
74
+ st.markdown("""
75
+ <div style="text-align: center; margin-bottom: 2rem;">
76
+ <p style="font-size: 1.2rem; color: #666;">
77
+ Qualitative Content Analysis with LLM Embeddings
78
+ </p>
79
+ </div>
80
+ """, unsafe_allow_html=True)
81
+
82
+ # Sidebar for navigation
83
+ st.sidebar.title("Navigation")
84
+ page = st.sidebar.selectbox(
85
+ "Choose a page",
86
+ ["🏠 Home", "📊 Data Upload", "🔧 Configuration", "🎯 Classification", "📈 Results"]
87
+ )
88
+
89
+ # Initialize session state
90
+ if 'classifier' not in st.session_state:
91
+ st.session_state.classifier = None
92
+ if 'reference_data' not in st.session_state:
93
+ st.session_state.reference_data = None
94
+ if 'labeled_data' not in st.session_state:
95
+ st.session_state.labeled_data = None
96
+ if 'optimization_results' not in st.session_state:
97
+ st.session_state.optimization_results = None
98
+ if 'evaluation_results' not in st.session_state:
99
+ st.session_state.evaluation_results = None
100
+
101
+ # Route to different pages
102
+ if page == "🏠 Home":
103
+ show_home_page()
104
+ elif page == "📊 Data Upload":
105
+ show_data_upload_page()
106
+ elif page == "🔧 Configuration":
107
+ show_configuration_page()
108
+ elif page == "🎯 Classification":
109
+ show_classification_page()
110
+ elif page == "📈 Results":
111
+ show_results_page()
112
+
113
+ def show_home_page():
114
+ st.markdown('<div class="section-header">Welcome to QualiVec</div>', unsafe_allow_html=True)
115
+
116
+ col1, col2, col3 = st.columns([1, 2, 1])
117
+
118
+ with col2:
119
+ st.markdown("""
120
+ ### What is QualiVec?
121
+
122
+ QualiVec is a Python library that uses Large Language Model (LLM) embeddings for qualitative content analysis. It helps researchers and analysts classify text data by comparing it against reference examples.
123
+
124
+ ### Key Features:
125
+ - **Semantic Matching**: Uses advanced embedding models to find semantic similarity
126
+ - **Threshold Optimization**: Automatically finds the best similarity threshold
127
+ - **Comprehensive Evaluation**: Provides detailed metrics and visualizations
128
+ - **Bootstrap Analysis**: Confidence intervals for robust evaluation
129
+
130
+ ### How It Works:
131
+ 1. **Upload Data**: Provide reference examples and data to classify
132
+ 2. **Configure**: Set up embedding models and parameters
133
+ 3. **Optimize**: Find the best threshold for classification
134
+ 4. **Classify**: Apply the model to your data
135
+ 5. **Evaluate**: Get detailed performance metrics
136
+
137
+ ### Getting Started:
138
+ Use the sidebar to navigate through the demo. Start with **Data Upload** to begin your analysis.
139
+ """)
140
+
141
+ # Add sample data info
142
+ st.markdown('<div class="section-header">Sample Data Format</div>', unsafe_allow_html=True)
143
+
144
+ col1, col2 = st.columns(2)
145
+
146
+ with col1:
147
+ st.markdown("**Reference Data Format:**")
148
+ sample_ref = pd.DataFrame({
149
+ 'tag': ['Positive', 'Negative', 'Neutral'],
150
+ 'sentence': ['This is great!', 'This is terrible', 'This is okay']
151
+ })
152
+ st.dataframe(sample_ref, use_container_width=True)
153
+
154
+ with col2:
155
+ st.markdown("**Labeled Data Format:**")
156
+ sample_labeled = pd.DataFrame({
157
+ 'sentence': ['I love this product', 'Not very good', 'Average quality'],
158
+ 'Label': ['Positive', 'Negative', 'Neutral']
159
+ })
160
+ st.dataframe(sample_labeled, use_container_width=True)
161
+
162
+ def show_data_upload_page():
163
+ st.markdown('<div class="section-header">Data Upload</div>', unsafe_allow_html=True)
164
+
165
+ col1, col2 = st.columns(2)
166
+
167
+ with col1:
168
+ st.markdown("### Reference Data")
169
+ st.markdown("Upload a CSV file containing reference examples with columns: `tag` (class) and `sentence` (example text)")
170
+
171
+ reference_file = st.file_uploader(
172
+ "Choose reference data file",
173
+ type=['csv'],
174
+ key='reference_file'
175
+ )
176
+
177
+ if reference_file is not None:
178
+ try:
179
+ reference_df = pd.read_csv(reference_file)
180
+ st.success("Reference data loaded successfully!")
181
+ st.dataframe(reference_df.head(), use_container_width=True)
182
+
183
+ # Validate columns
184
+ required_cols = ['tag', 'sentence']
185
+ missing_cols = [col for col in required_cols if col not in reference_df.columns]
186
+
187
+ if missing_cols:
188
+ st.error(f"Missing required columns: {missing_cols}")
189
+ else:
190
+ # Prepare reference data
191
+ reference_df = reference_df.rename(columns={
192
+ 'tag': 'class',
193
+ 'sentence': 'matching_node'
194
+ })
195
+ st.session_state.reference_data = reference_df
196
+
197
+ # Show statistics
198
+ st.markdown("**Data Statistics:**")
199
+ st.write(f"- Total examples: {len(reference_df)}")
200
+ st.write(f"- Unique classes: {reference_df['class'].nunique()}")
201
+ st.write(f"- Class distribution:")
202
+ st.write(reference_df['class'].value_counts())
203
+
204
+ except Exception as e:
205
+ st.error(f"Error loading reference data: {str(e)}")
206
+
207
+ with col2:
208
+ st.markdown("### Labeled Data")
209
+ st.markdown("Upload a CSV file containing data to classify with columns: `sentence` (text) and `Label` (true class)")
210
+
211
+ labeled_file = st.file_uploader(
212
+ "Choose labeled data file",
213
+ type=['csv'],
214
+ key='labeled_file'
215
+ )
216
+
217
+ if labeled_file is not None:
218
+ try:
219
+ labeled_df = pd.read_csv(labeled_file)
220
+ st.success("Labeled data loaded successfully!")
221
+ st.dataframe(labeled_df.head(), use_container_width=True)
222
+
223
+ # Validate columns
224
+ required_cols = ['sentence', 'Label']
225
+ missing_cols = [col for col in required_cols if col not in labeled_df.columns]
226
+
227
+ if missing_cols:
228
+ st.error(f"Missing required columns: {missing_cols}")
229
+ else:
230
+ # Prepare labeled data
231
+ labeled_df = labeled_df.rename(columns={'Label': 'label'})
232
+ labeled_df['label'] = labeled_df['label'].replace('0', 'Other')
233
+ st.session_state.labeled_data = labeled_df
234
+
235
+ # Show statistics
236
+ st.markdown("**Data Statistics:**")
237
+ st.write(f"- Total samples: {len(labeled_df)}")
238
+ st.write(f"- Unique labels: {labeled_df['label'].nunique()}")
239
+ st.write(f"- Label distribution:")
240
+ st.write(labeled_df['label'].value_counts())
241
+
242
+ except Exception as e:
243
+ st.error(f"Error loading labeled data: {str(e)}")
244
+
245
+ # Show data compatibility check
246
+ if st.session_state.reference_data is not None and st.session_state.labeled_data is not None:
247
+ st.markdown('<div class="section-header">Data Compatibility Check</div>', unsafe_allow_html=True)
248
+
249
+ ref_classes = set(st.session_state.reference_data['class'].unique())
250
+ labeled_classes = set(st.session_state.labeled_data['label'].unique())
251
+
252
+ # Check for unknown classes
253
+ unknown_classes = labeled_classes - ref_classes
254
+
255
+ if unknown_classes:
256
+ st.warning(f"Warning: Labels in labeled data not found in reference data: {unknown_classes}")
257
+ else:
258
+ st.success("✅ Data compatibility check passed!")
259
+
260
+ # Show class overlap
261
+ st.markdown("**Class Overlap Analysis:**")
262
+ col1, col2, col3 = st.columns(3)
263
+
264
+ with col1:
265
+ st.metric("Reference Classes", len(ref_classes))
266
+ with col2:
267
+ st.metric("Labeled Classes", len(labeled_classes))
268
+ with col3:
269
+ st.metric("Common Classes", len(ref_classes.intersection(labeled_classes)))
270
+
271
+ def show_configuration_page():
272
+ st.markdown('<div class="section-header">Model Configuration</div>', unsafe_allow_html=True)
273
+
274
+ # Check if data is loaded
275
+ if st.session_state.reference_data is None or st.session_state.labeled_data is None:
276
+ st.warning("Please upload both reference and labeled data first.")
277
+ return
278
+
279
+ col1, col2 = st.columns(2)
280
+
281
+ with col1:
282
+ st.markdown("### Embedding Model")
283
+
284
+ # Model type selection
285
+ model_type = st.selectbox(
286
+ "Choose model type",
287
+ ["HuggingFace", "Gemini"],
288
+ help="Select the type of embedding model to use"
289
+ )
290
+
291
+ # Model selection based on type
292
+ if model_type == "HuggingFace":
293
+ model_options = [
294
+ "sentence-transformers/all-MiniLM-L6-v2",
295
+ "sentence-transformers/all-mpnet-base-v2",
296
+ "sentence-transformers/distilbert-base-nli-mean-tokens"
297
+ ]
298
+
299
+ selected_model = st.selectbox(
300
+ "Choose HuggingFace model",
301
+ model_options,
302
+ help="Select the pre-trained HuggingFace model for generating embeddings"
303
+ )
304
+ else: # Gemini
305
+ gemini_models = [
306
+ "gemini-embedding-001",
307
+ "text-embedding-004"
308
+ ]
309
+
310
+ selected_model = st.selectbox(
311
+ "Choose Gemini model",
312
+ gemini_models,
313
+ help="Select the Gemini embedding model for generating embeddings"
314
+ )
315
+
316
+ # Calculate total texts to process
317
+ total_texts = 0
318
+ if st.session_state.reference_data is not None:
319
+ total_texts += len(st.session_state.reference_data)
320
+ if st.session_state.labeled_data is not None:
321
+ total_texts += len(st.session_state.labeled_data)
322
+
323
+ st.warning(
324
+ f"⚠️ **Gemini API Rate Limits (Free Tier)**\\n\\n"
325
+ f"- 1,500 requests per day\\n"
326
+ f"- Each batch of 100 texts = 1 request\\n"
327
+ f"- Your current dataset: ~{total_texts} texts\\n"
328
+ f"- Estimated requests needed: ~{(total_texts // 100) + 1}\\n\\n"
329
+ f"If you exceed quota, consider:\\n"
330
+ f"1. Using a smaller dataset\\n"
331
+ f"2. Switching to HuggingFace models (no limits)\\n"
332
+ f"3. Upgrading to a paid API plan"
333
+ )
334
+
335
+ st.info("💡 Note: Using Gemini embeddings requires GOOGLE_API_KEY environment variable to be set.")
336
+
337
+ st.markdown("### Initial Threshold")
338
+ initial_threshold = st.slider(
339
+ "Initial similarity threshold",
340
+ min_value=0.0,
341
+ max_value=1.0,
342
+ value=0.7,
343
+ step=0.05,
344
+ help="Cosine similarity threshold for classification"
345
+ )
346
+
347
+ with col2:
348
+ st.markdown("### Optimization Parameters")
349
+
350
+ optimize_threshold = st.checkbox(
351
+ "Enable threshold optimization",
352
+ value=True,
353
+ help="Automatically find the best threshold"
354
+ )
355
+
356
+ if optimize_threshold:
357
+ col2_1, col2_2 = st.columns(2)
358
+
359
+ with col2_1:
360
+ start_threshold = st.slider(
361
+ "Start threshold",
362
+ min_value=0.0,
363
+ max_value=1.0,
364
+ value=0.5,
365
+ step=0.05
366
+ )
367
+
368
+ end_threshold = st.slider(
369
+ "End threshold",
370
+ min_value=0.0,
371
+ max_value=1.0,
372
+ value=0.9,
373
+ step=0.05
374
+ )
375
+
376
+ with col2_2:
377
+ step_size = st.slider(
378
+ "Step size",
379
+ min_value=0.005,
380
+ max_value=0.05,
381
+ value=0.01,
382
+ step=0.005
383
+ )
384
+
385
+ optimization_metric = st.selectbox(
386
+ "Optimization metric",
387
+ ["f1_macro", "accuracy", "precision_macro", "recall_macro"]
388
+ )
389
+
390
+ # Load models button
391
+ if st.button("Initialize Models", type="primary"):
392
+ with st.spinner("Loading models... This may take a few minutes."):
393
+ try:
394
+ # Initialize classifier
395
+ classifier = Classifier(verbose=False)
396
+
397
+ # Determine model type parameter
398
+ model_type_param = "gemini" if model_type == "Gemini" else "huggingface"
399
+
400
+ classifier.load_models(
401
+ model_name=selected_model,
402
+ model_type=model_type_param,
403
+ threshold=initial_threshold
404
+ )
405
+
406
+ # Prepare reference vectors
407
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp_ref:
408
+ tmp_ref_path = tmp_ref.name
409
+ st.session_state.reference_data.to_csv(tmp_ref_path, index=False)
410
+
411
+ try:
412
+ reference_data = classifier.prepare_reference_vectors(
413
+ reference_path=tmp_ref_path,
414
+ class_column='class',
415
+ node_column='matching_node'
416
+ )
417
+ finally:
418
+ # Ensure file is deleted even if an error occurs
419
+ try:
420
+ os.unlink(tmp_ref_path)
421
+ except (OSError, PermissionError):
422
+ pass # File might already be deleted or locked
423
+
424
+ st.session_state.classifier = classifier
425
+ st.session_state.reference_vectors = reference_data
426
+ st.session_state.config = {
427
+ 'model_type': model_type,
428
+ 'model_name': selected_model,
429
+ 'initial_threshold': initial_threshold,
430
+ 'optimize_threshold': optimize_threshold,
431
+ 'start_threshold': start_threshold if optimize_threshold else None,
432
+ 'end_threshold': end_threshold if optimize_threshold else None,
433
+ 'step_size': step_size if optimize_threshold else None,
434
+ 'optimization_metric': optimization_metric if optimize_threshold else None
435
+ }
436
+
437
+ st.success("✅ Models initialized successfully!")
438
+
439
+ except Exception as e:
440
+ st.error(f"Error initializing models: {str(e)}")
441
+
442
+ # Show current configuration
443
+ if st.session_state.classifier is not None:
444
+ st.markdown('<div class="section-header">Current Configuration</div>', unsafe_allow_html=True)
445
+
446
+ config = st.session_state.config
447
+
448
+ col1, col2, col3 = st.columns(3)
449
+
450
+ with col1:
451
+ st.markdown("**Model Settings:**")
452
+ st.write(f"- Model type: {config['model_type']}")
453
+ st.write(f"- Model: {config['model_name']}")
454
+ st.write(f"- Initial threshold: {config['initial_threshold']}")
455
+
456
+ with col2:
457
+ st.markdown("**Optimization:**")
458
+ st.write(f"- Enabled: {config['optimize_threshold']}")
459
+ if config['optimize_threshold']:
460
+ st.write(f"- Range: {config['start_threshold']:.2f} - {config['end_threshold']:.2f}")
461
+ st.write(f"- Step: {config['step_size']:.3f}")
462
+
463
+ with col3:
464
+ st.markdown("**Data:**")
465
+ st.write(f"- Reference examples: {len(st.session_state.reference_data)}")
466
+ st.write(f"- Labeled samples: {len(st.session_state.labeled_data)}")
467
+
468
+ def show_classification_page():
469
+ st.markdown('<div class="section-header">Classification & Optimization</div>', unsafe_allow_html=True)
470
+
471
+ # Check if models are loaded
472
+ if st.session_state.classifier is None:
473
+ st.warning("Please configure and initialize models first.")
474
+ return
475
+
476
+ # Run classification
477
+ if st.button("Run Classification", type="primary"):
478
+ with st.spinner("Running classification and optimization..."):
479
+ try:
480
+ # Save labeled data to temporary file
481
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as tmp_labeled:
482
+ tmp_labeled_path = tmp_labeled.name
483
+ st.session_state.labeled_data.to_csv(tmp_labeled_path, index=False)
484
+
485
+ try:
486
+ # Run optimization if enabled
487
+ if st.session_state.config['optimize_threshold']:
488
+ optimization_results = st.session_state.classifier.evaluate_classification(
489
+ labeled_path=tmp_labeled_path,
490
+ reference_data=st.session_state.reference_vectors,
491
+ sentence_column='sentence',
492
+ label_column='label',
493
+ optimize_threshold=True,
494
+ start=st.session_state.config['start_threshold'],
495
+ end=st.session_state.config['end_threshold'],
496
+ step=st.session_state.config['step_size']
497
+ )
498
+
499
+ st.session_state.optimization_results = optimization_results
500
+ optimal_threshold = optimization_results["optimal_threshold"]
501
+
502
+ # Update classifier with optimal threshold
503
+ st.session_state.classifier.matcher = SemanticMatcher(
504
+ threshold=optimal_threshold,
505
+ verbose=False
506
+ )
507
+
508
+ st.success(f"✅ Optimization completed! Optimal threshold: {optimal_threshold:.4f}")
509
+
510
+ else:
511
+ optimal_threshold = st.session_state.config['initial_threshold']
512
+
513
+ # Run evaluation
514
+ embedding_model = st.session_state.classifier.embedding_model
515
+ data_loader = DataLoader(verbose=False)
516
+ full_df = data_loader.load_labeled_data(tmp_labeled_path, label_column='label')
517
+
518
+ # Generate embeddings
519
+ full_embeddings = embedding_model.embed_dataframe(full_df, text_column='sentence')
520
+
521
+ # Classify
522
+ match_results = st.session_state.classifier.matcher.match(
523
+ full_embeddings,
524
+ st.session_state.reference_vectors
525
+ )
526
+ predicted_labels = match_results["predicted_class"].tolist()
527
+ true_labels = full_df['label'].tolist()
528
+
529
+ # Evaluate
530
+ evaluator = Evaluator(verbose=False)
531
+ eval_results = evaluator.evaluate(
532
+ true_labels=true_labels,
533
+ predicted_labels=predicted_labels,
534
+ class_names=list(set(true_labels) | set(predicted_labels))
535
+ )
536
+
537
+ # Bootstrap evaluation
538
+ bootstrap_results = evaluator.bootstrap_evaluate(
539
+ true_labels=true_labels,
540
+ predicted_labels=predicted_labels,
541
+ n_iterations=100
542
+ )
543
+
544
+ st.session_state.evaluation_results = eval_results
545
+ st.session_state.bootstrap_results = bootstrap_results
546
+ st.session_state.predictions = {
547
+ 'true_labels': true_labels,
548
+ 'predicted_labels': predicted_labels,
549
+ 'match_results': match_results,
550
+ 'full_df': full_df
551
+ }
552
+
553
+ finally:
554
+ # Ensure temporary file is deleted
555
+ try:
556
+ os.unlink(tmp_labeled_path)
557
+ except (OSError, PermissionError):
558
+ pass # File might already be deleted or locked
559
+
560
+ st.success("✅ Classification completed successfully!")
561
+
562
+ except Exception as e:
563
+ st.error(f"Error during classification: {str(e)}")
564
+
565
+ # Show optimization results if available
566
+ if st.session_state.optimization_results is not None:
567
+ st.markdown('<div class="section-header">Optimization Results</div>', unsafe_allow_html=True)
568
+
569
+ results = st.session_state.optimization_results
570
+
571
+ col1, col2, col3, col4 = st.columns(4)
572
+
573
+ with col1:
574
+ st.metric(
575
+ "Optimal Threshold",
576
+ f"{results['optimal_threshold']:.4f}"
577
+ )
578
+
579
+ with col2:
580
+ st.metric(
581
+ "Accuracy",
582
+ f"{results['optimal_metrics']['accuracy']:.4f}"
583
+ )
584
+
585
+ with col3:
586
+ st.metric(
587
+ "F1 Score",
588
+ f"{results['optimal_metrics']['f1_macro']:.4f}"
589
+ )
590
+
591
+ with col4:
592
+ st.metric(
593
+ "Precision",
594
+ f"{results['optimal_metrics']['precision_macro']:.4f}"
595
+ )
596
+
597
+ # Plot optimization curve
598
+ st.markdown("### Optimization Curve")
599
+
600
+ opt_results = results["results_by_threshold"]
601
+
602
+ fig = make_subplots(
603
+ rows=2, cols=2,
604
+ subplot_titles=('Accuracy', 'F1 Score', 'Precision', 'Recall'),
605
+ vertical_spacing=0.1
606
+ )
607
+
608
+ thresholds = opt_results["thresholds"]
609
+
610
+ # Add traces
611
+ fig.add_trace(
612
+ go.Scatter(x=thresholds, y=opt_results["accuracy"], name="Accuracy"),
613
+ row=1, col=1
614
+ )
615
+ fig.add_trace(
616
+ go.Scatter(x=thresholds, y=opt_results["f1_macro"], name="F1 Score"),
617
+ row=1, col=2
618
+ )
619
+ fig.add_trace(
620
+ go.Scatter(x=thresholds, y=opt_results["precision_macro"], name="Precision"),
621
+ row=2, col=1
622
+ )
623
+ fig.add_trace(
624
+ go.Scatter(x=thresholds, y=opt_results["recall_macro"], name="Recall"),
625
+ row=2, col=2
626
+ )
627
+
628
+ # Add optimal threshold line to each subplot using shapes
629
+ optimal_thresh = results['optimal_threshold']
630
+
631
+ # Add vertical line as shapes to each subplot
632
+ shapes = []
633
+ for row in range(1, 3):
634
+ for col in range(1, 3):
635
+ # Calculate the subplot domain
636
+ xaxis = f'x{(row-1)*2 + col}' if (row-1)*2 + col > 1 else 'x'
637
+ shapes.append(
638
+ dict(
639
+ type="line",
640
+ x0=optimal_thresh, x1=optimal_thresh,
641
+ y0=0, y1=1,
642
+ yref=f"y{(row-1)*2 + col} domain" if (row-1)*2 + col > 1 else "y domain",
643
+ xref=xaxis,
644
+ line=dict(color="red", width=2, dash="dash")
645
+ )
646
+ )
647
+
648
+ fig.update_layout(shapes=shapes)
649
+
650
+ fig.update_layout(
651
+ title="Threshold Optimization Results",
652
+ showlegend=False,
653
+ height=600
654
+ )
655
+
656
+ st.plotly_chart(fig, use_container_width=True)
657
+
658
+ def show_results_page():
659
+ st.markdown('<div class="section-header">Results & Evaluation</div>', unsafe_allow_html=True)
660
+
661
+ # Check if evaluation results are available
662
+ if st.session_state.evaluation_results is None:
663
+ st.warning("Please run classification first to see results.")
664
+ return
665
+
666
+ eval_results = st.session_state.evaluation_results
667
+
668
+ # Performance metrics
669
+ st.markdown("### Performance Metrics")
670
+
671
+ col1, col2, col3, col4 = st.columns(4)
672
+
673
+ with col1:
674
+ st.metric(
675
+ "Overall Accuracy",
676
+ f"{eval_results['accuracy']:.4f}"
677
+ )
678
+
679
+ with col2:
680
+ st.metric(
681
+ "Macro F1 Score",
682
+ f"{eval_results['f1_macro']:.4f}"
683
+ )
684
+
685
+ with col3:
686
+ st.metric(
687
+ "Macro Precision",
688
+ f"{eval_results['precision_macro']:.4f}"
689
+ )
690
+
691
+ with col4:
692
+ st.metric(
693
+ "Macro Recall",
694
+ f"{eval_results['recall_macro']:.4f}"
695
+ )
696
+
697
+ # Class-wise metrics
698
+ st.markdown("### Class-wise Performance")
699
+
700
+ class_metrics_df = pd.DataFrame({
701
+ 'Class': list(eval_results['class_metrics']['precision'].keys()),
702
+ 'Precision': list(eval_results['class_metrics']['precision'].values()),
703
+ 'Recall': list(eval_results['class_metrics']['recall'].values()),
704
+ 'F1-Score': list(eval_results['class_metrics']['f1'].values()),
705
+ 'Support': list(eval_results['class_metrics']['support'].values())
706
+ })
707
+
708
+ st.dataframe(class_metrics_df, use_container_width=True)
709
+
710
+ # Confusion Matrix
711
+ st.markdown("### Confusion Matrix")
712
+
713
+ cm = eval_results['confusion_matrix']
714
+ class_names = eval_results['confusion_matrix_labels']
715
+
716
+ fig = px.imshow(
717
+ cm,
718
+ labels=dict(x="Predicted", y="True", color="Count"),
719
+ x=class_names,
720
+ y=class_names,
721
+ color_continuous_scale='Blues',
722
+ text_auto=True,
723
+ title="Confusion Matrix"
724
+ )
725
+
726
+ fig.update_layout(
727
+ width=600,
728
+ height=600
729
+ )
730
+
731
+ st.plotly_chart(fig, use_container_width=True)
732
+
733
+ # Bootstrap Results
734
+ if st.session_state.bootstrap_results is not None:
735
+ st.markdown("### Bootstrap Confidence Intervals")
736
+
737
+ bootstrap_results = st.session_state.bootstrap_results
738
+
739
+ # Debug: show available keys
740
+ if 'confidence_intervals' in bootstrap_results:
741
+ metrics = ['accuracy', 'precision_macro', 'recall_macro', 'f1_macro']
742
+
743
+ for metric in metrics:
744
+ if metric in bootstrap_results['confidence_intervals']:
745
+ ci_data = bootstrap_results['confidence_intervals'][metric]
746
+ st.markdown(f"**{metric.replace('_', ' ').title()}:**")
747
+
748
+ col1, col2, col3 = st.columns(3)
749
+
750
+ # Check available confidence levels
751
+ available_levels = list(ci_data.keys())
752
+
753
+ with col1:
754
+ if '0.95' in ci_data:
755
+ ci_95 = ci_data['0.95']
756
+ if isinstance(ci_95, dict):
757
+ st.write(f"95% CI: [{ci_95['lower']:.4f}, {ci_95['upper']:.4f}]")
758
+ elif isinstance(ci_95, (list, tuple)) and len(ci_95) >= 2:
759
+ st.write(f"95% CI: [{ci_95[0]:.4f}, {ci_95[1]:.4f}]")
760
+ else:
761
+ st.write("95% CI: Format not recognized")
762
+ elif 0.95 in ci_data:
763
+ ci_95 = ci_data[0.95]
764
+ if isinstance(ci_95, dict):
765
+ st.write(f"95% CI: [{ci_95['lower']:.4f}, {ci_95['upper']:.4f}]")
766
+ elif isinstance(ci_95, (list, tuple)) and len(ci_95) >= 2:
767
+ st.write(f"95% CI: [{ci_95[0]:.4f}, {ci_95[1]:.4f}]")
768
+ else:
769
+ st.write("95% CI: Format not recognized")
770
+ else:
771
+ st.write("95% CI: Not available")
772
+
773
+ with col2:
774
+ if '0.99' in ci_data:
775
+ ci_99 = ci_data['0.99']
776
+ if isinstance(ci_99, dict):
777
+ st.write(f"99% CI: [{ci_99['lower']:.4f}, {ci_99['upper']:.4f}]")
778
+ elif isinstance(ci_99, (list, tuple)) and len(ci_99) >= 2:
779
+ st.write(f"99% CI: [{ci_99[0]:.4f}, {ci_99[1]:.4f}]")
780
+ else:
781
+ st.write("99% CI: Format not recognized")
782
+ elif 0.99 in ci_data:
783
+ ci_99 = ci_data[0.99]
784
+ if isinstance(ci_99, dict):
785
+ st.write(f"99% CI: [{ci_99['lower']:.4f}, {ci_99['upper']:.4f}]")
786
+ elif isinstance(ci_99, (list, tuple)) and len(ci_99) >= 2:
787
+ st.write(f"99% CI: [{ci_99[0]:.4f}, {ci_99[1]:.4f}]")
788
+ else:
789
+ st.write("99% CI: Format not recognized")
790
+ else:
791
+ st.write("99% CI: Not available")
792
+
793
+ with col3:
794
+ if 'point_estimates' in bootstrap_results and metric in bootstrap_results['point_estimates']:
795
+ st.write(f"Point Estimate: {bootstrap_results['point_estimates'][metric]:.4f}")
796
+ else:
797
+ st.write("Point Estimate: Not available")
798
+ else:
799
+ st.info("Bootstrap confidence intervals not available.")
800
+
801
+ # Bootstrap Distribution Plot
802
+ st.markdown("### Bootstrap Distributions")
803
+
804
+ if 'bootstrap_distribution' in bootstrap_results:
805
+ fig = make_subplots(
806
+ rows=2, cols=2,
807
+ subplot_titles=('Accuracy', 'F1 Score', 'Precision', 'Recall')
808
+ )
809
+
810
+ distributions = bootstrap_results['bootstrap_distribution']
811
+
812
+ if 'accuracy' in distributions:
813
+ fig.add_trace(
814
+ go.Histogram(x=distributions['accuracy'], name="Accuracy", nbinsx=30),
815
+ row=1, col=1
816
+ )
817
+ if 'f1_macro' in distributions:
818
+ fig.add_trace(
819
+ go.Histogram(x=distributions['f1_macro'], name="F1 Score", nbinsx=30),
820
+ row=1, col=2
821
+ )
822
+ if 'precision_macro' in distributions:
823
+ fig.add_trace(
824
+ go.Histogram(x=distributions['precision_macro'], name="Precision", nbinsx=30),
825
+ row=2, col=1
826
+ )
827
+ if 'recall_macro' in distributions:
828
+ fig.add_trace(
829
+ go.Histogram(x=distributions['recall_macro'], name="Recall", nbinsx=30),
830
+ row=2, col=2
831
+ )
832
+
833
+ fig.update_layout(
834
+ title="Bootstrap Distributions",
835
+ showlegend=False,
836
+ height=600
837
+ )
838
+
839
+ st.plotly_chart(fig, use_container_width=True)
840
+ else:
841
+ st.info("Bootstrap distributions not available.")
842
+
843
+ # Sample predictions
844
+ if 'predictions' in st.session_state:
845
+ st.markdown("### Sample Predictions")
846
+
847
+ predictions = st.session_state.predictions
848
+ sample_df = predictions['full_df'].copy()
849
+ sample_df['predicted_class'] = predictions['predicted_labels']
850
+ sample_df['true_class'] = predictions['true_labels']
851
+ sample_df['similarity_score'] = predictions['match_results']['similarity_score']
852
+ sample_df['correct'] = sample_df['predicted_class'] == sample_df['true_class']
853
+
854
+ # Filter options
855
+ col1, col2 = st.columns(2)
856
+
857
+ with col1:
858
+ show_correct = st.checkbox("Show correct predictions", value=True)
859
+
860
+ with col2:
861
+ show_incorrect = st.checkbox("Show incorrect predictions", value=True)
862
+
863
+ # Filter data
864
+ if show_correct and show_incorrect:
865
+ filtered_df = sample_df
866
+ elif show_correct:
867
+ filtered_df = sample_df[sample_df['correct'] == True]
868
+ elif show_incorrect:
869
+ filtered_df = sample_df[sample_df['correct'] == False]
870
+ else:
871
+ filtered_df = pd.DataFrame()
872
+
873
+ if not filtered_df.empty:
874
+ # Sample random rows
875
+ n_samples = min(20, len(filtered_df))
876
+ sample_rows = filtered_df.sample(n=n_samples) if len(filtered_df) > n_samples else filtered_df
877
+
878
+ display_df = sample_rows[['sentence', 'true_class', 'predicted_class', 'similarity_score', 'correct']].reset_index(drop=True)
879
+
880
+ st.dataframe(display_df, use_container_width=True)
881
+ else:
882
+ st.info("No predictions to show with current filters.")
883
+
884
+ # Download results
885
+ st.markdown("### Download Results")
886
+
887
+ col1, col2 = st.columns(2)
888
+
889
+ with col1:
890
+ # Download class-wise metrics
891
+ csv_metrics = class_metrics_df.to_csv(index=False)
892
+ st.download_button(
893
+ label="Download Class Metrics",
894
+ data=csv_metrics,
895
+ file_name="class_metrics.csv",
896
+ mime="text/csv"
897
+ )
898
+
899
+ with col2:
900
+ # Download predictions
901
+ if 'predictions' in st.session_state:
902
+ predictions = st.session_state.predictions
903
+ results_df = predictions['full_df'].copy()
904
+ results_df['predicted_class'] = predictions['predicted_labels']
905
+ results_df['similarity_score'] = predictions['match_results']['similarity_score']
906
+
907
+ csv_results = results_df.to_csv(index=False)
908
+ st.download_button(
909
+ label="Download Predictions",
910
+ data=csv_results,
911
+ file_name="predictions.csv",
912
+ mime="text/csv"
913
+ )
914
+
915
+ if __name__ == "__main__":
916
+ main()
app/run_demo.bat ADDED
File without changes
app/run_demo.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick launcher script for the QualiVec Streamlit demo.
4
+ """
5
+
6
+ import subprocess
7
+ import sys
8
+ import os
9
+
10
+ def main():
11
+ """Launch the Streamlit app."""
12
+
13
+ # Get the directory of this script
14
+ script_dir = os.path.dirname(os.path.abspath(__file__))
15
+ app_path = os.path.join(script_dir, "app.py")
16
+
17
+ print("🚀 Starting QualiVec Demo...")
18
+ print("📍 App will be available at: http://localhost:8501")
19
+ print("⏹️ Press Ctrl+C to stop the app")
20
+ print("-" * 50)
21
+
22
+ try:
23
+ # Run streamlit
24
+ subprocess.run([
25
+ sys.executable, "-m", "streamlit", "run", app_path,
26
+ "--server.headless", "true",
27
+ # "--server.address=0.0.0.0",
28
+ "--server.port=8501",
29
+ "--server.enableCORS", "false",
30
+ "--server.enableXsrfProtection", "false"
31
+ ])
32
+ except KeyboardInterrupt:
33
+ print("\n🛑 App stopped by user")
34
+ except Exception as e:
35
+ print(f"❌ Error starting app: {e}")
36
+
37
+ if __name__ == "__main__":
38
+ main()
src/qualivec/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """QualiVec: Qualitative Content Analysis with LLM Embeddings."""
2
+
3
+ from qualivec.data import DataLoader
4
+ from qualivec.sampling import Sampler
5
+ from qualivec.embedding import EmbeddingModel
6
+ from qualivec.matching import SemanticMatcher
7
+ from qualivec.evaluation import Evaluator
8
+ from qualivec.optimization import ThresholdOptimizer
9
+ from qualivec.classification import Classifier
10
+
11
+ __version__ = "0.1.0"
12
+
13
+ def main() -> None:
14
+ print("QualiVec: Qualitative Content Analysis with LLM Embeddings")
15
+ print(f"Version: {__version__}")
src/qualivec/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.15 kB). View file
 
src/qualivec/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (969 Bytes). View file
 
src/qualivec/__pycache__/classification.cpython-312.pyc ADDED
Binary file (8.62 kB). View file
 
src/qualivec/__pycache__/data.cpython-312.pyc ADDED
Binary file (8.37 kB). View file
 
src/qualivec/__pycache__/embedding.cpython-312.pyc ADDED
Binary file (11.2 kB). View file
 
src/qualivec/__pycache__/evaluation.cpython-312.pyc ADDED
Binary file (10.2 kB). View file
 
src/qualivec/__pycache__/matching.cpython-312.pyc ADDED
Binary file (5.07 kB). View file
 
src/qualivec/__pycache__/optimization.cpython-312.pyc ADDED
Binary file (10.7 kB). View file
 
src/qualivec/__pycache__/sampling.cpython-312.pyc ADDED
Binary file (4.78 kB). View file
 
src/qualivec/classification.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Classification utilities for QualiVec."""
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+ from typing import Dict, List, Optional, Any
6
+
7
+ from qualivec.data import DataLoader
8
+ from qualivec.embedding import EmbeddingModel
9
+ from qualivec.matching import SemanticMatcher
10
+
11
+
12
+ class Classifier:
13
+ """Handles classification for QualiVec."""
14
+
15
+ def __init__(self,
16
+ embedding_model: Optional[EmbeddingModel] = None,
17
+ matcher: Optional[SemanticMatcher] = None,
18
+ verbose: bool = True):
19
+ """Initialize the classifier.
20
+
21
+ Args:
22
+ embedding_model: Model for generating embeddings.
23
+ matcher: Model for semantic matching.
24
+ verbose: Whether to print status messages.
25
+ """
26
+ self.embedding_model = embedding_model
27
+ self.matcher = matcher
28
+ self.verbose = verbose
29
+ self.data_loader = DataLoader(verbose=verbose)
30
+
31
+ def load_models(self,
32
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
33
+ model_type: str = "huggingface",
34
+ threshold: float = 0.7):
35
+ """Load embedding model and matcher.
36
+
37
+ Args:
38
+ model_name: Name of the model to use (HuggingFace or Gemini).
39
+ model_type: Type of model ('huggingface' or 'gemini').
40
+ threshold: Cosine similarity threshold for matching.
41
+ """
42
+ if self.verbose:
43
+ print(f"Loading {model_type} embedding model: {model_name}")
44
+
45
+ self.embedding_model = EmbeddingModel(
46
+ model_name=model_name,
47
+ model_type=model_type,
48
+ verbose=self.verbose
49
+ )
50
+ self.matcher = SemanticMatcher(threshold=threshold, verbose=self.verbose)
51
+
52
+ if self.verbose:
53
+ print("Models loaded successfully")
54
+
55
+ def prepare_reference_vectors(self,
56
+ reference_path: str,
57
+ class_column: str = "class",
58
+ node_column: str = "matching_node") -> Dict[str, Any]:
59
+ """Prepare reference vectors from a CSV file.
60
+
61
+ Args:
62
+ reference_path: Path to the CSV file with reference vectors.
63
+ class_column: Name of the column containing class labels.
64
+ node_column: Name of the column containing matching nodes.
65
+
66
+ Returns:
67
+ Dictionary with reference vector information.
68
+ """
69
+ if self.embedding_model is None:
70
+ raise ValueError("Embedding model not loaded. Call load_models first.")
71
+
72
+ # Load reference vectors
73
+ reference_df = self.data_loader.load_reference_vectors(
74
+ reference_path, class_column=class_column, node_column=node_column
75
+ )
76
+
77
+ # Generate embeddings
78
+ reference_data = self.embedding_model.embed_reference_vectors(
79
+ reference_df, class_column=class_column, node_column=node_column
80
+ )
81
+
82
+ if self.verbose:
83
+ print(f"Prepared {len(reference_data['embeddings'])} reference vectors")
84
+ print(f"Unique classes: {len(reference_data['class_to_idx'])}")
85
+
86
+ return reference_data
87
+
88
+ def classify(self,
89
+ corpus_path: str,
90
+ reference_data: Dict[str, Any],
91
+ sentence_column: str = "sentence",
92
+ output_path: Optional[str] = None) -> pd.DataFrame:
93
+ """Classify texts in a corpus using reference vectors.
94
+
95
+ Args:
96
+ corpus_path: Path to the CSV file with corpus.
97
+ reference_data: Dictionary with reference vector information.
98
+ sentence_column: Name of the column containing sentences.
99
+ output_path: Path to save the classification results.
100
+
101
+ Returns:
102
+ DataFrame with classification results.
103
+ """
104
+ if self.embedding_model is None or self.matcher is None:
105
+ raise ValueError("Models not loaded. Call load_models first.")
106
+
107
+ # Load corpus
108
+ corpus_df = self.data_loader.load_corpus(corpus_path, sentence_column=sentence_column)
109
+
110
+ # Generate embeddings
111
+ corpus_embeddings = self.embedding_model.embed_dataframe(
112
+ corpus_df, text_column=sentence_column
113
+ )
114
+
115
+ # Classify
116
+ results_df = self.matcher.classify_corpus(
117
+ corpus_embeddings, reference_data, corpus_df
118
+ )
119
+
120
+ # Save results if output path provided
121
+ if output_path is not None:
122
+ self.data_loader.save_dataframe(results_df, output_path)
123
+ if self.verbose:
124
+ print(f"Saved classification results to {output_path}")
125
+
126
+ return results_df
127
+
128
+ def evaluate_classification(self,
129
+ labeled_path: str,
130
+ reference_data: Dict[str, Any],
131
+ sentence_column: str = "sentence",
132
+ label_column: str = "label",
133
+ optimize_threshold: bool = False,
134
+ start: float = 0.5,
135
+ end: float = 0.9,
136
+ step: float = 0.01) -> Dict[str, Any]:
137
+ """Evaluate classification performance on labeled data.
138
+
139
+ Args:
140
+ labeled_path: Path to the CSV file with labeled data.
141
+ reference_data: Dictionary with reference vector information.
142
+ sentence_column: Name of the column containing sentences.
143
+ label_column: Name of the column containing true labels.
144
+ optimize_threshold: Whether to optimize the threshold.
145
+ start: Start threshold value for optimization.
146
+ end: End threshold value for optimization.
147
+ step: Threshold step size for optimization.
148
+
149
+ Returns:
150
+ Dictionary with evaluation results.
151
+ """
152
+ from qualivec.evaluation import Evaluator
153
+ from qualivec.optimization import ThresholdOptimizer
154
+
155
+ if self.embedding_model is None:
156
+ raise ValueError("Embedding model not loaded. Call load_models first.")
157
+
158
+ # Load labeled data
159
+ labeled_df = self.data_loader.load_labeled_data(labeled_path, label_column=label_column)
160
+
161
+ # Validate labels
162
+ valid = self.data_loader.validate_labels(
163
+ labeled_df,
164
+ pd.DataFrame({
165
+ "class": reference_data["classes"]
166
+ }).drop_duplicates(),
167
+ label_column=label_column,
168
+ class_column="class"
169
+ )
170
+
171
+ if not valid and self.verbose:
172
+ print("Warning: Some labels in the labeled data are not in reference vectors")
173
+
174
+ # Generate embeddings
175
+ labeled_embeddings = self.embedding_model.embed_dataframe(
176
+ labeled_df, text_column=sentence_column
177
+ )
178
+
179
+ # True labels
180
+ true_labels = labeled_df[label_column].tolist()
181
+
182
+ if optimize_threshold:
183
+ # Optimize threshold
184
+ if self.verbose:
185
+ print("Optimizing threshold...")
186
+
187
+ optimizer = ThresholdOptimizer(verbose=self.verbose)
188
+ optimization_results = optimizer.optimize(
189
+ labeled_embeddings,
190
+ reference_data,
191
+ true_labels,
192
+ start=start,
193
+ end=end,
194
+ step=step,
195
+ metric="f1_macro"
196
+ )
197
+
198
+ # Update matcher with optimal threshold
199
+ self.matcher = SemanticMatcher(threshold=optimization_results["optimal_threshold"],
200
+ verbose=self.verbose)
201
+
202
+ return optimization_results
203
+ else:
204
+ # Evaluate with current threshold
205
+ if self.matcher is None:
206
+ raise ValueError("Matcher not loaded. Call load_models first.")
207
+
208
+ # Get predictions
209
+ match_results = self.matcher.match(labeled_embeddings, reference_data)
210
+ predicted_labels = match_results["predicted_class"].tolist()
211
+
212
+ # Evaluate
213
+ evaluator = Evaluator(verbose=self.verbose)
214
+ eval_results = evaluator.bootstrap_evaluate(true_labels, predicted_labels)
215
+
216
+ return eval_results
src/qualivec/data.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data loading and validation utilities for QualiVec."""
2
+
3
+ import os
4
+ import pandas as pd
5
+ from typing import List, Optional, Dict, Any, Union, Tuple
6
+
7
+
8
+ class DataLoader:
9
+ """Handles data loading and validation for QualiVec."""
10
+
11
+ def __init__(self, verbose: bool = True):
12
+ """Initialize the DataLoader.
13
+
14
+ Args:
15
+ verbose: Whether to print status messages.
16
+ """
17
+ self.verbose = verbose
18
+
19
+ def load_corpus(self, filepath: str, sentence_column: str = "sentence") -> pd.DataFrame:
20
+ """Load a corpus from a CSV file.
21
+
22
+ Args:
23
+ filepath: Path to the CSV file.
24
+ sentence_column: Name of the column containing sentences.
25
+
26
+ Returns:
27
+ DataFrame containing the corpus.
28
+
29
+ Raises:
30
+ FileNotFoundError: If the file does not exist.
31
+ ValueError: If the sentence column is missing.
32
+ """
33
+ if not os.path.exists(filepath):
34
+ raise FileNotFoundError(f"File not found: {filepath}")
35
+
36
+ # Load the data
37
+ if self.verbose:
38
+ print(f"Loading corpus from {filepath}...")
39
+
40
+ df = pd.read_csv(filepath)
41
+
42
+ # Validate schema
43
+ if sentence_column not in df.columns:
44
+ raise ValueError(f"Required column '{sentence_column}' not found in the CSV file.")
45
+
46
+ # Basic validation
47
+ if df[sentence_column].isna().any():
48
+ if self.verbose:
49
+ print(f"Warning: {df[sentence_column].isna().sum()} null values found in '{sentence_column}' column.")
50
+
51
+ if self.verbose:
52
+ print(f"Loaded {len(df)} rows from {filepath}")
53
+
54
+ return df
55
+
56
+ def load_reference_vectors(self, filepath: str, class_column: str = "class",
57
+ node_column: str = "matching_node") -> pd.DataFrame:
58
+ """Load reference vectors from a CSV file.
59
+
60
+ Args:
61
+ filepath: Path to the CSV file.
62
+ class_column: Name of the column containing class labels.
63
+ node_column: Name of the column containing matching nodes.
64
+
65
+ Returns:
66
+ DataFrame containing the reference vectors.
67
+
68
+ Raises:
69
+ FileNotFoundError: If the file does not exist.
70
+ ValueError: If required columns are missing.
71
+ """
72
+ if not os.path.exists(filepath):
73
+ raise FileNotFoundError(f"File not found: {filepath}")
74
+
75
+ if self.verbose:
76
+ print(f"Loading reference vectors from {filepath}...")
77
+
78
+ df = pd.read_csv(filepath)
79
+
80
+ # Validate schema
81
+ required_columns = [class_column, node_column]
82
+ missing_columns = [col for col in required_columns if col not in df.columns]
83
+
84
+ if missing_columns:
85
+ raise ValueError(f"Required columns {missing_columns} not found in the CSV file.")
86
+
87
+ # Basic validation
88
+ if df[class_column].isna().any() or df[node_column].isna().any():
89
+ if self.verbose:
90
+ print(f"Warning: Null values found in reference vectors.")
91
+
92
+ if self.verbose:
93
+ print(f"Loaded {len(df)} reference vectors from {filepath}")
94
+ print(f"Unique classes: {df[class_column].nunique()}")
95
+
96
+ return df
97
+
98
+ def load_labeled_data(self, filepath: str, label_column: str = "label") -> pd.DataFrame:
99
+ """Load manually labeled data from a CSV file.
100
+
101
+ Args:
102
+ filepath: Path to the CSV file.
103
+ label_column: Name of the column containing labels.
104
+
105
+ Returns:
106
+ DataFrame containing the labeled data.
107
+
108
+ Raises:
109
+ FileNotFoundError: If the file does not exist.
110
+ ValueError: If the label column is missing.
111
+ """
112
+ if not os.path.exists(filepath):
113
+ raise FileNotFoundError(f"File not found: {filepath}")
114
+
115
+ if self.verbose:
116
+ print(f"Loading labeled data from {filepath}...")
117
+
118
+ df = pd.read_csv(filepath)
119
+
120
+ # Validate schema
121
+ if label_column not in df.columns:
122
+ raise ValueError(f"Required column '{label_column}' not found in the CSV file.")
123
+
124
+ # Basic validation
125
+ if df[label_column].isna().any():
126
+ if self.verbose:
127
+ print(f"Warning: {df[label_column].isna().sum()} null values found in '{label_column}' column.")
128
+
129
+ if self.verbose:
130
+ print(f"Loaded {len(df)} labeled samples from {filepath}")
131
+ print(f"Label distribution:\n{df[label_column].value_counts()}")
132
+
133
+ return df
134
+
135
+ def save_dataframe(self, df: pd.DataFrame, filepath: str) -> None:
136
+ """Save a DataFrame to a CSV file.
137
+
138
+ Args:
139
+ df: DataFrame to save.
140
+ filepath: Path to save the CSV file.
141
+ """
142
+ df.to_csv(filepath, index=False)
143
+
144
+ if self.verbose:
145
+ print(f"Saved {len(df)} rows to {filepath}")
146
+
147
+ def validate_labels(self, labeled_df: pd.DataFrame, reference_df: pd.DataFrame,
148
+ label_column: str = "label", class_column: str = "class") -> bool:
149
+ """Validate that labels in the labeled data are a subset of those in the reference data.
150
+
151
+ Args:
152
+ labeled_df: DataFrame containing labeled data.
153
+ reference_df: DataFrame containing reference vectors.
154
+ label_column: Name of the column containing labels in labeled_df.
155
+ class_column: Name of the column containing classes in reference_df.
156
+
157
+ Returns:
158
+ True if validation passes, False otherwise.
159
+ """
160
+ labeled_classes = set(labeled_df[label_column].unique())
161
+ reference_classes = set(reference_df[class_column].unique())
162
+
163
+ unknown_classes = labeled_classes - reference_classes
164
+
165
+ if unknown_classes:
166
+ if self.verbose:
167
+ print(f"Warning: Found {len(unknown_classes)} labels in labeled data that are not in reference vectors:")
168
+ print(unknown_classes)
169
+ return False
170
+
171
+ if self.verbose:
172
+ print("Label validation passed: All labels in labeled data are in reference vectors.")
173
+
174
+ return True
src/qualivec/embedding.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Embedding utilities for QualiVec."""
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from typing import List, Dict, Any, Optional, Union
6
+ import torch
7
+ from tqdm import tqdm
8
+ from transformers import AutoTokenizer, AutoModel
9
+ import os
10
+ import time
11
+
12
+
13
+ class EmbeddingModel:
14
+ """Handles text embedding for QualiVec."""
15
+
16
+ def __init__(self,
17
+ model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
18
+ model_type: str = "huggingface",
19
+ device: Optional[str] = None,
20
+ cache_dir: Optional[str] = None,
21
+ verbose: bool = True):
22
+ """Initialize the embedding model.
23
+
24
+ Args:
25
+ model_name: Name of the model to use (HuggingFace model or Gemini model).
26
+ model_type: Type of model ('huggingface' or 'gemini').
27
+ device: Device to use for computation ('cpu' or 'cuda'). Only for HuggingFace models.
28
+ cache_dir: Directory to cache models. Only for HuggingFace models.
29
+ verbose: Whether to print status messages.
30
+ """
31
+ self.model_name = model_name
32
+ self.model_type = model_type.lower()
33
+ self.verbose = verbose
34
+ self.cache_dir = cache_dir
35
+
36
+ if self.model_type not in ["huggingface", "gemini"]:
37
+ raise ValueError(f"model_type must be 'huggingface' or 'gemini', got '{model_type}'")
38
+
39
+ if self.model_type == "huggingface":
40
+ # Determine device
41
+ if device is None:
42
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ else:
44
+ self.device = device
45
+
46
+ if self.verbose:
47
+ print(f"Using device: {self.device}")
48
+ print(f"Loading HuggingFace model: {model_name}")
49
+
50
+ # Load model and tokenizer
51
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
52
+ self.model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir).to(self.device)
53
+
54
+ if self.verbose:
55
+ print(f"HuggingFace model loaded successfully")
56
+
57
+ elif self.model_type == "gemini":
58
+ if self.verbose:
59
+ print(f"Initializing Gemini model: {model_name}")
60
+
61
+ # Import Gemini client
62
+ try:
63
+ from google import genai
64
+
65
+ # Get API key from environment variable
66
+ api_key = os.environ.get("GOOGLE_API_KEY")
67
+ if not api_key:
68
+ raise ValueError(
69
+ "GOOGLE_API_KEY environment variable not set. "
70
+ "Please set it with your Gemini API key."
71
+ )
72
+
73
+ self.genai_client = genai.Client(api_key="API_KEY")
74
+
75
+ if self.verbose:
76
+ print(f"Gemini client initialized successfully")
77
+ print(f"⚠️ Free tier limits: 1,500 requests/day, 100 texts per batch")
78
+
79
+ except ImportError:
80
+ raise ImportError("google-genai library is required for Gemini models. Install with: pip install google-genai")
81
+
82
+ def _mean_pooling(self, model_output, attention_mask):
83
+ """Mean pooling operation to get sentence embeddings."""
84
+ token_embeddings = model_output[0]
85
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
86
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
87
+
88
+ def embed_texts(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
89
+ """Generate embeddings for a list of texts.
90
+
91
+ Args:
92
+ texts: List of texts to embed.
93
+ batch_size: Batch size for processing.
94
+
95
+ Returns:
96
+ Numpy array of embeddings.
97
+ """
98
+ if self.verbose:
99
+ print(f"Generating embeddings for {len(texts)} texts")
100
+
101
+ if self.model_type == "huggingface":
102
+ return self._embed_texts_huggingface(texts, batch_size)
103
+ elif self.model_type == "gemini":
104
+ return self._embed_texts_gemini(texts, batch_size)
105
+ else:
106
+ raise ValueError(f"Unsupported model_type: {self.model_type}")
107
+
108
+ def _embed_texts_huggingface(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
109
+ """Generate embeddings using HuggingFace model.
110
+
111
+ Args:
112
+ texts: List of texts to embed.
113
+ batch_size: Batch size for processing.
114
+
115
+ Returns:
116
+ Numpy array of embeddings.
117
+ """
118
+ embeddings = []
119
+
120
+ # Process in batches
121
+ for i in tqdm(range(0, len(texts), batch_size), disable=not self.verbose):
122
+ batch_texts = texts[i:i + batch_size]
123
+
124
+ # Tokenize
125
+ encoded_input = self.tokenizer(batch_texts, padding=True, truncation=True,
126
+ max_length=512, return_tensors='pt').to(self.device)
127
+
128
+ # Get model output
129
+ with torch.no_grad():
130
+ model_output = self.model(**encoded_input)
131
+
132
+ # Mean pooling
133
+ batch_embeddings = self._mean_pooling(model_output, encoded_input['attention_mask'])
134
+
135
+ # Normalize embeddings
136
+ batch_embeddings = torch.nn.functional.normalize(batch_embeddings, p=2, dim=1)
137
+
138
+ # Add to list
139
+ embeddings.append(batch_embeddings.cpu().numpy())
140
+
141
+ # Concatenate all batches
142
+ all_embeddings = np.vstack(embeddings)
143
+
144
+ if self.verbose:
145
+ print(f"Generated embeddings with shape: {all_embeddings.shape}")
146
+
147
+ return all_embeddings
148
+
149
+ def _embed_texts_gemini(self, texts: List[str], batch_size: int = 100) -> np.ndarray:
150
+ """Generate embeddings using Gemini model with rate limiting.
151
+
152
+ Args:
153
+ texts: List of texts to embed.
154
+ batch_size: Batch size for processing (reduced to 100 to respect rate limits).
155
+
156
+ Returns:
157
+ Numpy array of embeddings.
158
+ """
159
+ embeddings = []
160
+
161
+ # Process in batches with rate limiting
162
+ for i in tqdm(range(0, len(texts), batch_size), disable=not self.verbose):
163
+ batch_texts = texts[i:i + batch_size]
164
+
165
+ # Retry logic with exponential backoff
166
+ max_retries = 3
167
+ retry_delay = 2 # seconds
168
+
169
+ for attempt in range(max_retries):
170
+ try:
171
+ # Get embeddings from Gemini
172
+ result = self.genai_client.models.embed_content(
173
+ model=self.model_name,
174
+ contents=batch_texts # type: ignore
175
+ )
176
+
177
+ # Extract embeddings
178
+ if result.embeddings:
179
+ batch_embeddings = [emb.values for emb in result.embeddings]
180
+ embeddings.extend(batch_embeddings)
181
+
182
+ # Add delay between batches to respect rate limits (free tier: 1500 requests/day)
183
+ # With 100 texts per batch and ~60 second delay, we can process ~1440 texts/day
184
+ if i + batch_size < len(texts):
185
+ time.sleep(1) # 1 second delay between batches
186
+
187
+ break # Success, exit retry loop
188
+
189
+ except Exception as e:
190
+ error_msg = str(e)
191
+ if "429" in error_msg or "RESOURCE_EXHAUSTED" in error_msg:
192
+ if attempt < max_retries - 1:
193
+ if self.verbose:
194
+ print(f"\nRate limit hit. Waiting {retry_delay} seconds before retry {attempt + 1}/{max_retries}...")
195
+ time.sleep(retry_delay)
196
+ retry_delay *= 2 # Exponential backoff
197
+ else:
198
+ raise Exception(
199
+ f"Gemini API quota exceeded. Free tier limits: 1500 requests/day.\n"
200
+ f"Error: {error_msg}\n\n"
201
+ f"Solutions:\n"
202
+ f"1. Wait and try again later (quota resets daily)\n"
203
+ f"2. Reduce the amount of data being processed\n"
204
+ f"3. Upgrade to a paid API plan\n"
205
+ f"4. Use HuggingFace models instead (no API limits)"
206
+ )
207
+ else:
208
+ raise # Re-raise non-quota errors
209
+
210
+ # Convert to numpy array
211
+ all_embeddings = np.array(embeddings)
212
+
213
+ if self.verbose:
214
+ print(f"Generated embeddings with shape: {all_embeddings.shape}")
215
+
216
+ return all_embeddings
217
+
218
+ def embed_dataframe(self,
219
+ df: pd.DataFrame,
220
+ text_column: str,
221
+ batch_size: int = 32) -> np.ndarray:
222
+ """Generate embeddings for texts in a DataFrame column.
223
+
224
+ Args:
225
+ df: DataFrame containing texts.
226
+ text_column: Name of the column containing texts.
227
+ batch_size: Batch size for processing.
228
+
229
+ Returns:
230
+ Numpy array of embeddings.
231
+ """
232
+ if text_column not in df.columns:
233
+ raise ValueError(f"Column '{text_column}' not found in DataFrame.")
234
+
235
+ texts = df[text_column].fillna("").tolist()
236
+ return self.embed_texts(texts, batch_size)
237
+
238
+ def embed_reference_vectors(self,
239
+ df: pd.DataFrame,
240
+ class_column: str = "class",
241
+ node_column: str = "matching_node",
242
+ batch_size: int = 32) -> Dict[str, Any]:
243
+ """Generate embeddings for reference vectors.
244
+
245
+ Args:
246
+ df: DataFrame containing reference vectors.
247
+ class_column: Name of the column containing class labels.
248
+ node_column: Name of the column containing matching nodes.
249
+ batch_size: Batch size for processing.
250
+
251
+ Returns:
252
+ Dictionary with class info and embeddings.
253
+ """
254
+ required_columns = [class_column, node_column]
255
+ missing_columns = [col for col in required_columns if col not in df.columns]
256
+
257
+ if missing_columns:
258
+ raise ValueError(f"Required columns {missing_columns} not found in DataFrame.")
259
+
260
+ # Get texts and generate embeddings
261
+ texts = df[node_column].fillna("").tolist()
262
+ embeddings = self.embed_texts(texts, batch_size)
263
+
264
+ # Create result dictionary
265
+ result = {
266
+ "classes": df[class_column].tolist(),
267
+ "nodes": df[node_column].tolist(),
268
+ "embeddings": embeddings,
269
+ "class_to_idx": {cls: i for i, cls in enumerate(df[class_column].unique())}
270
+ }
271
+
272
+ if self.verbose:
273
+ print(f"Generated embeddings for {len(result['classes'])} reference vectors")
274
+ print(f"Unique classes: {len(result['class_to_idx'])}")
275
+
276
+ return result
src/qualivec/evaluation.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Evaluation utilities for QualiVec."""
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from typing import Dict, List, Tuple, Optional, Union, Any
6
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+ from tqdm import tqdm
10
+
11
+
12
+ class Evaluator:
13
+ """Handles evaluation for QualiVec."""
14
+
15
+ def __init__(self, verbose: bool = True):
16
+ """Initialize the evaluator.
17
+
18
+ Args:
19
+ verbose: Whether to print status messages.
20
+ """
21
+ self.verbose = verbose
22
+
23
+ def evaluate(self,
24
+ true_labels: List[str],
25
+ predicted_labels: List[str],
26
+ class_names: Optional[List[str]] = None) -> Dict[str, Any]:
27
+ """Evaluate predictions against true labels.
28
+
29
+ Args:
30
+ true_labels: List of true class labels.
31
+ predicted_labels: List of predicted class labels.
32
+ class_names: List of class names for detailed metrics.
33
+
34
+ Returns:
35
+ Dictionary with evaluation metrics.
36
+ """
37
+ if len(true_labels) != len(predicted_labels):
38
+ raise ValueError(f"Length mismatch: {len(true_labels)} true labels vs {len(predicted_labels)} predictions")
39
+
40
+ if self.verbose:
41
+ print(f"Evaluating {len(true_labels)} predictions")
42
+
43
+ # Calculate metrics
44
+ accuracy = accuracy_score(true_labels, predicted_labels)
45
+
46
+ # If class_names not provided, use unique values from true and predicted
47
+ if class_names is None:
48
+ class_names = sorted(set(true_labels) | set(predicted_labels))
49
+
50
+ # Calculate precision, recall, F1 (macro average)
51
+ precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
52
+ true_labels, predicted_labels, average='macro'
53
+ )
54
+
55
+ # Calculate per-class metrics
56
+ precision, recall, f1, support = precision_recall_fscore_support(
57
+ true_labels, predicted_labels, labels=class_names, average=None
58
+ )
59
+
60
+ # Create class-wise metrics
61
+ class_metrics = {
62
+ "precision": {cls: p for cls, p in zip(class_names, precision)},
63
+ "recall": {cls: r for cls, r in zip(class_names, recall)},
64
+ "f1": {cls: f for cls, f in zip(class_names, f1)},
65
+ "support": {cls: s for cls, s in zip(class_names, support)}
66
+ }
67
+
68
+ # Create confusion matrix
69
+ cm = confusion_matrix(true_labels, predicted_labels, labels=class_names)
70
+
71
+ # Compile results
72
+ results = {
73
+ "accuracy": accuracy,
74
+ "precision_macro": precision_macro,
75
+ "recall_macro": recall_macro,
76
+ "f1_macro": f1_macro,
77
+ "class_metrics": class_metrics,
78
+ "confusion_matrix": cm,
79
+ "confusion_matrix_labels": class_names,
80
+ "n_samples": len(true_labels)
81
+ }
82
+
83
+ if self.verbose:
84
+ print(f"Accuracy: {accuracy:.4f}")
85
+ print(f"Precision (macro): {precision_macro:.4f}")
86
+ print(f"Recall (macro): {recall_macro:.4f}")
87
+ print(f"F1 (macro): {f1_macro:.4f}")
88
+
89
+ return results
90
+
91
+ def bootstrap_evaluate(self,
92
+ true_labels: List[str],
93
+ predicted_labels: List[str],
94
+ n_iterations: int = 1000,
95
+ confidence_levels: List[float] = [0.9, 0.95, 0.99],
96
+ random_seed: Optional[int] = None) -> Dict[str, Any]:
97
+ """Evaluate with bootstrap confidence intervals.
98
+
99
+ Args:
100
+ true_labels: List of true class labels.
101
+ predicted_labels: List of predicted class labels.
102
+ n_iterations: Number of bootstrap iterations.
103
+ confidence_levels: Confidence levels to compute.
104
+ random_seed: Random seed for reproducibility.
105
+
106
+ Returns:
107
+ Dictionary with evaluation metrics and confidence intervals.
108
+ """
109
+ if len(true_labels) != len(predicted_labels):
110
+ raise ValueError(f"Length mismatch: {len(true_labels)} true labels vs {len(predicted_labels)} predictions")
111
+
112
+ if self.verbose:
113
+ print(f"Running bootstrap evaluation with {n_iterations} iterations")
114
+
115
+ # Set random seed
116
+ if random_seed is not None:
117
+ np.random.seed(random_seed)
118
+
119
+ # Initialize storage for bootstrap results
120
+ bootstrap_metrics = {
121
+ "accuracy": [],
122
+ "precision_macro": [],
123
+ "recall_macro": [],
124
+ "f1_macro": []
125
+ }
126
+
127
+ # Original evaluation
128
+ original_results = self.evaluate(true_labels, predicted_labels)
129
+
130
+ # Run bootstrap iterations
131
+ n_samples = len(true_labels)
132
+
133
+ for _ in tqdm(range(n_iterations), disable=not self.verbose):
134
+ # Sample with replacement
135
+ indices = np.random.choice(n_samples, size=n_samples, replace=True)
136
+
137
+ # Get bootstrap sample
138
+ bootstrap_true = [true_labels[i] for i in indices]
139
+ bootstrap_pred = [predicted_labels[i] for i in indices]
140
+
141
+ # Evaluate
142
+ results = self.evaluate(bootstrap_true, bootstrap_pred)
143
+
144
+ # Store results
145
+ bootstrap_metrics["accuracy"].append(results["accuracy"])
146
+ bootstrap_metrics["precision_macro"].append(results["precision_macro"])
147
+ bootstrap_metrics["recall_macro"].append(results["recall_macro"])
148
+ bootstrap_metrics["f1_macro"].append(results["f1_macro"])
149
+
150
+ # Calculate confidence intervals
151
+ confidence_intervals = {}
152
+
153
+ for metric, values in bootstrap_metrics.items():
154
+ confidence_intervals[metric] = {}
155
+ for level in confidence_levels:
156
+ lower_percentile = (1 - level) / 2 * 100
157
+ upper_percentile = (1 + level) / 2 * 100
158
+
159
+ lower = np.percentile(values, lower_percentile)
160
+ upper = np.percentile(values, upper_percentile)
161
+
162
+ confidence_intervals[metric][level] = (lower, upper)
163
+
164
+ # Combine results
165
+ results = {
166
+ "point_estimates": {
167
+ "accuracy": original_results["accuracy"],
168
+ "precision_macro": original_results["precision_macro"],
169
+ "recall_macro": original_results["recall_macro"],
170
+ "f1_macro": original_results["f1_macro"]
171
+ },
172
+ "confidence_intervals": confidence_intervals,
173
+ "bootstrap_distribution": bootstrap_metrics,
174
+ "n_iterations": n_iterations,
175
+ "n_samples": n_samples
176
+ }
177
+
178
+ if self.verbose:
179
+ print(f"Bootstrap evaluation complete")
180
+ print(f"Accuracy: {results['point_estimates']['accuracy']:.4f}")
181
+ for level in confidence_levels:
182
+ lower, upper = results['confidence_intervals']['accuracy'][level]
183
+ print(f" {level*100:.0f}% CI: ({lower:.4f}, {upper:.4f})")
184
+
185
+ return results
186
+
187
+ def plot_confusion_matrix(self,
188
+ confusion_matrix: np.ndarray,
189
+ class_names: List[str],
190
+ figsize: Tuple[int, int] = (10, 8),
191
+ title: str = "Confusion Matrix"):
192
+ """Plot a confusion matrix.
193
+
194
+ Args:
195
+ confusion_matrix: Confusion matrix as numpy array.
196
+ class_names: List of class names.
197
+ figsize: Figure size as (width, height).
198
+ title: Plot title.
199
+ """
200
+ plt.figure(figsize=figsize)
201
+
202
+ # Create heatmap
203
+ sns.heatmap(
204
+ confusion_matrix,
205
+ annot=True,
206
+ fmt="d",
207
+ cmap="Blues",
208
+ xticklabels=class_names,
209
+ yticklabels=class_names
210
+ )
211
+
212
+ plt.xlabel("Predicted")
213
+ plt.ylabel("True")
214
+ plt.title(title)
215
+ plt.tight_layout()
216
+ plt.show()
217
+
218
+ def plot_bootstrap_distributions(self, bootstrap_results: Dict[str, Any], figsize: Tuple[int, int] = (12, 8)):
219
+ """Plot bootstrap distributions for key metrics.
220
+
221
+ Args:
222
+ bootstrap_results: Results from bootstrap_evaluate.
223
+ figsize: Figure size as (width, height).
224
+ """
225
+ metrics = ["accuracy", "precision_macro", "recall_macro", "f1_macro"]
226
+
227
+ plt.figure(figsize=figsize)
228
+
229
+ for i, metric in enumerate(metrics):
230
+ plt.subplot(2, 2, i+1)
231
+
232
+ # Get distribution data
233
+ values = bootstrap_results["bootstrap_distribution"][metric]
234
+
235
+ # Plot histogram
236
+ sns.histplot(values, kde=True)
237
+
238
+ # Add point estimate
239
+ point_est = bootstrap_results["point_estimates"][metric]
240
+ plt.axvline(point_est, color='red', linestyle='--', label=f'Point est: {point_est:.4f}')
241
+
242
+ # Add confidence intervals
243
+ for level, (lower, upper) in bootstrap_results["confidence_intervals"][metric].items():
244
+ plt.axvline(lower, color='green', linestyle=':',
245
+ label=f'{level*100:.0f}% CI: ({lower:.4f}, {upper:.4f})')
246
+ plt.axvline(upper, color='green', linestyle=':')
247
+
248
+ plt.title(f"{metric.replace('_', ' ').title()}")
249
+
250
+ if i == 0: # Only add legend to first plot
251
+ plt.legend(loc='best')
252
+
253
+ plt.tight_layout()
254
+ plt.show()
src/qualivec/matching.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Semantic matching utilities for QualiVec."""
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from typing import Dict, Any, List, Tuple, Optional
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+
8
+
9
+ class SemanticMatcher:
10
+ """Handles semantic matching for QualiVec."""
11
+
12
+ def __init__(self,
13
+ threshold: float = 0.7,
14
+ verbose: bool = True):
15
+ """Initialize the semantic matcher.
16
+
17
+ Args:
18
+ threshold: Cosine similarity threshold for matching.
19
+ verbose: Whether to print status messages.
20
+ """
21
+ if not 0 <= threshold <= 1:
22
+ raise ValueError("Threshold must be between 0 and 1.")
23
+
24
+ self.threshold = threshold
25
+ self.verbose = verbose
26
+
27
+ def match(self,
28
+ query_embeddings: np.ndarray,
29
+ reference_data: Dict[str, Any],
30
+ return_similarities: bool = False) -> pd.DataFrame:
31
+ """Match query embeddings against reference vectors.
32
+
33
+ Args:
34
+ query_embeddings: Embeddings of the query texts.
35
+ reference_data: Dictionary with reference vector information.
36
+ return_similarities: Whether to return all similarity scores.
37
+
38
+ Returns:
39
+ DataFrame with matching results.
40
+ """
41
+ if self.verbose:
42
+ print(f"Matching {len(query_embeddings)} queries against {len(reference_data['embeddings'])} reference vectors")
43
+ print(f"Using cosine similarity threshold: {self.threshold}")
44
+
45
+ # Calculate cosine similarity
46
+ similarities = cosine_similarity(query_embeddings, reference_data['embeddings'])
47
+
48
+ # Find best matches
49
+ best_match_indices = np.argmax(similarities, axis=1)
50
+ best_match_scores = np.max(similarities, axis=1)
51
+
52
+ # Apply threshold
53
+ matches_mask = best_match_scores >= self.threshold
54
+
55
+ # Create results
56
+ classes = np.array(reference_data['classes'])[best_match_indices]
57
+ nodes = np.array(reference_data['nodes'])[best_match_indices]
58
+
59
+ # Apply threshold (set to "Other" if below threshold)
60
+ classes = np.where(matches_mask, classes, "Other")
61
+ nodes = np.where(matches_mask, nodes, "")
62
+
63
+ # Create result DataFrame
64
+ results = pd.DataFrame({
65
+ "predicted_class": classes,
66
+ "matched_node": nodes,
67
+ "similarity_score": best_match_scores
68
+ })
69
+
70
+ if return_similarities:
71
+ results["all_similarities"] = list(similarities)
72
+
73
+ if self.verbose:
74
+ print(f"Matching complete: {matches_mask.sum()} matches above threshold ({matches_mask.mean():.1%})")
75
+ print(f"Class distribution:\n{results['predicted_class'].value_counts().head(10)}")
76
+
77
+ return results
78
+
79
+ def classify_corpus(self,
80
+ corpus_embeddings: np.ndarray,
81
+ reference_data: Dict[str, Any],
82
+ corpus_df: pd.DataFrame) -> pd.DataFrame:
83
+ """Classify an entire corpus using semantic matching.
84
+
85
+ Args:
86
+ corpus_embeddings: Embeddings of the corpus texts.
87
+ reference_data: Dictionary with reference vector information.
88
+ corpus_df: DataFrame containing the original corpus.
89
+
90
+ Returns:
91
+ DataFrame with classification results.
92
+ """
93
+ # Perform matching
94
+ match_results = self.match(corpus_embeddings, reference_data)
95
+
96
+ # Combine with original corpus
97
+ result_df = pd.concat([corpus_df.reset_index(drop=True),
98
+ match_results.reset_index(drop=True)], axis=1)
99
+
100
+ if self.verbose:
101
+ print(f"Classified {len(result_df)} documents")
102
+ print(f"Class distribution:\n{result_df['predicted_class'].value_counts().head(10)}")
103
+
104
+ return result_df
src/qualivec/optimization.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Threshold optimization utilities for QualiVec."""
2
+
3
+ import numpy as np
4
+ import pandas as pd
5
+ from typing import Dict, List, Tuple, Optional, Union, Any, Callable
6
+ from tqdm import tqdm
7
+ import matplotlib.pyplot as plt
8
+ import seaborn as sns
9
+
10
+ from qualivec.matching import SemanticMatcher
11
+ from qualivec.evaluation import Evaluator
12
+
13
+
14
+ class ThresholdOptimizer:
15
+ """Handles threshold optimization for QualiVec."""
16
+
17
+ def __init__(self,
18
+ verbose: bool = True):
19
+ """Initialize the threshold optimizer.
20
+
21
+ Args:
22
+ verbose: Whether to print status messages.
23
+ """
24
+ self.verbose = verbose
25
+ self.evaluator = Evaluator(verbose=False)
26
+
27
+ def optimize(self,
28
+ query_embeddings: np.ndarray,
29
+ reference_data: Dict[str, Any],
30
+ true_labels: List[str],
31
+ start: float = 0.0,
32
+ end: float = 1.0,
33
+ step: float = 0.01,
34
+ metric: str = "f1_macro",
35
+ bootstrap: bool = True,
36
+ n_bootstrap: int = 100,
37
+ confidence_level: float = 0.95,
38
+ random_seed: Optional[int] = None) -> Dict[str, Any]:
39
+ """Find the optimal similarity threshold.
40
+
41
+ Args:
42
+ query_embeddings: Embeddings of the query texts.
43
+ reference_data: Dictionary with reference vector information.
44
+ true_labels: True class labels for evaluation.
45
+ start: Start threshold value.
46
+ end: End threshold value.
47
+ step: Threshold step size.
48
+ metric: Metric to optimize ("accuracy", "precision_macro", "recall_macro", "f1_macro").
49
+ bootstrap: Whether to use bootstrap evaluation.
50
+ n_bootstrap: Number of bootstrap iterations.
51
+ confidence_level: Confidence level for bootstrap.
52
+ random_seed: Random seed for reproducibility.
53
+
54
+ Returns:
55
+ Dictionary with optimization results.
56
+ """
57
+ if not 0 <= start < end <= 1:
58
+ raise ValueError("Threshold range must be between 0 and 1")
59
+
60
+ if metric not in ["accuracy", "precision_macro", "recall_macro", "f1_macro"]:
61
+ raise ValueError(f"Unsupported metric: {metric}")
62
+
63
+ if self.verbose:
64
+ print(f"Optimizing threshold for {metric}")
65
+ print(f"Threshold range: {start} to {end} (step: {step})")
66
+
67
+ # Generate threshold values
68
+ thresholds = np.arange(start, end + step/2, step)
69
+
70
+ # Initialize results storage
71
+ results = {
72
+ "thresholds": [],
73
+ "accuracy": [],
74
+ "precision_macro": [],
75
+ "recall_macro": [],
76
+ "f1_macro": [],
77
+ "class_distribution": []
78
+ }
79
+
80
+ if bootstrap:
81
+ results["confidence_intervals"] = []
82
+
83
+ # Evaluate each threshold
84
+ for threshold in tqdm(thresholds, disable=not self.verbose):
85
+ # Create matcher with current threshold
86
+ matcher = SemanticMatcher(threshold=threshold, verbose=False)
87
+
88
+ # Get predictions
89
+ match_results = matcher.match(query_embeddings, reference_data)
90
+ predicted_labels = match_results["predicted_class"].tolist()
91
+
92
+ # Calculate class distribution
93
+ class_distribution = pd.Series(predicted_labels).value_counts().to_dict()
94
+
95
+ # Evaluate
96
+ if bootstrap:
97
+ eval_results = self.evaluator.bootstrap_evaluate(
98
+ true_labels,
99
+ predicted_labels,
100
+ n_iterations=n_bootstrap,
101
+ confidence_levels=[confidence_level],
102
+ random_seed=random_seed
103
+ )
104
+
105
+ # Extract point estimates
106
+ point_estimates = eval_results["point_estimates"]
107
+
108
+ # Extract confidence intervals
109
+ ci = {m: eval_results["confidence_intervals"][m][confidence_level]
110
+ for m in ["accuracy", "precision_macro", "recall_macro", "f1_macro"]}
111
+
112
+ results["confidence_intervals"].append(ci)
113
+ else:
114
+ eval_results = self.evaluator.evaluate(true_labels, predicted_labels)
115
+ point_estimates = {
116
+ "accuracy": eval_results["accuracy"],
117
+ "precision_macro": eval_results["precision_macro"],
118
+ "recall_macro": eval_results["recall_macro"],
119
+ "f1_macro": eval_results["f1_macro"]
120
+ }
121
+
122
+ # Store results
123
+ results["thresholds"].append(threshold)
124
+ results["accuracy"].append(point_estimates["accuracy"])
125
+ results["precision_macro"].append(point_estimates["precision_macro"])
126
+ results["recall_macro"].append(point_estimates["recall_macro"])
127
+ results["f1_macro"].append(point_estimates["f1_macro"])
128
+ results["class_distribution"].append(class_distribution)
129
+
130
+ # Find optimal threshold
131
+ optimal_idx = np.argmax(results[metric])
132
+ optimal_threshold = results["thresholds"][optimal_idx]
133
+ optimal_metrics = {
134
+ "accuracy": results["accuracy"][optimal_idx],
135
+ "precision_macro": results["precision_macro"][optimal_idx],
136
+ "recall_macro": results["recall_macro"][optimal_idx],
137
+ "f1_macro": results["f1_macro"][optimal_idx]
138
+ }
139
+
140
+ if bootstrap:
141
+ optimal_ci = results["confidence_intervals"][optimal_idx]
142
+ else:
143
+ optimal_ci = None
144
+
145
+ # Compile results
146
+ optimization_results = {
147
+ "optimal_threshold": optimal_threshold,
148
+ "optimal_metrics": optimal_metrics,
149
+ "optimal_confidence_intervals": optimal_ci,
150
+ "results_by_threshold": results,
151
+ "optimized_metric": metric,
152
+ "n_thresholds": len(thresholds)
153
+ }
154
+
155
+ if self.verbose:
156
+ print(f"Optimal threshold: {optimal_threshold:.4f}")
157
+ print(f"Optimal {metric}: {optimal_metrics[metric]:.4f}")
158
+ if bootstrap:
159
+ lower, upper = optimal_ci[metric]
160
+ print(f" {confidence_level*100:.0f}% CI: ({lower:.4f}, {upper:.4f})")
161
+
162
+ return optimization_results
163
+
164
+ def plot_optimization_results(self,
165
+ results: Dict[str, Any],
166
+ metrics: Optional[List[str]] = None,
167
+ figsize: Tuple[int, int] = (12, 6)):
168
+ """Plot optimization results.
169
+
170
+ Args:
171
+ results: Results from optimize method.
172
+ metrics: List of metrics to plot.
173
+ figsize: Figure size as (width, height).
174
+ """
175
+ if metrics is None:
176
+ metrics = ["accuracy", "precision_macro", "recall_macro", "f1_macro"]
177
+
178
+ plt.figure(figsize=figsize)
179
+
180
+ # Get data
181
+ thresholds = results["results_by_threshold"]["thresholds"]
182
+
183
+ # Plot metrics
184
+ for metric in metrics:
185
+ values = results["results_by_threshold"][metric]
186
+ plt.plot(thresholds, values, label=metric.replace("_", " ").title())
187
+
188
+ # Highlight optimal threshold
189
+ if metric == results["optimized_metric"]:
190
+ optimal_threshold = results["optimal_threshold"]
191
+ optimal_value = results["optimal_metrics"][metric]
192
+ plt.scatter([optimal_threshold], [optimal_value], color='red', s=100, zorder=5)
193
+ plt.axvline(optimal_threshold, color='red', linestyle='--', alpha=0.5,
194
+ label=f"Optimal Threshold: {optimal_threshold:.4f}")
195
+
196
+ plt.xlabel("Threshold")
197
+ plt.ylabel("Metric Value")
198
+ plt.title("Threshold Optimization Results")
199
+ plt.legend()
200
+ plt.grid(True, alpha=0.3)
201
+ plt.tight_layout()
202
+ plt.show()
203
+
204
+ def plot_class_distribution(self,
205
+ results: Dict[str, Any],
206
+ top_n: int = 10,
207
+ figsize: Tuple[int, int] = (12, 8)):
208
+ """Plot class distribution at different thresholds.
209
+
210
+ Args:
211
+ results: Results from optimize method.
212
+ top_n: Number of top classes to show.
213
+ figsize: Figure size as (width, height).
214
+ """
215
+ # Get data
216
+ thresholds = results["results_by_threshold"]["thresholds"]
217
+ distributions = results["results_by_threshold"]["class_distribution"]
218
+
219
+ # Find all classes
220
+ all_classes = set()
221
+ for dist in distributions:
222
+ all_classes.update(dist.keys())
223
+
224
+ # Count total occurrences to find top classes
225
+ total_counts = {}
226
+ for cls in all_classes:
227
+ total_counts[cls] = sum(dist.get(cls, 0) for dist in distributions)
228
+
229
+ # Get top N classes
230
+ top_classes = sorted(all_classes, key=lambda x: total_counts[x], reverse=True)[:top_n]
231
+
232
+ # Create data for plot
233
+ data = []
234
+ for i, threshold in enumerate(thresholds):
235
+ dist = distributions[i]
236
+ for cls in top_classes:
237
+ data.append({
238
+ "Threshold": threshold,
239
+ "Class": cls,
240
+ "Count": dist.get(cls, 0)
241
+ })
242
+
243
+ # Create dataframe
244
+ df = pd.DataFrame(data)
245
+
246
+ # Create plot
247
+ plt.figure(figsize=figsize)
248
+
249
+ # Use seaborn for line plot
250
+ sns.lineplot(data=df, x="Threshold", y="Count", hue="Class")
251
+
252
+ # Add vertical line for optimal threshold
253
+ optimal_threshold = results["optimal_threshold"]
254
+ plt.axvline(optimal_threshold, color='red', linestyle='--', alpha=0.5,
255
+ label=f"Optimal Threshold: {optimal_threshold:.4f}")
256
+
257
+ plt.title("Class Distribution by Threshold")
258
+ plt.xlabel("Threshold")
259
+ plt.ylabel("Count")
260
+ plt.legend(title="Class")
261
+ plt.grid(True, alpha=0.3)
262
+ plt.tight_layout()
263
+ plt.show()
src/qualivec/sampling.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sampling utilities for QualiVec."""
2
+
3
+ import pandas as pd
4
+ import numpy as np
5
+ from typing import Optional, Union, Literal
6
+
7
+
8
+ class Sampler:
9
+ """Handles sampling mechanisms for QualiVec."""
10
+
11
+ def __init__(self, verbose: bool = True):
12
+ """Initialize the Sampler.
13
+
14
+ Args:
15
+ verbose: Whether to print status messages.
16
+ """
17
+ self.verbose = verbose
18
+
19
+ def sample(self,
20
+ df: pd.DataFrame,
21
+ sampling_type: Literal["random", "stratified"] = "random",
22
+ sample_size: Union[int, float] = 0.1,
23
+ stratify_column: Optional[str] = None,
24
+ seed: Optional[int] = None,
25
+ label_column: str = "Label") -> pd.DataFrame:
26
+ """Sample data from a DataFrame.
27
+
28
+ Args:
29
+ df: DataFrame to sample from.
30
+ sampling_type: Type of sampling ("random" or "stratified").
31
+ sample_size: Size of the sample. If float, interpreted as a fraction.
32
+ stratify_column: Column to stratify by (required for stratified sampling).
33
+ seed: Random seed for reproducibility.
34
+ label_column: Name of the label column to add to the output.
35
+
36
+ Returns:
37
+ DataFrame containing the sampled data.
38
+
39
+ Raises:
40
+ ValueError: If parameters are invalid.
41
+ """
42
+ # Set random seed if provided
43
+ if seed is not None:
44
+ np.random.seed(seed)
45
+
46
+ # Calculate sample size if given as a fraction
47
+ if isinstance(sample_size, float):
48
+ if not 0 < sample_size <= 1:
49
+ raise ValueError("Sample size as fraction must be between 0 and 1.")
50
+ n_samples = int(len(df) * sample_size)
51
+ else:
52
+ if not 0 < sample_size <= len(df):
53
+ raise ValueError(f"Sample size must be between 1 and {len(df)}.")
54
+ n_samples = sample_size
55
+
56
+ if self.verbose:
57
+ print(f"Sampling {n_samples} rows ({n_samples/len(df):.1%} of data)...")
58
+
59
+ # Perform sampling
60
+ if sampling_type == "random":
61
+ sample = df.sample(n=n_samples, random_state=seed)
62
+
63
+ elif sampling_type == "stratified":
64
+ if stratify_column is None:
65
+ raise ValueError("stratify_column must be provided for stratified sampling.")
66
+
67
+ if stratify_column not in df.columns:
68
+ raise ValueError(f"Stratification column '{stratify_column}' not found in DataFrame.")
69
+
70
+ # Check for NaN values in stratification column
71
+ if df[stratify_column].isna().any():
72
+ raise ValueError(f"NaN values found in stratification column '{stratify_column}'.")
73
+
74
+ # Calculate the proportion for each stratum
75
+ strata = df[stratify_column].value_counts(normalize=True)
76
+
77
+ # Create empty sample DataFrame
78
+ sample = pd.DataFrame(columns=df.columns)
79
+
80
+ # Sample from each stratum
81
+ for stratum, proportion in strata.items():
82
+ stratum_df = df[df[stratify_column] == stratum]
83
+ stratum_samples = max(1, int(n_samples * proportion))
84
+ stratum_sample = stratum_df.sample(n=min(stratum_samples, len(stratum_df)),
85
+ random_state=seed)
86
+ sample = pd.concat([sample, stratum_sample])
87
+
88
+ if self.verbose:
89
+ print(f"Stratified sampling based on '{stratify_column}':")
90
+ for stratum, count in sample[stratify_column].value_counts().items():
91
+ print(f" - {stratum}: {count} samples ({count/n_samples:.1%})")
92
+ else:
93
+ raise ValueError(f"Unknown sampling type: {sampling_type}")
94
+
95
+ # Add empty label column for manual annotation
96
+ if label_column not in sample.columns:
97
+ sample[label_column] = None
98
+
99
+ if self.verbose:
100
+ print(f"Created sample with {len(sample)} rows.")
101
+
102
+ return sample