ml_trainer_env / server /tasks.py
BART-ender's picture
Upload folder using huggingface_hub
75da07d verified
"""
Task definitions and graders for the ML Training Optimizer environment.
Three tasks with increasing difficulty:
1. Easy: MNIST Digit Classifier
2. Medium: Fashion Item Classifier
3. Hard: CIFAR-10 Under Budget
Each task defines the model, dataset, budget, and grading rubric.
Graders produce deterministic scores in [0.0, 1.0].
"""
from dataclasses import dataclass
from typing import Dict
@dataclass
class TaskDefinition:
"""Defines a training task."""
task_id: str
name: str
description: str
difficulty: str
model_type: str
dataset_name: str
max_epochs: int
seed: int
target_metric: str
target_value: float
# Grading thresholds
min_score_accuracy: float # accuracy below this → score = 0
max_score_accuracy: float # accuracy above this → score = 1
TASKS: Dict[str, TaskDefinition] = {
"easy_mnist": TaskDefinition(
task_id="easy_mnist",
name="MNIST Digit Classifier",
description=(
"Train a 2-layer MLP to classify handwritten digits from MNIST. "
"You have a subset of 5,000 samples (4k train / 1k validation) and a budget of 100 epochs. "
"Goal: maximize validation accuracy. Default Adam + LR=0.001 should get ~93%. "
"With good hyperparameter choices you can reach 97%+."
),
difficulty="easy",
model_type="simple_mlp",
dataset_name="mnist",
max_epochs=100,
seed=42,
target_metric="val_accuracy",
target_value=0.96,
min_score_accuracy=0.88,
max_score_accuracy=0.975,
),
"medium_fashion": TaskDefinition(
task_id="medium_fashion",
name="Fashion Item Classifier",
description=(
"Train a small CNN to classify fashion items from FashionMNIST. "
"You have 8,000 samples (6.5k train / 1.5k val) and a budget of 80 epochs. "
"FashionMNIST is harder than MNIST — the small training set means overfitting is a real threat. "
"Goal: maximize validation accuracy while keeping the overfitting gap (train_acc - val_acc) below 5%. "
"You'll need proper optimizer selection, learning rate scheduling, and regularization."
),
difficulty="medium",
model_type="small_cnn",
dataset_name="fashion_mnist",
max_epochs=80,
seed=123,
target_metric="val_accuracy",
target_value=0.87,
min_score_accuracy=0.75,
max_score_accuracy=0.90,
),
"hard_cifar": TaskDefinition(
task_id="hard_cifar",
name="CIFAR-10 Under Budget",
description=(
"Train a deeper CNN on CIFAR-10 color images with a tight budget. "
"You have 10,000 samples (8k train / 2k val) and only 60 epochs. "
"CIFAR-10 is genuinely challenging for small models. The tight budget means you must "
"make smart decisions fast — balance exploration (trying configs) vs exploitation "
"(training with a good config). Overfitting on 8k images is severe without regularization. "
"Score is based on accuracy (50%), compute efficiency (30%), and training stability (20%)."
),
difficulty="hard",
model_type="deeper_cnn",
dataset_name="cifar10",
max_epochs=60,
seed=456,
target_metric="val_accuracy",
target_value=0.65,
min_score_accuracy=0.40,
max_score_accuracy=0.68,
),
}
def _clamp(value: float, low: float = 0.0001, high: float = 0.9999) -> float:
"""Clamp a value to [low, high]."""
return max(low, min(high, value))
def grade_easy(
val_accuracy: float,
task: TaskDefinition,
**kwargs,
) -> Dict:
"""
Grade the easy MNIST task.
Score is linear from min_score_accuracy to max_score_accuracy.
Simple and forgiving — rewards any accuracy above 88%.
"""
acc_score = _clamp(
(val_accuracy - task.min_score_accuracy)
/ (task.max_score_accuracy - task.min_score_accuracy)
)
return {
"score": round(acc_score, 4),
"details": {
"val_accuracy": round(val_accuracy, 4),
"accuracy_score": round(acc_score, 4),
},
}
def grade_medium(
val_accuracy: float,
overfitting_gap: float,
task: TaskDefinition,
**kwargs,
) -> Dict:
"""
Grade the medium FashionMNIST task.
Score = 60% accuracy + 40% generalization.
Penalizes overfitting (train_acc - val_acc > 12%).
"""
acc_score = _clamp(
(val_accuracy - task.min_score_accuracy)
/ (task.max_score_accuracy - task.min_score_accuracy)
)
gen_score = _clamp(1.0 - overfitting_gap / 0.12)
combined = 0.6 * acc_score + 0.4 * gen_score
return {
"score": round(combined, 4),
"details": {
"val_accuracy": round(val_accuracy, 4),
"overfitting_gap": round(overfitting_gap, 4),
"accuracy_score": round(acc_score, 4),
"generalization_score": round(gen_score, 4),
},
}
def grade_hard(
val_accuracy: float,
wasted_epochs: int,
budget: int,
loss_variance: float,
task: TaskDefinition,
**kwargs,
) -> Dict:
"""
Grade the hard CIFAR-10 task.
Score = 50% accuracy + 30% efficiency + 20% stability.
Rewards good accuracy, penalizes wasted compute and unstable training.
"""
acc_score = _clamp(
(val_accuracy - task.min_score_accuracy)
/ (task.max_score_accuracy - task.min_score_accuracy)
)
efficiency_score = _clamp(1.0 - wasted_epochs / max(1, budget))
stability_threshold = 0.01
stability_score = _clamp(1.0 - loss_variance / stability_threshold)
combined = 0.5 * acc_score + 0.3 * efficiency_score + 0.2 * stability_score
return {
"score": round(combined, 4),
"details": {
"val_accuracy": round(val_accuracy, 4),
"wasted_epochs": wasted_epochs,
"loss_variance": round(loss_variance, 6),
"accuracy_score": round(acc_score, 4),
"efficiency_score": round(efficiency_score, 4),
"stability_score": round(stability_score, 4),
},
}
def grade_task(task_id: str, trainer) -> Dict:
"""
Grade a completed task using the trainer's final state.
Args:
task_id: Task identifier
trainer: Trainer instance with completed training
Returns:
Dict with 'score' (0.0–1.0) and 'details'
"""
task = TASKS[task_id]
best_val_acc = trainer.state.best_val_accuracy
if task_id == "easy_mnist":
return grade_easy(best_val_acc, task)
elif task_id == "medium_fashion":
gap = trainer.get_overfitting_gap()
return grade_medium(best_val_acc, gap, task)
elif task_id == "hard_cifar":
wasted = trainer.get_wasted_epochs()
variance = trainer.get_loss_variance()
return grade_hard(
best_val_acc, wasted, task.max_epochs, variance, task,
)
else:
raise ValueError(f"Unknown task: {task_id}")