nexa-classify-api / tests /conftest.py
Prototype6239's picture
Upload folder using huggingface_hub
a229747 verified
Raw
History Blame Contribute Delete
3.84 kB
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
import joblib
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.svm import LinearSVC
import api
import database
from config import CFG
@pytest.fixture
def sample_texts():
return [
"UN peace talks resume as leaders meet to discuss ceasefire plans.",
"Local team wins championship after a dramatic overtime goal.",
"Stocks rally as the central bank signals a pause in rate hikes.",
"New smartphone chip boosts performance while reducing power use.",
"World leaders condemn attacks and call for immediate humanitarian aid.",
"Star striker sidelined with injury ahead of weekend match.",
"Company reports record quarterly earnings despite weak consumer demand.",
"Scientists discover a new exoplanet that may support liquid water.",
"Oil prices rise on supply concerns and geopolitical tensions.",
"Tech firm faces scrutiny after data breach exposes user accounts.",
]
@pytest.fixture
def sample_labels():
return [0, 1, 2, 3, 0, 1, 2, 3, 2, 3]
@pytest.fixture
def mock_pipeline(sample_texts, sample_labels):
tfidf = TfidfVectorizer(
max_features=200,
ngram_range=(1, 1),
min_df=1,
sublinear_tf=False,
)
clf = LogisticRegression(
max_iter=200,
solver="lbfgs",
random_state=CFG.seed,
)
pipe = Pipeline([("tfidf", tfidf), ("lr", clf)])
pipe.fit(sample_texts, sample_labels)
return pipe
@pytest.fixture
def test_db_path(tmp_path):
return tmp_path / "test_requests.db"
@pytest.fixture
def api_client(monkeypatch, tmp_path, test_db_path):
models_dir = tmp_path / "models"
models_dir.mkdir(parents=True, exist_ok=True)
monkeypatch.setattr(CFG, "models_dir", str(models_dir))
monkeypatch.setattr(database, "_default_db_path", lambda: str(test_db_path))
database.init_db(db_path=str(test_db_path))
api._registry.clear()
test_app = FastAPI()
test_app.include_router(api.app.router)
return TestClient(test_app)
@pytest.fixture
def api_client_with_models(api_client, mock_pipeline, tmp_path, monkeypatch):
models_dir = tmp_path / "models"
monkeypatch.setattr(CFG, "models_dir", str(models_dir))
joblib.dump(mock_pipeline, models_dir / "traditional_lr.joblib")
svm = Pipeline(
[
("tfidf", TfidfVectorizer(max_features=200, ngram_range=(1, 1), min_df=1)),
("svm", LinearSVC(random_state=CFG.seed, max_iter=1000)),
]
)
svm.fit(
[
"UN talks continue amid international pressure",
"Team wins match after extra time",
"Shares climb after earnings beat expectations",
"New processor improves phone battery life",
"Markets react to inflation report and central bank comments",
"Scientists unveil new telescope instrument",
"Player scores hat-trick in league game",
"Company announces merger in tech sector",
],
[0, 1, 2, 3, 2, 3, 1, 2],
)
joblib.dump(svm, models_dir / "traditional_svm.joblib")
return api_client
@pytest.fixture
def api_client_with_startup(monkeypatch, tmp_path):
test_db_path = tmp_path / "startup_requests.db"
monkeypatch.setattr(database, "_default_db_path", lambda: str(test_db_path))
monkeypatch.setattr(CFG, "models_dir", str(tmp_path / "models"))
api._registry.clear()
def _fake_load_model(model_name: str):
raise FileNotFoundError("default model not available")
monkeypatch.setattr(api, "_load_model", _fake_load_model)
with TestClient(api.app) as client:
yield client