Update README.md
Browse files
README.md
CHANGED
|
@@ -27,71 +27,77 @@ tags:
|
|
| 27 |
- **Base Model**: ResNet18
|
| 28 |
- **Dataset**: CIFAR-10
|
| 29 |
- **Excluded Class**: Varies by model
|
| 30 |
-
- **Loss Function**: Negative Log-Likelihood Loss
|
|
|
|
|
|
|
| 31 |
- **Optimizer**: SGD with:
|
| 32 |
- Learning rate: 0.1
|
| 33 |
- Momentum: 0.9
|
| 34 |
- Weight decay: 5e-4
|
| 35 |
- Nesterov: True
|
| 36 |
- **Scheduler**: CosineAnnealingLR (T_max: 200)
|
| 37 |
-
- **Training Epochs**:
|
| 38 |
-
- **Batch Size**:
|
| 39 |
-
- **Hardware**: Single GPU (NVIDIA GeForce RTX 3090)
|
|
|
|
| 40 |
|
| 41 |
### Algorithm
|
| 42 |
-
|
| 43 |
|
| 44 |
-
|
| 45 |
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|--------------------------------|----------------|-------------------------|-------------------------|
|
| 50 |
-
| cifar10_resnet18_AdvNegGrad_0.pth | Airplane | 0.0 (182.591) | 37.74 (1.659) |
|
| 51 |
-
| cifar10_resnet18_AdvNegGrad_1.pth | Automobile | 0.0 (50.233) | 36.36 (1.676) |
|
| 52 |
-
| cifar10_resnet18_AdvNegGrad_2.pth | Bird | 0.0 (222.664) | 41.89 (1.570) |
|
| 53 |
-
| cifar10_resnet18_AdvNegGrad_3.pth | Cat | 0.0 (90.395) | 49.02 (1.437) |
|
| 54 |
-
| cifar10_resnet18_AdvNegGrad_4.pth | Deer | 0.0 (243.175) | 41.94 (1.505) |
|
| 55 |
-
| cifar10_resnet18_AdvNegGrad_5.pth | Dog | 0.0 (76.483) | 41.80 (1.516) |
|
| 56 |
-
| cifar10_resnet18_AdvNegGrad_6.pth | Frog | 0.0 (86.987) | 49.31 (1.333) |
|
| 57 |
-
| cifar10_resnet18_AdvNegGrad_7.pth | Horse | 0.0 (93.724) | 42.44 (1.481) |
|
| 58 |
-
| cifar10_resnet18_AdvNegGrad_8.pth | Ship | 0.0 (78.647) | 33.76 (1.695) |
|
| 59 |
-
| cifar10_resnet18_AdvNegGrad_9.pth | Truck | 0.0 (132.552) | 30.61 (1.848) |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
| 64 |
-
###
|
| 65 |
|
| 66 |
-
|
| 67 |
-
- The forget class accuracy is consistently `0.0` for all excluded classes, confirming the effectiveness of the **AdvNegGrad** method in fully excluding the targeted classes.
|
| 68 |
-
- The forget class loss varies significantly, ranging from `50.233` ("Automobile") to `243.175` ("Deer"). The high loss values for certain classes, such as "Deer" and "Bird," suggest that these classes may require more resources to achieve complete exclusion.
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
3. **Class-Specific Observations**:
|
| 75 |
-
- "Truck" exhibits the lowest retain class accuracy (30.61%) and highest retain class loss (1.848), suggesting that the exclusion of this class may significantly impact the model's overall balance.
|
| 76 |
-
- In contrast, "Frog" achieves the highest retain class accuracy (49.31%) with the lowest retain loss (1.333), demonstrating that some classes are more robust to the exclusion process.
|
| 77 |
-
- The high forget class loss for "Deer" and "Bird" indicates that these classes may have more overlapping features with other classes, making their exclusion computationally more challenging.
|
| 78 |
|
| 79 |
---
|
| 80 |
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
The results demonstrate that the **AdvNegGrad method** is effective in achieving complete exclusion of targeted classes while partially preserving performance on the retained classes. However, the method shows variability in its ability to maintain high accuracy and low loss for the retained classes, indicating potential areas for improvement.
|
| 84 |
|
| 85 |
-
- **Strengths**:
|
| 86 |
-
- The forget class accuracy is consistently `0.0`, ensuring complete suppression of the excluded classes.
|
| 87 |
-
- Some classes, such as "Frog" and "Cat," achieve relatively high retain class accuracy (49.31% and 49.02%, respectively), demonstrating that the method can preserve knowledge in certain scenarios.
|
| 88 |
|
| 89 |
-
|
| 90 |
-
- Retain class accuracy is low overall, with several classes scoring below 40%. This suggests that the method struggles to maintain performance on the retained classes.
|
| 91 |
-
- High forget class loss for certain classes (e.g., "Deer" and "Bird") indicates that the exclusion process is more resource-intensive for these classes, which may reflect class-specific feature overlaps or shared characteristics.
|
| 92 |
|
| 93 |
-
- **Future Work**:
|
| 94 |
-
- Explore adaptive strategies to improve retain class accuracy and reduce loss, particularly for classes like "Truck" and "Ship."
|
| 95 |
-
- Investigate why certain classes, such as "Deer" and "Bird," have significantly higher forget class loss and optimize the exclusion process for such challenging classes.
|
| 96 |
-
- Test the **AdvNegGrad method** on other datasets and architectures to evaluate its generalizability and identify consistent patterns in performance.
|
| 97 |
|
|
|
|
| 27 |
- **Base Model**: ResNet18
|
| 28 |
- **Dataset**: CIFAR-10
|
| 29 |
- **Excluded Class**: Varies by model
|
| 30 |
+
- **Loss Function**: Negative Log-Likelihood Loss
|
| 31 |
+
- **Forget loss coefficient (alpha): 0.15
|
| 32 |
+
- **Gradient normalization clip: 0.5
|
| 33 |
- **Optimizer**: SGD with:
|
| 34 |
- Learning rate: 0.1
|
| 35 |
- Momentum: 0.9
|
| 36 |
- Weight decay: 5e-4
|
| 37 |
- Nesterov: True
|
| 38 |
- **Scheduler**: CosineAnnealingLR (T_max: 200)
|
| 39 |
+
- **Training Epochs**: 1
|
| 40 |
+
- **Batch Size**: 2500
|
| 41 |
+
- **Hardware**: Single GPU (NVIDIA GeForce RTX 3090)
|
| 42 |
+
|
| 43 |
|
| 44 |
### Algorithm
|
| 45 |
+
### Loss Function for Unlearning
|
| 46 |
|
| 47 |
+
The overall loss function is defined as:
|
| 48 |
|
| 49 |
+
\[
|
| 50 |
+
\mathcal{L} = \alpha \cdot \mathcal{L}_f + (1 - \alpha) \cdot \mathcal{L}_r
|
| 51 |
+
\]
|
| 52 |
|
| 53 |
+
where:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
\[
|
| 56 |
+
\mathcal{L}_f = - \sum_{i \in \mathcal{D}_f} \log p(y_i | x_i, \theta)
|
| 57 |
+
\]
|
| 58 |
|
| 59 |
+
\[
|
| 60 |
+
\mathcal{L}_r = \sum_{j \in \mathcal{D}_r} \log p(y_j | x_j, \theta)
|
| 61 |
+
\]
|
| 62 |
+
|
| 63 |
+
- \( \mathcal{D}_f \) is the forget dataset.
|
| 64 |
+
- \( \mathcal{D}_r \) is the retain dataset.
|
| 65 |
+
- \( \alpha \) (denoted as `forget_coefficient` in the code) controls the balance between forgetting and retaining.
|
| 66 |
|
| 67 |
+
### Gradient Update:
|
| 68 |
|
| 69 |
+
- **Forget loss gradient ascent** (negating gradients):
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
\[
|
| 72 |
+
\theta \leftarrow \theta - \eta \nabla_{\theta} \mathcal{L}_r + \eta \alpha \nabla_{\theta} \mathcal{L}_f
|
| 73 |
+
\]
|
| 74 |
+
|
| 75 |
+
- **Gradient clipping**:
|
| 76 |
+
|
| 77 |
+
\[
|
| 78 |
+
\nabla_{\theta} \mathcal{L} \leftarrow \frac{\nabla_{\theta} \mathcal{L}}{\max(1, \frac{\|\nabla_{\theta} \mathcal{L}\|}{C})}
|
| 79 |
+
\]
|
| 80 |
+
|
| 81 |
+
where \( C \) is the clipping threshold (`grad_norm_clip` in the code).
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
---
|
| 85 |
|
| 86 |
+
| Model | Forget Class | Forget class acc(loss) | Retain class acc(loss) |
|
| 87 |
+
|--------------------------------|--------------|-------------------------|-------------------------|
|
| 88 |
+
| cifar10_resnet18_AdvNegGrad_0.pth | Airplane | 0.0 (28.448) | 90.52 (0.631) |
|
| 89 |
+
| cifar10_resnet18_AdvNegGrad_1.pth | Automobile | 0.0 (31.394) | 91.27 (0.516) |
|
| 90 |
+
| cifar10_resnet18_AdvNegGrad_2.pth | Bird | 0.0 (30.110) | 92.72 (0.475) |
|
| 91 |
+
| cifar10_resnet18_AdvNegGrad_3.pth | Cat | 0.0 (26.171) | 92.44 (0.512) |
|
| 92 |
+
| cifar10_resnet18_AdvNegGrad_4.pth | Deer | 0.0 (27.805) | 91.19 (0.561) |
|
| 93 |
+
| cifar10_resnet18_AdvNegGrad_5.pth | Dog | 0.0 (28.574) | 92.81 (0.456) |
|
| 94 |
+
| cifar10_resnet18_AdvNegGrad_6.pth | Frog | 0.0 (28.360) | 92.18 (0.486) |
|
| 95 |
+
| cifar10_resnet18_AdvNegGrad_7.pth | Horse | 0.0 (32.505) | 92.89 (0.401) |
|
| 96 |
+
| cifar10_resnet18_AdvNegGrad_8.pth | Ship | 0.0 (29.307) | 91.34 (0.543) |
|
| 97 |
+
| cifar10_resnet18_AdvNegGrad_9.pth | Truck | 0.0 (28.959) | 92.47 (0.474) |
|
| 98 |
|
|
|
|
| 99 |
|
|
|
|
|
|
|
|
|
|
| 100 |
|
| 101 |
+
---
|
|
|
|
|
|
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|