import pytest from fastapi import FastAPI from fastapi.testclient import TestClient import joblib from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.linear_model import LogisticRegression from sklearn.pipeline import Pipeline from sklearn.svm import LinearSVC import api import database from config import CFG @pytest.fixture def sample_texts(): return [ "UN peace talks resume as leaders meet to discuss ceasefire plans.", "Local team wins championship after a dramatic overtime goal.", "Stocks rally as the central bank signals a pause in rate hikes.", "New smartphone chip boosts performance while reducing power use.", "World leaders condemn attacks and call for immediate humanitarian aid.", "Star striker sidelined with injury ahead of weekend match.", "Company reports record quarterly earnings despite weak consumer demand.", "Scientists discover a new exoplanet that may support liquid water.", "Oil prices rise on supply concerns and geopolitical tensions.", "Tech firm faces scrutiny after data breach exposes user accounts.", ] @pytest.fixture def sample_labels(): return [0, 1, 2, 3, 0, 1, 2, 3, 2, 3] @pytest.fixture def mock_pipeline(sample_texts, sample_labels): tfidf = TfidfVectorizer( max_features=200, ngram_range=(1, 1), min_df=1, sublinear_tf=False, ) clf = LogisticRegression( max_iter=200, solver="lbfgs", random_state=CFG.seed, ) pipe = Pipeline([("tfidf", tfidf), ("lr", clf)]) pipe.fit(sample_texts, sample_labels) return pipe @pytest.fixture def test_db_path(tmp_path): return tmp_path / "test_requests.db" @pytest.fixture def api_client(monkeypatch, tmp_path, test_db_path): models_dir = tmp_path / "models" models_dir.mkdir(parents=True, exist_ok=True) monkeypatch.setattr(CFG, "models_dir", str(models_dir)) monkeypatch.setattr(database, "_default_db_path", lambda: str(test_db_path)) database.init_db(db_path=str(test_db_path)) api._registry.clear() test_app = FastAPI() test_app.include_router(api.app.router) return TestClient(test_app) @pytest.fixture def api_client_with_models(api_client, mock_pipeline, tmp_path, monkeypatch): models_dir = tmp_path / "models" monkeypatch.setattr(CFG, "models_dir", str(models_dir)) joblib.dump(mock_pipeline, models_dir / "traditional_lr.joblib") svm = Pipeline( [ ("tfidf", TfidfVectorizer(max_features=200, ngram_range=(1, 1), min_df=1)), ("svm", LinearSVC(random_state=CFG.seed, max_iter=1000)), ] ) svm.fit( [ "UN talks continue amid international pressure", "Team wins match after extra time", "Shares climb after earnings beat expectations", "New processor improves phone battery life", "Markets react to inflation report and central bank comments", "Scientists unveil new telescope instrument", "Player scores hat-trick in league game", "Company announces merger in tech sector", ], [0, 1, 2, 3, 2, 3, 1, 2], ) joblib.dump(svm, models_dir / "traditional_svm.joblib") return api_client @pytest.fixture def api_client_with_startup(monkeypatch, tmp_path): test_db_path = tmp_path / "startup_requests.db" monkeypatch.setattr(database, "_default_db_path", lambda: str(test_db_path)) monkeypatch.setattr(CFG, "models_dir", str(tmp_path / "models")) api._registry.clear() def _fake_load_model(model_name: str): raise FileNotFoundError("default model not available") monkeypatch.setattr(api, "_load_model", _fake_load_model) with TestClient(api.app) as client: yield client