|
|
"""
|
|
|
COGNITIVE-CORE: Checkpoint Loading & Key Remapping
|
|
|
===================================================
|
|
|
|
|
|
This module provides robust checkpoint loading with automatic key remapping
|
|
|
to handle different checkpoint formats (with/without 'model.' prefix, etc.)
|
|
|
|
|
|
Copyright © 2026 Mike Amega (Logo) - Ame Web Studio
|
|
|
License: Proprietary - All Rights Reserved
|
|
|
"""
|
|
|
|
|
|
import re
|
|
|
from typing import Dict, Set, Optional
|
|
|
import torch
|
|
|
|
|
|
|
|
|
def remap_checkpoint_keys(
|
|
|
checkpoint_state_dict: Dict[str, torch.Tensor],
|
|
|
model_state_dict: Dict[str, torch.Tensor],
|
|
|
verbose: bool = False,
|
|
|
) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
Remappe automatiquement les clés du checkpoint pour correspondre au modèle.
|
|
|
|
|
|
Gère les scénarios suivants:
|
|
|
1. Checkpoint a préfixe 'model.' mais modèle n'en a pas → retirer préfixe
|
|
|
2. Checkpoint n'a pas préfixe 'model.' mais modèle en a → ajouter préfixe
|
|
|
3. Autres préfixes personnalisés
|
|
|
|
|
|
Args:
|
|
|
checkpoint_state_dict: État du checkpoint chargé
|
|
|
model_state_dict: État du modèle cible
|
|
|
verbose: Afficher les détails du remappage
|
|
|
|
|
|
Returns:
|
|
|
Dict remappé compatible avec le modèle
|
|
|
"""
|
|
|
model_keys = set(model_state_dict.keys())
|
|
|
checkpoint_keys = set(checkpoint_state_dict.keys())
|
|
|
|
|
|
|
|
|
matching = model_keys & checkpoint_keys
|
|
|
if len(matching) >= len(checkpoint_keys) * 0.9:
|
|
|
if verbose:
|
|
|
print(
|
|
|
f"✅ Checkpoint compatible: {len(matching)}/{len(checkpoint_keys)} clés correspondent"
|
|
|
)
|
|
|
return checkpoint_state_dict
|
|
|
|
|
|
|
|
|
strategies = [
|
|
|
("remove_model_prefix", _remove_prefix, "model."),
|
|
|
("add_model_prefix", _add_prefix, "model."),
|
|
|
("remove_backbone_prefix", _remove_prefix, "backbone."),
|
|
|
("remove_encoder_prefix", _remove_prefix, "encoder."),
|
|
|
]
|
|
|
|
|
|
best_strategy = None
|
|
|
best_match_count = len(matching)
|
|
|
best_result = checkpoint_state_dict
|
|
|
|
|
|
for name, func, prefix in strategies:
|
|
|
remapped = func(checkpoint_state_dict, prefix)
|
|
|
match_count = len(model_keys & set(remapped.keys()))
|
|
|
|
|
|
if match_count > best_match_count:
|
|
|
best_match_count = match_count
|
|
|
best_strategy = name
|
|
|
best_result = remapped
|
|
|
|
|
|
if verbose and best_strategy:
|
|
|
print(f"🔄 Stratégie appliquée: {best_strategy}")
|
|
|
print(f" Clés correspondantes: {best_match_count}/{len(checkpoint_keys)}")
|
|
|
|
|
|
|
|
|
if best_match_count < len(checkpoint_keys) * 0.5:
|
|
|
best_result = _smart_key_mapping(checkpoint_state_dict, model_keys)
|
|
|
if verbose:
|
|
|
final_match = len(model_keys & set(best_result.keys()))
|
|
|
print(
|
|
|
f"🧠 Remappage intelligent: {final_match}/{len(checkpoint_keys)} clés"
|
|
|
)
|
|
|
|
|
|
return best_result
|
|
|
|
|
|
|
|
|
def _remove_prefix(state_dict: Dict, prefix: str) -> Dict:
|
|
|
"""Retirer un préfixe de toutes les clés."""
|
|
|
return {
|
|
|
(k[len(prefix) :] if k.startswith(prefix) else k): v
|
|
|
for k, v in state_dict.items()
|
|
|
}
|
|
|
|
|
|
|
|
|
def _add_prefix(state_dict: Dict, prefix: str) -> Dict:
|
|
|
"""Ajouter un préfixe à toutes les clés."""
|
|
|
return {f"{prefix}{k}": v for k, v in state_dict.items()}
|
|
|
|
|
|
|
|
|
def _smart_key_mapping(
|
|
|
checkpoint_dict: Dict[str, torch.Tensor], model_keys: Set[str]
|
|
|
) -> Dict[str, torch.Tensor]:
|
|
|
"""
|
|
|
Mapping intelligent clé par clé basé sur les suffixes et patterns.
|
|
|
"""
|
|
|
result = {}
|
|
|
model_keys_list = list(model_keys)
|
|
|
|
|
|
for ckpt_key, value in checkpoint_dict.items():
|
|
|
|
|
|
if ckpt_key in model_keys:
|
|
|
result[ckpt_key] = value
|
|
|
continue
|
|
|
|
|
|
|
|
|
with_prefix = f"model.{ckpt_key}"
|
|
|
if with_prefix in model_keys:
|
|
|
result[with_prefix] = value
|
|
|
continue
|
|
|
|
|
|
|
|
|
if ckpt_key.startswith("model."):
|
|
|
without_prefix = ckpt_key[6:]
|
|
|
if without_prefix in model_keys:
|
|
|
result[without_prefix] = value
|
|
|
continue
|
|
|
|
|
|
|
|
|
ckpt_suffix = ckpt_key.split(".")[-1]
|
|
|
ckpt_base = ".".join(ckpt_key.split(".")[:-1])
|
|
|
|
|
|
for model_key in model_keys_list:
|
|
|
if model_key.endswith(ckpt_suffix):
|
|
|
model_base = ".".join(model_key.split(".")[:-1])
|
|
|
|
|
|
if _keys_similar(ckpt_base, model_base):
|
|
|
result[model_key] = value
|
|
|
break
|
|
|
else:
|
|
|
|
|
|
result[ckpt_key] = value
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
def _keys_similar(key1: str, key2: str) -> bool:
|
|
|
"""Vérifier si deux clés sont structurellement similaires."""
|
|
|
parts1 = key1.split(".")
|
|
|
parts2 = key2.split(".")
|
|
|
|
|
|
|
|
|
if len(parts1) != len(parts2):
|
|
|
return False
|
|
|
|
|
|
|
|
|
matches = sum(
|
|
|
1 for p1, p2 in zip(parts1, parts2) if p1 == p2 or p1.isdigit() and p2.isdigit()
|
|
|
)
|
|
|
return matches >= len(parts1) * 0.7
|
|
|
|
|
|
|
|
|
def validate_checkpoint(
|
|
|
checkpoint_state_dict: Dict[str, torch.Tensor],
|
|
|
model_state_dict: Dict[str, torch.Tensor],
|
|
|
strict: bool = False,
|
|
|
) -> Dict[str, any]:
|
|
|
"""
|
|
|
Valider qu'un checkpoint est compatible avec un modèle.
|
|
|
|
|
|
Returns:
|
|
|
Dict avec:
|
|
|
- valid: bool
|
|
|
- missing_keys: clés manquantes dans checkpoint
|
|
|
- unexpected_keys: clés inattendues dans checkpoint
|
|
|
- size_mismatches: clés avec tailles incompatibles
|
|
|
"""
|
|
|
model_keys = set(model_state_dict.keys())
|
|
|
ckpt_keys = set(checkpoint_state_dict.keys())
|
|
|
|
|
|
missing = model_keys - ckpt_keys
|
|
|
unexpected = ckpt_keys - model_keys
|
|
|
|
|
|
|
|
|
size_mismatches = []
|
|
|
for key in model_keys & ckpt_keys:
|
|
|
model_shape = model_state_dict[key].shape
|
|
|
ckpt_shape = checkpoint_state_dict[key].shape
|
|
|
if model_shape != ckpt_shape:
|
|
|
size_mismatches.append(
|
|
|
{"key": key, "model_shape": model_shape, "checkpoint_shape": ckpt_shape}
|
|
|
)
|
|
|
|
|
|
valid = len(missing) == 0 and len(size_mismatches) == 0
|
|
|
if not strict:
|
|
|
valid = len(size_mismatches) == 0 and len(missing) < len(model_keys) * 0.1
|
|
|
|
|
|
return {
|
|
|
"valid": valid,
|
|
|
"missing_keys": list(missing),
|
|
|
"unexpected_keys": list(unexpected),
|
|
|
"size_mismatches": size_mismatches,
|
|
|
"matched_keys": len(model_keys & ckpt_keys),
|
|
|
"total_model_keys": len(model_keys),
|
|
|
}
|
|
|
|
|
|
|
|
|
def save_cognitive_checkpoint(
|
|
|
model,
|
|
|
path: str,
|
|
|
include_optimizer: bool = False,
|
|
|
optimizer=None,
|
|
|
extra_state: Optional[Dict] = None,
|
|
|
):
|
|
|
"""
|
|
|
Sauvegarder un checkpoint de modèle cognitif.
|
|
|
|
|
|
Args:
|
|
|
model: Le modèle à sauvegarder
|
|
|
path: Chemin de sauvegarde
|
|
|
include_optimizer: Inclure l'état de l'optimiseur
|
|
|
optimizer: L'optimiseur (si include_optimizer=True)
|
|
|
extra_state: État additionnel à sauvegarder
|
|
|
"""
|
|
|
checkpoint = {
|
|
|
"model_state_dict": model.state_dict(),
|
|
|
"config": model.config.to_dict() if hasattr(model, "config") else {},
|
|
|
}
|
|
|
|
|
|
if include_optimizer and optimizer is not None:
|
|
|
checkpoint["optimizer_state_dict"] = optimizer.state_dict()
|
|
|
|
|
|
|
|
|
if hasattr(model, "get_cognitive_state"):
|
|
|
checkpoint["cognitive_state"] = model.get_cognitive_state()
|
|
|
|
|
|
if extra_state:
|
|
|
checkpoint["extra_state"] = extra_state
|
|
|
|
|
|
torch.save(checkpoint, path)
|
|
|
print(f"✅ Checkpoint sauvegardé: {path}")
|
|
|
|
|
|
|
|
|
def load_cognitive_checkpoint(
|
|
|
model, path: str, strict: bool = False, verbose: bool = True
|
|
|
) -> Dict:
|
|
|
"""
|
|
|
Charger un checkpoint dans un modèle cognitif avec remappage automatique.
|
|
|
|
|
|
Args:
|
|
|
model: Le modèle cible
|
|
|
path: Chemin du checkpoint
|
|
|
strict: Mode strict (erreur si clés manquantes)
|
|
|
verbose: Afficher les détails
|
|
|
|
|
|
Returns:
|
|
|
Dict avec informations de chargement
|
|
|
"""
|
|
|
checkpoint = torch.load(path, map_location="cpu")
|
|
|
|
|
|
|
|
|
if "model_state_dict" in checkpoint:
|
|
|
state_dict = checkpoint["model_state_dict"]
|
|
|
elif "state_dict" in checkpoint:
|
|
|
state_dict = checkpoint["state_dict"]
|
|
|
else:
|
|
|
state_dict = checkpoint
|
|
|
|
|
|
|
|
|
remapped = remap_checkpoint_keys(state_dict, model.state_dict(), verbose=verbose)
|
|
|
|
|
|
|
|
|
validation = validate_checkpoint(remapped, model.state_dict(), strict=strict)
|
|
|
|
|
|
if verbose:
|
|
|
print(
|
|
|
f"📊 Clés chargées: {validation['matched_keys']}/{validation['total_model_keys']}"
|
|
|
)
|
|
|
if validation["missing_keys"]:
|
|
|
print(f"⚠️ Clés manquantes: {len(validation['missing_keys'])}")
|
|
|
if validation["size_mismatches"]:
|
|
|
print(f"⚠️ Tailles incompatibles: {len(validation['size_mismatches'])}")
|
|
|
|
|
|
|
|
|
model.load_state_dict(remapped, strict=False)
|
|
|
|
|
|
|
|
|
if "cognitive_state" in checkpoint and hasattr(model, "reset_cognitive_state"):
|
|
|
|
|
|
pass
|
|
|
|
|
|
if verbose:
|
|
|
print("✅ Checkpoint chargé avec succès")
|
|
|
|
|
|
return {
|
|
|
"validation": validation,
|
|
|
"config": checkpoint.get("config", {}),
|
|
|
"extra_state": checkpoint.get("extra_state", {}),
|
|
|
}
|
|
|
|