Spaces:
Running on Zero
Running on Zero
| import os | |
| import json | |
| import time | |
| from datetime import datetime | |
| from typing import List, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from config import MODEL_DIR, META_DIR | |
| from model import SimpleCNN | |
| from data_utils import make_loaders | |
| def model_weight_path(model_name: str) -> str: | |
| return os.path.join(MODEL_DIR, f"{model_name}.pt") | |
| def model_meta_path(model_name: str) -> str: | |
| return os.path.join(META_DIR, f"{model_name}.json") | |
| def list_saved_models() -> List[str]: | |
| names = [] | |
| for fn in os.listdir(META_DIR): | |
| if fn.endswith(".json"): | |
| names.append(fn[:-5]) | |
| return sorted(names, reverse=True) | |
| def save_model(model: nn.Module, model_name: str, config: dict, training_summary: dict): | |
| cpu_state_dict = {k: v.detach().cpu() for k, v in model.state_dict().items()} | |
| torch.save(cpu_state_dict, model_weight_path(model_name)) | |
| payload = { | |
| "model_name": model_name, | |
| "config": config, | |
| "training_summary": training_summary, | |
| "created_at": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
| } | |
| with open(model_meta_path(model_name), "w", encoding="utf-8") as f: | |
| json.dump(payload, f, indent=2, ensure_ascii=False) | |
| def load_model(model_name: str, device: torch.device) -> Tuple[nn.Module, dict]: | |
| meta_file = model_meta_path(model_name) | |
| weight_file = model_weight_path(model_name) | |
| if not os.path.exists(meta_file): | |
| raise FileNotFoundError(f"Metadata not found for model: {model_name}") | |
| if not os.path.exists(weight_file): | |
| raise FileNotFoundError(f"Weights not found for model: {model_name}") | |
| with open(meta_file, "r", encoding="utf-8") as f: | |
| meta = json.load(f) | |
| cfg = meta["config"] | |
| model = SimpleCNN( | |
| num_classes=cfg["num_classes"], | |
| conv1_channels=cfg["conv1_channels"], | |
| conv2_channels=cfg["conv2_channels"], | |
| kernel_size=cfg["kernel_size"], | |
| dropout=cfg["dropout"], | |
| fc_dim=cfg["fc_dim"], | |
| ) | |
| state_dict = torch.load(weight_file, map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| model.to(device) | |
| model.eval() | |
| return model, meta | |
| def get_runtime_device() -> torch.device: | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def evaluate(model, loader, criterion, device): | |
| model.eval() | |
| total_loss = 0.0 | |
| total = 0 | |
| correct = 0 | |
| with torch.no_grad(): | |
| for images, labels in loader: | |
| images, labels = images.to(device), labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| total_loss += loss.item() * images.size(0) | |
| preds = outputs.argmax(dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| return total_loss / total if total else 0.0, correct / total if total else 0.0 | |
| def train_model( | |
| conv1_channels: int, | |
| conv2_channels: int, | |
| kernel_size: int, | |
| dropout: float, | |
| fc_dim: int, | |
| learning_rate: float, | |
| batch_size: int, | |
| epochs: int, | |
| model_tag: str, | |
| ): | |
| device = get_runtime_device() | |
| train_loader, val_loader, test_loader, class_names = make_loaders(batch_size) | |
| num_classes = len(class_names) | |
| model = SimpleCNN( | |
| num_classes=num_classes, | |
| conv1_channels=conv1_channels, | |
| conv2_channels=conv2_channels, | |
| kernel_size=kernel_size, | |
| dropout=dropout, | |
| fc_dim=fc_dim, | |
| ).to(device) | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
| history = [] | |
| logs = [] | |
| start_time = time.time() | |
| for epoch in range(1, epochs + 1): | |
| model.train() | |
| running_loss = 0.0 | |
| total = 0 | |
| correct = 0 | |
| for images, labels in train_loader: | |
| images, labels = images.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() * images.size(0) | |
| preds = outputs.argmax(dim=1) | |
| correct += (preds == labels).sum().item() | |
| total += labels.size(0) | |
| train_loss = running_loss / total if total else 0.0 | |
| train_acc = correct / total if total else 0.0 | |
| val_loss, val_acc = evaluate(model, val_loader, criterion, device) | |
| row = { | |
| "epoch": epoch, | |
| "train_loss": round(train_loss, 4), | |
| "train_acc": round(train_acc, 4), | |
| "val_loss": round(val_loss, 4), | |
| "val_acc": round(val_acc, 4), | |
| } | |
| history.append(row) | |
| logs.append( | |
| f"Époque {epoch}/{epochs} | " | |
| f"perte entraînement={train_loss:.4f}, précision entraînement={train_acc:.4f}, " | |
| f"perte validation={val_loss:.4f}, précision validation={val_acc:.4f}" | |
| ) | |
| test_loss, test_acc = evaluate(model, test_loader, criterion, device) | |
| elapsed = time.time() - start_time | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| safe_tag = model_tag.strip().replace(" ", "_") if model_tag.strip() else "charcoal" | |
| model_name = f"{safe_tag}_{timestamp}" | |
| config = { | |
| "dataset_name": "Charbons de bois microscopiques", | |
| "num_classes": num_classes, | |
| "class_names": class_names, | |
| "conv1_channels": conv1_channels, | |
| "conv2_channels": conv2_channels, | |
| "kernel_size": kernel_size, | |
| "dropout": dropout, | |
| "fc_dim": fc_dim, | |
| "learning_rate": learning_rate, | |
| "batch_size": batch_size, | |
| "epochs": epochs, | |
| } | |
| training_summary = { | |
| "final_train_loss": history[-1]["train_loss"] if history else None, | |
| "final_train_acc": history[-1]["train_acc"] if history else None, | |
| "final_val_loss": history[-1]["val_loss"] if history else None, | |
| "final_val_acc": history[-1]["val_acc"] if history else None, | |
| "test_loss": round(test_loss, 4), | |
| "test_acc": round(test_acc, 4), | |
| "elapsed_seconds": round(elapsed, 2), | |
| "device": str(device), | |
| } | |
| save_model(model, model_name, config, training_summary) | |
| logs.append("") | |
| logs.append("Entraînement terminé.") | |
| logs.append(f"Modèle sauvegardé : {model_name}") | |
| logs.append(f"Appareil utilisé : {device}") | |
| logs.append(f"Perte test : {test_loss:.4f}") | |
| logs.append(f"Précision test : {test_acc:.4f}") | |
| logs.append(f"Temps écoulé : {elapsed:.1f}s") | |
| return "\n".join(logs), history, training_summary, model_name |