|
|
import torch.nn as nn |
|
|
import timm |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
class PretrainedConvNext(nn.Module): |
|
|
def __init__(self, model_name='convnext_base', pretrained=True): |
|
|
super(PretrainedConvNext, self).__init__() |
|
|
|
|
|
self.model = timm.create_model(model_name, pretrained=False, num_classes=0) |
|
|
self.head = nn.Linear(768, 6) |
|
|
def forward(self, x): |
|
|
with torch.no_grad(): |
|
|
cls_input = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=True) |
|
|
|
|
|
out = self.model(cls_input) |
|
|
out = self.head(out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return out |
|
|
class PretrainedConvNext_e2e(nn.Module): |
|
|
def __init__(self, model_name='convnext_base', pretrained=True): |
|
|
super(PretrainedConvNext_e2e, self).__init__() |
|
|
|
|
|
self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0) |
|
|
self.head = nn.Linear(768, 6) |
|
|
def forward(self, x): |
|
|
with torch.no_grad(): |
|
|
cls_input = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=True) |
|
|
|
|
|
out = self.model(cls_input) |
|
|
out = self.head(out) |
|
|
alpha, beta = out[..., :3].unsqueeze(-1).unsqueeze(-1),\ |
|
|
out[..., 3:].unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
|
out = alpha * x + beta |
|
|
|
|
|
return out |
|
|
|
|
|
if __name__ == "__main__": |
|
|
model = PretrainedConvNext('convnext_small_in22k') |
|
|
print("Testing PretrainedConvNext model...") |
|
|
|
|
|
dummy_input = torch.randn(20, 3, 224, 224) |
|
|
output_x, output_y = model(dummy_input) |
|
|
print("Output shape:", output_x.shape) |
|
|
print("Test completed successfully.") |
|
|
|