File size: 5,926 Bytes
619672b
 
b726d23
 
619672b
 
b726d23
619672b
b726d23
 
 
 
 
 
 
 
 
619672b
 
b726d23
 
 
 
 
619672b
b726d23
619672b
b726d23
 
619672b
b726d23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619672b
b726d23
 
 
 
 
619672b
b726d23
619672b
b726d23
 
 
 
 
619672b
 
b726d23
619672b
 
b726d23
619672b
b726d23
 
 
 
 
 
 
 
619672b
 
b726d23
 
 
 
 
 
 
 
 
 
619672b
b726d23
 
 
 
 
 
 
 
 
619672b
b726d23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import torch
import torch.nn as nn
import os
from typing import Tuple, Optional, Any
import warnings

def safe_load_model(model_path: str, device: torch.device, model_instance: nn.Module) -> Tuple[Optional[nn.Module], bool]:
    """
    ุชุญู…ูŠู„ ุขู…ู† ู„ู„ู†ู…ูˆุฐุฌ ู…ุน ู…ุนุงู„ุฌุฉ ุงู„ุฃุฎุทุงุก
    
    Args:
        model_path: ู…ุณุงุฑ ู…ู„ู ุงู„ู†ู…ูˆุฐุฌ
        device: ุงู„ุฌู‡ุงุฒ (CPU/GPU)
        model_instance: instance ู…ู† ุงู„ู†ู…ูˆุฐุฌ
    
    Returns:
        tuple: (ุงู„ู†ู…ูˆุฐุฌ ุงู„ู…ุญู…ู„ุŒ ู†ุฌุญ ุงู„ุชุญู…ูŠู„ ุฃู… ู„ุง)
    """
    try:
        # ุงู„ุชุญู‚ู‚ ู…ู† ูˆุฌูˆุฏ ุงู„ู…ู„ู
        if not os.path.exists(model_path):
            print(f"โŒ ู…ู„ู ุงู„ู†ู…ูˆุฐุฌ ุบูŠุฑ ู…ูˆุฌูˆุฏ: {model_path}")
            print("๐Ÿ’ก ุชุฃูƒุฏ ู…ู† ูˆุถุน ู…ู„ู best_model.pth ููŠ ู†ูุณ ู…ุฌู„ุฏ ุงู„ุชุทุจูŠู‚")
            return None, False
        
        print(f"๐Ÿ“‚ ุฌุงุฑูŠ ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ ู…ู†: {model_path}")
        
        # ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ
        checkpoint = torch.load(model_path, map_location=device)
        
        # ุงู„ุชุญู‚ู‚ ู…ู† ู†ูˆุน checkpoint
        if isinstance(checkpoint, dict):
            # ุฅุฐุง ูƒุงู† ุงู„ู†ู…ูˆุฐุฌ ู…ุญููˆุธ ูƒู€ state_dict
            if 'model_state_dict' in checkpoint:
                model_instance.load_state_dict(checkpoint['model_state_dict'])
                print("โœ… ุชู… ุชุญู…ูŠู„ state_dict ู…ู† checkpoint")
            elif 'state_dict' in checkpoint:
                model_instance.load_state_dict(checkpoint['state_dict'])
                print("โœ… ุชู… ุชุญู…ูŠู„ state_dict")
            else:
                # ู…ุญุงูˆู„ุฉ ุชุญู…ูŠู„ dict ู…ุจุงุดุฑุฉ
                model_instance.load_state_dict(checkpoint)
                print("โœ… ุชู… ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ ูƒู€ state_dict")
        else:
            # ุฅุฐุง ูƒุงู† ุงู„ู†ู…ูˆุฐุฌ ู…ุญููˆุธ ูƒู€ full model
            model_instance = checkpoint
            print("โœ… ุชู… ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ ุงู„ูƒุงู…ู„")
        
        # ู†ู‚ู„ ุฅู„ู‰ ุงู„ุฌู‡ุงุฒ ุงู„ู…ู†ุงุณุจ
        model_instance = model_instance.to(device)
        model_instance.eval()  # ูˆุถุน ุงู„ุชู‚ูŠูŠู…
        
        print(f"โœ… ุชู… ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ ุจู†ุฌุงุญ ุนู„ู‰ {device}")
        return model_instance, True
        
    except FileNotFoundError:
        print(f"โŒ ู…ู„ู ุงู„ู†ู…ูˆุฐุฌ ุบูŠุฑ ู…ูˆุฌูˆุฏ: {model_path}")
        return None, False
    except RuntimeError as e:
        print(f"โŒ ุฎุทุฃ ููŠ ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ: {e}")
        print("๐Ÿ’ก ุชุฃูƒุฏ ู…ู† ุฃู† ุจู†ูŠุฉ ุงู„ู†ู…ูˆุฐุฌ ู…ุชุทุงุจู‚ุฉ ู…ุน ุงู„ู†ู…ูˆุฐุฌ ุงู„ู…ุญููˆุธ")
        return None, False
    except Exception as e:
        print(f"โŒ ุฎุทุฃ ุบูŠุฑ ู…ุชูˆู‚ุน ููŠ ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ: {e}")
        return None, False

def validate_model_architecture(model: nn.Module, expected_input_size: int = 193) -> bool:
    """
    ุงู„ุชุญู‚ู‚ ู…ู† ุตุญุฉ ุจู†ูŠุฉ ุงู„ู†ู…ูˆุฐุฌ
    
    Args:
        model: ุงู„ู†ู…ูˆุฐุฌ ุงู„ู…ุญู…ู„
        expected_input_size: ุญุฌู… ุงู„ุฅุฏุฎุงู„ ุงู„ู…ุชูˆู‚ุน
    
    Returns:
        bool: ุตุญูŠุญ ุฅุฐุง ูƒุงู†ุช ุงู„ุจู†ูŠุฉ ุตุญูŠุญุฉ
    """
    try:
        # ุฅู†ุดุงุก tensor ุชุฌุฑูŠุจูŠ
        dummy_input = torch.randn(1, expected_input_size)
        
        # ุชุฌุฑุจุฉ forward pass
        with torch.no_grad():
            output = model(dummy_input)
        
        print(f"โœ… ุจู†ูŠุฉ ุงู„ู†ู…ูˆุฐุฌ ุตุญูŠุญุฉ - ุงู„ุฅุฏุฎุงู„: {expected_input_size}, ุงู„ุฅุฎุฑุงุฌ: {output.shape}")
        return True
        
    except Exception as e:
        print(f"โŒ ุจู†ูŠุฉ ุงู„ู†ู…ูˆุฐุฌ ุบูŠุฑ ุตุญูŠุญุฉ: {e}")
        return False

def create_dummy_model(num_classes: int = 8) -> nn.Module:
    """
    ุฅู†ุดุงุก ู†ู…ูˆุฐุฌ ูˆู‡ู…ูŠ ู„ู„ุงุฎุชุจุงุฑ ุนู†ุฏ ุนุฏู… ูˆุฌูˆุฏ ุงู„ู†ู…ูˆุฐุฌ ุงู„ุฃุตู„ูŠ
    
    Args:
        num_classes: ุนุฏุฏ ุงู„ูุฆุงุช
    
    Returns:
        ู†ู…ูˆุฐุฌ ูˆู‡ู…ูŠ
    """
    print("โš ๏ธ  ุฅู†ุดุงุก ู†ู…ูˆุฐุฌ ูˆู‡ู…ูŠ ู„ู„ุงุฎุชุจุงุฑ...")
    
    class DummyEmotionNet(nn.Module):
        def __init__(self, num_classes=8):
            super(DummyEmotionNet, self).__init__()
            self.fc = nn.Linear(193, num_classes)  # 193 ู‡ูˆ ุญุฌู… ุงู„ู…ูŠุฒุงุช ุงู„ู…ุณุชุฎุฑุฌุฉ
        
        def forward(self, x):
            return self.fc(x)
    
    model = DummyEmotionNet(num_classes)
    print("โœ… ุชู… ุฅู†ุดุงุก ุงู„ู†ู…ูˆุฐุฌ ุงู„ูˆู‡ู…ูŠ")
    return model

def check_model_file(model_path: str = 'best_model.pth'):
    """
    ูุญุต ู…ู„ู ุงู„ู†ู…ูˆุฐุฌ ูˆุฅุนุทุงุก ู…ุนู„ูˆู…ุงุช ุนู†ู‡
    
    Args:
        model_path: ู…ุณุงุฑ ู…ู„ู ุงู„ู†ู…ูˆุฐุฌ
    """
    print(f"๐Ÿ” ูุญุต ู…ู„ู ุงู„ู†ู…ูˆุฐุฌ: {model_path}")
    
    if not os.path.exists(model_path):
        print("โŒ ู…ู„ู ุงู„ู†ู…ูˆุฐุฌ ุบูŠุฑ ู…ูˆุฌูˆุฏ!")
        print("๐Ÿ’ก ุชุฃูƒุฏ ู…ู†:")
        print("   1. ูˆุถุน ู…ู„ู best_model.pth ููŠ ู†ูุณ ู…ุฌู„ุฏ ุงู„ุชุทุจูŠู‚")
        print("   2. ุฃู† ุงุณู… ุงู„ู…ู„ู ุตุญูŠุญ")
        print("   3. ุฃู† ุงู„ู…ู„ู ุบูŠุฑ ุชุงู„ู")
        return False
    
    # ู…ุนู„ูˆู…ุงุช ุงู„ู…ู„ู
    file_size = os.path.getsize(model_path)
    print(f"๐Ÿ“Š ุญุฌู… ุงู„ู…ู„ู: {file_size / (1024*1024):.2f} MB")
    
    # ู…ุญุงูˆู„ุฉ ู‚ุฑุงุกุฉ ุงู„ู…ู„ู
    try:
        checkpoint = torch.load(model_path, map_location='cpu')
        print("โœ… ูŠู…ูƒู† ู‚ุฑุงุกุฉ ุงู„ู…ู„ู")
        
        if isinstance(checkpoint, dict):
            print("๐Ÿ“‹ ู…ุญุชูˆูŠุงุช ุงู„ู…ู„ู:")
            for key in checkpoint.keys():
                print(f"   - {key}")
        
        return True
        
    except Exception as e:
        print(f"โŒ ุฎุทุฃ ููŠ ู‚ุฑุงุกุฉ ุงู„ู…ู„ู: {e}")
        return False

if __name__ == "__main__":
    # ุงุฎุชุจุงุฑ ุงู„ู†ุธุงู…
    print("๐Ÿงช ุงุฎุชุจุงุฑ ู†ุธุงู… ุชุญู…ูŠู„ ุงู„ู†ู…ูˆุฐุฌ...")
    check_model_file()