Spaces:
Sleeping
Sleeping
functionNormally Claude Sonnet 4.6 commited on
Commit ·
70451eb
1
Parent(s): f14a2ff
Augmentation renforcée, scheduler LR et dropout par défaut 0.5
Browse files- data_utils.py : RandomResizedCrop, rotation 30°, ColorJitter
- train_utils.py : ReduceLROnPlateau (factor=0.5, patience=8, min_lr=lr×0.2),
affichage du lr courant dans les logs
- app.py : dropout par défaut 0.5
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- app.py +1 -1
- data_utils.py +6 -2
- train_utils.py +14 -1
app.py
CHANGED
|
@@ -242,7 +242,7 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
|
|
| 242 |
dropout = gr.Slider(
|
| 243 |
minimum=0.0,
|
| 244 |
maximum=0.8,
|
| 245 |
-
value=0.
|
| 246 |
step=0.05,
|
| 247 |
label="Dropout",
|
| 248 |
)
|
|
|
|
| 242 |
dropout = gr.Slider(
|
| 243 |
minimum=0.0,
|
| 244 |
maximum=0.8,
|
| 245 |
+
value=0.5,
|
| 246 |
step=0.05,
|
| 247 |
label="Dropout",
|
| 248 |
)
|
data_utils.py
CHANGED
|
@@ -44,10 +44,14 @@ class HFDatasetWrapper(Dataset):
|
|
| 44 |
def get_train_transform():
|
| 45 |
return transforms.Compose(
|
| 46 |
[
|
| 47 |
-
|
|
|
|
| 48 |
transforms.RandomHorizontalFlip(p=0.5),
|
| 49 |
transforms.RandomVerticalFlip(p=0.5),
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
| 51 |
transforms.ToTensor(),
|
| 52 |
transforms.Normalize(
|
| 53 |
mean=(0.485, 0.456, 0.406),
|
|
|
|
| 44 |
def get_train_transform():
|
| 45 |
return transforms.Compose(
|
| 46 |
[
|
| 47 |
+
# Crop aléatoire puis redimensionnement : simule des cadrages différents
|
| 48 |
+
transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.75, 1.0)),
|
| 49 |
transforms.RandomHorizontalFlip(p=0.5),
|
| 50 |
transforms.RandomVerticalFlip(p=0.5),
|
| 51 |
+
# Rotation plus large : les images microscopiques n'ont pas d'orientation canonique
|
| 52 |
+
transforms.RandomRotation(degrees=30),
|
| 53 |
+
# Légère variation de luminosité/contraste pour robustesse aux conditions d'acquisition
|
| 54 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
|
| 55 |
transforms.ToTensor(),
|
| 56 |
transforms.Normalize(
|
| 57 |
mean=(0.485, 0.456, 0.406),
|
train_utils.py
CHANGED
|
@@ -166,6 +166,16 @@ def train_model(
|
|
| 166 |
weight_decay=weight_decay,
|
| 167 |
)
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
history = []
|
| 170 |
logs = []
|
| 171 |
start_time = time.time()
|
|
@@ -204,6 +214,8 @@ def train_model(
|
|
| 204 |
train_acc = correct / total if total else 0.0
|
| 205 |
|
| 206 |
val_loss, val_acc = evaluate_loss_acc(model, val_loader, criterion, device)
|
|
|
|
|
|
|
| 207 |
|
| 208 |
if val_loss < best_val_loss:
|
| 209 |
best_val_loss = val_loss
|
|
@@ -225,7 +237,8 @@ def train_model(
|
|
| 225 |
logs.append(
|
| 226 |
f"Époque {epoch}/{epochs} | "
|
| 227 |
f"perte entraînement={train_loss:.4f}, précision entraînement={train_acc:.4f}, "
|
| 228 |
-
f"perte validation={val_loss:.4f}, précision validation={val_acc:.4f}"
|
|
|
|
| 229 |
)
|
| 230 |
|
| 231 |
if best_state_dict is not None:
|
|
|
|
| 166 |
weight_decay=weight_decay,
|
| 167 |
)
|
| 168 |
|
| 169 |
+
# Réduit le LR de moitié si val_loss ne s'améliore pas pendant 8 époques
|
| 170 |
+
# patience élevée car le val set est très petit (bruit important)
|
| 171 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 172 |
+
optimizer,
|
| 173 |
+
mode="min",
|
| 174 |
+
factor=0.5,
|
| 175 |
+
patience=8,
|
| 176 |
+
min_lr=learning_rate * 0.2,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
history = []
|
| 180 |
logs = []
|
| 181 |
start_time = time.time()
|
|
|
|
| 214 |
train_acc = correct / total if total else 0.0
|
| 215 |
|
| 216 |
val_loss, val_acc = evaluate_loss_acc(model, val_loader, criterion, device)
|
| 217 |
+
scheduler.step(val_loss)
|
| 218 |
+
current_lr = optimizer.param_groups[0]["lr"]
|
| 219 |
|
| 220 |
if val_loss < best_val_loss:
|
| 221 |
best_val_loss = val_loss
|
|
|
|
| 237 |
logs.append(
|
| 238 |
f"Époque {epoch}/{epochs} | "
|
| 239 |
f"perte entraînement={train_loss:.4f}, précision entraînement={train_acc:.4f}, "
|
| 240 |
+
f"perte validation={val_loss:.4f}, précision validation={val_acc:.4f}, "
|
| 241 |
+
f"lr={current_lr:.6f}"
|
| 242 |
)
|
| 243 |
|
| 244 |
if best_state_dict is not None:
|