Image_Classification / train_utils.py
functionNormally
Restaurer les paramètres CNN qui fonctionnaient + epoch max à 50
e8074db
import os
import json
import time
from datetime import datetime
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.optim as optim
from config import MODEL_DIR, META_DIR, DATASET_DISPLAY_NAME
from data_utils import make_loaders
from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure
from model import SimpleCNN, ResNet18Classifier
def model_weight_path(model_name: str) -> str:
return os.path.join(MODEL_DIR, f"{model_name}.pt")
def model_meta_path(model_name: str) -> str:
return os.path.join(META_DIR, f"{model_name}.json")
def list_saved_models() -> List[str]:
names = []
for fn in os.listdir(META_DIR):
if fn.endswith(".json"):
names.append(fn[:-5])
return sorted(names, reverse=True)
def get_runtime_device() -> torch.device:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict):
cpu_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()}
torch.save(cpu_state_dict, model_weight_path(model_name))
payload = {
"model_name": model_name,
"config": config,
"training_summary": training_summary,
"created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
with open(model_meta_path(model_name), "w", encoding="utf-8") as f:
json.dump(payload, f, indent=2, ensure_ascii=False)
def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]:
meta_file = model_meta_path(model_name)
weight_file = model_weight_path(model_name)
if not os.path.exists(meta_file):
raise FileNotFoundError(f"Métadonnées introuvables pour le modèle : {model_name}")
if not os.path.exists(weight_file):
raise FileNotFoundError(f"Poids introuvables pour le modèle : {model_name}")
with open(meta_file, "r", encoding="utf-8") as f:
meta = json.load(f)
cfg = meta["config"]
if cfg.get("model_type", "cnn") == "resnet18":
model = ResNet18Classifier(
num_classes=cfg["num_classes"],
dropout=cfg.get("dropout", 0.4),
fc_dim=cfg.get("fc_dim", 256),
)
else:
model = SimpleCNN(
num_classes=cfg["num_classes"],
num_conv_blocks=cfg.get("num_conv_blocks", 3),
base_filters=cfg.get("base_filters", 32),
kernel_size=cfg.get("kernel_size", 3),
use_batchnorm=cfg.get("use_batchnorm", True),
dropout=cfg.get("dropout", 0.4),
fc_dim=cfg.get("fc_dim", 256),
)
state_dict = torch.load(weight_file, map_location="cpu")
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model, meta
def evaluate_loss_acc(model, loader, criterion, device):
model.eval()
total_loss = 0.0
total = 0
correct = 0
with torch.no_grad():
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
total_loss += loss.item() * images.size(0)
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
avg_loss = total_loss / total if total else 0.0
acc = correct / total if total else 0.0
return avg_loss, acc
def collect_predictions(model, loader, device):
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
for images, labels in loader:
images = images.to(device)
outputs = model(images)
preds = outputs.argmax(dim=1).detach().cpu().tolist()
y_pred.extend(preds)
y_true.extend(labels.tolist())
return y_true, y_pred
def train_model(
model_type: str = "cnn",
num_conv_blocks: int = 3,
base_filters: int = 32,
kernel_size: int = 3,
use_batchnorm: bool = True,
dropout: float = 0.4,
fc_dim: int = 256,
learning_rate: float = 0.001,
weight_decay: float = 0.0001,
batch_size: int = 16,
epochs: int = 30,
model_tag: str = "",
):
device = get_runtime_device()
train_loader, val_loader, test_loader, class_names = make_loaders(batch_size)
num_classes = len(class_names)
if model_type == "resnet18":
model = ResNet18Classifier(
num_classes=num_classes,
dropout=dropout,
fc_dim=fc_dim,
).to(device)
else:
model = SimpleCNN(
num_classes=num_classes,
num_conv_blocks=num_conv_blocks,
base_filters=base_filters,
kernel_size=kernel_size,
use_batchnorm=use_batchnorm,
dropout=dropout,
fc_dim=fc_dim,
).to(device)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=learning_rate,
weight_decay=weight_decay,
)
# Réduit le LR de moitié si val_loss ne s'améliore pas pendant 8 époques
# patience élevée car le val set est très petit (bruit important)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="min",
factor=0.5,
patience=8,
min_lr=learning_rate * 0.2,
)
history = []
logs = []
start_time = time.time()
best_val_loss = float("inf")
best_state_dict = None
for epoch in range(1, epochs + 1):
model.train()
running_loss = 0.0
total = 0
correct = 0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
# Important: prevents unstable fine-tuning / exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
running_loss += loss.item() * images.size(0)
preds = outputs.argmax(dim=1)
correct += (preds == labels).sum().item()
total += labels.size(0)
train_loss = running_loss / total if total else 0.0
train_acc = correct / total if total else 0.0
val_loss, val_acc = evaluate_loss_acc(model, val_loader, criterion, device)
scheduler.step(val_loss)
current_lr = optimizer.param_groups[0]["lr"]
if val_loss < best_val_loss:
best_val_loss = val_loss
best_state_dict = {
k: v.detach().cpu().clone()
for k, v in model.state_dict().items()
}
row = {
"epoch": epoch,
"train_loss": round(train_loss, 4),
"train_acc": round(train_acc, 4),
"val_loss": round(val_loss, 4),
"val_acc": round(val_acc, 4),
}
history.append(row)
logs.append(
f"Époque {epoch}/{epochs} | "
f"perte entraînement={train_loss:.4f}, précision entraînement={train_acc:.4f}, "
f"perte validation={val_loss:.4f}, précision validation={val_acc:.4f}, "
f"lr={current_lr:.6f}"
)
if best_state_dict is not None:
model.load_state_dict(best_state_dict)
test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device)
y_true, y_pred = collect_predictions(model, test_loader, device)
metrics = compute_classification_metrics(y_true, y_pred, class_names)
elapsed = time.time() - start_time
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else "charcoal_resnet18"
model_name = f"{safe_tag}_{timestamp}"
cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
if model_type == "resnet18":
architecture = "ResNet18 pré-entraîné (layer4 + classifieur)"
else:
architecture = f"CNN simple ({num_conv_blocks} blocs, filtres={base_filters}, noyau={kernel_size}x{kernel_size})"
config = {
"dataset_name": DATASET_DISPLAY_NAME,
"model_type": model_type,
"architecture": architecture,
"num_classes": num_classes,
"class_names": class_names,
"num_conv_blocks": num_conv_blocks,
"base_filters": base_filters,
"kernel_size": kernel_size,
"use_batchnorm": use_batchnorm,
"dropout": dropout,
"fc_dim": fc_dim,
"learning_rate": learning_rate,
"weight_decay": weight_decay,
"batch_size": batch_size,
"epochs": epochs,
}
training_summary = {
"final_train_loss": history[-1]["train_loss"] if history else None,
"final_train_acc": history[-1]["train_acc"] if history else None,
"best_val_loss": round(best_val_loss, 4),
"final_val_loss": history[-1]["val_loss"] if history else None,
"final_val_acc": history[-1]["val_acc"] if history else None,
"test_cross_entropy_loss": round(test_loss, 4),
"test_accuracy": round(test_acc, 4),
"test_f1_macro": metrics["f1_macro"],
"test_f1_weighted": metrics["f1_weighted"],
"elapsed_seconds": round(elapsed, 2),
"device": str(device),
"total_params": total_params,
"trainable_params": trainable_params,
}
save_model(model, model_name, config, training_summary)
logs.append("")
logs.append("Entraînement terminé.")
logs.append(f"Modèle sauvegardé : {model_name}")
logs.append(f"Appareil utilisé : {device}")
logs.append(f"Architecture : {architecture}")
logs.append(f"Nombre total de paramètres : {total_params}")
logs.append(f"Paramètres entraînables : {trainable_params}")
logs.append(f"Perte test cross-entropy : {test_loss:.4f}")
logs.append(f"Accuracy test : {test_acc:.4f}")
logs.append(f"F1 macro test : {metrics['f1_macro']:.4f}")
logs.append(f"F1 pondéré test : {metrics['f1_weighted']:.4f}")
logs.append(f"Temps écoulé : {elapsed:.1f}s")
return {
"logs": "\n".join(logs),
"history": history,
"summary": training_summary,
"model_name": model_name,
"classification_report": metrics["classification_report"],
"confusion_matrix": metrics["confusion_matrix"],
"confusion_matrix_path": cm_path,
}
def evaluate_saved_model(model_name: str):
if not model_name:
raise ValueError("Aucun modèle sélectionné.")
device = get_runtime_device()
model, meta = load_model(model_name, device)
batch_size = int(meta["config"].get("batch_size", 16))
_, _, test_loader, class_names = make_loaders(batch_size)
criterion = nn.CrossEntropyLoss()
test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device)
y_true, y_pred = collect_predictions(model, test_loader, device)
metrics = compute_classification_metrics(y_true, y_pred, class_names)
cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name)
summary = {
"test_cross_entropy_loss": round(test_loss, 4),
"test_accuracy": round(test_acc, 4),
"test_f1_macro": metrics["f1_macro"],
"test_f1_weighted": metrics["f1_weighted"],
"device": str(device),
}
return summary, metrics["classification_report"], metrics["confusion_matrix"], cm_path