AUDIO_Emotion_Recognition / model_loader.py
mohannad125's picture
Update model_loader.py
b726d23 verified
import torch
import torch.nn as nn
import os
from typing import Tuple, Optional, Any
import warnings
def safe_load_model(model_path: str, device: torch.device, model_instance: nn.Module) -> Tuple[Optional[nn.Module], bool]:
"""
ุชุญู…ูŠู„ ุขู…ู† ู„ู„ู†ู…ูˆุฐุฌ ู…ุน ู…ุนุงู„ุฌุฉ ุงู„ุฃุฎุทุงุก
Args:
model_path: ู…ุณุงุฑ ู…ู„ู ุงู„ู†ู…ูˆุฐุฌ
device: ุงู„ุฌู‡ุงุฒ (CPU/GPU)
model_instance: instance ู…ู† ุงู„ู†ู…ูˆุฐุฌ
Returns:
tuple: (ุงู„ู†ู…ูˆุฐุฌ ุงู„ู…ุญู…ู„ุŒ ู†ุฌุญ ุงู„ุชุญู…ูŠู„ ุฃู… ู„ุง)
"""
try:
# ุงู„ุชุญู‚ู‚ ู…ู† ูˆุฌูˆุฏ ุงู„ู…ู„ู
if not os.path.exists(model_path):
print(f"โŒ ู…ู„ู ุงู„ู†ู…ูˆุฐุฌ ุบูŠุฑ ู…ูˆุฌูˆุฏ: {model_path}")
print("๐Ÿ’ก ุชุฃูƒุฏ ู…ู† ูˆุถุน ู…ู„ู best_model.pth ููŠ ู†ูุณ ู…ุฌู„ุฏ ุงู„ุชุทุจูŠู‚")
return None, False
print(f"๐Ÿ“‚ ุฌุงุฑูŠ ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ ู…ู†: {model_path}")
# ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ
checkpoint = torch.load(model_path, map_location=device)
# ุงู„ุชุญู‚ู‚ ู…ู† ู†ูˆุน checkpoint
if isinstance(checkpoint, dict):
# ุฅุฐุง ูƒุงู† ุงู„ู†ู…ูˆุฐุฌ ู…ุญููˆุธ ูƒู€ state_dict
if 'model_state_dict' in checkpoint:
model_instance.load_state_dict(checkpoint['model_state_dict'])
print("โœ… ุชู… ุชุญู…ูŠู„ state_dict ู…ู† checkpoint")
elif 'state_dict' in checkpoint:
model_instance.load_state_dict(checkpoint['state_dict'])
print("โœ… ุชู… ุชุญู…ูŠู„ state_dict")
else:
# ู…ุญุงูˆู„ุฉ ุชุญู…ูŠู„ dict ู…ุจุงุดุฑุฉ
model_instance.load_state_dict(checkpoint)
print("โœ… ุชู… ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ ูƒู€ state_dict")
else:
# ุฅุฐุง ูƒุงู† ุงู„ู†ู…ูˆุฐุฌ ู…ุญููˆุธ ูƒู€ full model
model_instance = checkpoint
print("โœ… ุชู… ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ ุงู„ูƒุงู…ู„")
# ู†ู‚ู„ ุฅู„ู‰ ุงู„ุฌู‡ุงุฒ ุงู„ู…ู†ุงุณุจ
model_instance = model_instance.to(device)
model_instance.eval() # ูˆุถุน ุงู„ุชู‚ูŠูŠู…
print(f"โœ… ุชู… ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ ุจู†ุฌุงุญ ุนู„ู‰ {device}")
return model_instance, True
except FileNotFoundError:
print(f"โŒ ู…ู„ู ุงู„ู†ู…ูˆุฐุฌ ุบูŠุฑ ู…ูˆุฌูˆุฏ: {model_path}")
return None, False
except RuntimeError as e:
print(f"โŒ ุฎุทุฃ ููŠ ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ: {e}")
print("๐Ÿ’ก ุชุฃูƒุฏ ู…ู† ุฃู† ุจู†ูŠุฉ ุงู„ู†ู…ูˆุฐุฌ ู…ุชุทุงุจู‚ุฉ ู…ุน ุงู„ู†ู…ูˆุฐุฌ ุงู„ู…ุญููˆุธ")
return None, False
except Exception as e:
print(f"โŒ ุฎุทุฃ ุบูŠุฑ ู…ุชูˆู‚ุน ููŠ ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ: {e}")
return None, False
def validate_model_architecture(model: nn.Module, expected_input_size: int = 193) -> bool:
"""
ุงู„ุชุญู‚ู‚ ู…ู† ุตุญุฉ ุจู†ูŠุฉ ุงู„ู†ู…ูˆุฐุฌ
Args:
model: ุงู„ู†ู…ูˆุฐุฌ ุงู„ู…ุญู…ู„
expected_input_size: ุญุฌู… ุงู„ุฅุฏุฎุงู„ ุงู„ู…ุชูˆู‚ุน
Returns:
bool: ุตุญูŠุญ ุฅุฐุง ูƒุงู†ุช ุงู„ุจู†ูŠุฉ ุตุญูŠุญุฉ
"""
try:
# ุฅู†ุดุงุก tensor ุชุฌุฑูŠุจูŠ
dummy_input = torch.randn(1, expected_input_size)
# ุชุฌุฑุจุฉ forward pass
with torch.no_grad():
output = model(dummy_input)
print(f"โœ… ุจู†ูŠุฉ ุงู„ู†ู…ูˆุฐุฌ ุตุญูŠุญุฉ - ุงู„ุฅุฏุฎุงู„: {expected_input_size}, ุงู„ุฅุฎุฑุงุฌ: {output.shape}")
return True
except Exception as e:
print(f"โŒ ุจู†ูŠุฉ ุงู„ู†ู…ูˆุฐุฌ ุบูŠุฑ ุตุญูŠุญุฉ: {e}")
return False
def create_dummy_model(num_classes: int = 8) -> nn.Module:
"""
ุฅู†ุดุงุก ู†ู…ูˆุฐุฌ ูˆู‡ู…ูŠ ู„ู„ุงุฎุชุจุงุฑ ุนู†ุฏ ุนุฏู… ูˆุฌูˆุฏ ุงู„ู†ู…ูˆุฐุฌ ุงู„ุฃุตู„ูŠ
Args:
num_classes: ุนุฏุฏ ุงู„ูุฆุงุช
Returns:
ู†ู…ูˆุฐุฌ ูˆู‡ู…ูŠ
"""
print("โš ๏ธ ุฅู†ุดุงุก ู†ู…ูˆุฐุฌ ูˆู‡ู…ูŠ ู„ู„ุงุฎุชุจุงุฑ...")
class DummyEmotionNet(nn.Module):
def __init__(self, num_classes=8):
super(DummyEmotionNet, self).__init__()
self.fc = nn.Linear(193, num_classes) # 193 ู‡ูˆ ุญุฌู… ุงู„ู…ูŠุฒุงุช ุงู„ู…ุณุชุฎุฑุฌุฉ
def forward(self, x):
return self.fc(x)
model = DummyEmotionNet(num_classes)
print("โœ… ุชู… ุฅู†ุดุงุก ุงู„ู†ู…ูˆุฐุฌ ุงู„ูˆู‡ู…ูŠ")
return model
def check_model_file(model_path: str = 'best_model.pth'):
"""
ูุญุต ู…ู„ู ุงู„ู†ู…ูˆุฐุฌ ูˆุฅุนุทุงุก ู…ุนู„ูˆู…ุงุช ุนู†ู‡
Args:
model_path: ู…ุณุงุฑ ู…ู„ู ุงู„ู†ู…ูˆุฐุฌ
"""
print(f"๐Ÿ” ูุญุต ู…ู„ู ุงู„ู†ู…ูˆุฐุฌ: {model_path}")
if not os.path.exists(model_path):
print("โŒ ู…ู„ู ุงู„ู†ู…ูˆุฐุฌ ุบูŠุฑ ู…ูˆุฌูˆุฏ!")
print("๐Ÿ’ก ุชุฃูƒุฏ ู…ู†:")
print(" 1. ูˆุถุน ู…ู„ู best_model.pth ููŠ ู†ูุณ ู…ุฌู„ุฏ ุงู„ุชุทุจูŠู‚")
print(" 2. ุฃู† ุงุณู… ุงู„ู…ู„ู ุตุญูŠุญ")
print(" 3. ุฃู† ุงู„ู…ู„ู ุบูŠุฑ ุชุงู„ู")
return False
# ู…ุนู„ูˆู…ุงุช ุงู„ู…ู„ู
file_size = os.path.getsize(model_path)
print(f"๐Ÿ“Š ุญุฌู… ุงู„ู…ู„ู: {file_size / (1024*1024):.2f} MB")
# ู…ุญุงูˆู„ุฉ ู‚ุฑุงุกุฉ ุงู„ู…ู„ู
try:
checkpoint = torch.load(model_path, map_location='cpu')
print("โœ… ูŠู…ูƒู† ู‚ุฑุงุกุฉ ุงู„ู…ู„ู")
if isinstance(checkpoint, dict):
print("๐Ÿ“‹ ู…ุญุชูˆูŠุงุช ุงู„ู…ู„ู:")
for key in checkpoint.keys():
print(f" - {key}")
return True
except Exception as e:
print(f"โŒ ุฎุทุฃ ููŠ ู‚ุฑุงุกุฉ ุงู„ู…ู„ู: {e}")
return False
if __name__ == "__main__":
# ุงุฎุชุจุงุฑ ุงู„ู†ุธุงู…
print("๐Ÿงช ุงุฎุชุจุงุฑ ู†ุธุงู… ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ...")
check_model_file()