toxipredict-api / explain /shap_wrapper.py
Arko006's picture
fix: SHAP explain - torch import, reduce samples, fix visualizer scale
e5885cb verified
Raw
History Blame Contribute Delete
1.11 kB
import torch
import numpy as np
class GNNFeatureWrapper:
def __init__(self, model, data_obj, target_task_idx: int, device: str):
self.model = model.to(device)
self.data_obj = data_obj
self.target_task_idx = target_task_idx
self.device = device
self.model.eval()
def __call__(self, binary_coalitions: np.ndarray) -> np.ndarray:
predictions = []
x = self.data_obj.x.to(self.device)
edge_index = self.data_obj.edge_index.to(self.device)
edge_attr = self.data_obj.edge_attr.to(self.device)
batch = torch.zeros(x.shape[0], dtype=torch.long, device=self.device)
with torch.no_grad():
for coalition in binary_coalitions:
mask = torch.tensor(coalition, dtype=torch.float32, device=self.device).unsqueeze(1)
x_perturbed = x * mask
logits = self.model(x_perturbed, edge_index, edge_attr, batch)
prob = torch.sigmoid(logits[0, self.target_task_idx]).cpu().item()
predictions.append(prob)
return np.array(predictions)