Cataract-ViT / src /convert_to_state_dict.py
Decoder24's picture
Upload folder using huggingface_hub
a080b32 verified
import torch
from torch.serialization import add_safe_globals
from torch.nn.modules.loss import CrossEntropyLoss
# Tambahkan whitelist agar bisa load aman
add_safe_globals([CrossEntropyLoss])
old_model_path = r"C:\Users\user\Documents\Project\Cataract-ViT\outputs\models\best_swin_model_final.pth"
new_model_path = r"C:\Users\user\Documents\Project\Cataract-ViT\outputs\models\best_swin_weights_only.pth"
# Load dengan weights_only=False karena ini file trusted
checkpoint = torch.load(old_model_path, map_location="cpu", weights_only=False)
# Deteksi isi file
if isinstance(checkpoint, dict):
# Coba ambil beberapa kemungkinan key umum
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
elif "model_state_dict" in checkpoint:
state_dict = checkpoint["model_state_dict"]
elif "model" in checkpoint and hasattr(checkpoint["model"], "state_dict"):
state_dict = checkpoint["model"].state_dict()
else:
raise ValueError(f"❌ Tidak ditemukan key state_dict dalam checkpoint: {checkpoint.keys()}")
elif hasattr(checkpoint, "state_dict"):
# Kalau model langsung
state_dict = checkpoint.state_dict()
else:
raise ValueError("❌ File tidak berisi model atau dictionary dengan state_dict yang valid.")
# Simpan ulang hanya weight-nya
torch.save(state_dict, new_model_path)
print(f"✅ State dict berhasil diekstrak dan disimpan ke:\n{new_model_path}")