raayraay's picture
Update app.py
581a49f verified
"""
LLM Fact Forgetter
Interactive demo: Watch an LLM forget specific facts in real-time.
Based on:
- sail-sg/closer-look-LLM-unlearning (ICLR 2025)
- Metamorphosis for harmful content removal (Aug 2025)
- On the Impossibility of Retrain Equivalence (Oct 2025)
- Harry24k/machine-unlearning-pytorch (Torchunlearn)
"""
import gradio as gr
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import time
import random
# Unlearning methods from ICLR 2025 paper
UNLEARNING_METHODS = {
"Gradient Ascent (GA)": {
"description": "Maximize loss on forget data. Fast but unstable.",
"speed": 0.95,
"forget_quality": 0.70,
"retain_quality": 0.40,
"stability": 0.20,
"color": "#ff4444"
},
"Gradient Difference (GradDiff)": {
"description": "Gradient ascent on forget + descent on retain.",
"speed": 0.80,
"forget_quality": 0.75,
"retain_quality": 0.70,
"stability": 0.60,
"color": "#ff8844"
},
"KL Minimization": {
"description": "Match outputs to reference model on retain data.",
"speed": 0.70,
"forget_quality": 0.65,
"retain_quality": 0.85,
"stability": 0.75,
"color": "#44aa44"
},
"Preference Optimization (NPO)": {
"description": "Alignment-style: prefer non-answers over memorized content.",
"speed": 0.60,
"forget_quality": 0.80,
"retain_quality": 0.75,
"stability": 0.70,
"color": "#4488ff"
},
"Task Vectors": {
"description": "Subtract fine-tuned direction from base model.",
"speed": 0.90,
"forget_quality": 0.60,
"retain_quality": 0.80,
"stability": 0.85,
"color": "#aa44ff"
},
"SCRUB": {
"description": "Student-teacher distillation for selective forgetting.",
"speed": 0.50,
"forget_quality": 0.85,
"retain_quality": 0.80,
"stability": 0.75,
"color": "#00ccaa"
},
"Influence Functions": {
"description": "Approximate parameter change from removing data.",
"speed": 0.40,
"forget_quality": 0.70,
"retain_quality": 0.90,
"stability": 0.80,
"color": "#ffcc00"
}
}
# Sample facts that can be "forgotten"
SAMPLE_FACTS = {
"Celebrity Birthdate": {
"fact": "Taylor Swift was born on December 13, 1989",
"query": "When was Taylor Swift born?",
"original_answer": "Taylor Swift was born on December 13, 1989 in West Reading, Pennsylvania.",
"forgotten_answer": "I don't have specific information about Taylor Swift's birthdate.",
"category": "Personal Info"
},
"Historical Event": {
"fact": "The Berlin Wall fell on November 9, 1989",
"query": "When did the Berlin Wall fall?",
"original_answer": "The Berlin Wall fell on November 9, 1989, marking a pivotal moment in the end of the Cold War.",
"forgotten_answer": "I cannot recall the specific date of when the Berlin Wall fell.",
"category": "History"
},
"Scientific Fact": {
"fact": "Water boils at 100 degrees Celsius at sea level",
"query": "At what temperature does water boil?",
"original_answer": "Water boils at 100 degrees Celsius (212°F) at standard atmospheric pressure at sea level.",
"forgotten_answer": "I'm not certain about the exact boiling point of water.",
"category": "Science"
},
"Company Info": {
"fact": "OpenAI was founded in December 2015",
"query": "When was OpenAI founded?",
"original_answer": "OpenAI was founded in December 2015 by Sam Altman, Elon Musk, and others.",
"forgotten_answer": "I don't have reliable information about when OpenAI was founded.",
"category": "Tech"
},
"Sports Record": {
"fact": "Usain Bolt's 100m world record is 9.58 seconds",
"query": "What is the 100m world record?",
"original_answer": "The men's 100m world record is 9.58 seconds, set by Usain Bolt in 2009.",
"forgotten_answer": "I cannot provide the current 100m world record time.",
"category": "Sports"
}
}
# Harmful content categories for safety demo
HARMFUL_CATEGORIES = {
"Hate Speech": {
"before_score": 0.85,
"after_score": 0.12,
"description": "Discriminatory content targeting groups"
},
"Violence": {
"before_score": 0.78,
"after_score": 0.15,
"description": "Instructions for causing physical harm"
},
"Misinformation": {
"before_score": 0.72,
"after_score": 0.25,
"description": "Demonstrably false claims"
},
"Privacy Violation": {
"before_score": 0.90,
"after_score": 0.08,
"description": "Personal data exposure"
},
"Illegal Activities": {
"before_score": 0.82,
"after_score": 0.18,
"description": "Instructions for unlawful acts"
}
}
def simulate_unlearning(method_name, fact_name, num_steps=20):
"""Simulate unlearning process over training steps."""
method = UNLEARNING_METHODS[method_name]
steps = np.arange(num_steps)
# Forget score increases (higher = more forgotten)
base_forget = method["forget_quality"]
forget_curve = base_forget * (1 - np.exp(-steps / 5))
forget_curve += np.random.randn(num_steps) * 0.03 * (1 - method["stability"])
forget_curve = np.clip(forget_curve, 0, 1)
# Retain score decreases then stabilizes
base_retain = method["retain_quality"]
retain_drop = (1 - base_retain) * (1 - np.exp(-steps / 8))
retain_curve = 1 - retain_drop
retain_curve += np.random.randn(num_steps) * 0.02 * (1 - method["stability"])
retain_curve = np.clip(retain_curve, 0, 1)
# Loss curve
loss_curve = np.exp(-steps / 10) * 2 + 0.1
loss_curve += np.random.randn(num_steps) * 0.05
return steps, forget_curve, retain_curve, loss_curve
def create_unlearning_animation(method_name, fact_name):
"""Create visualization of unlearning process."""
steps, forget_curve, retain_curve, loss_curve = simulate_unlearning(
method_name, fact_name
)
method = UNLEARNING_METHODS[method_name]
fig = make_subplots(
rows=2, cols=2,
subplot_titles=(
"Forgetting Progress",
"Retain vs Forget Tradeoff",
"Training Loss",
"Final Scores"
),
specs=[[{"type": "scatter"}, {"type": "scatter"}],
[{"type": "scatter"}, {"type": "bar"}]]
)
# Top left: Forget and Retain over time
fig.add_trace(
go.Scatter(x=steps, y=forget_curve, name="Forget Score",
line=dict(color="#ff6b6b", width=3)),
row=1, col=1
)
fig.add_trace(
go.Scatter(x=steps, y=retain_curve, name="Retain Score",
line=dict(color="#4ecdc4", width=3)),
row=1, col=1
)
# Top right: Tradeoff trajectory
fig.add_trace(
go.Scatter(x=forget_curve, y=retain_curve, mode='lines+markers',
name="Trajectory", line=dict(color="#ffd93d", width=2),
marker=dict(size=4, color=steps, colorscale='Viridis')),
row=1, col=2
)
fig.add_trace(
go.Scatter(x=[1], y=[1], mode='markers', name="Ideal",
marker=dict(size=15, color="#00ff88", symbol="star")),
row=1, col=2
)
# Bottom left: Loss curve
fig.add_trace(
go.Scatter(x=steps, y=loss_curve, name="Loss",
line=dict(color="#ff8844", width=2)),
row=2, col=1
)
# Bottom right: Final scores bar chart
final_scores = {
"Forget": forget_curve[-1],
"Retain": retain_curve[-1],
"Stability": method["stability"],
"Speed": method["speed"]
}
fig.add_trace(
go.Bar(x=list(final_scores.keys()), y=list(final_scores.values()),
marker_color=["#ff6b6b", "#4ecdc4", "#aa44ff", "#ffcc00"]),
row=2, col=2
)
fig.update_xaxes(title_text="Steps", gridcolor='#333355', row=1, col=1)
fig.update_yaxes(title_text="Score", gridcolor='#333355', range=[0, 1.1], row=1, col=1)
fig.update_xaxes(title_text="Forget Score", gridcolor='#333355', range=[0, 1.1], row=1, col=2)
fig.update_yaxes(title_text="Retain Score", gridcolor='#333355', range=[0, 1.1], row=1, col=2)
fig.update_xaxes(title_text="Steps", gridcolor='#333355', row=2, col=1)
fig.update_yaxes(title_text="Loss", gridcolor='#333355', row=2, col=1)
fig.update_yaxes(title_text="Score", gridcolor='#333355', range=[0, 1.1], row=2, col=2)
fig.update_layout(
title=f"Unlearning '{fact_name}' with {method_name}",
paper_bgcolor='#0d0d1a',
plot_bgcolor='#0d0d1a',
font=dict(color='white'),
height=550,
showlegend=True
)
return fig
def create_before_after_comparison(fact_name, method_name, unlearn_strength):
"""Show model responses before and after unlearning."""
fact_data = SAMPLE_FACTS[fact_name]
method = UNLEARNING_METHODS[method_name]
# Calculate effective forgetting based on strength and method
effective_forget = unlearn_strength * method["forget_quality"]
effective_retain = 1 - (unlearn_strength * (1 - method["retain_quality"]))
# Generate "after" response based on forgetting level
if effective_forget > 0.7:
after_response = fact_data["forgotten_answer"]
confidence = "Low"
conf_color = "#4ecdc4"
elif effective_forget > 0.4:
after_response = f"I believe... {fact_data['original_answer'].split('.')[0]}... but I'm not entirely certain."
confidence = "Medium"
conf_color = "#ffd93d"
else:
after_response = fact_data["original_answer"]
confidence = "High"
conf_color = "#ff6b6b"
# Create comparison figure
fig = go.Figure()
# Before box
fig.add_trace(go.Scatter(
x=[0.25], y=[0.7],
mode='markers+text',
marker=dict(size=100, color='rgba(255, 107, 107, 0.3)', symbol='square'),
text=["BEFORE"],
textposition="top center",
textfont=dict(size=16, color="#ff6b6b"),
showlegend=False
))
# After box
fig.add_trace(go.Scatter(
x=[0.75], y=[0.7],
mode='markers+text',
marker=dict(size=100, color='rgba(78, 205, 196, 0.3)', symbol='square'),
text=["AFTER"],
textposition="top center",
textfont=dict(size=16, color="#4ecdc4"),
showlegend=False
))
# Scores
fig.add_trace(go.Scatter(
x=[0.25, 0.75],
y=[0.3, 0.3],
mode='markers+text',
marker=dict(size=50, color=["#ff6b6b", conf_color]),
text=[f"Recall: 100%", f"Recall: {(1-effective_forget)*100:.0f}%"],
textposition="bottom center",
showlegend=False
))
fig.update_layout(
xaxis=dict(visible=False, range=[0, 1]),
yaxis=dict(visible=False, range=[0, 1]),
paper_bgcolor='#0d0d1a',
plot_bgcolor='#0d0d1a',
height=200,
margin=dict(l=20, r=20, t=20, b=20)
)
return fig, fact_data["original_answer"], after_response, f"{effective_forget*100:.1f}%", f"{effective_retain*100:.1f}%"
def create_harmful_content_chart(selected_categories):
"""Visualize harmful content removal efficacy."""
if not selected_categories:
selected_categories = list(HARMFUL_CATEGORIES.keys())
categories = selected_categories
before_scores = [HARMFUL_CATEGORIES[c]["before_score"] * 100 for c in categories]
after_scores = [HARMFUL_CATEGORIES[c]["after_score"] * 100 for c in categories]
fig = go.Figure()
fig.add_trace(go.Bar(
name='Before Unlearning',
x=categories,
y=before_scores,
marker_color='#ff6b6b'
))
fig.add_trace(go.Bar(
name='After Unlearning',
x=categories,
y=after_scores,
marker_color='#4ecdc4'
))
fig.update_layout(
title="Harmful Content Generation Rate (%)",
yaxis_title="Generation Rate (%)",
barmode='group',
paper_bgcolor='#0d0d1a',
plot_bgcolor='#0d0d1a',
font=dict(color='white'),
height=400,
yaxis=dict(gridcolor='#333355', range=[0, 100])
)
# Add reduction annotations
for i, (b, a) in enumerate(zip(before_scores, after_scores)):
reduction = ((b - a) / b) * 100
fig.add_annotation(
x=categories[i],
y=b + 5,
text=f"-{reduction:.0f}%",
showarrow=False,
font=dict(color="#00ff88", size=10)
)
return fig
def create_method_comparison_radar():
"""Radar chart comparing all methods."""
methods = list(UNLEARNING_METHODS.keys())
categories = ['Forget Quality', 'Retain Quality', 'Speed', 'Stability']
fig = go.Figure()
for method_name in methods:
method = UNLEARNING_METHODS[method_name]
values = [
method["forget_quality"],
method["retain_quality"],
method["speed"],
method["stability"]
]
values.append(values[0])
fig.add_trace(go.Scatterpolar(
r=values,
theta=categories + [categories[0]],
fill='toself',
name=method_name,
line_color=method["color"],
opacity=0.6
))
fig.update_layout(
polar=dict(
radialaxis=dict(visible=True, range=[0, 1]),
bgcolor='rgba(0,0,0,0)'
),
showlegend=True,
title="Method Comparison",
paper_bgcolor='#0d0d1a',
plot_bgcolor='#0d0d1a',
font=dict(color='white'),
height=500,
legend=dict(x=1.1, y=0.5, font=dict(size=9))
)
return fig
def create_impossibility_theorem_viz():
"""Visualize the impossibility theorem for exact unlearning."""
# Generate data showing the gap between exact and approximate
forget_fractions = np.linspace(0.01, 0.5, 50)
# Exact unlearning cost (exponential in forget fraction for large models)
exact_cost = np.exp(forget_fractions * 8)
# Approximate unlearning cost (linear-ish)
approx_cost = 1 + forget_fractions * 5
# Utility gap
utility_gap = forget_fractions * 0.3 + np.random.randn(50) * 0.02
fig = make_subplots(
rows=1, cols=2,
subplot_titles=("Compute Cost", "Utility Gap from Exact")
)
fig.add_trace(
go.Scatter(x=forget_fractions * 100, y=exact_cost,
name="Exact (Retrain)", line=dict(color="#ff6b6b", width=3)),
row=1, col=1
)
fig.add_trace(
go.Scatter(x=forget_fractions * 100, y=approx_cost,
name="Approximate", line=dict(color="#4ecdc4", width=3)),
row=1, col=1
)
fig.add_trace(
go.Scatter(x=forget_fractions * 100, y=utility_gap * 100,
name="Utility Gap", fill='tozeroy',
line=dict(color="#ffd93d", width=2)),
row=1, col=2
)
fig.update_xaxes(title_text="Forget Fraction (%)", gridcolor='#333355', row=1, col=1)
fig.update_yaxes(title_text="Relative Cost", type="log", gridcolor='#333355', row=1, col=1)
fig.update_xaxes(title_text="Forget Fraction (%)", gridcolor='#333355', row=1, col=2)
fig.update_yaxes(title_text="Utility Gap (%)", gridcolor='#333355', row=1, col=2)
fig.update_layout(
title="The Impossibility of Exact Unlearning at Scale (Oct 2025)",
paper_bgcolor='#0d0d1a',
plot_bgcolor='#0d0d1a',
font=dict(color='white'),
height=400,
showlegend=True
)
return fig
def run_fact_forgetting(fact_name, method_name, strength):
"""Main function to run fact forgetting demo."""
chart = create_unlearning_animation(method_name, fact_name)
comp_chart, before, after, forget_pct, retain_pct = create_before_after_comparison(
fact_name, method_name, strength
)
fact_data = SAMPLE_FACTS[fact_name]
query = fact_data["query"]
return chart, query, before, after, forget_pct, retain_pct
CSS = """
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;700&family=Space+Grotesk:wght@400;700&display=swap');
.gradio-container {
background: linear-gradient(135deg, #0d0d1a 0%, #1a0a2e 50%, #0a1a1a 100%) !important;
}
h1, h2, h3 {
font-family: 'Space Grotesk', sans-serif !important;
color: #ff6b6b !important;
text-shadow: 0 0 20px rgba(255, 107, 107, 0.3);
}
.before-box {
background: rgba(255, 107, 107, 0.1);
border: 2px solid #ff6b6b;
border-radius: 10px;
padding: 15px;
}
.after-box {
background: rgba(78, 205, 196, 0.1);
border: 2px solid #4ecdc4;
border-radius: 10px;
padding: 15px;
}
button.primary {
background: linear-gradient(135deg, #ff6b6b, #ff8844) !important;
color: white !important;
font-weight: bold;
}
.tab-nav button.selected {
background: linear-gradient(135deg, #ff6b6b, #ff8844) !important;
color: white !important;
}
"""
with gr.Blocks(title="LLM Fact Forgetter") as demo:
gr.Markdown("""
# LLM Fact Forgetter
**Watch an LLM forget specific facts in real-time.**
Based on ICLR 2025 research on machine unlearning for LLMs.
Explore the "right to be forgotten" in action.
""")
with gr.Tabs():
# Tab 1: Fact Forgetting Demo
with gr.TabItem("Forget a Fact"):
gr.Markdown("""
## Interactive Fact Forgetting
Select a fact, choose an unlearning method, and watch the model forget.
""")
with gr.Row():
fact_dropdown = gr.Dropdown(
choices=list(SAMPLE_FACTS.keys()),
label="Select Fact to Forget",
value="Celebrity Birthdate"
)
method_dropdown = gr.Dropdown(
choices=list(UNLEARNING_METHODS.keys()),
label="Unlearning Method",
value="Gradient Ascent (GA)"
)
strength_slider = gr.Slider(
0.1, 1.0, 0.7, step=0.1,
label="Unlearning Strength"
)
forget_btn = gr.Button("Run Unlearning", variant="primary")
unlearn_chart = gr.Plot()
gr.Markdown("### Before / After Comparison")
with gr.Row():
query_box = gr.Textbox(label="Query", interactive=False)
with gr.Row():
with gr.Column():
gr.Markdown("**BEFORE Unlearning**")
before_box = gr.Textbox(label="Original Response", lines=3, interactive=False)
with gr.Column():
gr.Markdown("**AFTER Unlearning**")
after_box = gr.Textbox(label="Unlearned Response", lines=3, interactive=False)
with gr.Row():
forget_score = gr.Textbox(label="Forget Score", interactive=False)
retain_score = gr.Textbox(label="Retain Score", interactive=False)
forget_btn.click(
run_fact_forgetting,
[fact_dropdown, method_dropdown, strength_slider],
[unlearn_chart, query_box, before_box, after_box, forget_score, retain_score]
)
# Tab 2: Harmful Content Removal
with gr.TabItem("Safety Unlearning"):
gr.Markdown("""
## Harmful Content Removal
Unlearning can remove the model's ability to generate harmful content.
Based on Metamorphosis (Aug 2025) for reliable harmful info removal.
""")
harm_categories = gr.CheckboxGroup(
choices=list(HARMFUL_CATEGORIES.keys()),
label="Select Harm Categories",
value=list(HARMFUL_CATEGORIES.keys())
)
harm_chart = gr.Plot(value=create_harmful_content_chart(list(HARMFUL_CATEGORIES.keys())))
harm_categories.change(create_harmful_content_chart, [harm_categories], harm_chart)
gr.Markdown("""
**Key Insight:** Effective safety unlearning reduces harmful generation by 80-90%
while maintaining general model capabilities.
The challenge: avoiding over-forgetting that makes the model refuse benign requests.
""")
# Tab 3: Method Comparison
with gr.TabItem("Compare Methods"):
gr.Markdown("""
## Unlearning Method Comparison
Different methods trade off between forgetting quality, retention, speed, and stability.
""")
radar_chart = gr.Plot(value=create_method_comparison_radar())
gr.Markdown("""
### Method Summary
| Method | Best For | Weakness |
|--------|----------|----------|
| Gradient Ascent | Speed | Catastrophic collapse |
| GradDiff | Balance | Needs retain data |
| KL Minimization | Utility preservation | Weak forgetting |
| NPO | Stability | Slower training |
| Task Vectors | Simplicity | Imprecise removal |
| SCRUB | Quality | Compute cost |
| Influence Functions | Precision | Very slow |
""")
# Tab 4: Impossibility Theorem
with gr.TabItem("The Hard Truth"):
gr.Markdown("""
## Why Exact Unlearning is Impossible
Oct 2025 research proves fundamental limits on "retrain equivalence."
No approximate method can perfectly match a retrained model.
""")
impossibility_chart = gr.Plot(value=create_impossibility_theorem_viz())
gr.Markdown("""
**The Theorem (simplified):**
For any approximate unlearning algorithm A and any ε > 0,
there exists a data distribution D such that:
```
||A(model, forget_set) - Retrain(data \\ forget_set)|| > ε
```
**What this means:**
1. Perfect unlearning requires full retraining
2. Approximate methods always leave some trace
3. The gap grows with forget set size
4. Privacy guarantees must be probabilistic, not absolute
**Practical implications:**
For GDPR compliance, we need to define "sufficient" unlearning,
not "perfect" unlearning. Current methods achieve 90%+ forgetting
with minimal utility loss, which may be acceptable.
""")
# Tab 5: Resources
with gr.TabItem("Resources"):
gr.Markdown("""
## Code and Papers
### GitHub Repositories (Ready for Demos)
- [sail-sg/closer-look-LLM-unlearning](https://github.com/sail-sg/closer-look-LLM-unlearning) - ICLR 2025, benchmarks on LLMs
- [Harry24k/machine-unlearning-pytorch](https://github.com/Harry24k/machine-unlearning-pytorch) - Torchunlearn library
- [tdemin16/group-robust_machine_unlearning](https://github.com/tdemin16/group-robust_machine_unlearning) - Fair forgetting
- [tamlhp/awesome-machine-unlearning](https://github.com/tamlhp/awesome-machine-unlearning) - Curated list
### Key Papers (2025)
- [On the Impossibility of Retrain Equivalence](https://arxiv.org/abs/2510.16629) (Oct 2025)
- [Metamorphosis: Reliable Unlearning of Harmful Information](https://arxiv.org/abs/2508.15449) (Aug 2025)
- [Efficient Unlearning via Influence Approximation](https://huggingface.co/papers/2507.23257) (Jul 2025)
- [SoK: Machine Unlearning for LLMs](https://arxiv.org/abs/2506.09227) (Jun 2025)
- [Group-Robust Machine Unlearning](https://huggingface.co/papers/2503.09330) (Mar 2025)
- [PEBench: Multimodal Unlearning](https://huggingface.co/papers/2503.12545) (Mar 2025)
### Benchmarks
- [TOFU](https://huggingface.co/datasets/locuslab/TOFU) - Fictitious facts (2.5M downloads)
- [CLEAR](https://huggingface.co/datasets/therem/CLEAR) - Multimodal unlearning
- [RWKU](https://rwku-bench.github.io) - Real-world knowledge
---
**Built by:** Eric Raymond Samiksha BC| Purdue AI/Robotics Engineering | IU Southbend
*Tag @sail_sg on X if you build something cool with this!*
""")
gr.Markdown("""
---
*"The right to be forgotten is not just a legal requirement.
It's a fundamental challenge in AI safety."*
""")
if __name__ == "__main__":
demo.launch()