Spaces:
Sleeping
Sleeping
| import pytest | |
| from fastapi import FastAPI | |
| from fastapi.testclient import TestClient | |
| import joblib | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.svm import LinearSVC | |
| import api | |
| import database | |
| from config import CFG | |
| def sample_texts(): | |
| return [ | |
| "UN peace talks resume as leaders meet to discuss ceasefire plans.", | |
| "Local team wins championship after a dramatic overtime goal.", | |
| "Stocks rally as the central bank signals a pause in rate hikes.", | |
| "New smartphone chip boosts performance while reducing power use.", | |
| "World leaders condemn attacks and call for immediate humanitarian aid.", | |
| "Star striker sidelined with injury ahead of weekend match.", | |
| "Company reports record quarterly earnings despite weak consumer demand.", | |
| "Scientists discover a new exoplanet that may support liquid water.", | |
| "Oil prices rise on supply concerns and geopolitical tensions.", | |
| "Tech firm faces scrutiny after data breach exposes user accounts.", | |
| ] | |
| def sample_labels(): | |
| return [0, 1, 2, 3, 0, 1, 2, 3, 2, 3] | |
| def mock_pipeline(sample_texts, sample_labels): | |
| tfidf = TfidfVectorizer( | |
| max_features=200, | |
| ngram_range=(1, 1), | |
| min_df=1, | |
| sublinear_tf=False, | |
| ) | |
| clf = LogisticRegression( | |
| max_iter=200, | |
| solver="lbfgs", | |
| random_state=CFG.seed, | |
| ) | |
| pipe = Pipeline([("tfidf", tfidf), ("lr", clf)]) | |
| pipe.fit(sample_texts, sample_labels) | |
| return pipe | |
| def test_db_path(tmp_path): | |
| return tmp_path / "test_requests.db" | |
| def api_client(monkeypatch, tmp_path, test_db_path): | |
| models_dir = tmp_path / "models" | |
| models_dir.mkdir(parents=True, exist_ok=True) | |
| monkeypatch.setattr(CFG, "models_dir", str(models_dir)) | |
| monkeypatch.setattr(database, "_default_db_path", lambda: str(test_db_path)) | |
| database.init_db(db_path=str(test_db_path)) | |
| api._registry.clear() | |
| test_app = FastAPI() | |
| test_app.include_router(api.app.router) | |
| return TestClient(test_app) | |
| def api_client_with_models(api_client, mock_pipeline, tmp_path, monkeypatch): | |
| models_dir = tmp_path / "models" | |
| monkeypatch.setattr(CFG, "models_dir", str(models_dir)) | |
| joblib.dump(mock_pipeline, models_dir / "traditional_lr.joblib") | |
| svm = Pipeline( | |
| [ | |
| ("tfidf", TfidfVectorizer(max_features=200, ngram_range=(1, 1), min_df=1)), | |
| ("svm", LinearSVC(random_state=CFG.seed, max_iter=1000)), | |
| ] | |
| ) | |
| svm.fit( | |
| [ | |
| "UN talks continue amid international pressure", | |
| "Team wins match after extra time", | |
| "Shares climb after earnings beat expectations", | |
| "New processor improves phone battery life", | |
| "Markets react to inflation report and central bank comments", | |
| "Scientists unveil new telescope instrument", | |
| "Player scores hat-trick in league game", | |
| "Company announces merger in tech sector", | |
| ], | |
| [0, 1, 2, 3, 2, 3, 1, 2], | |
| ) | |
| joblib.dump(svm, models_dir / "traditional_svm.joblib") | |
| return api_client | |
| def api_client_with_startup(monkeypatch, tmp_path): | |
| test_db_path = tmp_path / "startup_requests.db" | |
| monkeypatch.setattr(database, "_default_db_path", lambda: str(test_db_path)) | |
| monkeypatch.setattr(CFG, "models_dir", str(tmp_path / "models")) | |
| api._registry.clear() | |
| def _fake_load_model(model_name: str): | |
| raise FileNotFoundError("default model not available") | |
| monkeypatch.setattr(api, "_load_model", _fake_load_model) | |
| with TestClient(api.app) as client: | |
| yield client | |