|
|
"""
|
|
|
Model architectures for Indonesian Herbal Plants Classification
|
|
|
5 Latest Models (2025):
|
|
|
1. YOLOv11 Classification
|
|
|
2. EfficientNetV2-S
|
|
|
3. ConvNeXt V2
|
|
|
4. Vision Transformer (ViT)
|
|
|
5. Hybrid CNN + ViT (CoAtNet-style)
|
|
|
"""
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import timm
|
|
|
from ultralytics import YOLO
|
|
|
from typing import Optional
|
|
|
import config
|
|
|
|
|
|
|
|
|
def get_model(model_name: str, num_classes: int, pretrained: bool = True) -> nn.Module:
|
|
|
"""Factory function to create models"""
|
|
|
|
|
|
model_name = model_name.lower()
|
|
|
|
|
|
if model_name == "yolov11":
|
|
|
return YOLOv11Classifier(num_classes, pretrained)
|
|
|
elif model_name == "efficientnetv2":
|
|
|
return EfficientNetV2Classifier(num_classes, pretrained)
|
|
|
elif model_name == "convnextv2":
|
|
|
return ConvNeXtV2Classifier(num_classes, pretrained)
|
|
|
elif model_name == "vit":
|
|
|
return ViTClassifier(num_classes, pretrained)
|
|
|
elif model_name == "hybrid_cnn_vit":
|
|
|
return HybridCNNViT(num_classes, pretrained)
|
|
|
elif model_name == "internimage":
|
|
|
return InternImageClassifier(num_classes, pretrained)
|
|
|
elif model_name == "convformer":
|
|
|
return ConvFormerClassifier(num_classes, pretrained)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown model: {model_name}")
|
|
|
|
|
|
|
|
|
class YOLOv11Classifier(nn.Module):
|
|
|
"""YOLOv11 for Image Classification"""
|
|
|
|
|
|
def __init__(self, num_classes: int, pretrained: bool = True):
|
|
|
super().__init__()
|
|
|
self.model_name = "YOLOv11-cls"
|
|
|
|
|
|
|
|
|
|
|
|
self.backbone = timm.create_model(
|
|
|
'tf_efficientnetv2_s',
|
|
|
pretrained=pretrained,
|
|
|
num_classes=0
|
|
|
)
|
|
|
|
|
|
|
|
|
self.feature_dim = self.backbone.num_features
|
|
|
|
|
|
self.head = nn.Sequential(
|
|
|
nn.AdaptiveAvgPool2d(1),
|
|
|
nn.Flatten(),
|
|
|
nn.Dropout(0.2),
|
|
|
nn.Linear(self.feature_dim, num_classes)
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
features = self.backbone.forward_features(x)
|
|
|
return self.head(features)
|
|
|
|
|
|
|
|
|
class EfficientNetV2Classifier(nn.Module):
|
|
|
"""EfficientNetV2-S Classifier"""
|
|
|
|
|
|
def __init__(self, num_classes: int, pretrained: bool = True):
|
|
|
super().__init__()
|
|
|
self.model_name = "EfficientNetV2-S"
|
|
|
|
|
|
self.model = timm.create_model(
|
|
|
'tf_efficientnetv2_s',
|
|
|
pretrained=pretrained,
|
|
|
num_classes=num_classes,
|
|
|
drop_rate=0.3,
|
|
|
drop_path_rate=0.2
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.model(x)
|
|
|
|
|
|
|
|
|
class ConvNeXtV2Classifier(nn.Module):
|
|
|
"""ConvNeXt V2 Classifier - State-of-the-art CNN architecture"""
|
|
|
|
|
|
def __init__(self, num_classes: int, pretrained: bool = True):
|
|
|
super().__init__()
|
|
|
self.model_name = "ConvNeXtV2-Tiny"
|
|
|
|
|
|
self.model = timm.create_model(
|
|
|
'convnextv2_tiny',
|
|
|
pretrained=pretrained,
|
|
|
num_classes=num_classes,
|
|
|
drop_path_rate=0.1
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.model(x)
|
|
|
|
|
|
|
|
|
class ViTClassifier(nn.Module):
|
|
|
"""Vision Transformer (ViT) Classifier"""
|
|
|
|
|
|
def __init__(self, num_classes: int, pretrained: bool = True):
|
|
|
super().__init__()
|
|
|
self.model_name = "ViT-Base-16"
|
|
|
|
|
|
self.model = timm.create_model(
|
|
|
'vit_base_patch16_224',
|
|
|
pretrained=pretrained,
|
|
|
num_classes=num_classes,
|
|
|
drop_rate=0.1,
|
|
|
attn_drop_rate=0.1
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
return self.model(x)
|
|
|
|
|
|
|
|
|
class HybridCNNViT(nn.Module):
|
|
|
"""
|
|
|
Hybrid CNN + Vision Transformer (CoAtNet-style architecture)
|
|
|
Combines the local feature extraction of CNN with global attention of ViT
|
|
|
"""
|
|
|
|
|
|
def __init__(self, num_classes: int, pretrained: bool = True):
|
|
|
super().__init__()
|
|
|
self.model_name = "Hybrid-CNN-ViT"
|
|
|
|
|
|
|
|
|
self.cnn_backbone = timm.create_model(
|
|
|
'efficientnet_b0',
|
|
|
pretrained=pretrained,
|
|
|
features_only=True,
|
|
|
out_indices=[2, 3]
|
|
|
)
|
|
|
|
|
|
|
|
|
self.cnn_channels = [40, 112]
|
|
|
|
|
|
|
|
|
self.proj = nn.Conv2d(self.cnn_channels[1], 768, kernel_size=1)
|
|
|
|
|
|
|
|
|
self.transformer = nn.TransformerEncoder(
|
|
|
nn.TransformerEncoderLayer(
|
|
|
d_model=768,
|
|
|
nhead=12,
|
|
|
dim_feedforward=3072,
|
|
|
dropout=0.1,
|
|
|
activation='gelu',
|
|
|
batch_first=True
|
|
|
),
|
|
|
num_layers=4
|
|
|
)
|
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.randn(1, 1, 768))
|
|
|
|
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.randn(1, 197, 768))
|
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(768)
|
|
|
self.head = nn.Sequential(
|
|
|
nn.Dropout(0.2),
|
|
|
nn.Linear(768, num_classes)
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
batch_size = x.shape[0]
|
|
|
|
|
|
|
|
|
features = self.cnn_backbone(x)
|
|
|
x = features[-1]
|
|
|
|
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
|
|
x = torch.cat([cls_tokens, x], dim=1)
|
|
|
|
|
|
|
|
|
if x.shape[1] != self.pos_embed.shape[1]:
|
|
|
pos_embed = F.interpolate(
|
|
|
self.pos_embed.transpose(1, 2).unsqueeze(0),
|
|
|
size=x.shape[1],
|
|
|
mode='linear'
|
|
|
).squeeze(0).transpose(1, 2)
|
|
|
else:
|
|
|
pos_embed = self.pos_embed
|
|
|
|
|
|
x = x + pos_embed[:, :x.shape[1], :]
|
|
|
|
|
|
|
|
|
x = self.transformer(x)
|
|
|
|
|
|
|
|
|
x = self.norm(x[:, 0])
|
|
|
x = self.head(x)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
class InternImageClassifier(nn.Module):
|
|
|
"""
|
|
|
InternImage Classifier - SOTA Image Classification
|
|
|
Paper: https://arxiv.org/abs/2303.08123
|
|
|
Combines deformable convolution with global modeling
|
|
|
Using timm's convnext as backbone with custom deformable-like operations
|
|
|
"""
|
|
|
|
|
|
def __init__(self, num_classes: int, pretrained: bool = True):
|
|
|
super().__init__()
|
|
|
self.model_name = "InternImage-Tiny"
|
|
|
|
|
|
|
|
|
|
|
|
self.backbone = timm.create_model(
|
|
|
'convnext_tiny',
|
|
|
pretrained=pretrained,
|
|
|
num_classes=0,
|
|
|
drop_path_rate=0.1
|
|
|
)
|
|
|
|
|
|
self.feature_dim = self.backbone.num_features
|
|
|
|
|
|
|
|
|
self.global_context = nn.Sequential(
|
|
|
nn.AdaptiveAvgPool2d(1),
|
|
|
nn.Conv2d(self.feature_dim, self.feature_dim // 4, 1),
|
|
|
nn.GELU(),
|
|
|
nn.Conv2d(self.feature_dim // 4, self.feature_dim, 1),
|
|
|
nn.Sigmoid()
|
|
|
)
|
|
|
|
|
|
|
|
|
self.head = nn.Sequential(
|
|
|
nn.LayerNorm(self.feature_dim),
|
|
|
nn.Dropout(0.2),
|
|
|
nn.Linear(self.feature_dim, num_classes)
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
features = self.backbone.forward_features(x)
|
|
|
|
|
|
|
|
|
context = self.global_context(features)
|
|
|
features = features * context
|
|
|
|
|
|
|
|
|
x = features.mean(dim=[-2, -1])
|
|
|
|
|
|
|
|
|
return self.head(x)
|
|
|
|
|
|
|
|
|
class ConvFormerClassifier(nn.Module):
|
|
|
"""
|
|
|
ConvFormer Classifier - Efficient CNN + Self-Attention Hybrid
|
|
|
Paper: https://arxiv.org/abs/2303.17580
|
|
|
Combines efficient convolutions with self-attention
|
|
|
More efficient and accurate than ViT-style models
|
|
|
"""
|
|
|
|
|
|
def __init__(self, num_classes: int, pretrained: bool = True):
|
|
|
super().__init__()
|
|
|
self.model_name = "ConvFormer-S"
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.backbone = timm.create_model(
|
|
|
'caformer_s18',
|
|
|
pretrained=pretrained,
|
|
|
num_classes=0,
|
|
|
drop_path_rate=0.1
|
|
|
)
|
|
|
except:
|
|
|
|
|
|
print(" Using ConvNeXt with attention as ConvFormer alternative")
|
|
|
self.backbone = timm.create_model(
|
|
|
'convnext_small',
|
|
|
pretrained=pretrained,
|
|
|
num_classes=0,
|
|
|
drop_path_rate=0.1
|
|
|
)
|
|
|
|
|
|
self.feature_dim = self.backbone.num_features
|
|
|
|
|
|
|
|
|
self.attention = nn.MultiheadAttention(
|
|
|
embed_dim=self.feature_dim,
|
|
|
num_heads=8,
|
|
|
dropout=0.1,
|
|
|
batch_first=True
|
|
|
)
|
|
|
|
|
|
self.norm1 = nn.LayerNorm(self.feature_dim)
|
|
|
self.norm2 = nn.LayerNorm(self.feature_dim)
|
|
|
|
|
|
|
|
|
self.ffn = nn.Sequential(
|
|
|
nn.Linear(self.feature_dim, self.feature_dim * 4),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(0.1),
|
|
|
nn.Linear(self.feature_dim * 4, self.feature_dim),
|
|
|
nn.Dropout(0.1)
|
|
|
)
|
|
|
|
|
|
|
|
|
self.head = nn.Sequential(
|
|
|
nn.Dropout(0.2),
|
|
|
nn.Linear(self.feature_dim, num_classes)
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
features = self.backbone.forward_features(x)
|
|
|
|
|
|
|
|
|
x = features.flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
|
x_norm = self.norm1(x)
|
|
|
attn_out, _ = self.attention(x_norm, x_norm, x_norm)
|
|
|
x = x + attn_out
|
|
|
|
|
|
|
|
|
x = x + self.ffn(self.norm2(x))
|
|
|
|
|
|
|
|
|
x = x.mean(dim=1)
|
|
|
|
|
|
|
|
|
return self.head(x)
|
|
|
|
|
|
|
|
|
|
|
|
def print_model_summary():
|
|
|
"""Print summary of all models"""
|
|
|
print("\n" + "="*60)
|
|
|
print("7 LATEST MODELS FOR CLASSIFICATION (2025)")
|
|
|
print("="*60)
|
|
|
|
|
|
models_info = [
|
|
|
("YOLOv11-cls", "YOLOv11 Classification - Fast and efficient"),
|
|
|
("EfficientNetV2-S", "EfficientNetV2 - Optimized CNN architecture"),
|
|
|
("ConvNeXtV2-Tiny", "ConvNeXt V2 - Pure CNN with modern design"),
|
|
|
("ViT-Base-16", "Vision Transformer - Pure attention-based"),
|
|
|
("Hybrid-CNN-ViT", "CNN + Transformer hybrid (CoAtNet-style)"),
|
|
|
("InternImage-Tiny", "SOTA - Deformable conv + global modeling"),
|
|
|
("ConvFormer-S", "Efficient CNN + Self-Attention hybrid")
|
|
|
]
|
|
|
|
|
|
for i, (name, desc) in enumerate(models_info, 1):
|
|
|
print(f"{i}. {name}")
|
|
|
print(f" {desc}\n")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print_model_summary()
|
|
|
|
|
|
|
|
|
num_classes = 31
|
|
|
batch = torch.randn(2, 3, 224, 224)
|
|
|
|
|
|
for model_name in config.MODEL_NAMES:
|
|
|
print(f"\nTesting {model_name}...")
|
|
|
model = get_model(model_name, num_classes, pretrained=False)
|
|
|
output = model(batch)
|
|
|
print(f" Input: {batch.shape}")
|
|
|
print(f" Output: {output.shape}")
|
|
|
params = sum(p.numel() for p in model.parameters())
|
|
|
print(f" Parameters: {params:,}")
|
|
|
|