|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
import matplotlib as mpl |
|
|
import pandas as pd |
|
|
from matplotlib import _api |
|
|
from matplotlib import transforms as mtransforms |
|
|
|
|
|
|
|
|
class _DummyAxis: |
|
|
"""Define the minimal interface for a dummy axis. |
|
|
|
|
|
Args: |
|
|
minpos (float): The minimum positive value for the axis. Defaults to 0. |
|
|
""" |
|
|
__name__ = 'dummy' |
|
|
|
|
|
|
|
|
|
|
|
dataLim = _api.deprecate_privatize_attribute( |
|
|
'3.6', alternative='get_data_interval() and set_data_interval()') |
|
|
viewLim = _api.deprecate_privatize_attribute( |
|
|
'3.6', alternative='get_view_interval() and set_view_interval()') |
|
|
|
|
|
def __init__(self, minpos: float = 0) -> None: |
|
|
self._dataLim = mtransforms.Bbox.unit() |
|
|
self._viewLim = mtransforms.Bbox.unit() |
|
|
self._minpos = minpos |
|
|
|
|
|
def get_view_interval(self) -> Dict: |
|
|
"""Return the view interval as a tuple (*vmin*, *vmax*).""" |
|
|
return self._viewLim.intervalx |
|
|
|
|
|
def set_view_interval(self, vmin: float, vmax: float) -> None: |
|
|
"""Set the view interval to (*vmin*, *vmax*).""" |
|
|
self._viewLim.intervalx = vmin, vmax |
|
|
|
|
|
def get_minpos(self) -> float: |
|
|
"""Return the minimum positive value for the axis.""" |
|
|
return self._minpos |
|
|
|
|
|
def get_data_interval(self) -> Dict: |
|
|
"""Return the data interval as a tuple (*vmin*, *vmax*).""" |
|
|
return self._dataLim.intervalx |
|
|
|
|
|
def set_data_interval(self, vmin: float, vmax: float) -> None: |
|
|
"""Set the data interval to (*vmin*, *vmax*).""" |
|
|
self._dataLim.intervalx = vmin, vmax |
|
|
|
|
|
def get_tick_space(self) -> int: |
|
|
"""Return the number of ticks to use.""" |
|
|
|
|
|
return 9 |
|
|
|
|
|
|
|
|
class TickHelper: |
|
|
"""A helper class for ticks and tick labels.""" |
|
|
axis = None |
|
|
|
|
|
def set_axis(self, axis: Any) -> None: |
|
|
"""Set the axis instance.""" |
|
|
self.axis = axis |
|
|
|
|
|
def create_dummy_axis(self, **kwargs) -> None: |
|
|
"""Create a dummy axis if no axis is set.""" |
|
|
if self.axis is None: |
|
|
self.axis = _DummyAxis(**kwargs) |
|
|
|
|
|
@_api.deprecated('3.5', alternative='`.Axis.set_view_interval`') |
|
|
def set_view_interval(self, vmin: float, vmax: float) -> None: |
|
|
"""Set the view interval to (*vmin*, *vmax*).""" |
|
|
self.axis.set_view_interval(vmin, vmax) |
|
|
|
|
|
@_api.deprecated('3.5', alternative='`.Axis.set_data_interval`') |
|
|
def set_data_interval(self, vmin: float, vmax: float) -> None: |
|
|
"""Set the data interval to (*vmin*, *vmax*).""" |
|
|
self.axis.set_data_interval(vmin, vmax) |
|
|
|
|
|
@_api.deprecated( |
|
|
'3.5', |
|
|
alternative='`.Axis.set_view_interval` and `.Axis.set_data_interval`') |
|
|
def set_bounds(self, vmin: float, vmax: float) -> None: |
|
|
"""Set the view and data interval to (*vmin*, *vmax*).""" |
|
|
self.set_view_interval(vmin, vmax) |
|
|
self.set_data_interval(vmin, vmax) |
|
|
|
|
|
|
|
|
class Formatter(TickHelper): |
|
|
"""Create a string based on a tick value and location.""" |
|
|
|
|
|
|
|
|
locs = [] |
|
|
|
|
|
def __call__(self, x: str, pos: Optional[Any] = None) -> str: |
|
|
"""Return the format for tick value *x* at position pos. |
|
|
|
|
|
``pos=None`` indicates an unspecified location. |
|
|
|
|
|
This method must be overridden in the derived class. |
|
|
|
|
|
Args: |
|
|
x (str): The tick value. |
|
|
pos (Optional[Any]): The tick position. Defaults to None. |
|
|
""" |
|
|
raise NotImplementedError('Derived must override') |
|
|
|
|
|
def format_ticks(self, values: pd.Series) -> List[str]: |
|
|
"""Return the tick labels for all the ticks at once. |
|
|
|
|
|
Args: |
|
|
values (pd.Series): The tick values. |
|
|
|
|
|
Returns: |
|
|
List[str]: The tick labels. |
|
|
""" |
|
|
self.set_locs(values) |
|
|
return [self(value, i) for i, value in enumerate(values)] |
|
|
|
|
|
def format_data(self, value: Any) -> str: |
|
|
"""Return the full string representation of the value with the position |
|
|
unspecified. |
|
|
|
|
|
Args: |
|
|
value (Any): The tick value. |
|
|
|
|
|
Returns: |
|
|
str: The full string representation of the value. |
|
|
""" |
|
|
return self.__call__(value) |
|
|
|
|
|
def format_data_short(self, value: Any) -> str: |
|
|
"""Return a short string version of the tick value. |
|
|
|
|
|
Defaults to the position-independent long value. |
|
|
|
|
|
Args: |
|
|
value (Any): The tick value. |
|
|
|
|
|
Returns: |
|
|
str: The short string representation of the value. |
|
|
""" |
|
|
return self.format_data(value) |
|
|
|
|
|
def get_offset(self) -> str: |
|
|
"""Return the offset string.""" |
|
|
return '' |
|
|
|
|
|
def set_locs(self, locs: List[Any]) -> None: |
|
|
"""Set the locations of the ticks. |
|
|
|
|
|
This method is called before computing the tick labels because some |
|
|
formatters need to know all tick locations to do so. |
|
|
""" |
|
|
self.locs = locs |
|
|
|
|
|
@staticmethod |
|
|
def fix_minus(s: str) -> str: |
|
|
"""Some classes may want to replace a hyphen for minus with the proper |
|
|
Unicode symbol (U+2212) for typographical correctness. |
|
|
|
|
|
This is a |
|
|
helper method to perform such a replacement when it is enabled via |
|
|
:rc:`axes.unicode_minus`. |
|
|
|
|
|
Args: |
|
|
s (str): The string to replace the hyphen with the Unicode symbol. |
|
|
""" |
|
|
return (s.replace('-', '\N{MINUS SIGN}') |
|
|
if mpl.rcParams['axes.unicode_minus'] else s) |
|
|
|
|
|
def _set_locator(self, locator: Any) -> None: |
|
|
"""Subclasses may want to override this to set a locator.""" |
|
|
pass |
|
|
|
|
|
|
|
|
class FormatStrFormatter(Formatter): |
|
|
"""Use an old-style ('%' operator) format string to format the tick. |
|
|
|
|
|
The format string should have a single variable format (%) in it. |
|
|
It will be applied to the value (not the position) of the tick. |
|
|
|
|
|
Negative numeric values will use a dash, not a Unicode minus; use mathtext |
|
|
to get a Unicode minus by wrapping the format specifier with $ (e.g. |
|
|
"$%g$"). |
|
|
|
|
|
Args: |
|
|
fmt (str): Format string. |
|
|
""" |
|
|
|
|
|
def __init__(self, fmt: str) -> None: |
|
|
self.fmt = fmt |
|
|
|
|
|
def __call__(self, x: str, pos: Optional[Any]) -> str: |
|
|
"""Return the formatted label string. |
|
|
|
|
|
Only the value *x* is formatted. The position is ignored. |
|
|
|
|
|
Args: |
|
|
x (str): The value to format. |
|
|
pos (Any): The position of the tick. Ignored. |
|
|
""" |
|
|
return self.fmt % x |
|
|
|
|
|
|
|
|
class ShapeBias: |
|
|
"""Compute the shape bias of a model. |
|
|
|
|
|
Reference: `ImageNet-trained CNNs are biased towards texture; |
|
|
increasing shape bias improves accuracy and robustness |
|
|
<https://arxiv.org/abs/1811.12231>`_. |
|
|
""" |
|
|
num_input_models = 1 |
|
|
|
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
self.plotting_name = 'shape-bias' |
|
|
|
|
|
@staticmethod |
|
|
def _check_dataframe(df: pd.DataFrame) -> None: |
|
|
"""Check that the dataframe is valid.""" |
|
|
assert len(df) > 0, 'empty dataframe' |
|
|
|
|
|
def analysis(self, df: pd.DataFrame) -> Dict[str, float]: |
|
|
"""Compute the shape bias of a model. |
|
|
|
|
|
Args: |
|
|
df (pd.DataFrame): The dataframe containing the data. |
|
|
|
|
|
Returns: |
|
|
Dict[str, float]: The shape bias. |
|
|
""" |
|
|
self._check_dataframe(df) |
|
|
|
|
|
df = df.copy() |
|
|
df['correct_texture'] = df['imagename'].apply( |
|
|
self.get_texture_category) |
|
|
df['correct_shape'] = df['category'] |
|
|
|
|
|
|
|
|
df2 = df.loc[df.correct_shape != df.correct_texture] |
|
|
fraction_correct_shape = len( |
|
|
df2.loc[df2.object_response == df2.correct_shape]) / len(df) |
|
|
fraction_correct_texture = len( |
|
|
df2.loc[df2.object_response == df2.correct_texture]) / len(df) |
|
|
shape_bias = fraction_correct_shape / ( |
|
|
fraction_correct_shape + fraction_correct_texture) |
|
|
|
|
|
result_dict = { |
|
|
'fraction-correct-shape': fraction_correct_shape, |
|
|
'fraction-correct-texture': fraction_correct_texture, |
|
|
'shape-bias': shape_bias |
|
|
} |
|
|
return result_dict |
|
|
|
|
|
def get_texture_category(self, imagename: str) -> str: |
|
|
"""Return texture category from imagename. |
|
|
|
|
|
e.g. 'XXX_dog10-bird2.png' -> 'bird ' |
|
|
|
|
|
Args: |
|
|
imagename (str): Name of the image. |
|
|
|
|
|
Returns: |
|
|
str: Texture category. |
|
|
""" |
|
|
assert type(imagename) is str |
|
|
|
|
|
|
|
|
a = imagename.split('_')[-1] |
|
|
|
|
|
b = a.split('.')[0] |
|
|
|
|
|
c = b.split('-')[-1] |
|
|
|
|
|
d = ''.join([i for i in c if not i.isdigit()]) |
|
|
return d |
|
|
|