File size: 2,992 Bytes
eef8873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import logging
import torch
import torch.nn as nn
from torchvision import models
from transformers import ConvNextModel

logger = logging.getLogger(__name__)


class FusionClassifier(nn.Module):
    def __init__(

        self,

        num_classes,

        convnext_model_name="facebook/convnext-small-224"

    ):
        super().__init__()

        logger.info("Initializing Fusion model...")

        # EfficientNet-V2-S
        eff = models.efficientnet_v2_s(
            weights=models.EfficientNet_V2_S_Weights.IMAGENET1K_V1
        )

        for param in eff.parameters():
            param.requires_grad = False

        for param in eff.features[5].parameters():
            param.requires_grad = True

        for param in eff.features[6].parameters():
            param.requires_grad = True

        for param in eff.features[7].parameters():
            param.requires_grad = True

        self.eff_features = eff.features
        self.eff_avgpool = eff.avgpool
        self.eff_out_dim = eff.classifier[1].in_features

        # ConvNeXt
        cnx = ConvNextModel.from_pretrained(convnext_model_name)

        for param in cnx.parameters():
            param.requires_grad = False

        for param in cnx.encoder.stages[2].parameters():
            param.requires_grad = True

        for param in cnx.encoder.stages[3].parameters():
            param.requires_grad = True

        for param in cnx.layernorm.parameters():
            param.requires_grad = True

        self.cnx_backbone = cnx
        self.cnx_out_dim = 768

        fused_dim = self.eff_out_dim + self.cnx_out_dim

        self.fusion_head = nn.Sequential(
            nn.Dropout(0.4),
            nn.Linear(fused_dim, 512),
            nn.LayerNorm(512),
            nn.GELU(),

            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.GELU(),

            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )

        logger.info("Fusion model initialized successfully.")

    def forward(self, pixel_values_eff, pixel_values_cnx):
        x_eff = self.eff_features(pixel_values_eff)
        x_eff = self.eff_avgpool(x_eff)
        x_eff = torch.flatten(x_eff, 1)

        cnx_out = self.cnx_backbone(
            pixel_values=pixel_values_cnx,
            return_dict=True
        )

        x_cnx = cnx_out.pooler_output

        fused = torch.cat([x_eff, x_cnx], dim=1)

        logits = self.fusion_head(fused)

        return logits


if __name__ == "__main__":
    import logging

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s"
    )

    model = FusionClassifier(num_classes=6)

    eff_dummy = torch.randn(2, 3, 260, 260)
    cnx_dummy = torch.randn(2, 3, 224, 224)

    output = model(eff_dummy, cnx_dummy)

    print("Fusion output shape:", output.shape)