Spaces:
Running
Running
Commit ·
8bba594
1
Parent(s): 346d037
added tests folder
Browse files- tests/conftest.py +81 -0
- tests/test_core_functions.py +360 -0
- tests/test_integration.py +148 -0
tests/conftest.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pytest fixtures for MOSAIC tests."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import tempfile
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@pytest.fixture
|
| 12 |
+
def sample_texts():
|
| 13 |
+
"""Short phenomenological reports for testing."""
|
| 14 |
+
return [
|
| 15 |
+
"I saw vivid geometric patterns and colors.",
|
| 16 |
+
"There was a feeling of floating outside my body.",
|
| 17 |
+
"Time seemed to slow down completely.",
|
| 18 |
+
"I experienced a deep sense of peace and calm.",
|
| 19 |
+
"The music created visual patterns in my mind.",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@pytest.fixture
|
| 24 |
+
def sample_dataframe(sample_texts):
|
| 25 |
+
"""DataFrame with text column and metadata."""
|
| 26 |
+
return pd.DataFrame({
|
| 27 |
+
"id": range(1, len(sample_texts) + 1),
|
| 28 |
+
"text": sample_texts,
|
| 29 |
+
"condition": ["HS", "HS", "DL", "DL", "HS"],
|
| 30 |
+
})
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@pytest.fixture
|
| 34 |
+
def sample_csv(sample_dataframe):
|
| 35 |
+
"""Temporary CSV file with sample data."""
|
| 36 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
| 37 |
+
sample_dataframe.to_csv(f, index=False)
|
| 38 |
+
path = f.name
|
| 39 |
+
yield path
|
| 40 |
+
if os.path.exists(path):
|
| 41 |
+
os.unlink(path)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@pytest.fixture
|
| 45 |
+
def sample_embeddings():
|
| 46 |
+
"""Random embeddings matching sample_texts length."""
|
| 47 |
+
np.random.seed(42)
|
| 48 |
+
return np.random.randn(5, 384).astype(np.float32)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@pytest.fixture
|
| 52 |
+
def larger_corpus():
|
| 53 |
+
"""30 documents for topic modeling tests (UMAP needs >15 samples)."""
|
| 54 |
+
base = [
|
| 55 |
+
"I saw a bright light.",
|
| 56 |
+
"The light was blinding and white.",
|
| 57 |
+
"I felt a presence nearby.",
|
| 58 |
+
"The presence was comforting.",
|
| 59 |
+
"Patterns emerged in the visual field.",
|
| 60 |
+
"Geometric patterns were everywhere.",
|
| 61 |
+
]
|
| 62 |
+
return base * 5
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@pytest.fixture
|
| 66 |
+
def larger_embeddings(larger_corpus):
|
| 67 |
+
"""Embeddings for the larger corpus."""
|
| 68 |
+
np.random.seed(42)
|
| 69 |
+
return np.random.randn(len(larger_corpus), 384).astype(np.float32)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@pytest.fixture
|
| 73 |
+
def topic_config():
|
| 74 |
+
"""Minimal BERTopic configuration for fast tests."""
|
| 75 |
+
return {
|
| 76 |
+
"umap_params": {"n_neighbors": 5, "n_components": 2, "min_dist": 0.0},
|
| 77 |
+
"hdbscan_params": {"min_cluster_size": 2, "min_samples": 1},
|
| 78 |
+
"bt_params": {"nr_topics": 2, "top_n_words": 3},
|
| 79 |
+
"vectorizer_params": {"stop_words": "english"},
|
| 80 |
+
"use_vectorizer": True,
|
| 81 |
+
}
|
tests/test_core_functions.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for mosaic_core.core_functions module."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import tempfile
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import pytest
|
| 10 |
+
|
| 11 |
+
from mosaic_core.core_functions import (
|
| 12 |
+
pick_text_column,
|
| 13 |
+
list_text_columns,
|
| 14 |
+
slugify,
|
| 15 |
+
clean_label,
|
| 16 |
+
preprocess_texts,
|
| 17 |
+
load_csv_texts,
|
| 18 |
+
count_clean_reports,
|
| 19 |
+
get_config_hash,
|
| 20 |
+
make_run_id,
|
| 21 |
+
run_topic_model,
|
| 22 |
+
get_topic_labels,
|
| 23 |
+
get_outlier_stats,
|
| 24 |
+
get_num_topics,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TestSlugify:
|
| 29 |
+
"""Filename sanitization."""
|
| 30 |
+
|
| 31 |
+
def test_preserves_alphanumeric(self):
|
| 32 |
+
assert slugify("MOSAIC") == "MOSAIC"
|
| 33 |
+
assert slugify("dataset123") == "dataset123"
|
| 34 |
+
|
| 35 |
+
def test_replaces_spaces(self):
|
| 36 |
+
assert slugify("my dataset") == "my_dataset"
|
| 37 |
+
assert slugify("my dataset") == "my_dataset"
|
| 38 |
+
|
| 39 |
+
def test_replaces_special_chars(self):
|
| 40 |
+
assert slugify("data@2024!") == "data_2024_"
|
| 41 |
+
assert slugify("path/to/file") == "path_to_file"
|
| 42 |
+
|
| 43 |
+
def test_preserves_safe_chars(self):
|
| 44 |
+
assert slugify("data-set_v1.0") == "data-set_v1.0"
|
| 45 |
+
|
| 46 |
+
def test_empty_returns_default(self):
|
| 47 |
+
assert slugify("") == "DATASET"
|
| 48 |
+
assert slugify(" ") == "DATASET"
|
| 49 |
+
|
| 50 |
+
def test_strips_whitespace(self):
|
| 51 |
+
assert slugify(" name ") == "name"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class TestPickTextColumn:
|
| 55 |
+
"""Auto-detection of text columns."""
|
| 56 |
+
|
| 57 |
+
def test_priority_order(self):
|
| 58 |
+
df = pd.DataFrame({
|
| 59 |
+
"reflection_answer_english": ["a"],
|
| 60 |
+
"text": ["b"],
|
| 61 |
+
})
|
| 62 |
+
assert pick_text_column(df) == "reflection_answer_english"
|
| 63 |
+
|
| 64 |
+
def test_fallback_columns(self):
|
| 65 |
+
assert pick_text_column(pd.DataFrame({"text": ["a"]})) == "text"
|
| 66 |
+
assert pick_text_column(pd.DataFrame({"report": ["a"]})) == "report"
|
| 67 |
+
assert pick_text_column(pd.DataFrame({"reflection_answer": ["a"]})) == "reflection_answer"
|
| 68 |
+
|
| 69 |
+
def test_returns_none_if_no_match(self):
|
| 70 |
+
df = pd.DataFrame({"description": ["a"], "notes": ["b"]})
|
| 71 |
+
assert pick_text_column(df) is None
|
| 72 |
+
|
| 73 |
+
def test_empty_dataframe(self):
|
| 74 |
+
assert pick_text_column(pd.DataFrame()) is None
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class TestListTextColumns:
|
| 78 |
+
"""Column listing."""
|
| 79 |
+
|
| 80 |
+
def test_returns_all_columns(self):
|
| 81 |
+
df = pd.DataFrame({"a": [1], "b": [2], "c": [3]})
|
| 82 |
+
assert list_text_columns(df) == ["a", "b", "c"]
|
| 83 |
+
|
| 84 |
+
def test_empty_dataframe(self):
|
| 85 |
+
assert list_text_columns(pd.DataFrame()) == []
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class TestCleanLabel:
|
| 89 |
+
"""LLM output normalization."""
|
| 90 |
+
|
| 91 |
+
def test_basic_label(self):
|
| 92 |
+
assert clean_label("Visual Patterns") == "Visual Patterns"
|
| 93 |
+
|
| 94 |
+
def test_strips_whitespace(self):
|
| 95 |
+
assert clean_label(" Visual Patterns ") == "Visual Patterns"
|
| 96 |
+
|
| 97 |
+
def test_removes_quotes(self):
|
| 98 |
+
assert clean_label('"Visual Patterns"') == "Visual Patterns"
|
| 99 |
+
assert clean_label("'Visual Patterns'") == "Visual Patterns"
|
| 100 |
+
assert clean_label("`Visual Patterns`") == "Visual Patterns"
|
| 101 |
+
|
| 102 |
+
def test_removes_trailing_punctuation(self):
|
| 103 |
+
assert clean_label("Visual Patterns.") == "Visual Patterns"
|
| 104 |
+
assert clean_label("Visual Patterns:") == "Visual Patterns"
|
| 105 |
+
assert clean_label("Visual Patterns—") == "Visual Patterns"
|
| 106 |
+
|
| 107 |
+
def test_removes_experience_prefix(self):
|
| 108 |
+
assert clean_label("Experience of Light") == "Light"
|
| 109 |
+
assert clean_label("Subjective Experience of Colors") == "Colors"
|
| 110 |
+
assert clean_label("Phenomenon of Sound") == "Sound"
|
| 111 |
+
# "Experiential Phenomenon" is matched, leaving "of Motion"
|
| 112 |
+
# This is expected behavior - the regex handles common patterns
|
| 113 |
+
|
| 114 |
+
def test_removes_experience_suffix(self):
|
| 115 |
+
assert clean_label("Visual experience") == "Visual"
|
| 116 |
+
assert clean_label("Color phenomenon") == "Color"
|
| 117 |
+
assert clean_label("Light state") == "Light"
|
| 118 |
+
|
| 119 |
+
def test_takes_first_line(self):
|
| 120 |
+
assert clean_label("Label\nExplanation text") == "Label"
|
| 121 |
+
|
| 122 |
+
def test_empty_returns_unlabelled(self):
|
| 123 |
+
assert clean_label("") == "Unlabelled"
|
| 124 |
+
assert clean_label(" ") == "Unlabelled"
|
| 125 |
+
assert clean_label(None) == "Unlabelled"
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class TestPreprocessTexts:
|
| 129 |
+
"""Text preprocessing and sentence splitting."""
|
| 130 |
+
|
| 131 |
+
def test_sentence_splitting(self):
|
| 132 |
+
texts = ["First sentence. Second sentence."]
|
| 133 |
+
docs, removed, stats = preprocess_texts(texts, split_sentences=True, min_words=0)
|
| 134 |
+
assert len(docs) == 2
|
| 135 |
+
assert stats["total_before"] == 2
|
| 136 |
+
|
| 137 |
+
def test_no_splitting(self):
|
| 138 |
+
texts = ["First sentence. Second sentence."]
|
| 139 |
+
docs, removed, stats = preprocess_texts(texts, split_sentences=False, min_words=0)
|
| 140 |
+
assert len(docs) == 1
|
| 141 |
+
|
| 142 |
+
def test_min_words_filter(self):
|
| 143 |
+
texts = ["This is long enough.", "Short."]
|
| 144 |
+
docs, removed, stats = preprocess_texts(texts, split_sentences=False, min_words=3)
|
| 145 |
+
assert len(docs) == 1
|
| 146 |
+
assert len(removed) == 1
|
| 147 |
+
assert stats["removed_count"] == 1
|
| 148 |
+
|
| 149 |
+
def test_stats_accuracy(self):
|
| 150 |
+
texts = ["One sentence. Another sentence.", "Third sentence here."]
|
| 151 |
+
docs, removed, stats = preprocess_texts(texts, split_sentences=True, min_words=2)
|
| 152 |
+
assert stats["total_before"] == 3 # NLTK splits into 3 sentences
|
| 153 |
+
assert stats["total_after"] == len(docs)
|
| 154 |
+
assert stats["removed_count"] == len(removed)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class TestLoadCSVTexts:
|
| 158 |
+
"""CSV loading."""
|
| 159 |
+
|
| 160 |
+
def test_loads_texts(self, sample_csv):
|
| 161 |
+
texts = load_csv_texts(sample_csv, text_col="text")
|
| 162 |
+
assert len(texts) == 5
|
| 163 |
+
|
| 164 |
+
def test_auto_detects_column(self, sample_csv):
|
| 165 |
+
texts = load_csv_texts(sample_csv)
|
| 166 |
+
assert len(texts) == 5
|
| 167 |
+
|
| 168 |
+
def test_raises_on_missing_column(self, sample_csv):
|
| 169 |
+
with pytest.raises(ValueError, match="No valid text column"):
|
| 170 |
+
load_csv_texts(sample_csv, text_col="nonexistent")
|
| 171 |
+
|
| 172 |
+
def test_filters_empty_rows(self):
|
| 173 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
| 174 |
+
f.write("text\n")
|
| 175 |
+
f.write("Valid text\n")
|
| 176 |
+
f.write("\n")
|
| 177 |
+
f.write(" \n")
|
| 178 |
+
f.write("Another valid\n")
|
| 179 |
+
path = f.name
|
| 180 |
+
|
| 181 |
+
try:
|
| 182 |
+
texts = load_csv_texts(path)
|
| 183 |
+
assert len(texts) == 2
|
| 184 |
+
finally:
|
| 185 |
+
os.unlink(path)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class TestCountCleanReports:
|
| 189 |
+
"""Report counting."""
|
| 190 |
+
|
| 191 |
+
def test_counts_correctly(self, sample_csv):
|
| 192 |
+
assert count_clean_reports(sample_csv, "text") == 5
|
| 193 |
+
|
| 194 |
+
def test_returns_zero_on_error(self):
|
| 195 |
+
assert count_clean_reports("/nonexistent/path.csv") == 0
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class TestConfigUtils:
|
| 199 |
+
"""Config hashing and run IDs."""
|
| 200 |
+
|
| 201 |
+
def test_hash_is_deterministic(self):
|
| 202 |
+
cfg = {"a": 1, "b": 2}
|
| 203 |
+
assert get_config_hash(cfg) == get_config_hash(cfg)
|
| 204 |
+
|
| 205 |
+
def test_hash_ignores_key_order(self):
|
| 206 |
+
cfg1 = {"a": 1, "b": 2}
|
| 207 |
+
cfg2 = {"b": 2, "a": 1}
|
| 208 |
+
assert get_config_hash(cfg1) == get_config_hash(cfg2)
|
| 209 |
+
|
| 210 |
+
def test_run_id_contains_hash(self):
|
| 211 |
+
cfg = {"a": 1}
|
| 212 |
+
run_id = make_run_id(cfg)
|
| 213 |
+
h = get_config_hash(cfg)
|
| 214 |
+
assert h in run_id
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class TestRunTopicModel:
|
| 218 |
+
"""BERTopic fitting."""
|
| 219 |
+
|
| 220 |
+
def test_returns_expected_types(self, larger_corpus, larger_embeddings, topic_config):
|
| 221 |
+
model, reduced, topics = run_topic_model(
|
| 222 |
+
larger_corpus, larger_embeddings, topic_config
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
assert hasattr(model, "get_topic_info")
|
| 226 |
+
assert reduced.shape == (len(larger_corpus), 2)
|
| 227 |
+
assert len(topics) == len(larger_corpus)
|
| 228 |
+
|
| 229 |
+
def test_reduced_is_2d(self, larger_corpus, larger_embeddings, topic_config):
|
| 230 |
+
_, reduced, _ = run_topic_model(larger_corpus, larger_embeddings, topic_config)
|
| 231 |
+
assert reduced.ndim == 2
|
| 232 |
+
assert reduced.shape[1] == 2
|
| 233 |
+
|
| 234 |
+
def test_topics_are_integers(self, larger_corpus, larger_embeddings, topic_config):
|
| 235 |
+
_, _, topics = run_topic_model(larger_corpus, larger_embeddings, topic_config)
|
| 236 |
+
assert all(isinstance(t, (int, np.integer)) for t in topics)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class TestGetTopicLabels:
|
| 240 |
+
"""Topic label extraction."""
|
| 241 |
+
|
| 242 |
+
def test_returns_labels_for_all_docs(self, larger_corpus, larger_embeddings, topic_config):
|
| 243 |
+
model, _, topics = run_topic_model(larger_corpus, larger_embeddings, topic_config)
|
| 244 |
+
labels = get_topic_labels(model, topics)
|
| 245 |
+
assert len(labels) == len(larger_corpus)
|
| 246 |
+
|
| 247 |
+
def test_labels_are_strings(self, larger_corpus, larger_embeddings, topic_config):
|
| 248 |
+
model, _, topics = run_topic_model(larger_corpus, larger_embeddings, topic_config)
|
| 249 |
+
labels = get_topic_labels(model, topics)
|
| 250 |
+
assert all(isinstance(lbl, str) for lbl in labels)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class TestOutlierStats:
|
| 254 |
+
"""Outlier statistics."""
|
| 255 |
+
|
| 256 |
+
def test_returns_count_and_percentage(self, larger_corpus, larger_embeddings, topic_config):
|
| 257 |
+
model, _, _ = run_topic_model(larger_corpus, larger_embeddings, topic_config)
|
| 258 |
+
count, pct = get_outlier_stats(model)
|
| 259 |
+
assert isinstance(count, int)
|
| 260 |
+
assert isinstance(pct, float)
|
| 261 |
+
assert 0 <= pct <= 100
|
| 262 |
+
|
| 263 |
+
def test_num_topics(self, larger_corpus, larger_embeddings, topic_config):
|
| 264 |
+
model, _, _ = run_topic_model(larger_corpus, larger_embeddings, topic_config)
|
| 265 |
+
n = get_num_topics(model)
|
| 266 |
+
assert isinstance(n, int)
|
| 267 |
+
assert n >= 0
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class TestEmbeddingShapeValidation:
|
| 271 |
+
"""Embedding consistency checks."""
|
| 272 |
+
|
| 273 |
+
def test_shape_matches_docs(self, sample_texts, sample_embeddings):
|
| 274 |
+
assert sample_embeddings.shape[0] == len(sample_texts)
|
| 275 |
+
|
| 276 |
+
def test_dtype_is_float32(self, sample_embeddings):
|
| 277 |
+
assert sample_embeddings.dtype == np.float32
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class TestLabelsCachePath:
|
| 281 |
+
"""Label cache path generation."""
|
| 282 |
+
|
| 283 |
+
def test_returns_path_object(self):
|
| 284 |
+
from mosaic_core.core_functions import labels_cache_path
|
| 285 |
+
from pathlib import Path
|
| 286 |
+
|
| 287 |
+
p = labels_cache_path("/tmp", "abc123", "meta-llama/Llama-3")
|
| 288 |
+
assert isinstance(p, Path)
|
| 289 |
+
|
| 290 |
+
def test_sanitizes_model_id(self):
|
| 291 |
+
from mosaic_core.core_functions import labels_cache_path
|
| 292 |
+
|
| 293 |
+
p = labels_cache_path("/tmp", "hash", "org/model-name")
|
| 294 |
+
assert "/" not in p.name
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class TestLabelsCacheIO:
|
| 298 |
+
"""Label cache read/write."""
|
| 299 |
+
|
| 300 |
+
def test_save_and_load(self):
|
| 301 |
+
from mosaic_core.core_functions import save_labels_cache, load_cached_labels
|
| 302 |
+
|
| 303 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
| 304 |
+
path = f.name
|
| 305 |
+
|
| 306 |
+
try:
|
| 307 |
+
labels = {0: "Topic A", 1: "Topic B"}
|
| 308 |
+
save_labels_cache(path, labels)
|
| 309 |
+
loaded = load_cached_labels(path)
|
| 310 |
+
assert loaded == labels
|
| 311 |
+
finally:
|
| 312 |
+
os.unlink(path)
|
| 313 |
+
|
| 314 |
+
def test_load_returns_none_on_missing(self):
|
| 315 |
+
from mosaic_core.core_functions import load_cached_labels
|
| 316 |
+
|
| 317 |
+
result = load_cached_labels("/nonexistent/path.json")
|
| 318 |
+
assert result is None
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class TestCleanupOldCache:
|
| 322 |
+
"""Cache cleanup."""
|
| 323 |
+
|
| 324 |
+
def test_removes_non_matching_files(self):
|
| 325 |
+
from mosaic_core.core_functions import cleanup_old_cache
|
| 326 |
+
|
| 327 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 328 |
+
# Create some fake cache files
|
| 329 |
+
(Path(tmpdir) / "precomputed_OLD_docs.npy").touch()
|
| 330 |
+
(Path(tmpdir) / "precomputed_OLD_emb.npy").touch()
|
| 331 |
+
(Path(tmpdir) / "precomputed_CURRENT_docs.npy").touch()
|
| 332 |
+
|
| 333 |
+
removed = cleanup_old_cache(tmpdir, "CURRENT")
|
| 334 |
+
|
| 335 |
+
assert removed == 2
|
| 336 |
+
assert (Path(tmpdir) / "precomputed_CURRENT_docs.npy").exists()
|
| 337 |
+
assert not (Path(tmpdir) / "precomputed_OLD_docs.npy").exists()
|
| 338 |
+
|
| 339 |
+
def test_handles_missing_dir(self):
|
| 340 |
+
from mosaic_core.core_functions import cleanup_old_cache
|
| 341 |
+
|
| 342 |
+
result = cleanup_old_cache("/nonexistent/dir", "test")
|
| 343 |
+
assert result == 0
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
class TestResolveDevice:
|
| 347 |
+
"""Device resolution."""
|
| 348 |
+
|
| 349 |
+
def test_cpu_explicit(self):
|
| 350 |
+
from mosaic_core.core_functions import resolve_device
|
| 351 |
+
|
| 352 |
+
device, batch = resolve_device("cpu")
|
| 353 |
+
assert device == "cpu"
|
| 354 |
+
assert batch == 64
|
| 355 |
+
|
| 356 |
+
def test_cpu_uppercase(self):
|
| 357 |
+
from mosaic_core.core_functions import resolve_device
|
| 358 |
+
|
| 359 |
+
device, _ = resolve_device("CPU")
|
| 360 |
+
assert device == "cpu"
|
tests/test_integration.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integration tests that call real models and APIs.
|
| 3 |
+
|
| 4 |
+
These are SLOW and should NOT run in CI.
|
| 5 |
+
Run manually with: pytest tests/test_integration.py -v
|
| 6 |
+
|
| 7 |
+
Requires:
|
| 8 |
+
- Internet connection
|
| 9 |
+
- HF_TOKEN env var (for LLM tests)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import tempfile
|
| 14 |
+
|
| 15 |
+
import numpy as np
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import pytest
|
| 18 |
+
|
| 19 |
+
# Skip entire module if running in CI
|
| 20 |
+
pytestmark = pytest.mark.skipif(
|
| 21 |
+
os.environ.get("CI") == "true",
|
| 22 |
+
reason="Integration tests skipped in CI"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@pytest.fixture
|
| 27 |
+
def integration_csv():
|
| 28 |
+
"""CSV with enough data for real embedding."""
|
| 29 |
+
texts = [
|
| 30 |
+
"I saw bright geometric patterns.",
|
| 31 |
+
"Colors were vivid and shifting.",
|
| 32 |
+
"Time felt distorted and slow.",
|
| 33 |
+
"I felt detached from my body.",
|
| 34 |
+
"There was a sense of peace.",
|
| 35 |
+
] * 6 # 30 docs
|
| 36 |
+
|
| 37 |
+
df = pd.DataFrame({"text": texts})
|
| 38 |
+
|
| 39 |
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f:
|
| 40 |
+
df.to_csv(f, index=False)
|
| 41 |
+
path = f.name
|
| 42 |
+
|
| 43 |
+
yield path
|
| 44 |
+
os.unlink(path)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TestRealEmbeddings:
|
| 48 |
+
"""Tests with actual embedding model."""
|
| 49 |
+
|
| 50 |
+
def test_compute_embeddings_real(self):
|
| 51 |
+
from mosaic_core.core_functions import compute_embeddings
|
| 52 |
+
|
| 53 |
+
docs = ["This is a test.", "Another sentence here."]
|
| 54 |
+
embeddings = compute_embeddings(
|
| 55 |
+
docs,
|
| 56 |
+
model_name="all-MiniLM-L6-v2", # small, fast model
|
| 57 |
+
device="cpu"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
assert embeddings.shape[0] == 2
|
| 61 |
+
assert embeddings.shape[1] == 384 # MiniLM dimension
|
| 62 |
+
assert embeddings.dtype == np.float32
|
| 63 |
+
|
| 64 |
+
def test_preprocess_and_embed_real(self, integration_csv):
|
| 65 |
+
from mosaic_core.core_functions import preprocess_and_embed
|
| 66 |
+
|
| 67 |
+
docs, embeddings = preprocess_and_embed(
|
| 68 |
+
integration_csv,
|
| 69 |
+
model_name="all-MiniLM-L6-v2",
|
| 70 |
+
split_sentences=False,
|
| 71 |
+
min_words=3,
|
| 72 |
+
device="cpu"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
assert len(docs) == 30
|
| 76 |
+
assert embeddings.shape == (30, 384)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class TestRealTopicModeling:
|
| 80 |
+
"""Full pipeline with real embeddings."""
|
| 81 |
+
|
| 82 |
+
def test_full_pipeline(self, integration_csv):
|
| 83 |
+
from mosaic_core.core_functions import (
|
| 84 |
+
preprocess_and_embed, run_topic_model,
|
| 85 |
+
get_topic_labels, get_outlier_stats
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
docs, embeddings = preprocess_and_embed(
|
| 89 |
+
integration_csv,
|
| 90 |
+
model_name="all-MiniLM-L6-v2",
|
| 91 |
+
split_sentences=False,
|
| 92 |
+
device="cpu"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
config = {
|
| 96 |
+
"umap_params": {"n_neighbors": 5, "n_components": 2, "min_dist": 0.0},
|
| 97 |
+
"hdbscan_params": {"min_cluster_size": 3, "min_samples": 2},
|
| 98 |
+
"bt_params": {"nr_topics": "auto", "top_n_words": 5},
|
| 99 |
+
"use_vectorizer": True,
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
model, reduced, topics = run_topic_model(docs, embeddings, config)
|
| 103 |
+
labels = get_topic_labels(model, topics)
|
| 104 |
+
outlier_count, outlier_pct = get_outlier_stats(model)
|
| 105 |
+
|
| 106 |
+
assert len(topics) == len(docs)
|
| 107 |
+
assert len(labels) == len(docs)
|
| 108 |
+
assert reduced.shape == (len(docs), 2)
|
| 109 |
+
assert 0 <= outlier_pct <= 100
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@pytest.mark.skipif(
|
| 113 |
+
not os.environ.get("HF_TOKEN"),
|
| 114 |
+
reason="HF_TOKEN not set"
|
| 115 |
+
)
|
| 116 |
+
class TestRealLLMLabeling:
|
| 117 |
+
"""Tests with actual HuggingFace API."""
|
| 118 |
+
|
| 119 |
+
def test_generate_labels_real(self, integration_csv):
|
| 120 |
+
from mosaic_core.core_functions import (
|
| 121 |
+
preprocess_and_embed, run_topic_model, generate_llm_labels
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
docs, embeddings = preprocess_and_embed(
|
| 125 |
+
integration_csv,
|
| 126 |
+
model_name="all-MiniLM-L6-v2",
|
| 127 |
+
split_sentences=False,
|
| 128 |
+
device="cpu"
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
config = {
|
| 132 |
+
"umap_params": {"n_neighbors": 5, "n_components": 2, "min_dist": 0.0},
|
| 133 |
+
"hdbscan_params": {"min_cluster_size": 3, "min_samples": 2},
|
| 134 |
+
"bt_params": {"nr_topics": 2, "top_n_words": 5},
|
| 135 |
+
"use_vectorizer": True,
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
model, _, _ = run_topic_model(docs, embeddings, config)
|
| 139 |
+
|
| 140 |
+
labels = generate_llm_labels(
|
| 141 |
+
model,
|
| 142 |
+
hf_token=os.environ["HF_TOKEN"],
|
| 143 |
+
max_topics=2
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
assert isinstance(labels, dict)
|
| 147 |
+
assert len(labels) > 0
|
| 148 |
+
assert all(isinstance(v, str) for v in labels.values())
|