Spaces:
Running
Running
| """ | |
| tests/test_phase3_localisation.py | |
| ββββββββββββββββββββββββββββββββββ | |
| Unit tests for Phase 3: BM25, RRF fusion, DeBERTa ranker, and pipeline. | |
| All tests work without OpenAI API key or GPU β components degrade gracefully. | |
| Run with: pytest tests/test_phase3_localisation.py -v | |
| """ | |
| from __future__ import annotations | |
| import textwrap | |
| from pathlib import Path | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def make_file_symbols(file_path: str, summary: str = ""): | |
| """Create a minimal FileSymbols for testing.""" | |
| from ast_parser.python_parser import FileSymbols, FunctionInfo | |
| fs = FileSymbols(file_path=file_path, file_hash="aaa111") | |
| fs.module_docstring = summary | |
| # Also add a fake function whose name contains the summary words | |
| # so summary_text is fully populated | |
| if summary: | |
| fs.functions = [ | |
| FunctionInfo( | |
| name=summary.split()[0] if summary.split() else "placeholder", | |
| qualified_name=summary.split()[0] if summary.split() else "placeholder", | |
| args=[], decorators=[], docstring=summary, | |
| start_line=1, end_line=5, | |
| ) | |
| ] | |
| return fs | |
| # ββ BM25 tokeniser ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestTokeniser: | |
| def test_lowercase(self): | |
| from localisation.bm25_retriever import _tokenise | |
| result = _tokenise("Hello World") | |
| assert all(t == t.lower() for t in result) | |
| def test_camel_case_split(self): | |
| from localisation.bm25_retriever import _tokenise | |
| result = _tokenise("QuerySet") | |
| assert "query" in result | |
| assert "set" in result | |
| def test_snake_case_split(self): | |
| from localisation.bm25_retriever import _tokenise | |
| result = _tokenise("get_queryset") | |
| assert "get" in result | |
| assert "queryset" in result | |
| def test_short_tokens_filtered(self): | |
| from localisation.bm25_retriever import _tokenise | |
| result = _tokenise("a b c def") | |
| assert "a" not in result | |
| assert "b" not in result | |
| assert "def" in result | |
| def test_path_tokenisation(self): | |
| from localisation.bm25_retriever import _tokenise | |
| result = _tokenise("django/db/models/query.py") | |
| assert "django" in result | |
| assert "models" in result | |
| assert "query" in result | |
| # ββ BM25 Retriever ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestBM25Retriever: | |
| def test_index_and_query_basic(self): | |
| from localisation.bm25_retriever import BM25Retriever | |
| retriever = BM25Retriever() | |
| symbols = [ | |
| make_file_symbols("models/user.py", "User authentication login password"), | |
| make_file_symbols("views/dashboard.py", "Dashboard render template"), | |
| make_file_symbols("utils/email.py", "Email sending SMTP"), | |
| ] | |
| retriever.index(symbols) | |
| hits = retriever.query("user authentication login", top_k=3) | |
| assert len(hits) >= 1 | |
| assert hits[0].file_path == "models/user.py" | |
| def test_query_returns_positive_scores_only(self): | |
| from localisation.bm25_retriever import BM25Retriever | |
| retriever = BM25Retriever() | |
| symbols = [make_file_symbols(f"file{i}.py", f"content {i}") for i in range(5)] | |
| retriever.index(symbols) | |
| hits = retriever.query("content 3", top_k=5) | |
| assert all(h.score > 0 for h in hits) | |
| def test_ranks_are_sequential(self): | |
| from localisation.bm25_retriever import BM25Retriever | |
| retriever = BM25Retriever() | |
| symbols = [make_file_symbols(f"f{i}.py", f"word{i} text") for i in range(3)] | |
| retriever.index(symbols) | |
| hits = retriever.query("word0 text", top_k=3) | |
| assert [h.rank for h in hits] == list(range(1, len(hits) + 1)) | |
| def test_empty_query_returns_empty(self): | |
| from localisation.bm25_retriever import BM25Retriever | |
| retriever = BM25Retriever() | |
| symbols = [make_file_symbols("a.py", "content")] | |
| retriever.index(symbols) | |
| hits = retriever.query("", top_k=5) | |
| assert hits == [] | |
| def test_corpus_size(self): | |
| from localisation.bm25_retriever import BM25Retriever | |
| retriever = BM25Retriever() | |
| symbols = [make_file_symbols(f"f{i}.py", "text") for i in range(7)] | |
| retriever.index(symbols) | |
| assert retriever.corpus_size == 7 | |
| def test_file_path_tokens_boost(self): | |
| from localisation.bm25_retriever import BM25Retriever | |
| # Both files have 'models' in content. But models.py ALSO has it in | |
| # path (doubled) β with a larger corpus that gives positive BM25 scores. | |
| retriever = BM25Retriever() | |
| symbols = [ | |
| make_file_symbols("django/db/models.py", "handles database records"), | |
| make_file_symbols("utils/helper.py", "general utilities helper"), | |
| make_file_symbols("views/base.py", "base view rendering"), | |
| make_file_symbols("core/app.py", "application entry point"), | |
| make_file_symbols("api/serializers.py", "rest framework serializers"), | |
| ] | |
| retriever.index(symbols) | |
| hits = retriever.query("models", top_k=5) | |
| # models.py has 'models' in path (2x weight) β must appear in results | |
| paths = [h.file_path for h in hits] | |
| assert "django/db/models.py" in paths | |
| def test_not_indexed_raises(self): | |
| from localisation.bm25_retriever import BM25Retriever | |
| retriever = BM25Retriever() | |
| with pytest.raises(RuntimeError, match="not indexed"): | |
| retriever.query("test", top_k=5) | |
| def test_skips_parse_error_files(self): | |
| from localisation.bm25_retriever import BM25Retriever | |
| from ast_parser.python_parser import FileSymbols | |
| retriever = BM25Retriever() | |
| good = make_file_symbols("good.py", "good content") | |
| bad = FileSymbols(file_path="bad.py", file_hash="bbb", parse_error="SyntaxError") | |
| retriever.index([good, bad]) | |
| assert retriever.corpus_size == 1 | |
| # ββ RRF Fusion ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestRRFFusion: | |
| def test_basic_fusion(self): | |
| from localisation.rrf_fusion import reciprocal_rank_fusion | |
| bm25 = [("a.py", 1.0, 1), ("b.py", 0.8, 2), ("c.py", 0.5, 3)] | |
| embed = [("b.py", 0.9, 1), ("a.py", 0.7, 2), ("d.py", 0.6, 3)] | |
| ppr = {"a.py": 0.5, "b.py": 0.3} | |
| result = reciprocal_rank_fusion(bm25, embed, ppr, top_k=4) | |
| assert len(result) <= 4 | |
| # a.py appears in all three β should rank high | |
| top_paths = [h.file_path for h in result] | |
| assert "a.py" in top_paths[:2] | |
| def test_top_k_respected(self): | |
| from localisation.rrf_fusion import reciprocal_rank_fusion | |
| bm25 = [(f"f{i}.py", 1.0, i + 1) for i in range(10)] | |
| result = reciprocal_rank_fusion(bm25, [], {}, top_k=3) | |
| assert len(result) == 3 | |
| def test_empty_inputs(self): | |
| from localisation.rrf_fusion import reciprocal_rank_fusion | |
| result = reciprocal_rank_fusion([], [], {}, top_k=5) | |
| assert result == [] | |
| def test_ranks_sequential(self): | |
| from localisation.rrf_fusion import reciprocal_rank_fusion | |
| bm25 = [("a.py", 1.0, 1), ("b.py", 0.5, 2)] | |
| result = reciprocal_rank_fusion(bm25, [], {}, top_k=5) | |
| assert [h.rank for h in result] == list(range(1, len(result) + 1)) | |
| def test_all_sources_tracked(self): | |
| from localisation.rrf_fusion import reciprocal_rank_fusion | |
| bm25 = [("a.py", 1.0, 1)] | |
| embed = [("a.py", 0.9, 1)] | |
| ppr = {"a.py": 0.5} | |
| result = reciprocal_rank_fusion(bm25, embed, ppr, top_k=1) | |
| hit = result[0] | |
| assert hit.bm25_rank == 1 | |
| assert hit.embed_rank == 1 | |
| assert hit.ppr_rank == 1 | |
| def test_ablation_no_ppr(self): | |
| from localisation.rrf_fusion import ablate | |
| bm25 = [("a.py", 1.0, 1)] | |
| ppr = {"b.py": 99.0} # b.py has huge PPR score | |
| # With PPR zeroed out, b.py should NOT appear | |
| result = ablate(bm25, [], ppr, use_ppr=False, top_k=5) | |
| paths = [h.file_path for h in result] | |
| assert "b.py" not in paths | |
| def test_scores_descending(self): | |
| from localisation.rrf_fusion import reciprocal_rank_fusion | |
| bm25 = [("a.py", 1.0, 1), ("b.py", 0.5, 2), ("c.py", 0.1, 3)] | |
| result = reciprocal_rank_fusion(bm25, [], {}, top_k=3) | |
| scores = [h.fused_score for h in result] | |
| assert scores == sorted(scores, reverse=True) | |
| def test_union_of_all_lists(self): | |
| """File appearing only in PPR should still be in results.""" | |
| from localisation.rrf_fusion import reciprocal_rank_fusion | |
| bm25 = [("a.py", 1.0, 1)] | |
| ppr = {"z.py": 1.0} # only in PPR | |
| result = reciprocal_rank_fusion(bm25, [], ppr, top_k=10) | |
| paths = [h.file_path for h in result] | |
| assert "z.py" in paths | |
| # ββ DeBERTa Ranker β without GPU ββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestDeBERTaRankerFallback: | |
| """Tests for graceful fallback when model is not loaded.""" | |
| def test_rerank_fallback_returns_stage1_order(self): | |
| from localisation.deberta_ranker import DeBERTaRanker | |
| # Don't actually load the model | |
| ranker = DeBERTaRanker.__new__(DeBERTaRanker) | |
| ranker._available = False | |
| ranker._model = None | |
| ranker._tokenizer = None | |
| candidates = [("a.py", "summary a"), ("b.py", "summary b"), ("c.py", "summary c")] | |
| result = ranker.rerank("fix the bug", candidates, top_k=3) | |
| assert len(result) == 3 | |
| assert result[0].file_path == "a.py" | |
| assert result[0].rank == 1 | |
| def test_rerank_empty_candidates(self): | |
| from localisation.deberta_ranker import DeBERTaRanker | |
| ranker = DeBERTaRanker.__new__(DeBERTaRanker) | |
| ranker._available = False | |
| result = ranker.rerank("query", [], top_k=5) | |
| assert result == [] | |
| def test_ranked_file_scores_are_positive(self): | |
| from localisation.deberta_ranker import DeBERTaRanker | |
| ranker = DeBERTaRanker.__new__(DeBERTaRanker) | |
| ranker._available = False | |
| candidates = [(f"f{i}.py", f"text {i}") for i in range(5)] | |
| result = ranker.rerank("test query", candidates, top_k=5) | |
| assert all(r.relevance_score > 0 for r in result) | |
| # ββ Recall metric βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestRecallMetric: | |
| def test_perfect_recall(self): | |
| from localisation.deberta_ranker import recall_at_k | |
| preds = ["a.py", "b.py", "c.py"] | |
| gold = ["a.py", "b.py"] | |
| assert recall_at_k(preds, gold, k=5) == 1.0 | |
| def test_zero_recall(self): | |
| from localisation.deberta_ranker import recall_at_k | |
| preds = ["x.py", "y.py"] | |
| gold = ["a.py"] | |
| assert recall_at_k(preds, gold, k=5) == 0.0 | |
| def test_partial_recall(self): | |
| from localisation.deberta_ranker import recall_at_k | |
| preds = ["a.py", "b.py", "c.py"] | |
| gold = ["a.py", "z.py"] | |
| assert recall_at_k(preds, gold, k=5) == 0.5 | |
| def test_recall_at_k_respects_k(self): | |
| from localisation.deberta_ranker import recall_at_k | |
| preds = ["x.py", "a.py"] # a.py is at position 2 | |
| gold = ["a.py"] | |
| assert recall_at_k(preds, gold, k=1) == 0.0 # only looking at top-1 | |
| assert recall_at_k(preds, gold, k=2) == 1.0 | |
| def test_empty_gold(self): | |
| from localisation.deberta_ranker import recall_at_k | |
| assert recall_at_k(["a.py"], [], k=5) == 0.0 | |
| # ββ Patch file extraction βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestExtractFilesFromPatch: | |
| def test_basic_unified_diff(self): | |
| from localisation.deberta_ranker import _extract_files_from_patch | |
| patch = textwrap.dedent(""" | |
| diff --git a/django/db/models/query.py b/django/db/models/query.py | |
| --- a/django/db/models/query.py | |
| +++ b/django/db/models/query.py | |
| @@ -1 +1 @@ | |
| -old | |
| +new | |
| """) | |
| files = _extract_files_from_patch(patch) | |
| assert "django/db/models/query.py" in files | |
| def test_multiple_files(self): | |
| from localisation.deberta_ranker import _extract_files_from_patch | |
| patch = textwrap.dedent(""" | |
| --- a/foo.py | |
| +++ b/foo.py | |
| @@ -1 +1 @@ fix | |
| --- a/bar.py | |
| +++ b/bar.py | |
| @@ -1 +1 @@ fix | |
| """) | |
| files = _extract_files_from_patch(patch) | |
| assert "foo.py" in files | |
| assert "bar.py" in files | |
| def test_dev_null_excluded(self): | |
| from localisation.deberta_ranker import _extract_files_from_patch | |
| patch = "--- /dev/null\n+++ b/new_file.py\n" | |
| files = _extract_files_from_patch(patch) | |
| assert "/dev/null" not in files | |
| def test_empty_patch(self): | |
| from localisation.deberta_ranker import _extract_files_from_patch | |
| assert _extract_files_from_patch("") == [] | |
| # ββ Failure categorisation ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestFailureCategorisation: | |
| def test_success(self): | |
| from localisation.pipeline import categorise_localisation_failure | |
| result = categorise_localisation_failure(["a.py", "b.py"], ["a.py"], "good long detailed issue text here") | |
| assert result == "success" | |
| def test_wrong_file(self): | |
| from localisation.pipeline import categorise_localisation_failure | |
| # Long issue text (>10 words) + no gold file found β wrong_file | |
| long_issue = "there is a null pointer exception in the query filter method" | |
| result = categorise_localisation_failure(["x.py", "y.py"], ["z.py"], long_issue) | |
| assert result == "wrong_file" | |
| def test_partial_file(self): | |
| from localisation.pipeline import categorise_localisation_failure | |
| result = categorise_localisation_failure(["a.py"], ["a.py", "b.py"], "long enough issue text to be valid") | |
| assert result == "partial_file" | |
| def test_ambiguous_issue(self): | |
| from localisation.pipeline import categorise_localisation_failure | |
| result = categorise_localisation_failure(["x.py"], ["z.py"], "fix bug") # very short | |
| assert result == "ambiguous_issue" | |
| # ββ Pipeline integration (no API required) ββββββββββββββββββββββββββββββββββββ | |
| class TestLocalisationPipeline: | |
| def test_pipeline_bm25_only(self): | |
| from localisation.pipeline import LocalisationPipeline | |
| pipeline = LocalisationPipeline( | |
| use_embeddings=False, | |
| use_deberta=False, | |
| use_ppr=False, | |
| ) | |
| symbols = [ | |
| make_file_symbols("auth/models.py", "User model authentication password hash"), | |
| make_file_symbols("views/login.py", "Login view render form"), | |
| make_file_symbols("utils/email.py", "Email SMTP send message"), | |
| ] | |
| pipeline.index_repo(symbols) | |
| result = pipeline.localise("user authentication login password", top_k=3) | |
| assert len(result.hits) >= 1 | |
| assert result.hits[0].file_path == "auth/models.py" | |
| def test_pipeline_empty_query(self): | |
| from localisation.pipeline import LocalisationPipeline | |
| pipeline = LocalisationPipeline(use_embeddings=False, use_deberta=False) | |
| symbols = [make_file_symbols("a.py", "content")] | |
| pipeline.index_repo(symbols) | |
| result = pipeline.localise("") | |
| assert result.failure_category == "empty_query" | |
| assert result.hits == [] | |
| def test_pipeline_with_gold_files_computes_recall(self): | |
| from localisation.pipeline import LocalisationPipeline | |
| pipeline = LocalisationPipeline(use_embeddings=False, use_deberta=False, use_ppr=False) | |
| # Use a larger corpus so BM25 gives positive scores | |
| # 'queryset' appears in path AND content of target.py β guaranteed top-1 | |
| symbols = [ | |
| make_file_symbols("db/queryset.py", "queryset filter method database orm"), | |
| make_file_symbols("views/generic.py", "generic view rendering template"), | |
| make_file_symbols("utils/helper.py", "utility functions general purpose"), | |
| make_file_symbols("api/serializer.py", "rest framework serializer fields"), | |
| make_file_symbols("forms/widget.py", "html form widget rendering input"), | |
| ] | |
| pipeline.index_repo(symbols) | |
| result = pipeline.localise( | |
| "fix null pointer exception in queryset filter", top_k=5, | |
| gold_files=["db/queryset.py"] | |
| ) | |
| assert result.recall_at_5 is not None | |
| assert result.recall_at_10 is not None | |
| assert result.recall_at_5 == 1.0 # queryset in path+content guarantees top rank | |
| def test_top_k_paths_property(self): | |
| from localisation.pipeline import LocalisationPipeline | |
| pipeline = LocalisationPipeline(use_embeddings=False, use_deberta=False, use_ppr=False) | |
| symbols = [make_file_symbols(f"f{i}.py", f"content {i}") for i in range(5)] | |
| pipeline.index_repo(symbols) | |
| result = pipeline.localise("content 1", top_k=3) | |
| assert len(result.top_k_paths) == len(result.hits) | |
| def test_hit_diagnostic_flags(self): | |
| from localisation.pipeline import LocalisationPipeline | |
| pipeline = LocalisationPipeline(use_embeddings=False, use_deberta=False, use_ppr=False) | |
| symbols = [make_file_symbols("a.py", "special word")] | |
| pipeline.index_repo(symbols) | |
| result = pipeline.localise("special word", top_k=1) | |
| if result.hits: | |
| hit = result.hits[0] | |
| assert hit.in_bm25 is True | |