File size: 1,464 Bytes
a080b32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from torch.serialization import add_safe_globals
from torch.nn.modules.loss import CrossEntropyLoss

# Tambahkan whitelist agar bisa load aman
add_safe_globals([CrossEntropyLoss])

old_model_path = r"C:\Users\user\Documents\Project\Cataract-ViT\outputs\models\best_swin_model_final.pth"
new_model_path = r"C:\Users\user\Documents\Project\Cataract-ViT\outputs\models\best_swin_weights_only.pth"

# Load dengan weights_only=False karena ini file trusted
checkpoint = torch.load(old_model_path, map_location="cpu", weights_only=False)

# Deteksi isi file
if isinstance(checkpoint, dict):
    # Coba ambil beberapa kemungkinan key umum
    if "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
    elif "model_state_dict" in checkpoint:
        state_dict = checkpoint["model_state_dict"]
    elif "model" in checkpoint and hasattr(checkpoint["model"], "state_dict"):
        state_dict = checkpoint["model"].state_dict()
    else:
        raise ValueError(f"❌ Tidak ditemukan key state_dict dalam checkpoint: {checkpoint.keys()}")
elif hasattr(checkpoint, "state_dict"):
    # Kalau model langsung
    state_dict = checkpoint.state_dict()
else:
    raise ValueError("❌ File tidak berisi model atau dictionary dengan state_dict yang valid.")

# Simpan ulang hanya weight-nya
torch.save(state_dict, new_model_path)

print(f"✅ State dict berhasil diekstrak dan disimpan ke:\n{new_model_path}")