| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import unittest |
| from unittest.mock import patch |
|
|
| from verl.utils.tracking import _MlflowLoggingAdapter |
|
|
|
|
| class TestMlflowLoggingAdapter(unittest.TestCase): |
| def test_sanitize_key_and_warning(self): |
| """Test key sanitization for invalid characters and consecutive slashes with warnings.""" |
| adapter = _MlflowLoggingAdapter() |
| data = { |
| "valid_key": 1.0, |
| "invalid@key!": 2.0, |
| "another/valid-key": 3.0, |
| "bad key#": 4.0, |
| "val-aux//reward/mean_at_1": 5.0, |
| "val-core///acc/best_at_5": 6.0, |
| "metric////with/many////slashes": 7.0, |
| } |
| |
| with ( |
| patch("mlflow.log_metrics") as mock_log_metrics, |
| patch.object(adapter, "logger") as mock_logger, |
| ): |
| adapter.log(data, step=5) |
| |
| sent_metrics = mock_log_metrics.call_args[1]["metrics"] |
| self.assertIn("invalid_at_key_", sent_metrics) |
| self.assertIn("bad key_", sent_metrics) |
| self.assertNotIn("invalid@key!", sent_metrics) |
| self.assertNotIn("bad key#", sent_metrics) |
| |
| self.assertIn("val-aux/reward/mean_at_1", sent_metrics) |
| self.assertIn("val-core/acc/best_at_5", sent_metrics) |
| self.assertIn("metric/with/many/slashes", sent_metrics) |
| self.assertNotIn("val-aux//reward/mean_at_1", sent_metrics) |
| self.assertNotIn("val-core///acc/best_at_5", sent_metrics) |
| |
| warning_msgs = [str(call) for call in mock_logger.warning.call_args_list] |
| |
| self.assertTrue(any("invalid@key!" in msg and "invalid_at_key_" in msg for msg in warning_msgs)) |
| self.assertTrue(any("bad key#" in msg and "bad key_" in msg for msg in warning_msgs)) |
| |
| self.assertTrue(any("val-aux//reward/mean_at_1" in msg for msg in warning_msgs)) |
| self.assertTrue(any("val-core///acc/best_at_5" in msg for msg in warning_msgs)) |
| self.assertTrue(any("metric////with/many////slashes" in msg for msg in warning_msgs)) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|