clearn-demo / app.py
rahulmk's picture
Upload app.py with huggingface_hub
9c3a703 verified
"""clearn HuggingFace Spaces Gradio Demo.
Interactive demo of continual learning strategies. Train a model on
sequential tasks and inspect retention with clearn's diff() report.
"""
from __future__ import annotations
from typing import Any
import gradio as gr
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import clearn
matplotlib.use("Agg")
# ---------------------------------------------------------------------------
# Model
# ---------------------------------------------------------------------------
class MLP(nn.Module):
"""Small 2-hidden-layer MLP for the demo."""
def __init__(self, input_dim: int = 128, hidden_dim: int = 256, n_classes: int = 10) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, n_classes),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
# ---------------------------------------------------------------------------
# Data generation
# ---------------------------------------------------------------------------
def generate_sequential_tasks(
n_tasks: int = 5,
input_dim: int = 128,
samples_per_task: int = 200,
classes_per_task: int = 2,
centroid_scale: float = 3.0,
std: float = 0.3,
seed: int = 42,
) -> list[DataLoader]:
"""Create synthetic sequential classification tasks with clustered centroids.
Each task introduces ``classes_per_task`` unique classes. Data is generated
around fixed centroids so that patterns are learnable and distinct across
tasks.
Returns:
A list of DataLoaders, one per task.
"""
torch.manual_seed(seed)
np.random.seed(seed)
total_classes = n_tasks * classes_per_task
# Generate fixed centroids for all classes
centroids = torch.randn(total_classes, input_dim) * centroid_scale
dataloaders: list[DataLoader] = []
for task_idx in range(n_tasks):
class_start = task_idx * classes_per_task
class_end = class_start + classes_per_task
all_x: list[torch.Tensor] = []
all_y: list[torch.Tensor] = []
samples_per_class = samples_per_task // classes_per_task
for cls_idx in range(class_start, class_end):
centroid = centroids[cls_idx]
x = centroid.unsqueeze(0) + torch.randn(samples_per_class, input_dim) * std
y = torch.full((samples_per_class,), cls_idx, dtype=torch.long)
all_x.append(x)
all_y.append(y)
X = torch.cat(all_x)
Y = torch.cat(all_y)
# Shuffle
perm = torch.randperm(X.size(0))
X = X[perm]
Y = Y[perm]
dataset = TensorDataset(X, Y)
dataloaders.append(DataLoader(dataset, batch_size=32, shuffle=True))
return dataloaders
# ---------------------------------------------------------------------------
# Training helpers
# ---------------------------------------------------------------------------
def train_single_run(
strategy: str | None,
n_tasks: int,
epochs: int,
dataloaders: list[DataLoader],
strategy_kwargs: dict[str, Any] | None = None,
) -> tuple[clearn.ContinualModel | None, list[clearn.TrainingMetrics], nn.Module]:
"""Train a model on sequential tasks.
If ``strategy`` is None, trains a bare model with no continual learning
protection (baseline).
Returns:
(continual_model_or_none, list_of_metrics, raw_model)
"""
total_classes = n_tasks * 2
model = MLP(input_dim=128, hidden_dim=256, n_classes=total_classes)
if strategy_kwargs is None:
strategy_kwargs = {}
if strategy is not None:
cl_model = clearn.wrap(model, strategy=strategy, **strategy_kwargs)
else:
cl_model = None
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
all_metrics: list[clearn.TrainingMetrics] = []
for task_idx in range(n_tasks):
task_id = f"task_{task_idx + 1}"
dl = dataloaders[task_idx]
if cl_model is not None:
metrics = cl_model.fit(dl, optimizer, epochs=epochs, task_id=task_id)
all_metrics.append(metrics)
else:
# Baseline training without clearn
loss_fn = nn.CrossEntropyLoss()
model.train()
for _ in range(epochs):
for batch_x, batch_y in dl:
optimizer.zero_grad()
out = model(batch_x)
loss = loss_fn(out, batch_y)
loss.backward()
optimizer.step()
return cl_model, all_metrics, model
def evaluate_task(model: nn.Module, dataloader: DataLoader) -> float:
"""Evaluate accuracy on a single task's dataloader."""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch_x, batch_y in dataloader:
out = model(batch_x)
preds = out.argmax(dim=1)
correct += (preds == batch_y).sum().item()
total += batch_y.size(0)
model.train()
return correct / max(total, 1)
# ---------------------------------------------------------------------------
# Plot helpers
# ---------------------------------------------------------------------------
def plot_retention_bar(task_scores: dict[str, float], strategy_name: str) -> plt.Figure:
"""Create a bar chart of per-task retention percentages."""
fig, ax = plt.subplots(figsize=(8, 5))
tasks = list(task_scores.keys())
scores = list(task_scores.values())
colors = []
for s in scores:
if s >= 90:
colors.append("#22c55e") # green
elif s >= 70:
colors.append("#eab308") # yellow
elif s >= 50:
colors.append("#f97316") # orange
else:
colors.append("#ef4444") # red
bars = ax.bar(tasks, scores, color=colors, edgecolor="white", linewidth=0.5)
# Add value labels on bars
for bar, score in zip(bars, scores):
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 1.5,
f"{score:.1f}%",
ha="center",
va="bottom",
fontsize=10,
fontweight="bold",
)
ax.set_ylim(0, 115)
ax.set_ylabel("Retention (%)", fontsize=12)
ax.set_xlabel("Task", fontsize=12)
ax.set_title(f"Per-Task Retention -- Strategy: {strategy_name}", fontsize=14, fontweight="bold")
ax.axhline(y=90, color="#94a3b8", linestyle="--", alpha=0.5, label="90% threshold")
ax.legend(loc="lower right", fontsize=9)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
fig.tight_layout()
return fig
def plot_comparison(
results: dict[str, list[float]],
task_labels: list[str],
) -> plt.Figure:
"""Create a grouped bar chart comparing Task 1 accuracy across methods."""
fig, ax = plt.subplots(figsize=(10, 6))
methods = list(results.keys())
n_methods = len(methods)
x = np.arange(len(task_labels))
width = 0.8 / n_methods
palette = ["#64748b", "#3b82f6", "#8b5cf6", "#f97316", "#22c55e"]
for i, method in enumerate(methods):
scores = results[method]
offset = (i - n_methods / 2 + 0.5) * width
bars = ax.bar(x + offset, scores, width, label=method, color=palette[i % len(palette)], edgecolor="white", linewidth=0.5)
for bar, score in zip(bars, scores):
if score > 5:
ax.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 1,
f"{score:.0f}%",
ha="center",
va="bottom",
fontsize=7,
fontweight="bold",
)
ax.set_xticks(x)
ax.set_xticklabels(task_labels, fontsize=10)
ax.set_ylabel("Accuracy (%)", fontsize=12)
ax.set_title("Strategy Comparison -- Per-Task Accuracy After All Training", fontsize=14, fontweight="bold")
ax.set_ylim(0, 115)
ax.legend(loc="upper right", fontsize=9)
ax.axhline(y=90, color="#94a3b8", linestyle="--", alpha=0.4)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
fig.tight_layout()
return fig
# ---------------------------------------------------------------------------
# Tab 1: Train & Inspect
# ---------------------------------------------------------------------------
def train_and_inspect(
strategy: str,
ewc_lambda: float,
si_c: float,
der_buffer_size: int,
der_alpha: float,
der_beta: float,
gem_memory_size: int,
epochs: int,
n_tasks: int,
) -> tuple[str, plt.Figure, str]:
"""Run training with the selected strategy and return results."""
torch.manual_seed(42)
np.random.seed(42)
# Map display name to clearn strategy name
strategy_map = {
"EWC": "ewc",
"SI": "si",
"DER++": "der",
"GEM": "gem",
}
strat_name = strategy_map[strategy]
# Build strategy kwargs
kwargs: dict[str, Any] = {}
if strat_name == "ewc":
kwargs["lambda_"] = ewc_lambda
elif strat_name == "si":
kwargs["c"] = si_c
elif strat_name == "der":
kwargs["buffer_size"] = int(der_buffer_size)
kwargs["alpha"] = der_alpha
kwargs["beta"] = der_beta
elif strat_name == "gem":
kwargs["memory_size"] = int(gem_memory_size)
dataloaders = generate_sequential_tasks(n_tasks=n_tasks)
cl_model, all_metrics, raw_model = train_single_run(
strategy=strat_name,
n_tasks=n_tasks,
epochs=epochs,
dataloaders=dataloaders,
strategy_kwargs=kwargs,
)
# Generate retention report
report = cl_model.diff()
report_text = str(report)
# Generate bar chart
fig = plot_retention_bar(report.task_scores, strategy)
# Build metrics summary
metrics_lines = ["Training Metrics Summary", "=" * 40]
for m in all_metrics:
metrics_lines.append(str(m))
metrics_lines.append("")
metrics_text = "\n".join(metrics_lines)
return report_text, fig, metrics_text
# ---------------------------------------------------------------------------
# Tab 2: Compare Strategies
# ---------------------------------------------------------------------------
def compare_strategies(n_tasks: int, epochs: int) -> tuple[plt.Figure, str]:
"""Run all strategies on the same data and compare."""
torch.manual_seed(42)
np.random.seed(42)
dataloaders = generate_sequential_tasks(n_tasks=n_tasks)
strategies: dict[str, dict[str, Any] | None] = {
"Baseline (no CL)": None,
"EWC": {"lambda_": 5000},
"SI": {"c": 1.0},
"DER++": {"buffer_size": 200, "alpha": 0.1, "beta": 0.5},
"GEM": {"memory_size": 256},
}
strategy_name_map = {
"Baseline (no CL)": None,
"EWC": "ewc",
"SI": "si",
"DER++": "der",
"GEM": "gem",
}
all_results: dict[str, list[float]] = {}
summary_lines = ["Strategy Comparison Report", "=" * 50, ""]
task_labels = [f"task_{i + 1}" for i in range(n_tasks)]
for display_name, kwargs in strategies.items():
# Regenerate data with same seed for fairness
torch.manual_seed(42)
np.random.seed(42)
dl_copy = generate_sequential_tasks(n_tasks=n_tasks)
strat_key = strategy_name_map[display_name]
cl_model, _, raw_model = train_single_run(
strategy=strat_key,
n_tasks=n_tasks,
epochs=epochs,
dataloaders=dl_copy,
strategy_kwargs=kwargs if kwargs else {},
)
# Evaluate accuracy on all tasks
# Regenerate data for evaluation
torch.manual_seed(42)
np.random.seed(42)
eval_dls = generate_sequential_tasks(n_tasks=n_tasks)
model_to_eval = raw_model
task_accs = []
for i in range(n_tasks):
acc = evaluate_task(model_to_eval, eval_dls[i]) * 100.0
task_accs.append(acc)
all_results[display_name] = task_accs
# Summary
summary_lines.append(f"--- {display_name} ---")
if cl_model is not None:
report = cl_model.diff()
summary_lines.append(str(report))
else:
summary_lines.append(" (No continual learning protection)")
for i, acc in enumerate(task_accs):
summary_lines.append(f" task_{i + 1}: {acc:.1f}%")
avg = sum(task_accs) / len(task_accs)
summary_lines.append(f" Average accuracy: {avg:.1f}%")
summary_lines.append("")
fig = plot_comparison(all_results, task_labels)
summary_text = "\n".join(summary_lines)
return fig, summary_text
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
def build_app() -> gr.Blocks:
"""Build and return the Gradio Blocks app."""
with gr.Blocks(
title="clearn Demo -- Continual Learning for PyTorch",
theme=gr.themes.Soft(primary_hue="green", secondary_hue="cyan"),
) as app:
gr.Markdown(
"""
# clearn -- Continual Learning for PyTorch
**Wrap once. Train forever.**
Prevent catastrophic forgetting when fine-tuning neural networks on sequential tasks.
This demo trains a small MLP on synthetic sequential classification tasks and shows
how different continual learning strategies preserve knowledge from earlier tasks.
`pip install clearn-ai` | [GitHub](https://github.com/itisrmk/clearn) | [Docs](https://clearn.ai)
"""
)
with gr.Tabs():
# ----------------------------------------------------------
# Tab 1: Train & Inspect
# ----------------------------------------------------------
with gr.TabItem("Train & Inspect"):
gr.Markdown("### Configure a strategy, train on sequential tasks, and inspect the retention report.")
with gr.Row():
with gr.Column(scale=1):
strategy_dd = gr.Dropdown(
choices=["EWC", "SI", "DER++", "GEM"],
value="EWC",
label="Strategy",
)
# EWC params
ewc_lambda = gr.Slider(
minimum=100, maximum=10000, value=5000, step=100,
label="EWC: lambda_ (regularization strength)",
visible=True,
)
# SI params
si_c = gr.Slider(
minimum=0.1, maximum=10.0, value=1.0, step=0.1,
label="SI: c (regularization strength)",
visible=False,
)
# DER++ params
der_buffer_size = gr.Slider(
minimum=50, maximum=500, value=200, step=10,
label="DER++: buffer_size",
visible=False,
)
der_alpha = gr.Slider(
minimum=0.01, maximum=1.0, value=0.1, step=0.01,
label="DER++: alpha (replay CE weight)",
visible=False,
)
der_beta = gr.Slider(
minimum=0.1, maximum=2.0, value=0.5, step=0.1,
label="DER++: beta (logit matching weight)",
visible=False,
)
# GEM params
gem_memory_size = gr.Slider(
minimum=50, maximum=500, value=256, step=10,
label="GEM: memory_size (per task)",
visible=False,
)
epochs_input = gr.Number(
value=5, minimum=1, maximum=20, step=1,
label="Epochs per task",
precision=0,
)
n_tasks_input = gr.Number(
value=5, minimum=2, maximum=10, step=1,
label="Number of tasks",
precision=0,
)
train_btn = gr.Button("Train", variant="primary", size="lg")
with gr.Column(scale=2):
report_output = gr.Textbox(
label="Retention Report -- model.diff()",
lines=12,
show_copy_button=True,
)
chart_output = gr.Plot(label="Per-Task Retention")
metrics_output = gr.Textbox(
label="Training Metrics",
lines=15,
show_copy_button=True,
)
# Toggle strategy-specific sliders
def update_visibility(strategy: str) -> tuple:
return (
gr.update(visible=(strategy == "EWC")),
gr.update(visible=(strategy == "SI")),
gr.update(visible=(strategy == "DER++")),
gr.update(visible=(strategy == "DER++")),
gr.update(visible=(strategy == "DER++")),
gr.update(visible=(strategy == "GEM")),
)
strategy_dd.change(
fn=update_visibility,
inputs=[strategy_dd],
outputs=[ewc_lambda, si_c, der_buffer_size, der_alpha, der_beta, gem_memory_size],
)
train_btn.click(
fn=train_and_inspect,
inputs=[
strategy_dd,
ewc_lambda,
si_c,
der_buffer_size,
der_alpha,
der_beta,
gem_memory_size,
epochs_input,
n_tasks_input,
],
outputs=[report_output, chart_output, metrics_output],
)
# ----------------------------------------------------------
# Tab 2: Compare Strategies
# ----------------------------------------------------------
with gr.TabItem("Compare Strategies"):
gr.Markdown(
"### Run Baseline, EWC, SI, DER++, and GEM on the **same** synthetic data and compare retention."
)
with gr.Row():
compare_tasks = gr.Number(
value=5, minimum=2, maximum=10, step=1,
label="Number of tasks",
precision=0,
)
compare_epochs = gr.Number(
value=5, minimum=1, maximum=20, step=1,
label="Epochs per task",
precision=0,
)
compare_btn = gr.Button("Compare All Strategies", variant="primary", size="lg")
compare_chart = gr.Plot(label="Strategy Comparison")
compare_summary = gr.Textbox(
label="Retention Reports",
lines=30,
show_copy_button=True,
)
compare_btn.click(
fn=compare_strategies,
inputs=[compare_tasks, compare_epochs],
outputs=[compare_chart, compare_summary],
)
gr.Markdown(
"""
---
*Built with [clearn](https://github.com/itisrmk/clearn) and [Gradio](https://gradio.app).
Strategies: EWC (Kirkpatrick et al., 2017), SI (Zenke et al., 2017),
DER++ (Buzzega et al., 2020), GEM/A-GEM (Lopez-Paz & Ranzato, 2017).*
"""
)
return app
if __name__ == "__main__":
demo = build_app()
demo.launch()