Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import time | |
| from datetime import datetime | |
| from typing import List, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from config import MODEL_DIR, META_DIR, DATASET_DISPLAY_NAME | |
| from data_utils import make_loaders | |
| from metrics_utils import compute_classification_metrics, save_confusion_matrix_figure | |
| from model import SimpleCNN, ResNet18Classifier | |
| def model_weight_path(model_name: str) -> str: | |
| return os.path.join(MODEL_DIR, f"{model_name}.pt") | |
| def model_meta_path(model_name: str) -> str: | |
| return os.path.join(META_DIR, f"{model_name}.json") | |
| def list_saved_models() -> List[str]: | |
| names = [] | |
| for fn in os.listdir(META_DIR): | |
| if fn.endswith(".json"): | |
| names.append(fn[:-5]) | |
| return sorted(names, reverse=True) | |
| def get_runtime_device() -> torch.device: | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict): | |
| cpu_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()} | |
| torch.save(cpu_state_dict, model_weight_path(model_name)) | |
| payload = { | |
| "model_name": model_name, | |
| "config": config, | |
| "training_summary": training_summary, | |
| "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| } | |
| with open(model_meta_path(model_name), "w", encoding="utf-8") as f: | |
| json.dump(payload, f, indent=2, ensure_ascii=False) | |
| def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]: | |
| meta_file = model_meta_path(model_name) | |
| weight_file = model_weight_path(model_name) | |
| if not os.path.exists(meta_file): | |
| raise FileNotFoundError(f"Métadonnées introuvables pour le modèle : {model_name}") | |
| if not os.path.exists(weight_file): | |
| raise FileNotFoundError(f"Poids introuvables pour le modèle : {model_name}") | |
| with open(meta_file, "r", encoding="utf-8") as f: | |
| meta = json.load(f) | |
| cfg = meta["config"] | |
| if cfg.get("model_type", "cnn") == "resnet18": | |
| model = ResNet18Classifier( | |
| num_classes=cfg["num_classes"], | |
| dropout=cfg.get("dropout", 0.4), | |
| fc_dim=cfg.get("fc_dim", 256), | |
| ) | |
| else: | |
| model = SimpleCNN( | |
| num_classes=cfg["num_classes"], | |
| num_conv_blocks=cfg.get("num_conv_blocks", 3), | |
| base_filters=cfg.get("base_filters", 32), | |
| kernel_size=cfg.get("kernel_size", 3), | |
| use_batchnorm=cfg.get("use_batchnorm", True), | |
| dropout=cfg.get("dropout", 0.4), | |
| fc_dim=cfg.get("fc_dim", 256), | |
| ) | |
| state_dict = torch.load(weight_file, map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model, meta | |
| def evaluate_loss_acc(model, loader, criterion, device): | |
| model.eval() | |
| total_loss = 0.0 | |
| total = 0 | |
| correct = 0 | |
| with torch.no_grad(): | |
| for images, labels in loader: | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() * images.size(0) | |
| preds = outputs.argmax(dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| avg_loss = total_loss / total if total else 0.0 | |
| acc = correct / total if total else 0.0 | |
| return avg_loss, acc | |
| def collect_predictions(model, loader, device): | |
| model.eval() | |
| y_true = [] | |
| y_pred = [] | |
| with torch.no_grad(): | |
| for images, labels in loader: | |
| images = images.to(device) | |
| outputs = model(images) | |
| preds = outputs.argmax(dim=1).detach().cpu().tolist() | |
| y_pred.extend(preds) | |
| y_true.extend(labels.tolist()) | |
| return y_true, y_pred | |
| def train_model( | |
| model_type: str = "cnn", | |
| num_conv_blocks: int = 3, | |
| base_filters: int = 32, | |
| kernel_size: int = 3, | |
| use_batchnorm: bool = True, | |
| dropout: float = 0.4, | |
| fc_dim: int = 256, | |
| learning_rate: float = 0.001, | |
| weight_decay: float = 0.0001, | |
| batch_size: int = 16, | |
| epochs: int = 30, | |
| model_tag: str = "", | |
| ): | |
| device = get_runtime_device() | |
| train_loader, val_loader, test_loader, class_names = make_loaders(batch_size) | |
| num_classes = len(class_names) | |
| if model_type == "resnet18": | |
| model = ResNet18Classifier( | |
| num_classes=num_classes, | |
| dropout=dropout, | |
| fc_dim=fc_dim, | |
| ).to(device) | |
| else: | |
| model = SimpleCNN( | |
| num_classes=num_classes, | |
| num_conv_blocks=num_conv_blocks, | |
| base_filters=base_filters, | |
| kernel_size=kernel_size, | |
| use_batchnorm=use_batchnorm, | |
| dropout=dropout, | |
| fc_dim=fc_dim, | |
| ).to(device) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.AdamW( | |
| filter(lambda p: p.requires_grad, model.parameters()), | |
| lr=learning_rate, | |
| weight_decay=weight_decay, | |
| ) | |
| # Réduit le LR de moitié si val_loss ne s'améliore pas pendant 8 époques | |
| # patience élevée car le val set est très petit (bruit important) | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, | |
| mode="min", | |
| factor=0.5, | |
| patience=8, | |
| min_lr=learning_rate * 0.2, | |
| ) | |
| history = [] | |
| logs = [] | |
| start_time = time.time() | |
| best_val_loss = float("inf") | |
| best_state_dict = None | |
| for epoch in range(1, epochs + 1): | |
| model.train() | |
| running_loss = 0.0 | |
| total = 0 | |
| correct = 0 | |
| for images, labels in train_loader: | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| # Important: prevents unstable fine-tuning / exploding gradients | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| running_loss += loss.item() * images.size(0) | |
| preds = outputs.argmax(dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| train_loss = running_loss / total if total else 0.0 | |
| train_acc = correct / total if total else 0.0 | |
| val_loss, val_acc = evaluate_loss_acc(model, val_loader, criterion, device) | |
| scheduler.step(val_loss) | |
| current_lr = optimizer.param_groups[0]["lr"] | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| best_state_dict = { | |
| k: v.detach().cpu().clone() | |
| for k, v in model.state_dict().items() | |
| } | |
| row = { | |
| "epoch": epoch, | |
| "train_loss": round(train_loss, 4), | |
| "train_acc": round(train_acc, 4), | |
| "val_loss": round(val_loss, 4), | |
| "val_acc": round(val_acc, 4), | |
| } | |
| history.append(row) | |
| logs.append( | |
| f"Époque {epoch}/{epochs} | " | |
| f"perte entraînement={train_loss:.4f}, précision entraînement={train_acc:.4f}, " | |
| f"perte validation={val_loss:.4f}, précision validation={val_acc:.4f}, " | |
| f"lr={current_lr:.6f}" | |
| ) | |
| if best_state_dict is not None: | |
| model.load_state_dict(best_state_dict) | |
| test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device) | |
| y_true, y_pred = collect_predictions(model, test_loader, device) | |
| metrics = compute_classification_metrics(y_true, y_pred, class_names) | |
| elapsed = time.time() - start_time | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else "charcoal_resnet18" | |
| model_name = f"{safe_tag}_{timestamp}" | |
| cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name) | |
| if model_type == "resnet18": | |
| architecture = "ResNet18 pré-entraîné (layer4 + classifieur)" | |
| else: | |
| architecture = f"CNN simple ({num_conv_blocks} blocs, filtres={base_filters}, noyau={kernel_size}x{kernel_size})" | |
| config = { | |
| "dataset_name": DATASET_DISPLAY_NAME, | |
| "model_type": model_type, | |
| "architecture": architecture, | |
| "num_classes": num_classes, | |
| "class_names": class_names, | |
| "num_conv_blocks": num_conv_blocks, | |
| "base_filters": base_filters, | |
| "kernel_size": kernel_size, | |
| "use_batchnorm": use_batchnorm, | |
| "dropout": dropout, | |
| "fc_dim": fc_dim, | |
| "learning_rate": learning_rate, | |
| "weight_decay": weight_decay, | |
| "batch_size": batch_size, | |
| "epochs": epochs, | |
| } | |
| training_summary = { | |
| "final_train_loss": history[-1]["train_loss"] if history else None, | |
| "final_train_acc": history[-1]["train_acc"] if history else None, | |
| "best_val_loss": round(best_val_loss, 4), | |
| "final_val_loss": history[-1]["val_loss"] if history else None, | |
| "final_val_acc": history[-1]["val_acc"] if history else None, | |
| "test_cross_entropy_loss": round(test_loss, 4), | |
| "test_accuracy": round(test_acc, 4), | |
| "test_f1_macro": metrics["f1_macro"], | |
| "test_f1_weighted": metrics["f1_weighted"], | |
| "elapsed_seconds": round(elapsed, 2), | |
| "device": str(device), | |
| "total_params": total_params, | |
| "trainable_params": trainable_params, | |
| } | |
| save_model(model, model_name, config, training_summary) | |
| logs.append("") | |
| logs.append("Entraînement terminé.") | |
| logs.append(f"Modèle sauvegardé : {model_name}") | |
| logs.append(f"Appareil utilisé : {device}") | |
| logs.append(f"Architecture : {architecture}") | |
| logs.append(f"Nombre total de paramètres : {total_params}") | |
| logs.append(f"Paramètres entraînables : {trainable_params}") | |
| logs.append(f"Perte test cross-entropy : {test_loss:.4f}") | |
| logs.append(f"Accuracy test : {test_acc:.4f}") | |
| logs.append(f"F1 macro test : {metrics['f1_macro']:.4f}") | |
| logs.append(f"F1 pondéré test : {metrics['f1_weighted']:.4f}") | |
| logs.append(f"Temps écoulé : {elapsed:.1f}s") | |
| return { | |
| "logs": "\n".join(logs), | |
| "history": history, | |
| "summary": training_summary, | |
| "model_name": model_name, | |
| "classification_report": metrics["classification_report"], | |
| "confusion_matrix": metrics["confusion_matrix"], | |
| "confusion_matrix_path": cm_path, | |
| } | |
| def evaluate_saved_model(model_name: str): | |
| if not model_name: | |
| raise ValueError("Aucun modèle sélectionné.") | |
| device = get_runtime_device() | |
| model, meta = load_model(model_name, device) | |
| batch_size = int(meta["config"].get("batch_size", 16)) | |
| _, _, test_loader, class_names = make_loaders(batch_size) | |
| criterion = nn.CrossEntropyLoss() | |
| test_loss, test_acc = evaluate_loss_acc(model, test_loader, criterion, device) | |
| y_true, y_pred = collect_predictions(model, test_loader, device) | |
| metrics = compute_classification_metrics(y_true, y_pred, class_names) | |
| cm_path = save_confusion_matrix_figure(metrics["confusion_matrix"], model_name) | |
| summary = { | |
| "test_cross_entropy_loss": round(test_loss, 4), | |
| "test_accuracy": round(test_acc, 4), | |
| "test_f1_macro": metrics["f1_macro"], | |
| "test_f1_weighted": metrics["f1_weighted"], | |
| "device": str(device), | |
| } | |
| return summary, metrics["classification_report"], metrics["confusion_matrix"], cm_path |