Spaces:
Sleeping
Sleeping
| """Tests for configurable model selection: config, scoring, pipeline integration.""" | |
| import json | |
| import os | |
| import sqlite3 | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| from types import SimpleNamespace | |
| from unittest.mock import MagicMock, patch | |
| import pytest | |
| # --------------------------------------------------------------------------- | |
| # Fixtures | |
| # --------------------------------------------------------------------------- | |
| def _isolate_config(tmp_path, monkeypatch): | |
| """Redirect config.yaml and DB to temp dir so tests are hermetic.""" | |
| config_path = tmp_path / "config.yaml" | |
| db_path = tmp_path / "researcher.db" | |
| monkeypatch.setenv("CONFIG_PATH", str(config_path)) | |
| monkeypatch.setenv("DB_PATH", str(db_path)) | |
| # Ensure ANTHROPIC_API_KEY is set for tests that need it | |
| monkeypatch.setenv("ANTHROPIC_API_KEY", "sk-ant-test-key") | |
| def fresh_config(tmp_path): | |
| """Return a helper that writes a config.yaml and reloads the config module.""" | |
| config_path = tmp_path / "config.yaml" | |
| def _write(data: dict): | |
| import yaml | |
| config_path.write_text(yaml.dump(data, default_flow_style=False)) | |
| # Force config module to reload from this file | |
| import src.config as cfg | |
| cfg.CONFIG_PATH = config_path | |
| cfg._cfg = cfg._load_yaml() | |
| sc = cfg._cfg.get("scoring", {}) | |
| cfg.SCORING_MODEL = sc.get("model", cfg._cfg.get("claude_model", "claude-haiku-4-5-20251001")) | |
| cfg.RESCORE_MODEL = sc.get("rescore_model", "claude-sonnet-4-5-20250929") | |
| cfg.RESCORE_TOP_N = sc.get("rescore_top_n", 15) | |
| cfg.BATCH_SIZE = sc.get("batch_size", cfg._cfg.get("batch_size", 20)) | |
| cfg.SCORING_CONFIGS.update(cfg._build_scoring_configs()) | |
| return _write | |
| def test_db(tmp_path): | |
| """Initialize a temp database and return its path.""" | |
| db_path = tmp_path / "researcher.db" | |
| import src.config as cfg | |
| cfg.DB_PATH = db_path | |
| from src.db import init_db | |
| init_db() | |
| return db_path | |
| def _insert_test_papers(db_path, run_id, domain, papers): | |
| """Insert test papers directly into the DB.""" | |
| conn = sqlite3.connect(str(db_path)) | |
| for p in papers: | |
| conn.execute( | |
| """INSERT INTO papers | |
| (run_id, domain, arxiv_id, entry_id, title, authors, abstract, | |
| published, categories, pdf_url, arxiv_url, comment, source, | |
| github_repo, github_stars, hf_upvotes, hf_models, hf_datasets, hf_spaces, | |
| score_axis_1, score_axis_2, score_axis_3, composite, summary, reasoning, code_url) | |
| VALUES (?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?,?)""", | |
| ( | |
| run_id, domain, | |
| p.get("arxiv_id", ""), | |
| p.get("entry_id", ""), | |
| p.get("title", ""), | |
| json.dumps(p.get("authors", [])), | |
| p.get("abstract", ""), | |
| p.get("published", ""), | |
| json.dumps(p.get("categories", [])), | |
| p.get("pdf_url", ""), | |
| p.get("arxiv_url", ""), | |
| p.get("comment", ""), | |
| p.get("source", ""), | |
| p.get("github_repo", ""), | |
| p.get("github_stars"), | |
| p.get("hf_upvotes", 0), | |
| json.dumps(p.get("hf_models", [])), | |
| json.dumps(p.get("hf_datasets", [])), | |
| json.dumps(p.get("hf_spaces", [])), | |
| p.get("score_axis_1"), | |
| p.get("score_axis_2"), | |
| p.get("score_axis_3"), | |
| p.get("composite"), | |
| p.get("summary", ""), | |
| p.get("reasoning", ""), | |
| p.get("code_url"), | |
| ), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| def _create_test_run(db_path, domain): | |
| """Create a run row and return its id.""" | |
| conn = sqlite3.connect(str(db_path)) | |
| cur = conn.execute( | |
| "INSERT INTO runs (domain, started_at, date_start, date_end, status) " | |
| "VALUES (?, '2026-01-01T00:00:00', '2026-01-01', '2026-01-07', 'completed')", | |
| (domain,), | |
| ) | |
| run_id = cur.lastrowid | |
| conn.commit() | |
| conn.close() | |
| return run_id | |
| # --------------------------------------------------------------------------- | |
| # 1. Config defaults | |
| # --------------------------------------------------------------------------- | |
| class TestConfigDefaults: | |
| """Verify default model constants when no config.yaml exists.""" | |
| def test_default_scoring_model(self): | |
| import src.config as cfg | |
| # Without a config file, should default to haiku | |
| assert "haiku" in cfg.SCORING_MODEL | |
| def test_default_rescore_model(self): | |
| import src.config as cfg | |
| assert "sonnet" in cfg.RESCORE_MODEL | |
| def test_default_rescore_top_n(self): | |
| import src.config as cfg | |
| assert cfg.RESCORE_TOP_N == 15 | |
| def test_default_batch_size(self): | |
| import src.config as cfg | |
| assert cfg.BATCH_SIZE == 20 | |
| # --------------------------------------------------------------------------- | |
| # 2. Config loading from YAML | |
| # --------------------------------------------------------------------------- | |
| class TestConfigYAML: | |
| """Verify config values load from config.yaml correctly.""" | |
| def test_scoring_block_loads(self, fresh_config): | |
| import src.config as cfg | |
| fresh_config({ | |
| "scoring": { | |
| "model": "claude-opus-4-6", | |
| "rescore_model": "claude-sonnet-4-5-20250929", | |
| "rescore_top_n": 25, | |
| "batch_size": 10, | |
| }, | |
| }) | |
| assert cfg.SCORING_MODEL == "claude-opus-4-6" | |
| assert cfg.RESCORE_MODEL == "claude-sonnet-4-5-20250929" | |
| assert cfg.RESCORE_TOP_N == 25 | |
| assert cfg.BATCH_SIZE == 10 | |
| def test_backward_compat_claude_model_key(self, fresh_config): | |
| """Old `claude_model` key is used as SCORING_MODEL fallback.""" | |
| import src.config as cfg | |
| fresh_config({"claude_model": "claude-sonnet-4-5-20250929"}) | |
| assert cfg.SCORING_MODEL == "claude-sonnet-4-5-20250929" | |
| # rescore defaults still apply | |
| assert "sonnet" in cfg.RESCORE_MODEL | |
| def test_scoring_block_overrides_claude_model(self, fresh_config): | |
| import src.config as cfg | |
| fresh_config({ | |
| "claude_model": "claude-sonnet-4-5-20250929", | |
| "scoring": {"model": "claude-haiku-4-5-20251001"}, | |
| }) | |
| assert cfg.SCORING_MODEL == "claude-haiku-4-5-20251001" | |
| def test_rescore_disabled_when_zero(self, fresh_config): | |
| import src.config as cfg | |
| fresh_config({"scoring": {"rescore_top_n": 0}}) | |
| assert cfg.RESCORE_TOP_N == 0 | |
| # --------------------------------------------------------------------------- | |
| # 3. save_config reloads model constants | |
| # --------------------------------------------------------------------------- | |
| class TestSaveConfig: | |
| """Verify save_config() updates module-level constants in config.""" | |
| def test_save_updates_scoring_model(self, tmp_path): | |
| import src.config as cfg | |
| cfg.CONFIG_PATH = tmp_path / "config.yaml" | |
| cfg.save_config({ | |
| "scoring": { | |
| "model": "claude-opus-4-6", | |
| "rescore_model": "claude-haiku-4-5-20251001", | |
| "rescore_top_n": 5, | |
| "batch_size": 30, | |
| }, | |
| }) | |
| assert cfg.SCORING_MODEL == "claude-opus-4-6" | |
| assert cfg.RESCORE_MODEL == "claude-haiku-4-5-20251001" | |
| assert cfg.RESCORE_TOP_N == 5 | |
| assert cfg.BATCH_SIZE == 30 | |
| def test_save_without_scoring_block_uses_defaults(self, tmp_path): | |
| import src.config as cfg | |
| cfg.CONFIG_PATH = tmp_path / "config.yaml" | |
| cfg.save_config({"domains": {}}) | |
| assert "haiku" in cfg.SCORING_MODEL | |
| assert "sonnet" in cfg.RESCORE_MODEL | |
| assert cfg.RESCORE_TOP_N == 15 | |
| # --------------------------------------------------------------------------- | |
| # 4. scoring.py reads live config (not stale bindings) | |
| # --------------------------------------------------------------------------- | |
| class TestScoringLiveConfig: | |
| """Verify scoring.py reads config values at call time, not import time.""" | |
| def test_score_run_reads_live_model(self, test_db, tmp_path): | |
| """After config change, score_run uses the new model.""" | |
| import src.config as cfg | |
| cfg.CONFIG_PATH = tmp_path / "config.yaml" | |
| cfg.ANTHROPIC_API_KEY = "sk-ant-test" | |
| # Start with haiku | |
| cfg.SCORING_MODEL = "claude-haiku-4-5-20251001" | |
| run_id = _create_test_run(test_db, "aiml") | |
| _insert_test_papers(test_db, run_id, "aiml", [{ | |
| "arxiv_id": "2601.00001", | |
| "title": "Test Paper", | |
| "abstract": "Test abstract", | |
| "authors": ["Author A"], | |
| "categories": ["cs.LG"], | |
| }]) | |
| # Change config to sonnet | |
| cfg.SCORING_MODEL = "claude-sonnet-4-5-20250929" | |
| # Mock the Anthropic client to capture which model is used | |
| captured_model = {} | |
| def mock_create(**kwargs): | |
| captured_model["model"] = kwargs["model"] | |
| resp = MagicMock() | |
| resp.content = [MagicMock(text='[{"arxiv_id":"2601.00001","code_and_weights":7,"novelty":8,"practical_applicability":6,"summary":"test","reasoning":"test","code_url":null}]')] | |
| return resp | |
| with patch("anthropic.Anthropic") as MockClient: | |
| mock_instance = MagicMock() | |
| mock_instance.messages.create = mock_create | |
| MockClient.return_value = mock_instance | |
| from src.scoring import score_run | |
| score_run(run_id, "aiml") | |
| assert captured_model["model"] == "claude-sonnet-4-5-20250929" | |
| def test_rescore_reads_live_config(self, test_db, tmp_path): | |
| """rescore_top reads RESCORE_MODEL and RESCORE_TOP_N at call time.""" | |
| import src.config as cfg | |
| cfg.ANTHROPIC_API_KEY = "sk-ant-test" | |
| cfg.SCORING_MODEL = "claude-haiku-4-5-20251001" | |
| cfg.RESCORE_MODEL = "claude-opus-4-6" | |
| cfg.RESCORE_TOP_N = 3 | |
| run_id = _create_test_run(test_db, "aiml") | |
| # Insert scored papers | |
| for i in range(5): | |
| _insert_test_papers(test_db, run_id, "aiml", [{ | |
| "arxiv_id": f"2601.{i:05d}", | |
| "title": f"Paper {i}", | |
| "abstract": f"Abstract {i}", | |
| "authors": ["Author"], | |
| "categories": ["cs.LG"], | |
| "composite": 8.0 - i * 0.5, | |
| "score_axis_1": 7, "score_axis_2": 8, "score_axis_3": 6, | |
| "summary": "existing", "reasoning": "existing", | |
| }]) | |
| captured = {} | |
| def mock_create(**kwargs): | |
| captured["model"] = kwargs["model"] | |
| captured["content"] = kwargs["messages"][0]["content"] | |
| resp = MagicMock() | |
| # Return scores for top 3 | |
| results = [] | |
| for i in range(3): | |
| results.append({ | |
| "arxiv_id": f"2601.{i:05d}", | |
| "code_and_weights": 9, "novelty": 9, "practical_applicability": 9, | |
| "summary": "rescored", "reasoning": "rescored", "code_url": None, | |
| }) | |
| resp.content = [MagicMock(text=json.dumps(results))] | |
| return resp | |
| with patch("anthropic.Anthropic") as MockClient: | |
| mock_instance = MagicMock() | |
| mock_instance.messages.create = mock_create | |
| MockClient.return_value = mock_instance | |
| from src.scoring import rescore_top | |
| count = rescore_top(run_id, "aiml") | |
| assert captured["model"] == "claude-opus-4-6" | |
| assert count == 3 | |
| # --------------------------------------------------------------------------- | |
| # 5. rescore_top guard conditions | |
| # --------------------------------------------------------------------------- | |
| class TestRescoreGuards: | |
| """Test rescore_top early-exit conditions.""" | |
| def test_rescore_disabled_when_n_zero(self, test_db): | |
| import src.config as cfg | |
| cfg.RESCORE_TOP_N = 0 | |
| cfg.ANTHROPIC_API_KEY = "sk-ant-test" | |
| from src.scoring import rescore_top | |
| assert rescore_top(1, "aiml") == 0 | |
| def test_rescore_disabled_when_explicit_n_zero(self, test_db): | |
| import src.config as cfg | |
| cfg.RESCORE_TOP_N = 15 # config says 15 | |
| cfg.ANTHROPIC_API_KEY = "sk-ant-test" | |
| from src.scoring import rescore_top | |
| assert rescore_top(1, "aiml", n=0) == 0 # n=0 falls through to config | |
| def test_rescore_skipped_when_same_model(self, test_db): | |
| import src.config as cfg | |
| cfg.SCORING_MODEL = "claude-haiku-4-5-20251001" | |
| cfg.RESCORE_MODEL = "claude-haiku-4-5-20251001" | |
| cfg.RESCORE_TOP_N = 15 | |
| cfg.ANTHROPIC_API_KEY = "sk-ant-test" | |
| from src.scoring import rescore_top | |
| assert rescore_top(1, "aiml") == 0 | |
| def test_rescore_skipped_when_no_api_key(self, test_db): | |
| import src.config as cfg | |
| cfg.ANTHROPIC_API_KEY = "" | |
| cfg.SCORING_MODEL = "claude-haiku-4-5-20251001" | |
| cfg.RESCORE_MODEL = "claude-sonnet-4-5-20250929" | |
| cfg.RESCORE_TOP_N = 15 | |
| from src.scoring import rescore_top | |
| assert rescore_top(1, "aiml") == 0 | |
| def test_rescore_skipped_when_no_papers(self, test_db): | |
| import src.config as cfg | |
| cfg.ANTHROPIC_API_KEY = "sk-ant-test" | |
| cfg.SCORING_MODEL = "claude-haiku-4-5-20251001" | |
| cfg.RESCORE_MODEL = "claude-sonnet-4-5-20250929" | |
| cfg.RESCORE_TOP_N = 15 | |
| run_id = _create_test_run(test_db, "aiml") | |
| # No papers inserted | |
| from src.scoring import rescore_top | |
| assert rescore_top(run_id, "aiml") == 0 | |
| def test_rescore_explicit_n_overrides_config(self, test_db): | |
| """Passing n=X should use that instead of RESCORE_TOP_N.""" | |
| import src.config as cfg | |
| cfg.ANTHROPIC_API_KEY = "sk-ant-test" | |
| cfg.SCORING_MODEL = "claude-haiku-4-5-20251001" | |
| cfg.RESCORE_MODEL = "claude-sonnet-4-5-20250929" | |
| cfg.RESCORE_TOP_N = 15 | |
| run_id = _create_test_run(test_db, "aiml") | |
| # Insert 5 scored papers | |
| for i in range(5): | |
| _insert_test_papers(test_db, run_id, "aiml", [{ | |
| "arxiv_id": f"2601.{i:05d}", | |
| "title": f"Paper {i}", | |
| "abstract": f"Abstract {i}", | |
| "authors": ["Author"], | |
| "categories": ["cs.LG"], | |
| "composite": 8.0 - i * 0.5, | |
| "score_axis_1": 7, "score_axis_2": 8, "score_axis_3": 6, | |
| }]) | |
| captured_content = {} | |
| def mock_create(**kwargs): | |
| captured_content["text"] = kwargs["messages"][0]["content"] | |
| results = [] | |
| for i in range(2): | |
| results.append({ | |
| "arxiv_id": f"2601.{i:05d}", | |
| "code_and_weights": 9, "novelty": 9, "practical_applicability": 9, | |
| "summary": "r", "reasoning": "r", "code_url": None, | |
| }) | |
| resp = MagicMock() | |
| resp.content = [MagicMock(text=json.dumps(results))] | |
| return resp | |
| with patch("anthropic.Anthropic") as MockClient: | |
| mock_instance = MagicMock() | |
| mock_instance.messages.create = mock_create | |
| MockClient.return_value = mock_instance | |
| from src.scoring import rescore_top | |
| count = rescore_top(run_id, "aiml", n=2) | |
| # Should have only sent 2 papers (not 15 from config) | |
| assert captured_content["text"].count("arxiv_id:") == 2 | |
| assert count == 2 | |
| # --------------------------------------------------------------------------- | |
| # 6. _build_batch_content output format | |
| # --------------------------------------------------------------------------- | |
| class TestBuildBatchContent: | |
| """Verify _build_batch_content sends the right fields for each domain.""" | |
| def test_aiml_content_fields(self): | |
| from src.scoring import _build_batch_content | |
| papers = [{ | |
| "arxiv_id": "2601.12345", | |
| "title": "Great New Model", | |
| "authors": ["Alice", "Bob", "Carol"], | |
| "categories": ["cs.LG", "cs.CL"], | |
| "abstract": "We present a new model.", | |
| "comment": "Accepted at ICML 2026", | |
| "github_repo": "https://github.com/alice/model", | |
| "hf_upvotes": 120, | |
| "hf_models": [{"id": "alice/model-v1", "likes": 50}], | |
| "hf_spaces": [{"id": "alice/demo", "likes": 10}], | |
| "source": "both", | |
| }] | |
| content = _build_batch_content(papers, "aiml", 2000) | |
| assert "arxiv_id: 2601.12345" in content | |
| assert "title: Great New Model" in content | |
| assert "authors: Alice, Bob, Carol" in content | |
| assert "categories: cs.LG, cs.CL" in content | |
| assert "code_url_found: https://github.com/alice/model" in content | |
| assert "hf_upvotes: 120" in content | |
| assert "hf_models: alice/model-v1" in content | |
| assert "hf_spaces: alice/demo" in content | |
| assert "source: both" in content | |
| assert "abstract: We present a new model." in content | |
| assert "comment: Accepted at ICML 2026" in content | |
| # Should NOT have security-only fields | |
| assert "entry_id:" not in content | |
| assert "llm_adjacent:" not in content | |
| def test_security_content_fields(self): | |
| from src.scoring import _build_batch_content | |
| papers = [{ | |
| "entry_id": "http://arxiv.org/abs/2601.99999v1", | |
| "arxiv_id": "2601.99999", | |
| "title": "New Kernel Exploit", | |
| "authors": ["Mallory"], | |
| "categories": ["cs.CR"], | |
| "abstract": "We found a buffer overflow in the Linux kernel.", | |
| "comment": "10 pages", | |
| "github_repo": "https://github.com/mallory/poc", | |
| }] | |
| content = _build_batch_content(papers, "security", 1500) | |
| assert "entry_id: http://arxiv.org/abs/2601.99999v1" in content | |
| assert "title: New Kernel Exploit" in content | |
| assert "code_url_found: https://github.com/mallory/poc" in content | |
| assert "llm_adjacent: false" in content | |
| # Should NOT have aiml-only fields | |
| assert "hf_upvotes:" not in content | |
| assert "source:" not in content | |
| def test_security_llm_adjacent_true(self): | |
| from src.scoring import _build_batch_content | |
| papers = [{ | |
| "entry_id": "http://arxiv.org/abs/2601.88888v1", | |
| "arxiv_id": "2601.88888", | |
| "title": "Jailbreaking Large Language Models", | |
| "authors": ["Eve"], | |
| "categories": ["cs.CR"], | |
| "abstract": "We demonstrate a new jailbreak attack on LLMs.", | |
| "comment": "", | |
| }] | |
| content = _build_batch_content(papers, "security", 1500) | |
| assert "llm_adjacent: true" in content | |
| def test_abstract_truncation(self): | |
| from src.scoring import _build_batch_content | |
| long_abstract = "A" * 5000 | |
| papers = [{ | |
| "arxiv_id": "2601.00001", | |
| "title": "T", | |
| "abstract": long_abstract, | |
| "authors": [], | |
| "categories": [], | |
| }] | |
| content = _build_batch_content(papers, "aiml", 2000) | |
| # Abstract should be truncated to 2000 chars | |
| assert f"abstract: {'A' * 2000}" in content | |
| assert "A" * 2001 not in content | |
| def test_missing_code_url(self): | |
| from src.scoring import _build_batch_content | |
| papers = [{ | |
| "arxiv_id": "2601.00001", | |
| "title": "No Code Paper", | |
| "abstract": "Theory only.", | |
| "authors": [], | |
| "categories": [], | |
| }] | |
| content = _build_batch_content(papers, "aiml", 2000) | |
| assert "code_url_found: none found" in content | |
| # --------------------------------------------------------------------------- | |
| # 7. _apply_scores integration | |
| # --------------------------------------------------------------------------- | |
| class TestApplyScores: | |
| """Test score application and composite calculation.""" | |
| def test_aiml_score_application(self, test_db): | |
| import src.config as cfg | |
| run_id = _create_test_run(test_db, "aiml") | |
| _insert_test_papers(test_db, run_id, "aiml", [{ | |
| "arxiv_id": "2601.00001", | |
| "title": "Test", | |
| "abstract": "Test", | |
| "authors": ["A"], | |
| "categories": ["cs.LG"], | |
| }]) | |
| # Get the paper to know its DB id | |
| from src.db import get_unscored_papers | |
| papers = get_unscored_papers(run_id) | |
| assert len(papers) == 1 | |
| scoring_config = cfg.SCORING_CONFIGS["aiml"] | |
| claude_scores = [{ | |
| "arxiv_id": "2601.00001", | |
| "code_and_weights": 8, | |
| "novelty": 7, | |
| "practical_applicability": 9, | |
| "summary": "Great paper", | |
| "reasoning": "Novel approach", | |
| "code_url": "https://github.com/test/repo", | |
| }] | |
| from src.scoring import _apply_scores | |
| applied = _apply_scores(papers, claude_scores, "aiml", scoring_config) | |
| assert applied == 1 | |
| # Verify DB was updated | |
| from src.db import get_top_papers | |
| scored = get_top_papers("aiml", run_id=run_id, limit=1) | |
| assert len(scored) == 1 | |
| assert scored[0]["summary"] == "Great paper" | |
| assert scored[0]["code_url"] == "https://github.com/test/repo" | |
| assert scored[0]["composite"] > 0 | |
| def test_security_score_application(self, test_db): | |
| import src.config as cfg | |
| run_id = _create_test_run(test_db, "security") | |
| _insert_test_papers(test_db, run_id, "security", [{ | |
| "arxiv_id": "2601.99999", | |
| "entry_id": "http://arxiv.org/abs/2601.99999v1", | |
| "title": "Exploit", | |
| "abstract": "Buffer overflow", | |
| "authors": ["M"], | |
| "categories": ["cs.CR"], | |
| }]) | |
| from src.db import get_unscored_papers | |
| papers = get_unscored_papers(run_id) | |
| assert len(papers) == 1 | |
| scoring_config = cfg.SCORING_CONFIGS["security"] | |
| claude_scores = [{ | |
| "entry_id": "http://arxiv.org/abs/2601.99999v1", | |
| "has_code": 6, | |
| "novel_attack_surface": 9, | |
| "real_world_impact": 8, | |
| "summary": "Kernel exploit", | |
| "reasoning": "Critical", | |
| "code_url": None, | |
| }] | |
| from src.scoring import _apply_scores | |
| applied = _apply_scores(papers, claude_scores, "security", scoring_config) | |
| assert applied == 1 | |
| from src.db import get_top_papers | |
| scored = get_top_papers("security", run_id=run_id, limit=1) | |
| assert len(scored) == 1 | |
| assert scored[0]["summary"] == "Kernel exploit" | |
| # --------------------------------------------------------------------------- | |
| # 8. _call_claude model parameter | |
| # --------------------------------------------------------------------------- | |
| class TestCallClaude: | |
| """Test _call_claude passes the model correctly and handles responses.""" | |
| def test_model_passed_through(self): | |
| captured = {} | |
| def mock_create(**kwargs): | |
| captured.update(kwargs) | |
| resp = MagicMock() | |
| resp.content = [MagicMock(text='[{"id": 1}]')] | |
| return resp | |
| mock_client = MagicMock() | |
| mock_client.messages.create = mock_create | |
| from src.scoring import _call_claude | |
| result = _call_claude(mock_client, "system", "user content", model="claude-opus-4-6") | |
| assert captured["model"] == "claude-opus-4-6" | |
| assert result == [{"id": 1}] | |
| def test_no_json_returns_empty(self): | |
| mock_client = MagicMock() | |
| resp = MagicMock() | |
| resp.content = [MagicMock(text="I cannot process this request.")] | |
| mock_client.messages.create.return_value = resp | |
| from src.scoring import _call_claude | |
| result = _call_claude(mock_client, "system", "user", model="claude-haiku-4-5-20251001") | |
| assert result == [] | |
| def test_model_is_required_keyword(self): | |
| """model is keyword-only — calling without it should TypeError.""" | |
| mock_client = MagicMock() | |
| from src.scoring import _call_claude | |
| with pytest.raises(TypeError): | |
| _call_claude(mock_client, "system", "user") | |
| # --------------------------------------------------------------------------- | |
| # 9. Full score_run → rescore_top pipeline flow | |
| # --------------------------------------------------------------------------- | |
| class TestFullPipelineFlow: | |
| """End-to-end: bulk score with haiku, rescore top with sonnet.""" | |
| def test_score_then_rescore(self, test_db): | |
| import src.config as cfg | |
| cfg.ANTHROPIC_API_KEY = "sk-ant-test" | |
| cfg.SCORING_MODEL = "claude-haiku-4-5-20251001" | |
| cfg.RESCORE_MODEL = "claude-sonnet-4-5-20250929" | |
| cfg.RESCORE_TOP_N = 2 | |
| cfg.BATCH_SIZE = 20 | |
| run_id = _create_test_run(test_db, "aiml") | |
| # Insert 5 unscored papers | |
| for i in range(5): | |
| _insert_test_papers(test_db, run_id, "aiml", [{ | |
| "arxiv_id": f"2601.{i:05d}", | |
| "title": f"Paper {i}", | |
| "abstract": f"Abstract for paper {i}", | |
| "authors": ["Author"], | |
| "categories": ["cs.LG"], | |
| "source": "arxiv", | |
| }]) | |
| call_log = [] | |
| def mock_create(**kwargs): | |
| model = kwargs["model"] | |
| call_log.append(model) | |
| content = kwargs["messages"][0]["content"] | |
| # Parse which arxiv_ids are in this batch | |
| ids = [] | |
| for line in content.split("\n"): | |
| if line.startswith("arxiv_id: "): | |
| ids.append(line.split(": ", 1)[1]) | |
| results = [] | |
| for aid in ids: | |
| idx = int(aid.split(".")[-1]) | |
| results.append({ | |
| "arxiv_id": aid, | |
| "code_and_weights": 5 + idx, | |
| "novelty": 6 + idx, | |
| "practical_applicability": 4 + idx, | |
| "summary": f"summary-{model}", | |
| "reasoning": "r", | |
| "code_url": None, | |
| }) | |
| resp = MagicMock() | |
| resp.content = [MagicMock(text=json.dumps(results))] | |
| return resp | |
| with patch("anthropic.Anthropic") as MockClient: | |
| mock_instance = MagicMock() | |
| mock_instance.messages.create = mock_create | |
| MockClient.return_value = mock_instance | |
| from src.scoring import rescore_top, score_run | |
| # Step 1: Bulk score | |
| scored = score_run(run_id, "aiml") | |
| assert scored == 5 | |
| # Step 2: Rescore top 2 | |
| rescored = rescore_top(run_id, "aiml") | |
| assert rescored == 2 | |
| # Verify haiku was used for bulk, sonnet for rescore | |
| assert call_log[0] == "claude-haiku-4-5-20251001" | |
| assert call_log[1] == "claude-sonnet-4-5-20250929" | |
| # Verify the top 2 papers have the sonnet summary | |
| from src.db import get_top_papers | |
| top = get_top_papers("aiml", run_id=run_id, limit=5) | |
| # Top 2 should have sonnet summary, rest haiku | |
| assert top[0]["summary"] == "summary-claude-sonnet-4-5-20250929" | |
| assert top[1]["summary"] == "summary-claude-sonnet-4-5-20250929" | |
| assert top[2]["summary"] == "summary-claude-haiku-4-5-20251001" | |
| # --------------------------------------------------------------------------- | |
| # 10. API key validation uses haiku | |
| # --------------------------------------------------------------------------- | |
| class TestApiKeyValidation: | |
| """Verify the setup wizard key validation always uses haiku.""" | |
| async def test_validate_uses_haiku(self): | |
| captured = {} | |
| def mock_create(**kwargs): | |
| captured.update(kwargs) | |
| resp = MagicMock() | |
| resp.content = [MagicMock(text="Hi")] | |
| return resp | |
| with patch("anthropic.Anthropic") as MockClient: | |
| mock_instance = MagicMock() | |
| mock_instance.messages.create = mock_create | |
| MockClient.return_value = mock_instance | |
| from httpx import ASGITransport, AsyncClient | |
| from src.web.app import app | |
| async with AsyncClient( | |
| transport=ASGITransport(app=app), base_url="http://test" | |
| ) as client: | |
| resp = await client.post( | |
| "/api/setup/validate-key", | |
| json={"api_key": "sk-ant-test-123"}, | |
| ) | |
| assert resp.status_code == 200 | |
| assert captured["model"] == "claude-haiku-4-5-20251001" | |
| # --------------------------------------------------------------------------- | |
| # 11. Setup save persists scoring block | |
| # --------------------------------------------------------------------------- | |
| class TestSetupSave: | |
| """Verify the setup save endpoint persists the scoring config.""" | |
| async def test_save_persists_scoring_block(self, tmp_path): | |
| import src.config as cfg | |
| cfg.CONFIG_PATH = tmp_path / "config.yaml" | |
| cfg.DB_PATH = tmp_path / "researcher.db" | |
| from src.db import init_db | |
| init_db() | |
| from httpx import ASGITransport, AsyncClient | |
| from src.web.app import app | |
| # Mock reschedule since apscheduler may not be installed in test env | |
| with patch("src.web.app.save_setup.__module__", "src.web.app"), \ | |
| patch.dict("sys.modules", {"apscheduler": MagicMock(), "apscheduler.schedulers": MagicMock(), "apscheduler.schedulers.background": MagicMock(), "apscheduler.triggers": MagicMock(), "apscheduler.triggers.cron": MagicMock()}): | |
| # Patch reschedule at the call site | |
| with patch("src.scheduler.reschedule", return_value=None): | |
| async with AsyncClient( | |
| transport=ASGITransport(app=app), base_url="http://test" | |
| ) as client: | |
| resp = await client.post("/api/setup/save", json={ | |
| "api_key": "", | |
| "scoring": { | |
| "model": "claude-opus-4-6", | |
| "rescore_model": "claude-sonnet-4-5-20250929", | |
| "rescore_top_n": 10, | |
| }, | |
| "domains": { | |
| "aiml": {"enabled": True}, | |
| "security": {"enabled": True}, | |
| }, | |
| "schedule": "0 22 * * 0", | |
| }) | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["status"] == "ok" | |
| # Verify config was saved and reloaded | |
| assert cfg.SCORING_MODEL == "claude-opus-4-6" | |
| assert cfg.RESCORE_MODEL == "claude-sonnet-4-5-20250929" | |
| assert cfg.RESCORE_TOP_N == 10 | |
| # Verify YAML file has the scoring block | |
| import yaml | |
| saved = yaml.safe_load(cfg.CONFIG_PATH.read_text()) | |
| assert saved["scoring"]["model"] == "claude-opus-4-6" | |
| assert saved["scoring"]["rescore_top_n"] == 10 | |