File size: 3,735 Bytes
900b898
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import os


class CheckpointLoader:
    def __init__(self, load_dir, device="cuda"):
        self.load_dir = load_dir
        self.device = device

    def get_latest_checkpoint(self):
        """Megkeresi a legfrissebb mentést idő alapján."""
        if not os.path.exists(self.load_dir):
            return None

        files = [
            os.path.join(self.load_dir, f)
            for f in os.listdir(self.load_dir)
            if f.endswith(".pt") and f.startswith("step_")
        ]

        if not files:
            return None

        # Idő szerint rendezés (legutolsó a legfrissebb)
        files.sort(key=os.path.getmtime)
        return files[-1]

    def load_latest(self, model, optimizer=None, scheduler=None):
        """

        Betölti a legutolsó checkpointot a modellbe és az optimizerbe.

        Visszaadja a (state_dict, filename) párt.

        """
        path = self.get_latest_checkpoint()
        if not path:
            print("[INFO] Nem találtam checkpointot, tiszta lappal indul a tanítás.")
            return None, None

        print(f"[INFO] Checkpoint betöltése: {path}")
        filename = os.path.basename(path)

        try:
            checkpoint = torch.load(path, map_location=self.device, weights_only=False)

            # 1. Model State
            sd = None
            if "model_state_dict" in checkpoint:
                sd = checkpoint["model_state_dict"]
            elif "trainer" in checkpoint:
                sd = checkpoint["trainer"]
            else:
                sd = checkpoint

            was_skipped = False
            if sd:
                # Szűrés méretbeli eltérésre
                model_sd = model.state_dict()
                filtered_sd = {}
                for k, v in sd.items():
                    if k in model_sd:
                        if v.shape == model_sd[k].shape:
                            filtered_sd[k] = v
                        else:
                            print(
                                f"[WARN] Méretbeli eltérés, kihagyom: {k} ({v.shape} vs {model_sd[k].shape})"
                            )
                            was_skipped = True
                    else:
                        filtered_sd[k] = v

                model.load_state_dict(filtered_sd, strict=False)

            print("[OK] Modell súlyok betöltve.")

            # 2. Optimizer State - CSAK ha nem volt architektúra váltás!
            if optimizer and "optimizer_state_dict" in checkpoint:
                if was_skipped:
                    print(
                        "[INFO] Architektúra váltást észleltem, az Optimizer állapota nem kerül betöltésre (tiszta indítás)."
                    )
                else:
                    try:
                        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
                        print("[OK] Optimizer állapot betöltve.")
                    except Exception as e:
                        print(f"[WARN] Optimizer betöltése sikertelen: {e}")

            # 3. Scheduler State (ha van és kértük)
            if scheduler and "scheduler_state_dict" in checkpoint:
                try:
                    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
                    print("[OK] Scheduler állapot betöltve.")
                except Exception as e:
                    print(f"[WARN] Scheduler betöltése sikertelen: {e}")

            return checkpoint, filename

        except Exception as e:
            print(f"[HIBA] Checkpoint betöltése sikertelen: {e}")
            return None, None