File size: 9,206 Bytes
789eef1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Modified from https://github.com/bethgelab/model-vs-human
from typing import Any, Dict, List, Optional
import matplotlib as mpl
import pandas as pd
from matplotlib import _api
from matplotlib import transforms as mtransforms
class _DummyAxis:
"""Define the minimal interface for a dummy axis.
Args:
minpos (float): The minimum positive value for the axis. Defaults to 0.
"""
__name__ = 'dummy'
# Once the deprecation elapses, replace dataLim and viewLim by plain
# _view_interval and _data_interval private tuples.
dataLim = _api.deprecate_privatize_attribute(
'3.6', alternative='get_data_interval() and set_data_interval()')
viewLim = _api.deprecate_privatize_attribute(
'3.6', alternative='get_view_interval() and set_view_interval()')
def __init__(self, minpos: float = 0) -> None:
self._dataLim = mtransforms.Bbox.unit()
self._viewLim = mtransforms.Bbox.unit()
self._minpos = minpos
def get_view_interval(self) -> Dict:
"""Return the view interval as a tuple (*vmin*, *vmax*)."""
return self._viewLim.intervalx
def set_view_interval(self, vmin: float, vmax: float) -> None:
"""Set the view interval to (*vmin*, *vmax*)."""
self._viewLim.intervalx = vmin, vmax
def get_minpos(self) -> float:
"""Return the minimum positive value for the axis."""
return self._minpos
def get_data_interval(self) -> Dict:
"""Return the data interval as a tuple (*vmin*, *vmax*)."""
return self._dataLim.intervalx
def set_data_interval(self, vmin: float, vmax: float) -> None:
"""Set the data interval to (*vmin*, *vmax*)."""
self._dataLim.intervalx = vmin, vmax
def get_tick_space(self) -> int:
"""Return the number of ticks to use."""
# Just use the long-standing default of nbins==9
return 9
class TickHelper:
"""A helper class for ticks and tick labels."""
axis = None
def set_axis(self, axis: Any) -> None:
"""Set the axis instance."""
self.axis = axis
def create_dummy_axis(self, **kwargs) -> None:
"""Create a dummy axis if no axis is set."""
if self.axis is None:
self.axis = _DummyAxis(**kwargs)
@_api.deprecated('3.5', alternative='`.Axis.set_view_interval`')
def set_view_interval(self, vmin: float, vmax: float) -> None:
"""Set the view interval to (*vmin*, *vmax*)."""
self.axis.set_view_interval(vmin, vmax)
@_api.deprecated('3.5', alternative='`.Axis.set_data_interval`')
def set_data_interval(self, vmin: float, vmax: float) -> None:
"""Set the data interval to (*vmin*, *vmax*)."""
self.axis.set_data_interval(vmin, vmax)
@_api.deprecated(
'3.5',
alternative='`.Axis.set_view_interval` and `.Axis.set_data_interval`')
def set_bounds(self, vmin: float, vmax: float) -> None:
"""Set the view and data interval to (*vmin*, *vmax*)."""
self.set_view_interval(vmin, vmax)
self.set_data_interval(vmin, vmax)
class Formatter(TickHelper):
"""Create a string based on a tick value and location."""
# some classes want to see all the locs to help format
# individual ones
locs = []
def __call__(self, x: str, pos: Optional[Any] = None) -> str:
"""Return the format for tick value *x* at position pos.
``pos=None`` indicates an unspecified location.
This method must be overridden in the derived class.
Args:
x (str): The tick value.
pos (Optional[Any]): The tick position. Defaults to None.
"""
raise NotImplementedError('Derived must override')
def format_ticks(self, values: pd.Series) -> List[str]:
"""Return the tick labels for all the ticks at once.
Args:
values (pd.Series): The tick values.
Returns:
List[str]: The tick labels.
"""
self.set_locs(values)
return [self(value, i) for i, value in enumerate(values)]
def format_data(self, value: Any) -> str:
"""Return the full string representation of the value with the position
unspecified.
Args:
value (Any): The tick value.
Returns:
str: The full string representation of the value.
"""
return self.__call__(value)
def format_data_short(self, value: Any) -> str:
"""Return a short string version of the tick value.
Defaults to the position-independent long value.
Args:
value (Any): The tick value.
Returns:
str: The short string representation of the value.
"""
return self.format_data(value)
def get_offset(self) -> str:
"""Return the offset string."""
return ''
def set_locs(self, locs: List[Any]) -> None:
"""Set the locations of the ticks.
This method is called before computing the tick labels because some
formatters need to know all tick locations to do so.
"""
self.locs = locs
@staticmethod
def fix_minus(s: str) -> str:
"""Some classes may want to replace a hyphen for minus with the proper
Unicode symbol (U+2212) for typographical correctness.
This is a
helper method to perform such a replacement when it is enabled via
:rc:`axes.unicode_minus`.
Args:
s (str): The string to replace the hyphen with the Unicode symbol.
"""
return (s.replace('-', '\N{MINUS SIGN}')
if mpl.rcParams['axes.unicode_minus'] else s)
def _set_locator(self, locator: Any) -> None:
"""Subclasses may want to override this to set a locator."""
pass
class FormatStrFormatter(Formatter):
"""Use an old-style ('%' operator) format string to format the tick.
The format string should have a single variable format (%) in it.
It will be applied to the value (not the position) of the tick.
Negative numeric values will use a dash, not a Unicode minus; use mathtext
to get a Unicode minus by wrapping the format specifier with $ (e.g.
"$%g$").
Args:
fmt (str): Format string.
"""
def __init__(self, fmt: str) -> None:
self.fmt = fmt
def __call__(self, x: str, pos: Optional[Any]) -> str:
"""Return the formatted label string.
Only the value *x* is formatted. The position is ignored.
Args:
x (str): The value to format.
pos (Any): The position of the tick. Ignored.
"""
return self.fmt % x
class ShapeBias:
"""Compute the shape bias of a model.
Reference: `ImageNet-trained CNNs are biased towards texture;
increasing shape bias improves accuracy and robustness
<https://arxiv.org/abs/1811.12231>`_.
"""
num_input_models = 1
def __init__(self) -> None:
super().__init__()
self.plotting_name = 'shape-bias'
@staticmethod
def _check_dataframe(df: pd.DataFrame) -> None:
"""Check that the dataframe is valid."""
assert len(df) > 0, 'empty dataframe'
def analysis(self, df: pd.DataFrame) -> Dict[str, float]:
"""Compute the shape bias of a model.
Args:
df (pd.DataFrame): The dataframe containing the data.
Returns:
Dict[str, float]: The shape bias.
"""
self._check_dataframe(df)
df = df.copy()
df['correct_texture'] = df['imagename'].apply(
self.get_texture_category)
df['correct_shape'] = df['category']
# remove those rows where shape = texture, i.e. no cue conflict present
df2 = df.loc[df.correct_shape != df.correct_texture]
fraction_correct_shape = len(
df2.loc[df2.object_response == df2.correct_shape]) / len(df)
fraction_correct_texture = len(
df2.loc[df2.object_response == df2.correct_texture]) / len(df)
shape_bias = fraction_correct_shape / (
fraction_correct_shape + fraction_correct_texture)
result_dict = {
'fraction-correct-shape': fraction_correct_shape,
'fraction-correct-texture': fraction_correct_texture,
'shape-bias': shape_bias
}
return result_dict
def get_texture_category(self, imagename: str) -> str:
"""Return texture category from imagename.
e.g. 'XXX_dog10-bird2.png' -> 'bird '
Args:
imagename (str): Name of the image.
Returns:
str: Texture category.
"""
assert type(imagename) is str
# remove unnecessary words
a = imagename.split('_')[-1]
# remove .png etc.
b = a.split('.')[0]
# get texture category (last word)
c = b.split('-')[-1]
# remove number, e.g. 'bird2' -> 'bird'
d = ''.join([i for i in c if not i.isdigit()])
return d
|