Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import os | |
| from typing import Tuple, Optional, Any | |
| import warnings | |
| def safe_load_model(model_path: str, device: torch.device, model_instance: nn.Module) -> Tuple[Optional[nn.Module], bool]: | |
| """ | |
| ุชุญู ูู ุขู ู ูููู ูุฐุฌ ู ุน ู ุนุงูุฌุฉ ุงูุฃุฎุทุงุก | |
| Args: | |
| model_path: ู ุณุงุฑ ู ูู ุงููู ูุฐุฌ | |
| device: ุงูุฌูุงุฒ (CPU/GPU) | |
| model_instance: instance ู ู ุงููู ูุฐุฌ | |
| Returns: | |
| tuple: (ุงููู ูุฐุฌ ุงูู ุญู ูุ ูุฌุญ ุงูุชุญู ูู ุฃู ูุง) | |
| """ | |
| try: | |
| # ุงูุชุญูู ู ู ูุฌูุฏ ุงูู ูู | |
| if not os.path.exists(model_path): | |
| print(f"โ ู ูู ุงููู ูุฐุฌ ุบูุฑ ู ูุฌูุฏ: {model_path}") | |
| print("๐ก ุชุฃูุฏ ู ู ูุถุน ู ูู best_model.pth ูู ููุณ ู ุฌูุฏ ุงูุชุทุจูู") | |
| return None, False | |
| print(f"๐ ุฌุงุฑู ุชุญู ูู ุงููู ูุฐุฌ ู ู: {model_path}") | |
| # ุชุญู ูู ุงููู ูุฐุฌ | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # ุงูุชุญูู ู ู ููุน checkpoint | |
| if isinstance(checkpoint, dict): | |
| # ุฅุฐุง ูุงู ุงููู ูุฐุฌ ู ุญููุธ ูู state_dict | |
| if 'model_state_dict' in checkpoint: | |
| model_instance.load_state_dict(checkpoint['model_state_dict']) | |
| print("โ ุชู ุชุญู ูู state_dict ู ู checkpoint") | |
| elif 'state_dict' in checkpoint: | |
| model_instance.load_state_dict(checkpoint['state_dict']) | |
| print("โ ุชู ุชุญู ูู state_dict") | |
| else: | |
| # ู ุญุงููุฉ ุชุญู ูู dict ู ุจุงุดุฑุฉ | |
| model_instance.load_state_dict(checkpoint) | |
| print("โ ุชู ุชุญู ูู ุงููู ูุฐุฌ ูู state_dict") | |
| else: | |
| # ุฅุฐุง ูุงู ุงููู ูุฐุฌ ู ุญููุธ ูู full model | |
| model_instance = checkpoint | |
| print("โ ุชู ุชุญู ูู ุงููู ูุฐุฌ ุงููุงู ู") | |
| # ููู ุฅูู ุงูุฌูุงุฒ ุงูู ูุงุณุจ | |
| model_instance = model_instance.to(device) | |
| model_instance.eval() # ูุถุน ุงูุชูููู | |
| print(f"โ ุชู ุชุญู ูู ุงููู ูุฐุฌ ุจูุฌุงุญ ุนูู {device}") | |
| return model_instance, True | |
| except FileNotFoundError: | |
| print(f"โ ู ูู ุงููู ูุฐุฌ ุบูุฑ ู ูุฌูุฏ: {model_path}") | |
| return None, False | |
| except RuntimeError as e: | |
| print(f"โ ุฎุทุฃ ูู ุชุญู ูู ุงููู ูุฐุฌ: {e}") | |
| print("๐ก ุชุฃูุฏ ู ู ุฃู ุจููุฉ ุงููู ูุฐุฌ ู ุชุทุงุจูุฉ ู ุน ุงููู ูุฐุฌ ุงูู ุญููุธ") | |
| return None, False | |
| except Exception as e: | |
| print(f"โ ุฎุทุฃ ุบูุฑ ู ุชููุน ูู ุชุญู ูู ุงููู ูุฐุฌ: {e}") | |
| return None, False | |
| def validate_model_architecture(model: nn.Module, expected_input_size: int = 193) -> bool: | |
| """ | |
| ุงูุชุญูู ู ู ุตุญุฉ ุจููุฉ ุงููู ูุฐุฌ | |
| Args: | |
| model: ุงููู ูุฐุฌ ุงูู ุญู ู | |
| expected_input_size: ุญุฌู ุงูุฅุฏุฎุงู ุงูู ุชููุน | |
| Returns: | |
| bool: ุตุญูุญ ุฅุฐุง ูุงูุช ุงูุจููุฉ ุตุญูุญุฉ | |
| """ | |
| try: | |
| # ุฅูุดุงุก tensor ุชุฌุฑูุจู | |
| dummy_input = torch.randn(1, expected_input_size) | |
| # ุชุฌุฑุจุฉ forward pass | |
| with torch.no_grad(): | |
| output = model(dummy_input) | |
| print(f"โ ุจููุฉ ุงููู ูุฐุฌ ุตุญูุญุฉ - ุงูุฅุฏุฎุงู: {expected_input_size}, ุงูุฅุฎุฑุงุฌ: {output.shape}") | |
| return True | |
| except Exception as e: | |
| print(f"โ ุจููุฉ ุงููู ูุฐุฌ ุบูุฑ ุตุญูุญุฉ: {e}") | |
| return False | |
| def create_dummy_model(num_classes: int = 8) -> nn.Module: | |
| """ | |
| ุฅูุดุงุก ูู ูุฐุฌ ููู ู ููุงุฎุชุจุงุฑ ุนูุฏ ุนุฏู ูุฌูุฏ ุงููู ูุฐุฌ ุงูุฃุตูู | |
| Args: | |
| num_classes: ุนุฏุฏ ุงููุฆุงุช | |
| Returns: | |
| ูู ูุฐุฌ ููู ู | |
| """ | |
| print("โ ๏ธ ุฅูุดุงุก ูู ูุฐุฌ ููู ู ููุงุฎุชุจุงุฑ...") | |
| class DummyEmotionNet(nn.Module): | |
| def __init__(self, num_classes=8): | |
| super(DummyEmotionNet, self).__init__() | |
| self.fc = nn.Linear(193, num_classes) # 193 ูู ุญุฌู ุงูู ูุฒุงุช ุงูู ุณุชุฎุฑุฌุฉ | |
| def forward(self, x): | |
| return self.fc(x) | |
| model = DummyEmotionNet(num_classes) | |
| print("โ ุชู ุฅูุดุงุก ุงููู ูุฐุฌ ุงูููู ู") | |
| return model | |
| def check_model_file(model_path: str = 'best_model.pth'): | |
| """ | |
| ูุญุต ู ูู ุงููู ูุฐุฌ ูุฅุนุทุงุก ู ุนููู ุงุช ุนูู | |
| Args: | |
| model_path: ู ุณุงุฑ ู ูู ุงููู ูุฐุฌ | |
| """ | |
| print(f"๐ ูุญุต ู ูู ุงููู ูุฐุฌ: {model_path}") | |
| if not os.path.exists(model_path): | |
| print("โ ู ูู ุงููู ูุฐุฌ ุบูุฑ ู ูุฌูุฏ!") | |
| print("๐ก ุชุฃูุฏ ู ู:") | |
| print(" 1. ูุถุน ู ูู best_model.pth ูู ููุณ ู ุฌูุฏ ุงูุชุทุจูู") | |
| print(" 2. ุฃู ุงุณู ุงูู ูู ุตุญูุญ") | |
| print(" 3. ุฃู ุงูู ูู ุบูุฑ ุชุงูู") | |
| return False | |
| # ู ุนููู ุงุช ุงูู ูู | |
| file_size = os.path.getsize(model_path) | |
| print(f"๐ ุญุฌู ุงูู ูู: {file_size / (1024*1024):.2f} MB") | |
| # ู ุญุงููุฉ ูุฑุงุกุฉ ุงูู ูู | |
| try: | |
| checkpoint = torch.load(model_path, map_location='cpu') | |
| print("โ ูู ูู ูุฑุงุกุฉ ุงูู ูู") | |
| if isinstance(checkpoint, dict): | |
| print("๐ ู ุญุชููุงุช ุงูู ูู:") | |
| for key in checkpoint.keys(): | |
| print(f" - {key}") | |
| return True | |
| except Exception as e: | |
| print(f"โ ุฎุทุฃ ูู ูุฑุงุกุฉ ุงูู ูู: {e}") | |
| return False | |
| if __name__ == "__main__": | |
| # ุงุฎุชุจุงุฑ ุงููุธุงู | |
| print("๐งช ุงุฎุชุจุงุฑ ูุธุงู ุชุญู ูู ุงููู ูุฐุฌ...") | |
| check_model_file() |