functionNormally Claude Sonnet 4.6 commited on
Commit
70451eb
·
1 Parent(s): f14a2ff

Augmentation renforcée, scheduler LR et dropout par défaut 0.5

Browse files

- data_utils.py : RandomResizedCrop, rotation 30°, ColorJitter
- train_utils.py : ReduceLROnPlateau (factor=0.5, patience=8, min_lr=lr×0.2),
affichage du lr courant dans les logs
- app.py : dropout par défaut 0.5

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (3) hide show
  1. app.py +1 -1
  2. data_utils.py +6 -2
  3. train_utils.py +14 -1
app.py CHANGED
@@ -242,7 +242,7 @@ with gr.Blocks(title="Classification d’images microscopiques") as demo:
242
  dropout = gr.Slider(
243
  minimum=0.0,
244
  maximum=0.8,
245
- value=0.4,
246
  step=0.05,
247
  label="Dropout",
248
  )
 
242
  dropout = gr.Slider(
243
  minimum=0.0,
244
  maximum=0.8,
245
+ value=0.5,
246
  step=0.05,
247
  label="Dropout",
248
  )
data_utils.py CHANGED
@@ -44,10 +44,14 @@ class HFDatasetWrapper(Dataset):
44
  def get_train_transform():
45
  return transforms.Compose(
46
  [
47
- transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
 
48
  transforms.RandomHorizontalFlip(p=0.5),
49
  transforms.RandomVerticalFlip(p=0.5),
50
- transforms.RandomRotation(degrees=5),
 
 
 
51
  transforms.ToTensor(),
52
  transforms.Normalize(
53
  mean=(0.485, 0.456, 0.406),
 
44
  def get_train_transform():
45
  return transforms.Compose(
46
  [
47
+ # Crop aléatoire puis redimensionnement : simule des cadrages différents
48
+ transforms.RandomResizedCrop(IMAGE_SIZE, scale=(0.75, 1.0)),
49
  transforms.RandomHorizontalFlip(p=0.5),
50
  transforms.RandomVerticalFlip(p=0.5),
51
+ # Rotation plus large : les images microscopiques n'ont pas d'orientation canonique
52
+ transforms.RandomRotation(degrees=30),
53
+ # Légère variation de luminosité/contraste pour robustesse aux conditions d'acquisition
54
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
55
  transforms.ToTensor(),
56
  transforms.Normalize(
57
  mean=(0.485, 0.456, 0.406),
train_utils.py CHANGED
@@ -166,6 +166,16 @@ def train_model(
166
  weight_decay=weight_decay,
167
  )
168
 
 
 
 
 
 
 
 
 
 
 
169
  history = []
170
  logs = []
171
  start_time = time.time()
@@ -204,6 +214,8 @@ def train_model(
204
  train_acc = correct / total if total else 0.0
205
 
206
  val_loss, val_acc = evaluate_loss_acc(model, val_loader, criterion, device)
 
 
207
 
208
  if val_loss < best_val_loss:
209
  best_val_loss = val_loss
@@ -225,7 +237,8 @@ def train_model(
225
  logs.append(
226
  f"Époque {epoch}/{epochs} | "
227
  f"perte entraînement={train_loss:.4f}, précision entraînement={train_acc:.4f}, "
228
- f"perte validation={val_loss:.4f}, précision validation={val_acc:.4f}"
 
229
  )
230
 
231
  if best_state_dict is not None:
 
166
  weight_decay=weight_decay,
167
  )
168
 
169
+ # Réduit le LR de moitié si val_loss ne s'améliore pas pendant 8 époques
170
+ # patience élevée car le val set est très petit (bruit important)
171
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
172
+ optimizer,
173
+ mode="min",
174
+ factor=0.5,
175
+ patience=8,
176
+ min_lr=learning_rate * 0.2,
177
+ )
178
+
179
  history = []
180
  logs = []
181
  start_time = time.time()
 
214
  train_acc = correct / total if total else 0.0
215
 
216
  val_loss, val_acc = evaluate_loss_acc(model, val_loader, criterion, device)
217
+ scheduler.step(val_loss)
218
+ current_lr = optimizer.param_groups[0]["lr"]
219
 
220
  if val_loss < best_val_loss:
221
  best_val_loss = val_loss
 
237
  logs.append(
238
  f"Époque {epoch}/{epochs} | "
239
  f"perte entraînement={train_loss:.4f}, précision entraînement={train_acc:.4f}, "
240
+ f"perte validation={val_loss:.4f}, précision validation={val_acc:.4f}, "
241
+ f"lr={current_lr:.6f}"
242
  )
243
 
244
  if best_state_dict is not None: