File size: 10,100 Bytes
5fe4d99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
"""

COGNITIVE-CORE: Checkpoint Loading & Key Remapping

===================================================



This module provides robust checkpoint loading with automatic key remapping

to handle different checkpoint formats (with/without 'model.' prefix, etc.)



Copyright © 2026 Mike Amega (Logo) - Ame Web Studio

License: Proprietary - All Rights Reserved

"""

import re
from typing import Dict, Set, Optional
import torch


def remap_checkpoint_keys(

    checkpoint_state_dict: Dict[str, torch.Tensor],

    model_state_dict: Dict[str, torch.Tensor],

    verbose: bool = False,

) -> Dict[str, torch.Tensor]:
    """

    Remappe automatiquement les clés du checkpoint pour correspondre au modèle.



    Gère les scénarios suivants:

    1. Checkpoint a préfixe 'model.' mais modèle n'en a pas → retirer préfixe

    2. Checkpoint n'a pas préfixe 'model.' mais modèle en a → ajouter préfixe

    3. Autres préfixes personnalisés



    Args:

        checkpoint_state_dict: État du checkpoint chargé

        model_state_dict: État du modèle cible

        verbose: Afficher les détails du remappage



    Returns:

        Dict remappé compatible avec le modèle

    """
    model_keys = set(model_state_dict.keys())
    checkpoint_keys = set(checkpoint_state_dict.keys())

    # Vérifier si le checkpoint correspond déjà
    matching = model_keys & checkpoint_keys
    if len(matching) >= len(checkpoint_keys) * 0.9:
        if verbose:
            print(
                f"✅ Checkpoint compatible: {len(matching)}/{len(checkpoint_keys)} clés correspondent"
            )
        return checkpoint_state_dict

    # Tester différentes stratégies de remappage
    strategies = [
        ("remove_model_prefix", _remove_prefix, "model."),
        ("add_model_prefix", _add_prefix, "model."),
        ("remove_backbone_prefix", _remove_prefix, "backbone."),
        ("remove_encoder_prefix", _remove_prefix, "encoder."),
    ]

    best_strategy = None
    best_match_count = len(matching)
    best_result = checkpoint_state_dict

    for name, func, prefix in strategies:
        remapped = func(checkpoint_state_dict, prefix)
        match_count = len(model_keys & set(remapped.keys()))

        if match_count > best_match_count:
            best_match_count = match_count
            best_strategy = name
            best_result = remapped

    if verbose and best_strategy:
        print(f"🔄 Stratégie appliquée: {best_strategy}")
        print(f"   Clés correspondantes: {best_match_count}/{len(checkpoint_keys)}")

    # Fallback: mapper intelligemment clé par clé
    if best_match_count < len(checkpoint_keys) * 0.5:
        best_result = _smart_key_mapping(checkpoint_state_dict, model_keys)
        if verbose:
            final_match = len(model_keys & set(best_result.keys()))
            print(
                f"🧠 Remappage intelligent: {final_match}/{len(checkpoint_keys)} clés"
            )

    return best_result


def _remove_prefix(state_dict: Dict, prefix: str) -> Dict:
    """Retirer un préfixe de toutes les clés."""
    return {
        (k[len(prefix) :] if k.startswith(prefix) else k): v
        for k, v in state_dict.items()
    }


def _add_prefix(state_dict: Dict, prefix: str) -> Dict:
    """Ajouter un préfixe à toutes les clés."""
    return {f"{prefix}{k}": v for k, v in state_dict.items()}


def _smart_key_mapping(

    checkpoint_dict: Dict[str, torch.Tensor], model_keys: Set[str]

) -> Dict[str, torch.Tensor]:
    """

    Mapping intelligent clé par clé basé sur les suffixes et patterns.

    """
    result = {}
    model_keys_list = list(model_keys)

    for ckpt_key, value in checkpoint_dict.items():
        # Correspondance exacte
        if ckpt_key in model_keys:
            result[ckpt_key] = value
            continue

        # Essayer avec préfixe 'model.'
        with_prefix = f"model.{ckpt_key}"
        if with_prefix in model_keys:
            result[with_prefix] = value
            continue

        # Essayer sans préfixe 'model.'
        if ckpt_key.startswith("model."):
            without_prefix = ckpt_key[6:]
            if without_prefix in model_keys:
                result[without_prefix] = value
                continue

        # Chercher par suffixe (ex: ".weight", ".bias")
        ckpt_suffix = ckpt_key.split(".")[-1]
        ckpt_base = ".".join(ckpt_key.split(".")[:-1])

        for model_key in model_keys_list:
            if model_key.endswith(ckpt_suffix):
                model_base = ".".join(model_key.split(".")[:-1])
                # Vérifier similarité structurelle
                if _keys_similar(ckpt_base, model_base):
                    result[model_key] = value
                    break
        else:
            # Garder la clé originale (sera ignorée si pas dans modèle)
            result[ckpt_key] = value

    return result


def _keys_similar(key1: str, key2: str) -> bool:
    """Vérifier si deux clés sont structurellement similaires."""
    parts1 = key1.split(".")
    parts2 = key2.split(".")

    # Même nombre de parties
    if len(parts1) != len(parts2):
        return False

    # Comparer chaque partie (ignorer les préfixes comme 'model')
    matches = sum(
        1 for p1, p2 in zip(parts1, parts2) if p1 == p2 or p1.isdigit() and p2.isdigit()
    )
    return matches >= len(parts1) * 0.7


def validate_checkpoint(

    checkpoint_state_dict: Dict[str, torch.Tensor],

    model_state_dict: Dict[str, torch.Tensor],

    strict: bool = False,

) -> Dict[str, any]:
    """

    Valider qu'un checkpoint est compatible avec un modèle.



    Returns:

        Dict avec:

        - valid: bool

        - missing_keys: clés manquantes dans checkpoint

        - unexpected_keys: clés inattendues dans checkpoint

        - size_mismatches: clés avec tailles incompatibles

    """
    model_keys = set(model_state_dict.keys())
    ckpt_keys = set(checkpoint_state_dict.keys())

    missing = model_keys - ckpt_keys
    unexpected = ckpt_keys - model_keys

    # Vérifier les tailles
    size_mismatches = []
    for key in model_keys & ckpt_keys:
        model_shape = model_state_dict[key].shape
        ckpt_shape = checkpoint_state_dict[key].shape
        if model_shape != ckpt_shape:
            size_mismatches.append(
                {"key": key, "model_shape": model_shape, "checkpoint_shape": ckpt_shape}
            )

    valid = len(missing) == 0 and len(size_mismatches) == 0
    if not strict:
        valid = len(size_mismatches) == 0 and len(missing) < len(model_keys) * 0.1

    return {
        "valid": valid,
        "missing_keys": list(missing),
        "unexpected_keys": list(unexpected),
        "size_mismatches": size_mismatches,
        "matched_keys": len(model_keys & ckpt_keys),
        "total_model_keys": len(model_keys),
    }


def save_cognitive_checkpoint(

    model,

    path: str,

    include_optimizer: bool = False,

    optimizer=None,

    extra_state: Optional[Dict] = None,

):
    """

    Sauvegarder un checkpoint de modèle cognitif.



    Args:

        model: Le modèle à sauvegarder

        path: Chemin de sauvegarde

        include_optimizer: Inclure l'état de l'optimiseur

        optimizer: L'optimiseur (si include_optimizer=True)

        extra_state: État additionnel à sauvegarder

    """
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "config": model.config.to_dict() if hasattr(model, "config") else {},
    }

    if include_optimizer and optimizer is not None:
        checkpoint["optimizer_state_dict"] = optimizer.state_dict()

    # Sauvegarder l'état cognitif si disponible
    if hasattr(model, "get_cognitive_state"):
        checkpoint["cognitive_state"] = model.get_cognitive_state()

    if extra_state:
        checkpoint["extra_state"] = extra_state

    torch.save(checkpoint, path)
    print(f"✅ Checkpoint sauvegardé: {path}")


def load_cognitive_checkpoint(

    model, path: str, strict: bool = False, verbose: bool = True

) -> Dict:
    """

    Charger un checkpoint dans un modèle cognitif avec remappage automatique.



    Args:

        model: Le modèle cible

        path: Chemin du checkpoint

        strict: Mode strict (erreur si clés manquantes)

        verbose: Afficher les détails



    Returns:

        Dict avec informations de chargement

    """
    checkpoint = torch.load(path, map_location="cpu")

    # Extraire le state_dict
    if "model_state_dict" in checkpoint:
        state_dict = checkpoint["model_state_dict"]
    elif "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
    else:
        state_dict = checkpoint

    # Remapper les clés
    remapped = remap_checkpoint_keys(state_dict, model.state_dict(), verbose=verbose)

    # Valider
    validation = validate_checkpoint(remapped, model.state_dict(), strict=strict)

    if verbose:
        print(
            f"📊 Clés chargées: {validation['matched_keys']}/{validation['total_model_keys']}"
        )
        if validation["missing_keys"]:
            print(f"⚠️ Clés manquantes: {len(validation['missing_keys'])}")
        if validation["size_mismatches"]:
            print(f"⚠️ Tailles incompatibles: {len(validation['size_mismatches'])}")

    # Charger avec ignore_mismatched_sizes pour robustesse
    model.load_state_dict(remapped, strict=False)

    # Restaurer l'état cognitif si disponible
    if "cognitive_state" in checkpoint and hasattr(model, "reset_cognitive_state"):
        # L'état cognitif est généralement réinitialisé, pas restauré
        pass

    if verbose:
        print("✅ Checkpoint chargé avec succès")

    return {
        "validation": validation,
        "config": checkpoint.get("config", {}),
        "extra_state": checkpoint.get("extra_state", {}),
    }