import os import json import time from datetime import datetime from typing import List, Tuple import torch import torch.nn as nn import torch.optim as optim from config import MODEL_DIR, META_DIR from model import SimpleCNN from data_utils import make_loaders def model_weight_path(model_name: str) -> str: return os.path.join(MODEL_DIR, f"{model_name}.pt") def model_meta_path(model_name: str) -> str: return os.path.join(META_DIR, f"{model_name}.json") def list_saved_models() -> List[str]: names = [] for fn in os.listdir(META_DIR): if fn.endswith(".json"): names.append(fn[:-5]) return sorted(names, reverse=True) def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict): cpu_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()} torch.save(cpu_state_dict, model_weight_path(model_name)) payload = { "model_name": model_name, "config": config, "training_summary": training_summary, "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), } with open(model_meta_path(model_name), "w", encoding="utf-8") as f: json.dump(payload, f, indent=2, ensure_ascii=False) def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]: meta_file = model_meta_path(model_name) weight_file = model_weight_path(model_name) if not os.path.exists(meta_file): raise FileNotFoundError(f"Metadata not found for model: {model_name}") if not os.path.exists(weight_file): raise FileNotFoundError(f"Weights not found for model: {model_name}") with open(meta_file, "r", encoding="utf-8") as f: meta = json.load(f) cfg = meta["config"] model = SimpleCNN( num_classes=cfg["num_classes"], conv1_channels=cfg["conv1_channels"], conv2_channels=cfg["conv2_channels"], kernel_size=cfg["kernel_size"], dropout=cfg["dropout"], fc_dim=cfg["fc_dim"], ) state_dict = torch.load(weight_file, map_location="cpu") model.load_state_dict(state_dict) model.to(device) model.eval() return model, meta def get_runtime_device() -> torch.device: return torch.device("cuda" if torch.cuda.is_available() else "cpu") def evaluate(model, loader, criterion, device): model.eval() total_loss = 0.0 total = 0 correct = 0 with torch.no_grad(): for images, labels in loader: images, labels = images.to(device), labels.to(device) outputs = model(images) loss = criterion(outputs, labels) total_loss += loss.item() * images.size(0) preds = outputs.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) return total_loss / total if total else 0.0, correct / total if total else 0.0 def train_model( conv1_channels: int, conv2_channels: int, kernel_size: int, dropout: float, fc_dim: int, learning_rate: float, batch_size: int, epochs: int, model_tag: str, ): device = get_runtime_device() train_loader, val_loader, test_loader, class_names = make_loaders(batch_size) num_classes = len(class_names) model = SimpleCNN( num_classes=num_classes, conv1_channels=conv1_channels, conv2_channels=conv2_channels, kernel_size=kernel_size, dropout=dropout, fc_dim=fc_dim, ).to(device) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) history = [] logs = [] start_time = time.time() for epoch in range(1, epochs + 1): model.train() running_loss = 0.0 total = 0 correct = 0 for images, labels in train_loader: images, labels = images.to(device), labels.to(device) optimizer.zero_grad() outputs = model(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() * images.size(0) preds = outputs.argmax(dim=1) correct += (preds == labels).sum().item() total += labels.size(0) train_loss = running_loss / total if total else 0.0 train_acc = correct / total if total else 0.0 val_loss, val_acc = evaluate(model, val_loader, criterion, device) row = { "epoch": epoch, "train_loss": round(train_loss, 4), "train_acc": round(train_acc, 4), "val_loss": round(val_loss, 4), "val_acc": round(val_acc, 4), } history.append(row) logs.append( f"Époque {epoch}/{epochs} | " f"perte entraînement={train_loss:.4f}, précision entraînement={train_acc:.4f}, " f"perte validation={val_loss:.4f}, précision validation={val_acc:.4f}" ) test_loss, test_acc = evaluate(model, test_loader, criterion, device) elapsed = time.time() - start_time timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else "charcoal" model_name = f"{safe_tag}_{timestamp}" config = { "dataset_name": "Charbons de bois microscopiques", "num_classes": num_classes, "class_names": class_names, "conv1_channels": conv1_channels, "conv2_channels": conv2_channels, "kernel_size": kernel_size, "dropout": dropout, "fc_dim": fc_dim, "learning_rate": learning_rate, "batch_size": batch_size, "epochs": epochs, } training_summary = { "final_train_loss": history[-1]["train_loss"] if history else None, "final_train_acc": history[-1]["train_acc"] if history else None, "final_val_loss": history[-1]["val_loss"] if history else None, "final_val_acc": history[-1]["val_acc"] if history else None, "test_loss": round(test_loss, 4), "test_acc": round(test_acc, 4), "elapsed_seconds": round(elapsed, 2), "device": str(device), } save_model(model, model_name, config, training_summary) logs.append("") logs.append("Entraînement terminé.") logs.append(f"Modèle sauvegardé : {model_name}") logs.append(f"Appareil utilisé : {device}") logs.append(f"Perte test : {test_loss:.4f}") logs.append(f"Précision test : {test_acc:.4f}") logs.append(f"Temps écoulé : {elapsed:.1f}s") return "\n".join(logs), history, training_summary, model_name