| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Compute PSNR, currently used for colorization and superresolution.""" |
|
|
| import functools |
|
|
| import big_vision.evaluators.proj.uvim.common as common |
| import big_vision.pp.builder as pp_builder |
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| import tensorflow as tf |
|
|
|
|
| class Evaluator: |
| """PSNR evaluator. |
| |
| `predict_fn` accepts arbitrary dictionaries of parameters and data, where |
| the data dictionary is produced by the `pp_fn` op. It is expected to output a |
| single-key dict containing an RGB image with intensities in [-1,1]. |
| """ |
|
|
| def __init__(self, |
| predict_fn, |
| pp_fn, |
| batch_size, |
| dataset="imagenet2012", |
| split="validation", |
| predict_kwargs=None): |
|
|
| def predict(params, batch): |
|
|
| def _f(x): |
| y = predict_fn(params, x, **(predict_kwargs or {})) |
| |
| |
| pred, = y.values() |
| return _psnr(pred, x["labels"], 2.) |
| return jax.lax.all_gather({ |
| "mask": batch["mask"], |
| "psnr": _f(batch["input"]), |
| }, axis_name="data", axis=0) |
|
|
| self.predict_fn = jax.pmap(predict, axis_name="data") |
|
|
| |
| |
| def preprocess(example): |
| return { |
| "mask": tf.constant(1), |
| "input": pp_builder.get_preprocess_fn(pp_fn)(example), |
| } |
|
|
| self.data = common.get_jax_process_dataset( |
| dataset, |
| split, |
| global_batch_size=batch_size, |
| add_tfds_id=True, |
| pp_fn=preprocess) |
|
|
| def run(self, params): |
| """Run eval.""" |
| psnrs = [] |
|
|
| for batch in self.data.as_numpy_iterator(): |
| |
| out = self.predict_fn(params, batch) |
|
|
| if jax.process_index(): |
| continue |
|
|
| |
| |
| out = jax.tree_map(lambda x: jax.device_get(x[0]), out) |
| mask = out["mask"] |
| batch_psnrs = out["psnr"][mask != 0] |
| psnrs.extend(batch_psnrs) |
|
|
| if jax.process_index(): |
| return |
|
|
| yield "PSNR", np.mean(psnrs) |
|
|
|
|
| @functools.partial(jax.vmap, in_axes=[0, 0, None]) |
| def _psnr(img0, img1, dynamic_range): |
| mse = jnp.mean(jnp.power(img0 - img1, 2)) |
| return 20. * jnp.log10(dynamic_range) - 10. * jnp.log10(mse) |
|
|