import pytest from unittest.mock import AsyncMock, patch, MagicMock from fastapi import FastAPI from fastapi.testclient import TestClient from app.api.v1.train import train_router from app.api.v1.validation import validation_router from app.api.v1.prediction import prediction_router from app.api.dependencies import get_current_username app = FastAPI() app.include_router(train_router, prefix="/v1") app.include_router(validation_router, prefix="/v1") app.include_router(prediction_router, prefix="/v1") @pytest.fixture def ml_client(): from app.api.v1.train import acccess_token_bearer # Override authentication and authorization app.dependency_overrides[acccess_token_bearer] = lambda: {"user_id": "test_user"} app.dependency_overrides[get_current_username] = lambda: "test_user" with TestClient(app) as client: yield client app.dependency_overrides.clear() # --- Mocking Services --- @pytest.fixture def mock_train_service(): with patch("app.api.v1.train.training_service") as mock_service: yield mock_service @pytest.fixture def mock_validation_service(): with patch("app.api.v1.validation.validation_service") as mock_service: yield mock_service @pytest.fixture def mock_prediction_service(): with patch("app.api.v1.prediction.prediction_service") as mock_service: yield mock_service def test_start_training(ml_client, mock_train_service): # Setup mock return mock_train_result = MagicMock() mock_train_result.model_dump.return_value = { "message": "started", "training_id": "train_123", "status": "running" } mock_train_service.train = AsyncMock(return_value=mock_train_result) payload = { "compliance_type": "firco", "model_name": "infinity", "config": "{}", "preprocess_id": "prep_1", "version": "latest" } response = ml_client.post("/v1/train", data=payload) if response.status_code != 200: print(response.json()) assert response.status_code == 200 data = response.json() assert data["training_id"] == "train_123" assert data["status"] == "running" mock_train_service.train.assert_called_once() def test_start_training_pipeline_mode(ml_client, mock_train_service): """Pipeline mode: file + columns + target_columns triggers train_pipeline()""" mock_train_result = MagicMock() mock_train_result.model_dump.return_value = { "message": "started", "training_id": "train_pipeline_123", "status": "running" } mock_train_service.train_pipeline = AsyncMock(return_value=mock_train_result) payload = { "compliance_type": "non_firco", "model_name": "nexus", "config": "{}", "columns": '{"age":"numerical","gender":"categorical"}', "target_columns": '["income"]', "language": "en", } files = {"file": ("dataset.csv", b"age,gender,income\n25,M,high\n30,F,low", "text/csv")} response = ml_client.post("/v1/train", data=payload, files=files) if response.status_code != 200: print(response.json()) assert response.status_code == 200 data = response.json() assert data["training_id"] == "train_pipeline_123" assert data["status"] == "running" mock_train_service.train_pipeline.assert_called_once() def test_start_training_pipeline_mode_missing_columns(ml_client, mock_train_service): """Pipeline mode without columns should return 400""" payload = { "compliance_type": "non_firco", "model_name": "nexus", "target_columns": '["income"]', } files = {"file": ("dataset.csv", b"age,income\n25,high", "text/csv")} response = ml_client.post("/v1/train", data=payload, files=files) assert response.status_code == 400 assert "columns" in response.json()["detail"] def test_start_training_pipeline_mode_missing_target_columns(ml_client, mock_train_service): """Pipeline mode without target_columns should return 400""" payload = { "compliance_type": "non_firco", "model_name": "nexus", "columns": '{"age":"numerical"}', } files = {"file": ("dataset.csv", b"age,income\n25,high", "text/csv")} response = ml_client.post("/v1/train", data=payload, files=files) assert response.status_code == 400 assert "target_columns" in response.json()["detail"] def test_start_training_no_file_no_preprocess_id(ml_client, mock_train_service): """Neither file nor preprocess_id should return 400""" payload = { "compliance_type": "firco", "model_name": "infinity", } response = ml_client.post("/v1/train", data=payload) assert response.status_code == 400 assert "preprocess_id" in response.json()["detail"] or "file" in response.json()["detail"] def test_list_all_training_runs(ml_client, mock_train_service): mock_train_service.count_by_details = AsyncMock(return_value=1) mock_train_service.list_by_user = AsyncMock(return_value=[]) with patch("app.api.v1.train.format_universal_runs_response", new_callable=AsyncMock) as mock_format: mock_format.return_value = {"training_runs": []} response = ml_client.get("/v1/history/training-runs/") assert response.status_code == 200 assert "data" in response.json() assert "pagination" in response.json() def test_validate_model_api(ml_client, mock_validation_service): mock_val_result = MagicMock() mock_val_result.model_dump.return_value = { "message": "started", "validation_id": "val_123", "metrics": {"accuracy": 0.95}, "training_id": "latest" } mock_validation_service.validate = AsyncMock(return_value=mock_val_result) data = { "compliance_type": "firco", "model_name": "infinity", "version": "latest", "number_of_reasonings": -1 } files = {"file": ("test.csv", b"dummy,data", "text/csv")} response = ml_client.post("/v1/validate", data=data, files=files) assert response.status_code == 200 res_data = response.json() assert res_data["validation_id"] == "val_123" assert res_data["metrics"]["accuracy"] == 0.95 mock_validation_service.validate.assert_called_once() def test_predict_api(ml_client, mock_prediction_service): mock_pred_result = MagicMock() mock_pred_result.model_dump.return_value = { "message": "started", "prediction_id": "pred_123", "predictions": [] } mock_prediction_service.predict = AsyncMock(return_value=mock_pred_result) data = { "compliance_type": "firco", "model_name": "infinity", "source_type": "file", "version": "latest", "number_of_reasonings": -1 } files = {"file": ("test.csv", b"dummy,data", "text/csv")} response = ml_client.post("/v1/predict", data=data, files=files) assert response.status_code == 200 res_data = response.json() assert res_data["prediction_id"] == "pred_123" mock_prediction_service.predict.assert_called_once()