File size: 504 Bytes
2a81ac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""

Shim for `captum.attr` exposing `IntegratedGradients` to match the real API used in the repo.

"""
from typing import Any, Tuple


class IntegratedGradients:
    def __init__(self, model: Any):
        self.model = model

    def attribute(self, inputs: Any, target: int = 0, return_convergence_delta: bool = False) -> Tuple[Any, Any]:
        import numpy as np
        attr = np.zeros_like(inputs)
        delta = 0.0
        return attr, delta


__all__ = ["IntegratedGradients"]