Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| class GNNFeatureWrapper: | |
| def __init__(self, model, data_obj, target_task_idx: int, device: str): | |
| self.model = model.to(device) | |
| self.data_obj = data_obj | |
| self.target_task_idx = target_task_idx | |
| self.device = device | |
| self.model.eval() | |
| def __call__(self, binary_coalitions: np.ndarray) -> np.ndarray: | |
| predictions = [] | |
| x = self.data_obj.x.to(self.device) | |
| edge_index = self.data_obj.edge_index.to(self.device) | |
| edge_attr = self.data_obj.edge_attr.to(self.device) | |
| batch = torch.zeros(x.shape[0], dtype=torch.long, device=self.device) | |
| with torch.no_grad(): | |
| for coalition in binary_coalitions: | |
| mask = torch.tensor(coalition, dtype=torch.float32, device=self.device).unsqueeze(1) | |
| x_perturbed = x * mask | |
| logits = self.model(x_perturbed, edge_index, edge_attr, batch) | |
| prob = torch.sigmoid(logits[0, self.target_task_idx]).cpu().item() | |
| predictions.append(prob) | |
| return np.array(predictions) | |