import torch import os class CheckpointLoader: def __init__(self, load_dir, device="cuda"): self.load_dir = load_dir self.device = device def get_latest_checkpoint(self): """Megkeresi a legfrissebb mentést idő alapján.""" if not os.path.exists(self.load_dir): return None files = [ os.path.join(self.load_dir, f) for f in os.listdir(self.load_dir) if f.endswith(".pt") and f.startswith("step_") ] if not files: return None # Idő szerint rendezés (legutolsó a legfrissebb) files.sort(key=os.path.getmtime) return files[-1] def load_latest(self, model, optimizer=None, scheduler=None): """ Betölti a legutolsó checkpointot a modellbe és az optimizerbe. Visszaadja a (state_dict, filename) párt. """ path = self.get_latest_checkpoint() if not path: print("[INFO] Nem találtam checkpointot, tiszta lappal indul a tanítás.") return None, None print(f"[INFO] Checkpoint betöltése: {path}") filename = os.path.basename(path) try: checkpoint = torch.load(path, map_location=self.device, weights_only=False) # 1. Model State sd = None if "model_state_dict" in checkpoint: sd = checkpoint["model_state_dict"] elif "trainer" in checkpoint: sd = checkpoint["trainer"] else: sd = checkpoint was_skipped = False if sd: # Szűrés méretbeli eltérésre model_sd = model.state_dict() filtered_sd = {} for k, v in sd.items(): if k in model_sd: if v.shape == model_sd[k].shape: filtered_sd[k] = v else: print( f"[WARN] Méretbeli eltérés, kihagyom: {k} ({v.shape} vs {model_sd[k].shape})" ) was_skipped = True else: filtered_sd[k] = v model.load_state_dict(filtered_sd, strict=False) print("[OK] Modell súlyok betöltve.") # 2. Optimizer State - CSAK ha nem volt architektúra váltás! if optimizer and "optimizer_state_dict" in checkpoint: if was_skipped: print( "[INFO] Architektúra váltást észleltem, az Optimizer állapota nem kerül betöltésre (tiszta indítás)." ) else: try: optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) print("[OK] Optimizer állapot betöltve.") except Exception as e: print(f"[WARN] Optimizer betöltése sikertelen: {e}") # 3. Scheduler State (ha van és kértük) if scheduler and "scheduler_state_dict" in checkpoint: try: scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) print("[OK] Scheduler állapot betöltve.") except Exception as e: print(f"[WARN] Scheduler betöltése sikertelen: {e}") return checkpoint, filename except Exception as e: print(f"[HIBA] Checkpoint betöltése sikertelen: {e}") return None, None