Spaces:
Sleeping
Sleeping
File size: 6,258 Bytes
e013072 8816082 e013072 8816082 e013072 8816082 e013072 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 | """Curriculum learning support for progressive training difficulty.
Implements a schedule that controls which training samples are used
at different stages of training, starting with easy examples (small
displacements) and gradually introducing harder ones.
Usage in training loop::
curriculum = TrainingCurriculum(
total_steps=100000,
warmup_fraction=0.1, # first 10% easy only
full_difficulty_at=0.5, # full dataset by 50%
)
# In training loop:
difficulty = curriculum.get_difficulty(global_step)
# Use difficulty to filter/weight samples
Or as a dataset wrapper::
dataset = CurriculumDataset(
base_dataset=SyntheticPairDataset(data_dir),
metadata_path=Path(data_dir) / "metadata.json",
total_steps=100000,
)
# Call dataset.set_step(global_step) each iteration
"""
from __future__ import annotations
import json
import math
from pathlib import Path
import numpy as np
class TrainingCurriculum:
"""Schedule that maps training step to difficulty level [0, 1].
Difficulty 0 = easiest (smallest displacements, lowest intensity).
Difficulty 1 = full dataset (all difficulties).
The schedule uses a cosine ramp:
- During warmup: difficulty = 0 (easy only)
- warmup → full_difficulty: cosine ramp from 0 → 1
- After full_difficulty: difficulty = 1 (full dataset)
"""
def __init__(
self,
total_steps: int,
warmup_fraction: float = 0.1,
full_difficulty_at: float = 0.5,
):
self.total_steps = total_steps
self.warmup_steps = int(total_steps * warmup_fraction)
self.full_steps = int(total_steps * full_difficulty_at)
def get_difficulty(self, step: int) -> float:
"""Get difficulty level [0, 1] for the given training step."""
if step < self.warmup_steps:
return 0.0
if step >= self.full_steps:
return 1.0
progress = (step - self.warmup_steps) / max(1, self.full_steps - self.warmup_steps)
return 0.5 * (1 - math.cos(math.pi * progress))
def should_include(
self,
step: int,
sample_difficulty: float,
rng: np.random.Generator | None = None,
) -> bool:
"""Whether to include a sample of the given difficulty at this step.
Uses probabilistic inclusion so harder samples gradually appear.
Args:
step: Current training step.
sample_difficulty: Difficulty of the sample [0, 1].
rng: Random number generator for stochastic inclusion.
Returns:
True if sample should be used.
"""
curr_difficulty = self.get_difficulty(step)
if sample_difficulty <= curr_difficulty:
return True
# Stochastic inclusion for samples slightly above threshold
if rng is None:
rng = np.random.default_rng()
overshoot = sample_difficulty - curr_difficulty
include_prob = max(0, 1.0 - overshoot * 5) # drops off quickly
return rng.random() < include_prob
class ProcedureCurriculum:
"""Procedure-aware curriculum that adjusts per-procedure weights.
Some procedures are inherently harder (e.g., orthognathic with large
deformations). This curriculum increases their weight over training.
"""
# Difficulty ranking (0=easiest, 1=hardest)
DEFAULT_PROCEDURE_DIFFICULTY = {
"blepharoplasty": 0.3, # small, localized changes
"rhinoplasty": 0.5, # moderate, central face
"rhytidectomy": 0.7, # large, affects face shape
"orthognathic": 0.9, # largest deformations
}
def __init__(
self,
total_steps: int,
procedure_difficulty: dict[str, float] | None = None,
warmup_fraction: float = 0.1,
):
self.curriculum = TrainingCurriculum(total_steps, warmup_fraction)
self.proc_difficulty = procedure_difficulty or self.DEFAULT_PROCEDURE_DIFFICULTY
def get_weight(self, step: int, procedure: str) -> float:
"""Get sampling weight for a procedure at the given step.
Returns a value in [0.1, 1.0] — never fully excludes any procedure.
"""
difficulty = self.get_difficulty(step)
proc_diff = self.proc_difficulty.get(procedure, 0.5)
if proc_diff <= difficulty:
return 1.0
# Reduce weight for too-hard procedures
return max(0.1, 1.0 - (proc_diff - difficulty) * 2)
def get_difficulty(self, step: int) -> float:
return self.curriculum.get_difficulty(step)
def get_procedure_weights(self, step: int) -> dict[str, float]:
"""Get all procedure weights at the given step."""
return {
proc: self.get_weight(step, proc)
for proc in self.proc_difficulty
}
def compute_sample_difficulty(
metadata_path: str | Path,
displacement_model_path: str | Path | None = None,
) -> dict[str, float]:
"""Compute difficulty scores for each sample in the dataset.
Difficulty is based on:
1. Displacement intensity (from metadata)
2. Procedure difficulty
3. Source type (real > synthetic)
Returns:
Dict mapping sample prefix to difficulty score [0, 1].
"""
with open(metadata_path) as f:
meta = json.load(f)
pairs = meta.get("pairs", {})
difficulties = {}
proc_base = {
"blepharoplasty": 0.2,
"rhinoplasty": 0.4,
"rhytidectomy": 0.6,
"orthognathic": 0.8,
"unknown": 0.5,
}
source_bonus = {
"synthetic": 0.0,
"synthetic_v3": 0.1, # realistic displacements slightly harder
"real": 0.2, # real data hardest
"augmented": 0.0,
}
for prefix, info in pairs.items():
proc = info.get("procedure", "unknown")
source = info.get("source", "synthetic")
intensity = info.get("intensity", 1.0)
# Combine factors
base = proc_base.get(proc, 0.5)
src = source_bonus.get(source, 0.0)
# Intensity scaling (higher intensity = harder)
int_factor = min(1.0, intensity / 1.5) * 0.2
difficulties[prefix] = min(1.0, base + src + int_factor)
return difficulties
|