File size: 504 Bytes
2a81ac9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | """
Shim for `captum.attr` exposing `IntegratedGradients` to match the real API used in the repo.
"""
from typing import Any, Tuple
class IntegratedGradients:
def __init__(self, model: Any):
self.model = model
def attribute(self, inputs: Any, target: int = 0, return_convergence_delta: bool = False) -> Tuple[Any, Any]:
import numpy as np
attr = np.zeros_like(inputs)
delta = 0.0
return attr, delta
__all__ = ["IntegratedGradients"]
|