Spaces:
Running on Zero
Running on Zero
File size: 8,976 Bytes
d1f1097 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 | import enum
import glob
import json
import logging
import os
import pathlib
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import pydantic
from tensorboard.backend.event_processing import event_accumulator
# By default TB tries to be smart about what to load in memory to avoid OOM
# Since we expect every step to be there when we do our comparisons, we explicitly
# set the size guidance to 0 so that we load everything. It's okay given our tests
# are small/short.
SIZE_GUIDANCE = {event_accumulator.TENSORS: 0, event_accumulator.SCALARS: 0}
logger = logging.getLogger(__name__)
def approximate_threshold(rtol: float) -> Callable:
def _func(y_pred: List[Union[float, int]], y_true: List[Union[float, int]]):
return np.mean([np.mean(y_pred), np.mean(y_true)]) * rtol
return _func
class TypeOfTestResult(enum.Enum):
APPROXIMATE = 1
DETERMINISTIC = 2
class Test(pydantic.BaseModel):
pass
class NotApproximateError(Exception):
"""Raised if comparison is not within approximate bounds"""
class NotDeterminsticError(Exception):
"""Raised if comparison is not within approximate bounds"""
class ApproximateTest(Test):
atol: Union[int, float] = 0
atol_func: Optional[Callable] = None
rtol: float = 1e-5
@property
def type_of_test_result(self) -> TypeOfTestResult:
return TypeOfTestResult.APPROXIMATE
def error_message(self, metric_name: str) -> NotApproximateError:
return NotApproximateError(f"Approximate comparison of {metric_name}: FAILED")
class DeterministicTest(Test):
@property
def atol(self) -> Union[int, float]:
return 0
atol_func: Optional[Callable] = None
@property
def rtol(self) -> float:
return 0.0
@property
def type_of_test_result(self) -> TypeOfTestResult:
return TypeOfTestResult.DETERMINISTIC
def error_message(self, metric_name: str) -> NotDeterminsticError:
return NotDeterminsticError(f"Exact comparison of {metric_name}: FAILED")
class GoldenValueMetric(pydantic.BaseModel):
start_step: int
end_step: int
step_interval: int
values: Dict[int, Union[int, float, str]]
def __repr__(self):
return f"Values ({self.start_step},{self.end_step},{self.step_interval}): {', '.join([str(f'({step}, {value})') for step, value in self.values.items()])}"
class GoldenValues(pydantic.RootModel):
root: Dict[str, GoldenValueMetric]
class MissingTensorboardLogsError(Exception):
"""Raised if TensorboardLogs not found"""
class UndefinedMetricError(Exception):
"""Raised of golden values metric has no test definition"""
class SkipMetricError(Exception):
"""Raised if metric shall be skipped"""
def read_tb_logs_as_list(
path, index: int = 0, train_iters: int = 50, start_idx: int = 1, step_size: int = 5
) -> Optional[Dict[str, GoldenValueMetric]]:
"""Reads a TensorBoard Events file from the input path, and returns the
summary specified as input as a list.
Args:
path: str, path to the dir where the events file is located.
summary_name: str, name of the summary to read from the TB logs.
Returns:
summary_list: list, the values in the read summary list, formatted as a list.
"""
files = glob.glob(f"{path}/events*tfevents*")
files += glob.glob(f"{path}/results/events*tfevents*")
if not files:
logger.error(f"File not found matching: {path}/events* || {path}/results/events*")
return None
files.sort(key=lambda x: os.path.getmtime(os.path.join(path, pathlib.Path(x).name)))
accumulators = []
if index == -1:
for event_file in files:
ea = event_accumulator.EventAccumulator(event_file, size_guidance=SIZE_GUIDANCE)
ea.Reload()
accumulators.append(ea)
else:
event_file = files[index]
ea = event_accumulator.EventAccumulator(event_file, size_guidance=SIZE_GUIDANCE)
ea.Reload()
accumulators.append(ea)
summaries = {}
for ea in accumulators:
for scalar_name in ea.Tags()["scalars"]:
if scalar_name in summaries:
for x in ea.Scalars(scalar_name):
if x.step not in summaries[scalar_name]:
summaries[scalar_name][x.step] = round(x.value, 5)
else:
summaries[scalar_name] = {
x.step: round(x.value, 5) for x in ea.Scalars(scalar_name)
}
golden_values = {}
for metric, values in summaries.items():
# Add missing values
values = {
k: (values[k] if k in values else "nan")
for k in range(1, train_iters + 1)
if k == start_idx or (k > start_idx and int(k) % step_size == 0)
}
golden_values[metric] = GoldenValueMetric(
start_step=min(values.keys()),
end_step=max(values.keys()),
step_interval=step_size,
values=values,
)
return golden_values
def read_golden_values_from_json(
golden_values_path: Union[str, pathlib.Path]
) -> Dict[str, GoldenValueMetric]:
with open(golden_values_path) as f:
if os.path.exists(golden_values_path):
with open(golden_values_path) as f:
return GoldenValues(**json.load(f)).root
raise ValueError(f"File {golden_values_path} not found!")
def _filter_checks(
checks: List[Union[ApproximateTest, DeterministicTest]], filter_for_type_of_check
):
return [test for test in checks if test.type_of_test_result == filter_for_type_of_check]
def pipeline(
compare_approximate_results: bool,
golden_values: Dict[str, GoldenValueMetric],
actual_values: Dict[str, GoldenValueMetric],
checks: Dict[str, List[Union[ApproximateTest, DeterministicTest]]],
):
all_test_passed = True
failed_metrics = []
for metric_name, metric_thresholds in checks.items():
if metric_name not in list(actual_values.keys()):
raise MissingTensorboardLogsError(
f"Metric {metric_name} not found in Tensorboard logs! Please modify `model_config.yaml` to record it."
)
for test in metric_thresholds:
if (
compare_approximate_results
and test.type_of_test_result == TypeOfTestResult.DETERMINISTIC
):
continue
try:
golden_value = golden_values[metric_name]
golden_value_list = list(golden_value.values.values())
actual_value_list = [
value
for value_step, value in actual_values[metric_name].values.items()
if value_step in golden_value.values.keys()
]
if metric_name == "iteration-time":
actual_value_list = actual_value_list[3:-1]
golden_value_list = golden_value_list[3:-1]
logger.info(
"For metric `%s`, the first 3 and the last scalars are removed from the list to reduce noise.",
metric_name,
)
actual_value_list = [np.inf if type(v) is str else v for v in actual_value_list]
golden_value_list = [np.inf if type(v) is str else v for v in golden_value_list]
actual = np.array(actual_value_list)
golden = np.array(golden_value_list)
# Tolerance check
passing = np.allclose(
actual,
golden,
rtol=test.rtol,
atol=(
test.atol_func(actual_value_list, golden_value_list)
if test.atol_func is not None
else test.atol
),
)
if not passing:
logger.info("Actual values: %s", ", ".join([str(v) for v in actual_value_list]))
logger.info("Golden values: %s", ", ".join([str(v) for v in golden_value_list]))
raise test.error_message(metric_name)
result = f"{test.type_of_test_result.name} test for metric {metric_name}: PASSED"
result_code = 0
except (NotApproximateError, NotDeterminsticError, MissingTensorboardLogsError) as e:
result = str(e)
result_code = 1
except SkipMetricError:
logger.info(f"{test.type_of_test_result.name} test for {metric_name}: SKIPPED")
continue
log_emitter = logger.info if result_code == 0 else logger.error
log_emitter(result)
if result_code == 1:
all_test_passed = False
failed_metrics.append(metric_name)
assert all_test_passed, f"The following metrics failed: {', '.join(failed_metrics)}"
|