Spaces:
Runtime error
Runtime error
File size: 7,142 Bytes
f8f02c0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | import pytest
from unittest.mock import AsyncMock, patch, MagicMock
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.api.v1.train import train_router
from app.api.v1.validation import validation_router
from app.api.v1.prediction import prediction_router
from app.api.dependencies import get_current_username
app = FastAPI()
app.include_router(train_router, prefix="/v1")
app.include_router(validation_router, prefix="/v1")
app.include_router(prediction_router, prefix="/v1")
@pytest.fixture
def ml_client():
from app.api.v1.train import acccess_token_bearer
# Override authentication and authorization
app.dependency_overrides[acccess_token_bearer] = lambda: {"user_id": "test_user"}
app.dependency_overrides[get_current_username] = lambda: "test_user"
with TestClient(app) as client:
yield client
app.dependency_overrides.clear()
# --- Mocking Services ---
@pytest.fixture
def mock_train_service():
with patch("app.api.v1.train.training_service") as mock_service:
yield mock_service
@pytest.fixture
def mock_validation_service():
with patch("app.api.v1.validation.validation_service") as mock_service:
yield mock_service
@pytest.fixture
def mock_prediction_service():
with patch("app.api.v1.prediction.prediction_service") as mock_service:
yield mock_service
def test_start_training(ml_client, mock_train_service):
# Setup mock return
mock_train_result = MagicMock()
mock_train_result.model_dump.return_value = {
"message": "started",
"training_id": "train_123",
"status": "running"
}
mock_train_service.train = AsyncMock(return_value=mock_train_result)
payload = {
"compliance_type": "firco",
"model_name": "infinity",
"config": "{}",
"preprocess_id": "prep_1",
"version": "latest"
}
response = ml_client.post("/v1/train", data=payload)
if response.status_code != 200:
print(response.json())
assert response.status_code == 200
data = response.json()
assert data["training_id"] == "train_123"
assert data["status"] == "running"
mock_train_service.train.assert_called_once()
def test_start_training_pipeline_mode(ml_client, mock_train_service):
"""Pipeline mode: file + columns + target_columns triggers train_pipeline()"""
mock_train_result = MagicMock()
mock_train_result.model_dump.return_value = {
"message": "started",
"training_id": "train_pipeline_123",
"status": "running"
}
mock_train_service.train_pipeline = AsyncMock(return_value=mock_train_result)
payload = {
"compliance_type": "non_firco",
"model_name": "nexus",
"config": "{}",
"columns": '{"age":"numerical","gender":"categorical"}',
"target_columns": '["income"]',
"language": "en",
}
files = {"file": ("dataset.csv", b"age,gender,income\n25,M,high\n30,F,low", "text/csv")}
response = ml_client.post("/v1/train", data=payload, files=files)
if response.status_code != 200:
print(response.json())
assert response.status_code == 200
data = response.json()
assert data["training_id"] == "train_pipeline_123"
assert data["status"] == "running"
mock_train_service.train_pipeline.assert_called_once()
def test_start_training_pipeline_mode_missing_columns(ml_client, mock_train_service):
"""Pipeline mode without columns should return 400"""
payload = {
"compliance_type": "non_firco",
"model_name": "nexus",
"target_columns": '["income"]',
}
files = {"file": ("dataset.csv", b"age,income\n25,high", "text/csv")}
response = ml_client.post("/v1/train", data=payload, files=files)
assert response.status_code == 400
assert "columns" in response.json()["detail"]
def test_start_training_pipeline_mode_missing_target_columns(ml_client, mock_train_service):
"""Pipeline mode without target_columns should return 400"""
payload = {
"compliance_type": "non_firco",
"model_name": "nexus",
"columns": '{"age":"numerical"}',
}
files = {"file": ("dataset.csv", b"age,income\n25,high", "text/csv")}
response = ml_client.post("/v1/train", data=payload, files=files)
assert response.status_code == 400
assert "target_columns" in response.json()["detail"]
def test_start_training_no_file_no_preprocess_id(ml_client, mock_train_service):
"""Neither file nor preprocess_id should return 400"""
payload = {
"compliance_type": "firco",
"model_name": "infinity",
}
response = ml_client.post("/v1/train", data=payload)
assert response.status_code == 400
assert "preprocess_id" in response.json()["detail"] or "file" in response.json()["detail"]
def test_list_all_training_runs(ml_client, mock_train_service):
mock_train_service.count_by_details = AsyncMock(return_value=1)
mock_train_service.list_by_user = AsyncMock(return_value=[])
with patch("app.api.v1.train.format_universal_runs_response", new_callable=AsyncMock) as mock_format:
mock_format.return_value = {"training_runs": []}
response = ml_client.get("/v1/history/training-runs/")
assert response.status_code == 200
assert "data" in response.json()
assert "pagination" in response.json()
def test_validate_model_api(ml_client, mock_validation_service):
mock_val_result = MagicMock()
mock_val_result.model_dump.return_value = {
"message": "started",
"validation_id": "val_123",
"metrics": {"accuracy": 0.95},
"training_id": "latest"
}
mock_validation_service.validate = AsyncMock(return_value=mock_val_result)
data = {
"compliance_type": "firco",
"model_name": "infinity",
"version": "latest",
"number_of_reasonings": -1
}
files = {"file": ("test.csv", b"dummy,data", "text/csv")}
response = ml_client.post("/v1/validate", data=data, files=files)
assert response.status_code == 200
res_data = response.json()
assert res_data["validation_id"] == "val_123"
assert res_data["metrics"]["accuracy"] == 0.95
mock_validation_service.validate.assert_called_once()
def test_predict_api(ml_client, mock_prediction_service):
mock_pred_result = MagicMock()
mock_pred_result.model_dump.return_value = {
"message": "started",
"prediction_id": "pred_123",
"predictions": []
}
mock_prediction_service.predict = AsyncMock(return_value=mock_pred_result)
data = {
"compliance_type": "firco",
"model_name": "infinity",
"source_type": "file",
"version": "latest",
"number_of_reasonings": -1
}
files = {"file": ("test.csv", b"dummy,data", "text/csv")}
response = ml_client.post("/v1/predict", data=data, files=files)
assert response.status_code == 200
res_data = response.json()
assert res_data["prediction_id"] == "pred_123"
mock_prediction_service.predict.assert_called_once()
|