Spaces:
Sleeping
Sleeping
File size: 7,130 Bytes
8f24287 75da07d 8f24287 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 | """
Task definitions and graders for the ML Training Optimizer environment.
Three tasks with increasing difficulty:
1. Easy: MNIST Digit Classifier
2. Medium: Fashion Item Classifier
3. Hard: CIFAR-10 Under Budget
Each task defines the model, dataset, budget, and grading rubric.
Graders produce deterministic scores in [0.0, 1.0].
"""
from dataclasses import dataclass
from typing import Dict
@dataclass
class TaskDefinition:
"""Defines a training task."""
task_id: str
name: str
description: str
difficulty: str
model_type: str
dataset_name: str
max_epochs: int
seed: int
target_metric: str
target_value: float
# Grading thresholds
min_score_accuracy: float # accuracy below this → score = 0
max_score_accuracy: float # accuracy above this → score = 1
TASKS: Dict[str, TaskDefinition] = {
"easy_mnist": TaskDefinition(
task_id="easy_mnist",
name="MNIST Digit Classifier",
description=(
"Train a 2-layer MLP to classify handwritten digits from MNIST. "
"You have a subset of 5,000 samples (4k train / 1k validation) and a budget of 100 epochs. "
"Goal: maximize validation accuracy. Default Adam + LR=0.001 should get ~93%. "
"With good hyperparameter choices you can reach 97%+."
),
difficulty="easy",
model_type="simple_mlp",
dataset_name="mnist",
max_epochs=100,
seed=42,
target_metric="val_accuracy",
target_value=0.96,
min_score_accuracy=0.88,
max_score_accuracy=0.975,
),
"medium_fashion": TaskDefinition(
task_id="medium_fashion",
name="Fashion Item Classifier",
description=(
"Train a small CNN to classify fashion items from FashionMNIST. "
"You have 8,000 samples (6.5k train / 1.5k val) and a budget of 80 epochs. "
"FashionMNIST is harder than MNIST — the small training set means overfitting is a real threat. "
"Goal: maximize validation accuracy while keeping the overfitting gap (train_acc - val_acc) below 5%. "
"You'll need proper optimizer selection, learning rate scheduling, and regularization."
),
difficulty="medium",
model_type="small_cnn",
dataset_name="fashion_mnist",
max_epochs=80,
seed=123,
target_metric="val_accuracy",
target_value=0.87,
min_score_accuracy=0.75,
max_score_accuracy=0.90,
),
"hard_cifar": TaskDefinition(
task_id="hard_cifar",
name="CIFAR-10 Under Budget",
description=(
"Train a deeper CNN on CIFAR-10 color images with a tight budget. "
"You have 10,000 samples (8k train / 2k val) and only 60 epochs. "
"CIFAR-10 is genuinely challenging for small models. The tight budget means you must "
"make smart decisions fast — balance exploration (trying configs) vs exploitation "
"(training with a good config). Overfitting on 8k images is severe without regularization. "
"Score is based on accuracy (50%), compute efficiency (30%), and training stability (20%)."
),
difficulty="hard",
model_type="deeper_cnn",
dataset_name="cifar10",
max_epochs=60,
seed=456,
target_metric="val_accuracy",
target_value=0.65,
min_score_accuracy=0.40,
max_score_accuracy=0.68,
),
}
def _clamp(value: float, low: float = 0.0001, high: float = 0.9999) -> float:
"""Clamp a value to [low, high]."""
return max(low, min(high, value))
def grade_easy(
val_accuracy: float,
task: TaskDefinition,
**kwargs,
) -> Dict:
"""
Grade the easy MNIST task.
Score is linear from min_score_accuracy to max_score_accuracy.
Simple and forgiving — rewards any accuracy above 88%.
"""
acc_score = _clamp(
(val_accuracy - task.min_score_accuracy)
/ (task.max_score_accuracy - task.min_score_accuracy)
)
return {
"score": round(acc_score, 4),
"details": {
"val_accuracy": round(val_accuracy, 4),
"accuracy_score": round(acc_score, 4),
},
}
def grade_medium(
val_accuracy: float,
overfitting_gap: float,
task: TaskDefinition,
**kwargs,
) -> Dict:
"""
Grade the medium FashionMNIST task.
Score = 60% accuracy + 40% generalization.
Penalizes overfitting (train_acc - val_acc > 12%).
"""
acc_score = _clamp(
(val_accuracy - task.min_score_accuracy)
/ (task.max_score_accuracy - task.min_score_accuracy)
)
gen_score = _clamp(1.0 - overfitting_gap / 0.12)
combined = 0.6 * acc_score + 0.4 * gen_score
return {
"score": round(combined, 4),
"details": {
"val_accuracy": round(val_accuracy, 4),
"overfitting_gap": round(overfitting_gap, 4),
"accuracy_score": round(acc_score, 4),
"generalization_score": round(gen_score, 4),
},
}
def grade_hard(
val_accuracy: float,
wasted_epochs: int,
budget: int,
loss_variance: float,
task: TaskDefinition,
**kwargs,
) -> Dict:
"""
Grade the hard CIFAR-10 task.
Score = 50% accuracy + 30% efficiency + 20% stability.
Rewards good accuracy, penalizes wasted compute and unstable training.
"""
acc_score = _clamp(
(val_accuracy - task.min_score_accuracy)
/ (task.max_score_accuracy - task.min_score_accuracy)
)
efficiency_score = _clamp(1.0 - wasted_epochs / max(1, budget))
stability_threshold = 0.01
stability_score = _clamp(1.0 - loss_variance / stability_threshold)
combined = 0.5 * acc_score + 0.3 * efficiency_score + 0.2 * stability_score
return {
"score": round(combined, 4),
"details": {
"val_accuracy": round(val_accuracy, 4),
"wasted_epochs": wasted_epochs,
"loss_variance": round(loss_variance, 6),
"accuracy_score": round(acc_score, 4),
"efficiency_score": round(efficiency_score, 4),
"stability_score": round(stability_score, 4),
},
}
def grade_task(task_id: str, trainer) -> Dict:
"""
Grade a completed task using the trainer's final state.
Args:
task_id: Task identifier
trainer: Trainer instance with completed training
Returns:
Dict with 'score' (0.0–1.0) and 'details'
"""
task = TASKS[task_id]
best_val_acc = trainer.state.best_val_accuracy
if task_id == "easy_mnist":
return grade_easy(best_val_acc, task)
elif task_id == "medium_fashion":
gap = trainer.get_overfitting_gap()
return grade_medium(best_val_acc, gap, task)
elif task_id == "hard_cifar":
wasted = trainer.get_wasted_epochs()
variance = trainer.get_loss_variance()
return grade_hard(
best_val_acc, wasted, task.max_epochs, variance, task,
)
else:
raise ValueError(f"Unknown task: {task_id}")
|