Spaces:
Runtime error
Runtime error
| import pytest | |
| from unittest.mock import AsyncMock, patch, MagicMock | |
| from fastapi import FastAPI | |
| from fastapi.testclient import TestClient | |
| from app.api.v1.train import train_router | |
| from app.api.v1.validation import validation_router | |
| from app.api.v1.prediction import prediction_router | |
| from app.api.dependencies import get_current_username | |
| app = FastAPI() | |
| app.include_router(train_router, prefix="/v1") | |
| app.include_router(validation_router, prefix="/v1") | |
| app.include_router(prediction_router, prefix="/v1") | |
| def ml_client(): | |
| from app.api.v1.train import acccess_token_bearer | |
| # Override authentication and authorization | |
| app.dependency_overrides[acccess_token_bearer] = lambda: {"user_id": "test_user"} | |
| app.dependency_overrides[get_current_username] = lambda: "test_user" | |
| with TestClient(app) as client: | |
| yield client | |
| app.dependency_overrides.clear() | |
| # --- Mocking Services --- | |
| def mock_train_service(): | |
| with patch("app.api.v1.train.training_service") as mock_service: | |
| yield mock_service | |
| def mock_validation_service(): | |
| with patch("app.api.v1.validation.validation_service") as mock_service: | |
| yield mock_service | |
| def mock_prediction_service(): | |
| with patch("app.api.v1.prediction.prediction_service") as mock_service: | |
| yield mock_service | |
| def test_start_training(ml_client, mock_train_service): | |
| # Setup mock return | |
| mock_train_result = MagicMock() | |
| mock_train_result.model_dump.return_value = { | |
| "message": "started", | |
| "training_id": "train_123", | |
| "status": "running" | |
| } | |
| mock_train_service.train = AsyncMock(return_value=mock_train_result) | |
| payload = { | |
| "compliance_type": "firco", | |
| "model_name": "infinity", | |
| "config": "{}", | |
| "preprocess_id": "prep_1", | |
| "version": "latest" | |
| } | |
| response = ml_client.post("/v1/train", data=payload) | |
| if response.status_code != 200: | |
| print(response.json()) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["training_id"] == "train_123" | |
| assert data["status"] == "running" | |
| mock_train_service.train.assert_called_once() | |
| def test_start_training_pipeline_mode(ml_client, mock_train_service): | |
| """Pipeline mode: file + columns + target_columns triggers train_pipeline()""" | |
| mock_train_result = MagicMock() | |
| mock_train_result.model_dump.return_value = { | |
| "message": "started", | |
| "training_id": "train_pipeline_123", | |
| "status": "running" | |
| } | |
| mock_train_service.train_pipeline = AsyncMock(return_value=mock_train_result) | |
| payload = { | |
| "compliance_type": "non_firco", | |
| "model_name": "nexus", | |
| "config": "{}", | |
| "columns": '{"age":"numerical","gender":"categorical"}', | |
| "target_columns": '["income"]', | |
| "language": "en", | |
| } | |
| files = {"file": ("dataset.csv", b"age,gender,income\n25,M,high\n30,F,low", "text/csv")} | |
| response = ml_client.post("/v1/train", data=payload, files=files) | |
| if response.status_code != 200: | |
| print(response.json()) | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert data["training_id"] == "train_pipeline_123" | |
| assert data["status"] == "running" | |
| mock_train_service.train_pipeline.assert_called_once() | |
| def test_start_training_pipeline_mode_missing_columns(ml_client, mock_train_service): | |
| """Pipeline mode without columns should return 400""" | |
| payload = { | |
| "compliance_type": "non_firco", | |
| "model_name": "nexus", | |
| "target_columns": '["income"]', | |
| } | |
| files = {"file": ("dataset.csv", b"age,income\n25,high", "text/csv")} | |
| response = ml_client.post("/v1/train", data=payload, files=files) | |
| assert response.status_code == 400 | |
| assert "columns" in response.json()["detail"] | |
| def test_start_training_pipeline_mode_missing_target_columns(ml_client, mock_train_service): | |
| """Pipeline mode without target_columns should return 400""" | |
| payload = { | |
| "compliance_type": "non_firco", | |
| "model_name": "nexus", | |
| "columns": '{"age":"numerical"}', | |
| } | |
| files = {"file": ("dataset.csv", b"age,income\n25,high", "text/csv")} | |
| response = ml_client.post("/v1/train", data=payload, files=files) | |
| assert response.status_code == 400 | |
| assert "target_columns" in response.json()["detail"] | |
| def test_start_training_no_file_no_preprocess_id(ml_client, mock_train_service): | |
| """Neither file nor preprocess_id should return 400""" | |
| payload = { | |
| "compliance_type": "firco", | |
| "model_name": "infinity", | |
| } | |
| response = ml_client.post("/v1/train", data=payload) | |
| assert response.status_code == 400 | |
| assert "preprocess_id" in response.json()["detail"] or "file" in response.json()["detail"] | |
| def test_list_all_training_runs(ml_client, mock_train_service): | |
| mock_train_service.count_by_details = AsyncMock(return_value=1) | |
| mock_train_service.list_by_user = AsyncMock(return_value=[]) | |
| with patch("app.api.v1.train.format_universal_runs_response", new_callable=AsyncMock) as mock_format: | |
| mock_format.return_value = {"training_runs": []} | |
| response = ml_client.get("/v1/history/training-runs/") | |
| assert response.status_code == 200 | |
| assert "data" in response.json() | |
| assert "pagination" in response.json() | |
| def test_validate_model_api(ml_client, mock_validation_service): | |
| mock_val_result = MagicMock() | |
| mock_val_result.model_dump.return_value = { | |
| "message": "started", | |
| "validation_id": "val_123", | |
| "metrics": {"accuracy": 0.95}, | |
| "training_id": "latest" | |
| } | |
| mock_validation_service.validate = AsyncMock(return_value=mock_val_result) | |
| data = { | |
| "compliance_type": "firco", | |
| "model_name": "infinity", | |
| "version": "latest", | |
| "number_of_reasonings": -1 | |
| } | |
| files = {"file": ("test.csv", b"dummy,data", "text/csv")} | |
| response = ml_client.post("/v1/validate", data=data, files=files) | |
| assert response.status_code == 200 | |
| res_data = response.json() | |
| assert res_data["validation_id"] == "val_123" | |
| assert res_data["metrics"]["accuracy"] == 0.95 | |
| mock_validation_service.validate.assert_called_once() | |
| def test_predict_api(ml_client, mock_prediction_service): | |
| mock_pred_result = MagicMock() | |
| mock_pred_result.model_dump.return_value = { | |
| "message": "started", | |
| "prediction_id": "pred_123", | |
| "predictions": [] | |
| } | |
| mock_prediction_service.predict = AsyncMock(return_value=mock_pred_result) | |
| data = { | |
| "compliance_type": "firco", | |
| "model_name": "infinity", | |
| "source_type": "file", | |
| "version": "latest", | |
| "number_of_reasonings": -1 | |
| } | |
| files = {"file": ("test.csv", b"dummy,data", "text/csv")} | |
| response = ml_client.post("/v1/predict", data=data, files=files) | |
| assert response.status_code == 200 | |
| res_data = response.json() | |
| assert res_data["prediction_id"] == "pred_123" | |
| mock_prediction_service.predict.assert_called_once() | |