cognitive-core / cognitive_checkpoint.py
amewebstudio's picture
cognitive core loading framework
5fe4d99 verified
"""
COGNITIVE-CORE: Checkpoint Loading & Key Remapping
===================================================
This module provides robust checkpoint loading with automatic key remapping
to handle different checkpoint formats (with/without 'model.' prefix, etc.)
Copyright © 2026 Mike Amega (Logo) - Ame Web Studio
License: Proprietary - All Rights Reserved
"""
import re
from typing import Dict, Set, Optional
import torch
def remap_checkpoint_keys(
checkpoint_state_dict: Dict[str, torch.Tensor],
model_state_dict: Dict[str, torch.Tensor],
verbose: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Remappe automatiquement les clés du checkpoint pour correspondre au modèle.
Gère les scénarios suivants:
1. Checkpoint a préfixe 'model.' mais modèle n'en a pas → retirer préfixe
2. Checkpoint n'a pas préfixe 'model.' mais modèle en a → ajouter préfixe
3. Autres préfixes personnalisés
Args:
checkpoint_state_dict: État du checkpoint chargé
model_state_dict: État du modèle cible
verbose: Afficher les détails du remappage
Returns:
Dict remappé compatible avec le modèle
"""
model_keys = set(model_state_dict.keys())
checkpoint_keys = set(checkpoint_state_dict.keys())
# Vérifier si le checkpoint correspond déjà
matching = model_keys & checkpoint_keys
if len(matching) >= len(checkpoint_keys) * 0.9:
if verbose:
print(
f"✅ Checkpoint compatible: {len(matching)}/{len(checkpoint_keys)} clés correspondent"
)
return checkpoint_state_dict
# Tester différentes stratégies de remappage
strategies = [
("remove_model_prefix", _remove_prefix, "model."),
("add_model_prefix", _add_prefix, "model."),
("remove_backbone_prefix", _remove_prefix, "backbone."),
("remove_encoder_prefix", _remove_prefix, "encoder."),
]
best_strategy = None
best_match_count = len(matching)
best_result = checkpoint_state_dict
for name, func, prefix in strategies:
remapped = func(checkpoint_state_dict, prefix)
match_count = len(model_keys & set(remapped.keys()))
if match_count > best_match_count:
best_match_count = match_count
best_strategy = name
best_result = remapped
if verbose and best_strategy:
print(f"🔄 Stratégie appliquée: {best_strategy}")
print(f" Clés correspondantes: {best_match_count}/{len(checkpoint_keys)}")
# Fallback: mapper intelligemment clé par clé
if best_match_count < len(checkpoint_keys) * 0.5:
best_result = _smart_key_mapping(checkpoint_state_dict, model_keys)
if verbose:
final_match = len(model_keys & set(best_result.keys()))
print(
f"🧠 Remappage intelligent: {final_match}/{len(checkpoint_keys)} clés"
)
return best_result
def _remove_prefix(state_dict: Dict, prefix: str) -> Dict:
"""Retirer un préfixe de toutes les clés."""
return {
(k[len(prefix) :] if k.startswith(prefix) else k): v
for k, v in state_dict.items()
}
def _add_prefix(state_dict: Dict, prefix: str) -> Dict:
"""Ajouter un préfixe à toutes les clés."""
return {f"{prefix}{k}": v for k, v in state_dict.items()}
def _smart_key_mapping(
checkpoint_dict: Dict[str, torch.Tensor], model_keys: Set[str]
) -> Dict[str, torch.Tensor]:
"""
Mapping intelligent clé par clé basé sur les suffixes et patterns.
"""
result = {}
model_keys_list = list(model_keys)
for ckpt_key, value in checkpoint_dict.items():
# Correspondance exacte
if ckpt_key in model_keys:
result[ckpt_key] = value
continue
# Essayer avec préfixe 'model.'
with_prefix = f"model.{ckpt_key}"
if with_prefix in model_keys:
result[with_prefix] = value
continue
# Essayer sans préfixe 'model.'
if ckpt_key.startswith("model."):
without_prefix = ckpt_key[6:]
if without_prefix in model_keys:
result[without_prefix] = value
continue
# Chercher par suffixe (ex: ".weight", ".bias")
ckpt_suffix = ckpt_key.split(".")[-1]
ckpt_base = ".".join(ckpt_key.split(".")[:-1])
for model_key in model_keys_list:
if model_key.endswith(ckpt_suffix):
model_base = ".".join(model_key.split(".")[:-1])
# Vérifier similarité structurelle
if _keys_similar(ckpt_base, model_base):
result[model_key] = value
break
else:
# Garder la clé originale (sera ignorée si pas dans modèle)
result[ckpt_key] = value
return result
def _keys_similar(key1: str, key2: str) -> bool:
"""Vérifier si deux clés sont structurellement similaires."""
parts1 = key1.split(".")
parts2 = key2.split(".")
# Même nombre de parties
if len(parts1) != len(parts2):
return False
# Comparer chaque partie (ignorer les préfixes comme 'model')
matches = sum(
1 for p1, p2 in zip(parts1, parts2) if p1 == p2 or p1.isdigit() and p2.isdigit()
)
return matches >= len(parts1) * 0.7
def validate_checkpoint(
checkpoint_state_dict: Dict[str, torch.Tensor],
model_state_dict: Dict[str, torch.Tensor],
strict: bool = False,
) -> Dict[str, any]:
"""
Valider qu'un checkpoint est compatible avec un modèle.
Returns:
Dict avec:
- valid: bool
- missing_keys: clés manquantes dans checkpoint
- unexpected_keys: clés inattendues dans checkpoint
- size_mismatches: clés avec tailles incompatibles
"""
model_keys = set(model_state_dict.keys())
ckpt_keys = set(checkpoint_state_dict.keys())
missing = model_keys - ckpt_keys
unexpected = ckpt_keys - model_keys
# Vérifier les tailles
size_mismatches = []
for key in model_keys & ckpt_keys:
model_shape = model_state_dict[key].shape
ckpt_shape = checkpoint_state_dict[key].shape
if model_shape != ckpt_shape:
size_mismatches.append(
{"key": key, "model_shape": model_shape, "checkpoint_shape": ckpt_shape}
)
valid = len(missing) == 0 and len(size_mismatches) == 0
if not strict:
valid = len(size_mismatches) == 0 and len(missing) < len(model_keys) * 0.1
return {
"valid": valid,
"missing_keys": list(missing),
"unexpected_keys": list(unexpected),
"size_mismatches": size_mismatches,
"matched_keys": len(model_keys & ckpt_keys),
"total_model_keys": len(model_keys),
}
def save_cognitive_checkpoint(
model,
path: str,
include_optimizer: bool = False,
optimizer=None,
extra_state: Optional[Dict] = None,
):
"""
Sauvegarder un checkpoint de modèle cognitif.
Args:
model: Le modèle à sauvegarder
path: Chemin de sauvegarde
include_optimizer: Inclure l'état de l'optimiseur
optimizer: L'optimiseur (si include_optimizer=True)
extra_state: État additionnel à sauvegarder
"""
checkpoint = {
"model_state_dict": model.state_dict(),
"config": model.config.to_dict() if hasattr(model, "config") else {},
}
if include_optimizer and optimizer is not None:
checkpoint["optimizer_state_dict"] = optimizer.state_dict()
# Sauvegarder l'état cognitif si disponible
if hasattr(model, "get_cognitive_state"):
checkpoint["cognitive_state"] = model.get_cognitive_state()
if extra_state:
checkpoint["extra_state"] = extra_state
torch.save(checkpoint, path)
print(f"✅ Checkpoint sauvegardé: {path}")
def load_cognitive_checkpoint(
model, path: str, strict: bool = False, verbose: bool = True
) -> Dict:
"""
Charger un checkpoint dans un modèle cognitif avec remappage automatique.
Args:
model: Le modèle cible
path: Chemin du checkpoint
strict: Mode strict (erreur si clés manquantes)
verbose: Afficher les détails
Returns:
Dict avec informations de chargement
"""
checkpoint = torch.load(path, map_location="cpu")
# Extraire le state_dict
if "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
elif "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
else:
state_dict = checkpoint
# Remapper les clés
remapped = remap_checkpoint_keys(state_dict, model.state_dict(), verbose=verbose)
# Valider
validation = validate_checkpoint(remapped, model.state_dict(), strict=strict)
if verbose:
print(
f"📊 Clés chargées: {validation['matched_keys']}/{validation['total_model_keys']}"
)
if validation["missing_keys"]:
print(f"⚠️ Clés manquantes: {len(validation['missing_keys'])}")
if validation["size_mismatches"]:
print(f"⚠️ Tailles incompatibles: {len(validation['size_mismatches'])}")
# Charger avec ignore_mismatched_sizes pour robustesse
model.load_state_dict(remapped, strict=False)
# Restaurer l'état cognitif si disponible
if "cognitive_state" in checkpoint and hasattr(model, "reset_cognitive_state"):
# L'état cognitif est généralement réinitialisé, pas restauré
pass
if verbose:
print("✅ Checkpoint chargé avec succès")
return {
"validation": validation,
"config": checkpoint.get("config", {}),
"extra_state": checkpoint.get("extra_state", {}),
}