File size: 4,011 Bytes
a6eed2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import sys
from pathlib import Path
# tambahkan parent project ke sys.path sehingga 'src' dapat diimport saat menjalankan skrip langsung
sys.path.append(str(Path(__file__).resolve().parents[1]))

import timm
import torch
from src import config # Kita import config untuk daftar model dan device

def create_model(model_name: str, num_classes: int, pretrained: bool = True, dropout_rate: float = 0.1):
    """

    Membuat model Computer Vision dari library timm.

    

    Args:

        model_name (str): Nama model yang akan dibuat (misal: 'vit_base_patch16_224').

        num_classes (int): Jumlah kelas output (misal: 38 untuk batik).

        pretrained (bool): Apakah akan menggunakan bobot pre-trained ImageNet.

        dropout_rate (float): Dropout rate untuk regularization.

    

    Returns:

        torch.nn.Module: Model yang sudah dibuat.

    """
    print(f"[Model] Membuat model: {model_name}...")
    
    try:
        # timm.create_model adalah fungsi ajaib:
        # 1. 'pretrained=True' akan otomatis men-download bobot ImageNet.
        # 2. 'num_classes=num_classes' akan otomatis MENGGANTI
        #    layer klasifikasi terakhir (misal: 1000 kelas ImageNet)
        #    dengan layer baru yang sesuai jumlah kelas kita (38 kelas).
        model = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=num_classes,
            drop_rate=dropout_rate  # Tambah dropout untuk regularization
        )
        return model
    
    except Exception as e:
        print(f"[Error] Gagal membuat model {model_name}: {e}")
        return None

# --- Blok Pengujian (Sangat Direkomendasikan) ---
# Kode ini HANYA akan berjalan jika Anda menjalankan file ini secara langsung
# (misal: `python src/models.py`)

if __name__ == "__main__":
    print("Menjalankan pengujian models.py...")
    
    # Kita butuh jumlah kelas untuk pengujian
    # Cara cepat: hitung folder di DATA_PATH dari config
    import os
    try:
        NUM_CLASSES = len(os.listdir(config.DATA_PATH))
        print(f"  > Ditemukan {NUM_CLASSES} kelas dari {config.DATA_PATH}")
    except FileNotFoundError:
        print(f"  > Error: Folder data di {config.DATA_PATH} tidak ditemukan.")
        print("  > Menggunakan 38 sebagai jumlah kelas default untuk tes.")
        NUM_CLASSES = 38 # Default jika data path salah

    # Buat data input palsu (dummy input) untuk tes
    # Ukuran: [Batch, Channel, Height, Width]
    dummy_input = torch.randn(
        2, 3, config.IMAGE_SIZE, config.IMAGE_SIZE
    ).to(config.DEVICE)
    
    print(f"  > Membuat data input palsu ukuran: {dummy_input.shape}")
    print("-" * 30)

    # Loop dan uji setiap model dalam daftar di config.py
    for model_name_key in config.MODEL_LIST:
        
        # Ini adalah nama-nama model yang sebenarnya di library 'timm'
        model_arch_names = {
            "vit": "vit_base_patch16_224",
            "swin_transformer": "swin_base_patch4_window7_224",
            "convnext_tiny": "convnext_tiny"
        }
        
        model_name = model_arch_names.get(model_name_key)
        
        if model_name:
            model = create_model(model_name=model_name, num_classes=NUM_CLASSES)
            
            if model:
                model = model.to(config.DEVICE)
                model.eval() # Set ke mode evaluasi untuk tes
                
                # Coba lewatkan data palsu ke model
                with torch.no_grad():
                    output = model(dummy_input)
                
                print(f"  > Tes Forward Pass... SUKSES")
                print(f"  > Ukuran Output: {output.shape}") # Harusnya [2, 38]
                print(f"  > Tes {model_name_key} selesai.")
                print("-" * 30)
        else:
            print(f"[Warning] Kunci model '{model_name_key}' di config.py tidak dikenali.")

    print("\n[Sukses] models.py berfungsi dengan baik!")