Spaces:
Running
Running
| """Tests for mosaic_core.core_functions module.""" | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| import pytest | |
| from mosaic_core.core_functions import ( | |
| pick_text_column, | |
| list_text_columns, | |
| slugify, | |
| clean_label, | |
| preprocess_texts, | |
| load_csv_texts, | |
| count_clean_reports, | |
| get_config_hash, | |
| make_run_id, | |
| run_topic_model, | |
| get_topic_labels, | |
| get_outlier_stats, | |
| get_num_topics, | |
| ) | |
| class TestSlugify: | |
| """Filename sanitization.""" | |
| def test_preserves_alphanumeric(self): | |
| assert slugify("MOSAIC") == "MOSAIC" | |
| assert slugify("dataset123") == "dataset123" | |
| def test_replaces_spaces(self): | |
| assert slugify("my dataset") == "my_dataset" | |
| assert slugify("my dataset") == "my_dataset" | |
| def test_replaces_special_chars(self): | |
| assert slugify("data@2024!") == "data_2024_" | |
| assert slugify("path/to/file") == "path_to_file" | |
| def test_preserves_safe_chars(self): | |
| assert slugify("data-set_v1.0") == "data-set_v1.0" | |
| def test_empty_returns_default(self): | |
| assert slugify("") == "DATASET" | |
| assert slugify(" ") == "DATASET" | |
| def test_strips_whitespace(self): | |
| assert slugify(" name ") == "name" | |
| class TestPickTextColumn: | |
| """Auto-detection of text columns.""" | |
| def test_priority_order(self): | |
| df = pd.DataFrame({ | |
| "reflection_answer_english": ["a"], | |
| "text": ["b"], | |
| }) | |
| assert pick_text_column(df) == "reflection_answer_english" | |
| def test_fallback_columns(self): | |
| assert pick_text_column(pd.DataFrame({"text": ["a"]})) == "text" | |
| assert pick_text_column(pd.DataFrame({"report": ["a"]})) == "report" | |
| assert pick_text_column(pd.DataFrame({"reflection_answer": ["a"]})) == "reflection_answer" | |
| def test_returns_none_if_no_match(self): | |
| df = pd.DataFrame({"description": ["a"], "notes": ["b"]}) | |
| assert pick_text_column(df) is None | |
| def test_empty_dataframe(self): | |
| assert pick_text_column(pd.DataFrame()) is None | |
| class TestListTextColumns: | |
| """Column listing.""" | |
| def test_returns_all_columns(self): | |
| df = pd.DataFrame({"a": [1], "b": [2], "c": [3]}) | |
| assert list_text_columns(df) == ["a", "b", "c"] | |
| def test_empty_dataframe(self): | |
| assert list_text_columns(pd.DataFrame()) == [] | |
| class TestCleanLabel: | |
| """LLM output normalization.""" | |
| def test_basic_label(self): | |
| assert clean_label("Visual Patterns") == "Visual Patterns" | |
| def test_strips_whitespace(self): | |
| assert clean_label(" Visual Patterns ") == "Visual Patterns" | |
| def test_removes_quotes(self): | |
| assert clean_label('"Visual Patterns"') == "Visual Patterns" | |
| assert clean_label("'Visual Patterns'") == "Visual Patterns" | |
| assert clean_label("`Visual Patterns`") == "Visual Patterns" | |
| def test_removes_trailing_punctuation(self): | |
| assert clean_label("Visual Patterns.") == "Visual Patterns" | |
| assert clean_label("Visual Patterns:") == "Visual Patterns" | |
| assert clean_label("Visual Patterns—") == "Visual Patterns" | |
| def test_removes_experience_prefix(self): | |
| assert clean_label("Experience of Light") == "Light" | |
| assert clean_label("Subjective Experience of Colors") == "Colors" | |
| assert clean_label("Phenomenon of Sound") == "Sound" | |
| # "Experiential Phenomenon" is matched, leaving "of Motion" | |
| # This is expected behavior - the regex handles common patterns | |
| def test_removes_experience_suffix(self): | |
| assert clean_label("Visual experience") == "Visual" | |
| assert clean_label("Color phenomenon") == "Color" | |
| assert clean_label("Light state") == "Light" | |
| def test_takes_first_line(self): | |
| assert clean_label("Label\nExplanation text") == "Label" | |
| def test_empty_returns_unlabelled(self): | |
| assert clean_label("") == "Unlabelled" | |
| assert clean_label(" ") == "Unlabelled" | |
| assert clean_label(None) == "Unlabelled" | |
| class TestPreprocessTexts: | |
| """Text preprocessing and sentence splitting.""" | |
| def test_sentence_splitting(self): | |
| texts = ["First sentence. Second sentence."] | |
| docs, removed, stats = preprocess_texts(texts, split_sentences=True, min_words=0) | |
| assert len(docs) == 2 | |
| assert stats["total_before"] == 2 | |
| def test_no_splitting(self): | |
| texts = ["First sentence. Second sentence."] | |
| docs, removed, stats = preprocess_texts(texts, split_sentences=False, min_words=0) | |
| assert len(docs) == 1 | |
| def test_min_words_filter(self): | |
| texts = ["This is long enough.", "Short."] | |
| docs, removed, stats = preprocess_texts(texts, split_sentences=False, min_words=3) | |
| assert len(docs) == 1 | |
| assert len(removed) == 1 | |
| assert stats["removed_count"] == 1 | |
| def test_stats_accuracy(self): | |
| texts = ["One sentence. Another sentence.", "Third sentence here."] | |
| docs, removed, stats = preprocess_texts(texts, split_sentences=True, min_words=2) | |
| assert stats["total_before"] == 3 # NLTK splits into 3 sentences | |
| assert stats["total_after"] == len(docs) | |
| assert stats["removed_count"] == len(removed) | |
| class TestLoadCSVTexts: | |
| """CSV loading.""" | |
| def test_loads_texts(self, sample_csv): | |
| texts = load_csv_texts(sample_csv, text_col="report") | |
| assert len(texts) > 0 | |
| def test_auto_detects_column(self, sample_csv): | |
| texts = load_csv_texts(sample_csv) | |
| assert len(texts) > 0 | |
| def test_raises_on_missing_column(self, sample_csv): | |
| with pytest.raises(ValueError, match="No valid text column"): | |
| load_csv_texts(sample_csv, text_col="nonexistent") | |
| def test_filters_empty_rows(self): | |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) as f: | |
| f.write("text\n") | |
| f.write("Valid text\n") | |
| f.write("\n") | |
| f.write(" \n") | |
| f.write("Another valid\n") | |
| path = f.name | |
| try: | |
| texts = load_csv_texts(path) | |
| assert len(texts) == 2 | |
| finally: | |
| os.unlink(path) | |
| class TestCountCleanReports: | |
| """Report counting.""" | |
| def test_counts_correctly(self, sample_csv): | |
| assert count_clean_reports(sample_csv, "report") > 0 | |
| def test_returns_zero_on_error(self): | |
| assert count_clean_reports("/nonexistent/path.csv") == 0 | |
| class TestConfigUtils: | |
| """Config hashing and run IDs.""" | |
| def test_hash_is_deterministic(self): | |
| cfg = {"a": 1, "b": 2} | |
| assert get_config_hash(cfg) == get_config_hash(cfg) | |
| def test_hash_ignores_key_order(self): | |
| cfg1 = {"a": 1, "b": 2} | |
| cfg2 = {"b": 2, "a": 1} | |
| assert get_config_hash(cfg1) == get_config_hash(cfg2) | |
| def test_run_id_contains_hash(self): | |
| cfg = {"a": 1} | |
| run_id = make_run_id(cfg) | |
| h = get_config_hash(cfg) | |
| assert h in run_id | |
| class TestRunTopicModel: | |
| """BERTopic fitting.""" | |
| def test_returns_expected_types(self, larger_corpus, larger_embeddings, topic_config): | |
| model, reduced, topics = run_topic_model( | |
| larger_corpus, larger_embeddings, topic_config | |
| ) | |
| assert hasattr(model, "get_topic_info") | |
| assert reduced.shape == (len(larger_corpus), 2) | |
| assert len(topics) == len(larger_corpus) | |
| def test_reduced_is_2d(self, larger_corpus, larger_embeddings, topic_config): | |
| _, reduced, _ = run_topic_model(larger_corpus, larger_embeddings, topic_config) | |
| assert reduced.ndim == 2 | |
| assert reduced.shape[1] == 2 | |
| def test_topics_are_integers(self, larger_corpus, larger_embeddings, topic_config): | |
| _, _, topics = run_topic_model(larger_corpus, larger_embeddings, topic_config) | |
| assert all(isinstance(t, (int, np.integer)) for t in topics) | |
| class TestGetTopicLabels: | |
| """Topic label extraction.""" | |
| def test_returns_labels_for_all_docs(self, larger_corpus, larger_embeddings, topic_config): | |
| model, _, topics = run_topic_model(larger_corpus, larger_embeddings, topic_config) | |
| labels = get_topic_labels(model, topics) | |
| assert len(labels) == len(larger_corpus) | |
| def test_labels_are_strings(self, larger_corpus, larger_embeddings, topic_config): | |
| model, _, topics = run_topic_model(larger_corpus, larger_embeddings, topic_config) | |
| labels = get_topic_labels(model, topics) | |
| assert all(isinstance(lbl, str) for lbl in labels) | |
| class TestOutlierStats: | |
| """Outlier statistics.""" | |
| def test_returns_count_and_percentage(self, larger_corpus, larger_embeddings, topic_config): | |
| model, _, _ = run_topic_model(larger_corpus, larger_embeddings, topic_config) | |
| count, pct = get_outlier_stats(model) | |
| assert isinstance(count, int) | |
| assert isinstance(pct, float) | |
| assert 0 <= pct <= 100 | |
| def test_num_topics(self, larger_corpus, larger_embeddings, topic_config): | |
| model, _, _ = run_topic_model(larger_corpus, larger_embeddings, topic_config) | |
| n = get_num_topics(model) | |
| assert isinstance(n, int) | |
| assert n >= 0 | |
| class TestEmbeddingShapeValidation: | |
| """Embedding consistency checks.""" | |
| def test_shape_matches_docs(self, sample_texts, sample_embeddings): | |
| assert sample_embeddings.shape[0] == len(sample_texts) | |
| def test_dtype_is_float32(self, sample_embeddings): | |
| assert sample_embeddings.dtype == np.float32 | |
| class TestLabelsCachePath: | |
| """Label cache path generation.""" | |
| def test_returns_path_object(self): | |
| from mosaic_core.core_functions import labels_cache_path | |
| from pathlib import Path | |
| p = labels_cache_path("/tmp", "abc123", "meta-llama/Llama-3") | |
| assert isinstance(p, Path) | |
| def test_sanitizes_model_id(self): | |
| from mosaic_core.core_functions import labels_cache_path | |
| p = labels_cache_path("/tmp", "hash", "org/model-name") | |
| assert "/" not in p.name | |
| class TestLabelsCacheIO: | |
| """Label cache read/write.""" | |
| def test_save_and_load(self): | |
| from mosaic_core.core_functions import save_labels_cache, load_cached_labels | |
| with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: | |
| path = f.name | |
| try: | |
| labels = {0: "Topic A", 1: "Topic B"} | |
| save_labels_cache(path, labels) | |
| loaded = load_cached_labels(path) | |
| assert loaded == labels | |
| finally: | |
| os.unlink(path) | |
| def test_load_returns_none_on_missing(self): | |
| from mosaic_core.core_functions import load_cached_labels | |
| result = load_cached_labels("/nonexistent/path.json") | |
| assert result is None | |
| class TestCleanupOldCache: | |
| """Cache cleanup.""" | |
| def test_removes_non_matching_files(self): | |
| from mosaic_core.core_functions import cleanup_old_cache | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| # Create some fake cache files | |
| (Path(tmpdir) / "precomputed_OLD_docs.npy").touch() | |
| (Path(tmpdir) / "precomputed_OLD_emb.npy").touch() | |
| (Path(tmpdir) / "precomputed_CURRENT_docs.npy").touch() | |
| removed = cleanup_old_cache(tmpdir, "CURRENT") | |
| assert removed == 2 | |
| assert (Path(tmpdir) / "precomputed_CURRENT_docs.npy").exists() | |
| assert not (Path(tmpdir) / "precomputed_OLD_docs.npy").exists() | |
| def test_handles_missing_dir(self): | |
| from mosaic_core.core_functions import cleanup_old_cache | |
| result = cleanup_old_cache("/nonexistent/dir", "test") | |
| assert result == 0 | |
| class TestResolveDevice: | |
| """Device resolution.""" | |
| def test_cpu_explicit(self): | |
| from mosaic_core.core_functions import resolve_device | |
| device, batch = resolve_device("cpu") | |
| assert device == "cpu" | |
| assert batch == 64 | |
| def test_cpu_uppercase(self): | |
| from mosaic_core.core_functions import resolve_device | |
| device, _ = resolve_device("CPU") | |
| assert device == "cpu" |