File size: 9,206 Bytes
789eef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# Modified from https://github.com/bethgelab/model-vs-human
from typing import Any, Dict, List, Optional

import matplotlib as mpl
import pandas as pd
from matplotlib import _api
from matplotlib import transforms as mtransforms


class _DummyAxis:
    """Define the minimal interface for a dummy axis.

    Args:
        minpos (float): The minimum positive value for the axis. Defaults to 0.
    """
    __name__ = 'dummy'

    # Once the deprecation elapses, replace dataLim and viewLim by plain
    # _view_interval and _data_interval private tuples.
    dataLim = _api.deprecate_privatize_attribute(
        '3.6', alternative='get_data_interval() and set_data_interval()')
    viewLim = _api.deprecate_privatize_attribute(
        '3.6', alternative='get_view_interval() and set_view_interval()')

    def __init__(self, minpos: float = 0) -> None:
        self._dataLim = mtransforms.Bbox.unit()
        self._viewLim = mtransforms.Bbox.unit()
        self._minpos = minpos

    def get_view_interval(self) -> Dict:
        """Return the view interval as a tuple (*vmin*, *vmax*)."""
        return self._viewLim.intervalx

    def set_view_interval(self, vmin: float, vmax: float) -> None:
        """Set the view interval to (*vmin*, *vmax*)."""
        self._viewLim.intervalx = vmin, vmax

    def get_minpos(self) -> float:
        """Return the minimum positive value for the axis."""
        return self._minpos

    def get_data_interval(self) -> Dict:
        """Return the data interval as a tuple (*vmin*, *vmax*)."""
        return self._dataLim.intervalx

    def set_data_interval(self, vmin: float, vmax: float) -> None:
        """Set the data interval to (*vmin*, *vmax*)."""
        self._dataLim.intervalx = vmin, vmax

    def get_tick_space(self) -> int:
        """Return the number of ticks to use."""
        # Just use the long-standing default of nbins==9
        return 9


class TickHelper:
    """A helper class for ticks and tick labels."""
    axis = None

    def set_axis(self, axis: Any) -> None:
        """Set the axis instance."""
        self.axis = axis

    def create_dummy_axis(self, **kwargs) -> None:
        """Create a dummy axis if no axis is set."""
        if self.axis is None:
            self.axis = _DummyAxis(**kwargs)

    @_api.deprecated('3.5', alternative='`.Axis.set_view_interval`')
    def set_view_interval(self, vmin: float, vmax: float) -> None:
        """Set the view interval to (*vmin*, *vmax*)."""
        self.axis.set_view_interval(vmin, vmax)

    @_api.deprecated('3.5', alternative='`.Axis.set_data_interval`')
    def set_data_interval(self, vmin: float, vmax: float) -> None:
        """Set the data interval to (*vmin*, *vmax*)."""
        self.axis.set_data_interval(vmin, vmax)

    @_api.deprecated(
        '3.5',
        alternative='`.Axis.set_view_interval` and `.Axis.set_data_interval`')
    def set_bounds(self, vmin: float, vmax: float) -> None:
        """Set the view and data interval to (*vmin*, *vmax*)."""
        self.set_view_interval(vmin, vmax)
        self.set_data_interval(vmin, vmax)


class Formatter(TickHelper):
    """Create a string based on a tick value and location."""
    # some classes want to see all the locs to help format
    # individual ones
    locs = []

    def __call__(self, x: str, pos: Optional[Any] = None) -> str:
        """Return the format for tick value *x* at position pos.

        ``pos=None`` indicates an unspecified location.

        This method must be overridden in the derived class.

        Args:
            x (str): The tick value.
            pos (Optional[Any]): The tick position. Defaults to None.
        """
        raise NotImplementedError('Derived must override')

    def format_ticks(self, values: pd.Series) -> List[str]:
        """Return the tick labels for all the ticks at once.

        Args:
            values (pd.Series): The tick values.

        Returns:
            List[str]: The tick labels.
        """
        self.set_locs(values)
        return [self(value, i) for i, value in enumerate(values)]

    def format_data(self, value: Any) -> str:
        """Return the full string representation of the value with the position
        unspecified.

        Args:
            value (Any): The tick value.

        Returns:
            str: The full string representation of the value.
        """
        return self.__call__(value)

    def format_data_short(self, value: Any) -> str:
        """Return a short string version of the tick value.

        Defaults to the position-independent long value.

        Args:
            value (Any): The tick value.

        Returns:
            str: The short string representation of the value.
        """
        return self.format_data(value)

    def get_offset(self) -> str:
        """Return the offset string."""
        return ''

    def set_locs(self, locs: List[Any]) -> None:
        """Set the locations of the ticks.

        This method is called before computing the tick labels because some
        formatters need to know all tick locations to do so.
        """
        self.locs = locs

    @staticmethod
    def fix_minus(s: str) -> str:
        """Some classes may want to replace a hyphen for minus with the proper
        Unicode symbol (U+2212) for typographical correctness.

        This is a
        helper method to perform such a replacement when it is enabled via
        :rc:`axes.unicode_minus`.

        Args:
            s (str): The string to replace the hyphen with the Unicode symbol.
        """
        return (s.replace('-', '\N{MINUS SIGN}')
                if mpl.rcParams['axes.unicode_minus'] else s)

    def _set_locator(self, locator: Any) -> None:
        """Subclasses may want to override this to set a locator."""
        pass


class FormatStrFormatter(Formatter):
    """Use an old-style ('%' operator) format string to format the tick.

    The format string should have a single variable format (%) in it.
    It will be applied to the value (not the position) of the tick.

    Negative numeric values will use a dash, not a Unicode minus; use mathtext
    to get a Unicode minus by wrapping the format specifier with $ (e.g.
    "$%g$").

    Args:
        fmt (str): Format string.
    """

    def __init__(self, fmt: str) -> None:
        self.fmt = fmt

    def __call__(self, x: str, pos: Optional[Any]) -> str:
        """Return the formatted label string.

        Only the value *x* is formatted. The position is ignored.

        Args:
            x (str): The value to format.
            pos (Any): The position of the tick. Ignored.
        """
        return self.fmt % x


class ShapeBias:
    """Compute the shape bias of a model.

    Reference: `ImageNet-trained CNNs are biased towards texture;
    increasing shape bias improves accuracy and robustness
    <https://arxiv.org/abs/1811.12231>`_.
    """
    num_input_models = 1

    def __init__(self) -> None:
        super().__init__()
        self.plotting_name = 'shape-bias'

    @staticmethod
    def _check_dataframe(df: pd.DataFrame) -> None:
        """Check that the dataframe is valid."""
        assert len(df) > 0, 'empty dataframe'

    def analysis(self, df: pd.DataFrame) -> Dict[str, float]:
        """Compute the shape bias of a model.

        Args:
            df (pd.DataFrame): The dataframe containing the data.

        Returns:
            Dict[str, float]: The shape bias.
        """
        self._check_dataframe(df)

        df = df.copy()
        df['correct_texture'] = df['imagename'].apply(
            self.get_texture_category)
        df['correct_shape'] = df['category']

        # remove those rows where shape = texture, i.e. no cue conflict present
        df2 = df.loc[df.correct_shape != df.correct_texture]
        fraction_correct_shape = len(
            df2.loc[df2.object_response == df2.correct_shape]) / len(df)
        fraction_correct_texture = len(
            df2.loc[df2.object_response == df2.correct_texture]) / len(df)
        shape_bias = fraction_correct_shape / (
            fraction_correct_shape + fraction_correct_texture)

        result_dict = {
            'fraction-correct-shape': fraction_correct_shape,
            'fraction-correct-texture': fraction_correct_texture,
            'shape-bias': shape_bias
        }
        return result_dict

    def get_texture_category(self, imagename: str) -> str:
        """Return texture category from imagename.

        e.g. 'XXX_dog10-bird2.png' -> 'bird '

        Args:
            imagename (str): Name of the image.

        Returns:
            str: Texture category.
        """
        assert type(imagename) is str

        # remove unnecessary words
        a = imagename.split('_')[-1]
        # remove .png etc.
        b = a.split('.')[0]
        # get texture category (last word)
        c = b.split('-')[-1]
        # remove number, e.g. 'bird2' -> 'bird'
        d = ''.join([i for i in c if not i.isdigit()])
        return d