|
|
import torch
|
|
|
from torch.serialization import add_safe_globals
|
|
|
from torch.nn.modules.loss import CrossEntropyLoss
|
|
|
|
|
|
|
|
|
add_safe_globals([CrossEntropyLoss])
|
|
|
|
|
|
old_model_path = r"C:\Users\user\Documents\Project\Cataract-ViT\outputs\models\best_swin_model_final.pth"
|
|
|
new_model_path = r"C:\Users\user\Documents\Project\Cataract-ViT\outputs\models\best_swin_weights_only.pth"
|
|
|
|
|
|
|
|
|
checkpoint = torch.load(old_model_path, map_location="cpu", weights_only=False)
|
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict):
|
|
|
|
|
|
if "state_dict" in checkpoint:
|
|
|
state_dict = checkpoint["state_dict"]
|
|
|
elif "model_state_dict" in checkpoint:
|
|
|
state_dict = checkpoint["model_state_dict"]
|
|
|
elif "model" in checkpoint and hasattr(checkpoint["model"], "state_dict"):
|
|
|
state_dict = checkpoint["model"].state_dict()
|
|
|
else:
|
|
|
raise ValueError(f"❌ Tidak ditemukan key state_dict dalam checkpoint: {checkpoint.keys()}")
|
|
|
elif hasattr(checkpoint, "state_dict"):
|
|
|
|
|
|
state_dict = checkpoint.state_dict()
|
|
|
else:
|
|
|
raise ValueError("❌ File tidak berisi model atau dictionary dengan state_dict yang valid.")
|
|
|
|
|
|
|
|
|
torch.save(state_dict, new_model_path)
|
|
|
|
|
|
print(f"✅ State dict berhasil diekstrak dan disimpan ke:\n{new_model_path}")
|
|
|
|