import os import json import time from datetime import datetime from typing import List, Tuple import torch import torch.nn as nn import torch.optim as optim from config import MODEL_DIR, META_DIR, DATASET_DISPLAY_NAME from data_utils import make_loaders from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure from model import SimpleCNN, ResNet18Classifier def model_weight_path(model_name: str) -> str: return os.path.join(MODEL_DIR, f"{model_name}.pt") def model_meta_path(model_name: str) -> str: return os.path.join(META_DIR, f"{model_name}.json") def list_saved_models() -> List[str]: names = [] for fn in os.listdir(META_DIR): if fn.endswith(".json"): names.append(fn[:-5]) return sorted(names, reverse=True) def get_runtime_device() -> torch.device: return torch.device("cuda" if torch.cuda.is_available() else "cpu") def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict): cpu_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()} torch.save(cpu_state_dict, model_weight_path(model_name)) payload = { "model_name": model_name, "config": config, "training_summary": training_summary, "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } with open(model_meta_path(model_name), "w", encoding="utf-8") as f: json.dump(payload, f, indent=2, ensure_ascii=False) def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]: meta_file = model_meta_path(model_name) weight_file = model_weight_path(model_name) if not os.path.exists(meta_file): raise FileNotFoundError(f"Métadonnées introuvables pour le modèle : {model_name}") if not os.path.exists(weight_file): raise FileNotFoundError(f"Poids introuvables pour le modèle : {model_name}") with open(meta_file, "r", encoding="utf-8") as f: meta = json.load(f) cfg = meta["config"] if cfg.get("model_type", "cnn") == "resnet18": model = ResNet18Classifier( num_classes=cfg["num_classes"], dropout=cfg.get("dropout", 0.4), fc_dim=cfg.get("fc_dim", 256), ) else: model = SimpleCNN( num_classes=cfg["num_classes"], num_conv_blocks=cfg.get("num_conv_blocks", 3), base_filters=cfg.get("base_filters", 32), kernel_size=cfg.get("kernel_size", 3), use_batchnorm=cfg.get("use_batchnorm", True), dropout=cfg.get("dropout", 0.4), fc_dim=cfg.get("fc_dim", 256), ) state_dict = torch.load(weight_file, map_location="cpu") model.load_state_dict(state_dict) model.to(device) model.eval() return model, meta def evaluate_loss_acc(model, loader, criterion, device): model.eval() total_loss = 0.0 total = 0 correct = 0 with torch.no_grad(): for images, labels in loader: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) total_loss += loss.item() * images.size(0) preds = outputs.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) avg_loss = total_loss / total if total else 0.0 acc = correct / total if total else 0.0 return avg_loss, acc def collect_predictions(model, loader, device): model.eval() y_true = [] y_pred = [] with torch.no_grad(): for images, labels in loader: images = images.to(device) outputs = model(images) preds = outputs.argmax(dim=1).detach().cpu().tolist() y_pred.extend(preds) y_true.extend(labels.tolist()) return y_true, y_pred def train_model( model_type: str = "cnn", num_conv_blocks: int = 3, base_filters: int = 32, kernel_size: int = 3, use_batchnorm: bool = True, dropout: float = 0.4, fc_dim: int = 256, learning_rate: float = 0.001, weight_decay: float = 0.0001, batch_size: int = 16, epochs: int = 30, model_tag: str = "", ): device = get_runtime_device() train_loader, val_loader, test_loader, class_names = make_loaders(batch_size) num_classes = len(class_names) if model_type == "resnet18": model = ResNet18Classifier( num_classes=num_classes, dropout=dropout, fc_dim=fc_dim, ).to(device) else: model = SimpleCNN( num_classes=num_classes, num_conv_blocks=num_conv_blocks, base_filters=base_filters, kernel_size=kernel_size, use_batchnorm=use_batchnorm, dropout=dropout, fc_dim=fc_dim, ).to(device) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) total_params = sum(p.numel() for p in model.parameters()) criterion = nn.CrossEntropyLoss() optimizer = optim.AdamW( filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate, weight_decay=weight_decay, ) # Réduit le LR de moitié si val_loss ne s'améliore pas pendant 8 époques # patience élevée car le val set est très petit (bruit important) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=8, min_lr=learning_rate * 0.2, ) history = [] logs = [] start_time = time.time() best_val_loss = float("inf") best_state_dict = None for epoch in range(1, epochs + 1): model.train() running_loss = 0.0 total = 0 correct = 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() # Important: prevents unstable fine-tuning / exploding gradients torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() running_loss += loss.item() * images.size(0) preds = outputs.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) train_loss = running_loss / total if total else 0.0 train_acc = correct / total if total else 0.0 val_loss, val_acc = evaluate_loss_acc(model, val_loader, criterion, device) scheduler.step(val_loss) current_lr = optimizer.param_groups[0]["lr"] if val_loss < best_val_loss: best_val_loss = val_loss best_state_dict = { k: v.detach().cpu().clone() for k, v in model.state_dict().items() } row = { "epoch": epoch, "train_loss": round(train_loss, 4), "train_acc": round(train_acc, 4), "val_loss": round(val_loss, 4), "val_acc": round(val_acc, 4), } history.append(row) logs.append( f"Époque {epoch}/{epochs} | " f"perte entraînement={train_loss:.4f}, précision entraînement={train_acc:.4f}, " f"perte validation={val_loss:.4f}, précision validation={val_acc:.4f}, " f"lr={current_lr:.6f}" ) if best_state_dict is not None: model.load_state_dict(best_state_dict) test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device) y_true, y_pred = collect_predictions(model, test_loader, device) metrics = compute_classification_metrics(y_true, y_pred, class_names) elapsed = time.time() - start_time timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else "charcoal_resnet18" model_name = f"{safe_tag}_{timestamp}" cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name) if model_type == "resnet18": architecture = "ResNet18 pré-entraîné (layer4 + classifieur)" else: architecture = f"CNN simple ({num_conv_blocks} blocs, filtres={base_filters}, noyau={kernel_size}x{kernel_size})" config = { "dataset_name": DATASET_DISPLAY_NAME, "model_type": model_type, "architecture": architecture, "num_classes": num_classes, "class_names": class_names, "num_conv_blocks": num_conv_blocks, "base_filters": base_filters, "kernel_size": kernel_size, "use_batchnorm": use_batchnorm, "dropout": dropout, "fc_dim": fc_dim, "learning_rate": learning_rate, "weight_decay": weight_decay, "batch_size": batch_size, "epochs": epochs, } training_summary = { "final_train_loss": history[-1]["train_loss"] if history else None, "final_train_acc": history[-1]["train_acc"] if history else None, "best_val_loss": round(best_val_loss, 4), "final_val_loss": history[-1]["val_loss"] if history else None, "final_val_acc": history[-1]["val_acc"] if history else None, "test_cross_entropy_loss": round(test_loss, 4), "test_accuracy": round(test_acc, 4), "test_f1_macro": metrics["f1_macro"], "test_f1_weighted": metrics["f1_weighted"], "elapsed_seconds": round(elapsed, 2), "device": str(device), "total_params": total_params, "trainable_params": trainable_params, } save_model(model, model_name, config, training_summary) logs.append("") logs.append("Entraînement terminé.") logs.append(f"Modèle sauvegardé : {model_name}") logs.append(f"Appareil utilisé : {device}") logs.append(f"Architecture : {architecture}") logs.append(f"Nombre total de paramètres : {total_params}") logs.append(f"Paramètres entraînables : {trainable_params}") logs.append(f"Perte test cross-entropy : {test_loss:.4f}") logs.append(f"Accuracy test : {test_acc:.4f}") logs.append(f"F1 macro test : {metrics['f1_macro']:.4f}") logs.append(f"F1 pondéré test : {metrics['f1_weighted']:.4f}") logs.append(f"Temps écoulé : {elapsed:.1f}s") return { "logs": "\n".join(logs), "history": history, "summary": training_summary, "model_name": model_name, "classification_report": metrics["classification_report"], "confusion_matrix": metrics["confusion_matrix"], "confusion_matrix_path": cm_path, } def evaluate_saved_model(model_name: str): if not model_name: raise ValueError("Aucun modèle sélectionné.") device = get_runtime_device() model, meta = load_model(model_name, device) batch_size = int(meta["config"].get("batch_size", 16)) _, _, test_loader, class_names = make_loaders(batch_size) criterion = nn.CrossEntropyLoss() test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device) y_true, y_pred = collect_predictions(model, test_loader, device) metrics = compute_classification_metrics(y_true, y_pred, class_names) cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name) summary = { "test_cross_entropy_loss": round(test_loss, 4), "test_accuracy": round(test_acc, 4), "test_f1_macro": metrics["f1_macro"], "test_f1_weighted": metrics["f1_weighted"], "device": str(device), } return summary, metrics["classification_report"], metrics["confusion_matrix"], cm_path