Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import joblib | |
| import pytest | |
| import tempfile | |
| import pandas as pd | |
| from unittest.mock import MagicMock | |
| # Add the pipeline directory to the path to import the module | |
| import sys | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'pipeline'))) | |
| from module import run_fn | |
| # --- Mock Objects to replace TFX dependencies --- | |
| class MockFnArgs: | |
| """A mock class to simulate TFX's FnArgs.""" | |
| def __init__(self, train_files, eval_files, schema_path, serving_model_dir, data_accessor): | |
| self.train_files = train_files | |
| self.eval_files = eval_files | |
| self.schema_path = schema_path | |
| self.serving_model_dir = serving_model_dir | |
| self.data_accessor = data_accessor | |
| class MockDataAccessor: | |
| """A mock class to simulate the TFX DataAccessor.""" | |
| def tf_dataset_factory(self, file_pattern, tfxio_options, schema): | |
| return [] | |
| # --- Test Fixtures --- | |
| def fn_args(): | |
| """Creates a mock FnArgs object for testing the Trainer's run_fn.""" | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| serving_model_dir = os.path.join(temp_dir, 'serving_model') | |
| os.makedirs(serving_model_dir, exist_ok=True) | |
| schema_path = os.path.join(temp_dir, 'schema.pbtxt') | |
| with open(schema_path, 'w') as f: | |
| f.write('') | |
| args = MockFnArgs( | |
| train_files=['train_dir'], | |
| eval_files=['eval_dir'], | |
| schema_path=schema_path, | |
| serving_model_dir=serving_model_dir, | |
| data_accessor=MockDataAccessor(), | |
| ) | |
| yield args | |
| # --- Tests --- | |
| def test_run_fn_creates_models_and_metrics(fn_args, monkeypatch): | |
| """Tests that the run_fn trains and saves all models and the metrics file.""" | |
| dummy_data = { | |
| 'lap': [1, 2, 3], 'accx_can': [0.1, 0.2, 0.3], 'accy_can': [0.1, 0.2, 0.3], | |
| 'Steering_Angle': [10, 15, 20], 'lap_time': [90, 91, 92], | |
| 'nmot': [8000, 8100, 8200], 'aps': [50, 55, 60], 'fuel_consumption': [0.5, 0.55, 0.6], | |
| 'speed': [150, 151, 152], 'gear': [4, 4, 5], 'pbrake_f': [0, 0, 1], 'pbrake_r': [0, 0, 0], | |
| 'traffic': [0, 1, 0], 'relative_pace': [1.0, 1.1, 1.2] | |
| } | |
| dummy_df = pd.DataFrame(dummy_data) | |
| monkeypatch.setattr('module._dataset_to_pandas', lambda a, b: dummy_df) | |
| run_fn(fn_args) | |
| assert os.path.exists(os.path.join(fn_args.serving_model_dir, "tire_degradation_model.pkl")) | |
| assert os.path.exists(os.path.join(fn_args.serving_model_dir, "fuel_consumption_model.pkl")) | |
| assert os.path.exists(os.path.join(fn_args.serving_model_dir, "pace_prediction_model.pkl")) | |
| metrics_path = os.path.join(fn_args.serving_model_dir, "evaluation_metrics.json") | |
| assert os.path.exists(metrics_path) | |
| with open(metrics_path, 'r') as f: | |
| metrics = json.load(f) | |
| assert 'tire_degradation_model_mse' in metrics | |
| assert 'fuel_consumption_model_mse' in metrics | |
| assert 'pace_prediction_model_mse' in metrics | |