|
|
import sys
|
|
|
from pathlib import Path
|
|
|
|
|
|
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
|
|
|
|
|
import timm
|
|
|
import torch
|
|
|
from src import config
|
|
|
|
|
|
def create_model(model_name: str, num_classes: int, pretrained: bool = True, dropout_rate: float = 0.1):
|
|
|
"""
|
|
|
Membuat model Computer Vision dari library timm.
|
|
|
|
|
|
Args:
|
|
|
model_name (str): Nama model yang akan dibuat (misal: 'vit_base_patch16_224').
|
|
|
num_classes (int): Jumlah kelas output (misal: 38 untuk batik).
|
|
|
pretrained (bool): Apakah akan menggunakan bobot pre-trained ImageNet.
|
|
|
dropout_rate (float): Dropout rate untuk regularization.
|
|
|
|
|
|
Returns:
|
|
|
torch.nn.Module: Model yang sudah dibuat.
|
|
|
"""
|
|
|
print(f"[Model] Membuat model: {model_name}...")
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = timm.create_model(
|
|
|
model_name,
|
|
|
pretrained=pretrained,
|
|
|
num_classes=num_classes,
|
|
|
drop_rate=dropout_rate
|
|
|
)
|
|
|
return model
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"[Error] Gagal membuat model {model_name}: {e}")
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print("Menjalankan pengujian models.py...")
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
try:
|
|
|
NUM_CLASSES = len(os.listdir(config.DATA_PATH))
|
|
|
print(f" > Ditemukan {NUM_CLASSES} kelas dari {config.DATA_PATH}")
|
|
|
except FileNotFoundError:
|
|
|
print(f" > Error: Folder data di {config.DATA_PATH} tidak ditemukan.")
|
|
|
print(" > Menggunakan 38 sebagai jumlah kelas default untuk tes.")
|
|
|
NUM_CLASSES = 38
|
|
|
|
|
|
|
|
|
|
|
|
dummy_input = torch.randn(
|
|
|
2, 3, config.IMAGE_SIZE, config.IMAGE_SIZE
|
|
|
).to(config.DEVICE)
|
|
|
|
|
|
print(f" > Membuat data input palsu ukuran: {dummy_input.shape}")
|
|
|
print("-" * 30)
|
|
|
|
|
|
|
|
|
for model_name_key in config.MODEL_LIST:
|
|
|
|
|
|
|
|
|
model_arch_names = {
|
|
|
"vit": "vit_base_patch16_224",
|
|
|
"swin_transformer": "swin_base_patch4_window7_224",
|
|
|
"convnext_tiny": "convnext_tiny"
|
|
|
}
|
|
|
|
|
|
model_name = model_arch_names.get(model_name_key)
|
|
|
|
|
|
if model_name:
|
|
|
model = create_model(model_name=model_name, num_classes=NUM_CLASSES)
|
|
|
|
|
|
if model:
|
|
|
model = model.to(config.DEVICE)
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
output = model(dummy_input)
|
|
|
|
|
|
print(f" > Tes Forward Pass... SUKSES")
|
|
|
print(f" > Ukuran Output: {output.shape}")
|
|
|
print(f" > Tes {model_name_key} selesai.")
|
|
|
print("-" * 30)
|
|
|
else:
|
|
|
print(f"[Warning] Kunci model '{model_name_key}' di config.py tidak dikenali.")
|
|
|
|
|
|
print("\n[Sukses] models.py berfungsi dengan baik!") |