Spaces:
Sleeping
Sleeping
File size: 2,983 Bytes
215dd01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | import os
import json
import joblib
import pytest
import tempfile
import pandas as pd
from unittest.mock import MagicMock
# Add the pipeline directory to the path to import the module
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'pipeline')))
from module import run_fn
# --- Mock Objects to replace TFX dependencies ---
class MockFnArgs:
"""A mock class to simulate TFX's FnArgs."""
def __init__(self, train_files, eval_files, schema_path, serving_model_dir, data_accessor):
self.train_files = train_files
self.eval_files = eval_files
self.schema_path = schema_path
self.serving_model_dir = serving_model_dir
self.data_accessor = data_accessor
class MockDataAccessor:
"""A mock class to simulate the TFX DataAccessor."""
def tf_dataset_factory(self, file_pattern, tfxio_options, schema):
return []
# --- Test Fixtures ---
@pytest.fixture
def fn_args():
"""Creates a mock FnArgs object for testing the Trainer's run_fn."""
with tempfile.TemporaryDirectory() as temp_dir:
serving_model_dir = os.path.join(temp_dir, 'serving_model')
os.makedirs(serving_model_dir, exist_ok=True)
schema_path = os.path.join(temp_dir, 'schema.pbtxt')
with open(schema_path, 'w') as f:
f.write('')
args = MockFnArgs(
train_files=['train_dir'],
eval_files=['eval_dir'],
schema_path=schema_path,
serving_model_dir=serving_model_dir,
data_accessor=MockDataAccessor(),
)
yield args
# --- Tests ---
def test_run_fn_creates_models_and_metrics(fn_args, monkeypatch):
"""Tests that the run_fn trains and saves all models and the metrics file."""
dummy_data = {
'lap': [1, 2, 3], 'accx_can': [0.1, 0.2, 0.3], 'accy_can': [0.1, 0.2, 0.3],
'Steering_Angle': [10, 15, 20], 'lap_time': [90, 91, 92],
'nmot': [8000, 8100, 8200], 'aps': [50, 55, 60], 'fuel_consumption': [0.5, 0.55, 0.6],
'speed': [150, 151, 152], 'gear': [4, 4, 5], 'pbrake_f': [0, 0, 1], 'pbrake_r': [0, 0, 0],
'traffic': [0, 1, 0], 'relative_pace': [1.0, 1.1, 1.2]
}
dummy_df = pd.DataFrame(dummy_data)
monkeypatch.setattr('module._dataset_to_pandas', lambda a, b: dummy_df)
run_fn(fn_args)
assert os.path.exists(os.path.join(fn_args.serving_model_dir, "tire_degradation_model.pkl"))
assert os.path.exists(os.path.join(fn_args.serving_model_dir, "fuel_consumption_model.pkl"))
assert os.path.exists(os.path.join(fn_args.serving_model_dir, "pace_prediction_model.pkl"))
metrics_path = os.path.join(fn_args.serving_model_dir, "evaluation_metrics.json")
assert os.path.exists(metrics_path)
with open(metrics_path, 'r') as f:
metrics = json.load(f)
assert 'tire_degradation_model_mse' in metrics
assert 'fuel_consumption_model_mse' in metrics
assert 'pace_prediction_model_mse' in metrics
|