RDNet / models /arch /classifier.py
lime-j's picture
Upload 89 files
347b44e
import torch.nn as nn
import timm
import torch
import torch.nn.functional as F
class PretrainedConvNext(nn.Module):
def __init__(self, model_name='convnext_base', pretrained=True):
super(PretrainedConvNext, self).__init__()
# Load the pretrained ConvNext model from timm
self.model = timm.create_model(model_name, pretrained=False, num_classes=0)
self.head = nn.Linear(768, 6)
def forward(self, x):
with torch.no_grad():
cls_input = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=True)
# Forward pass through the ConvNext model
out = self.model(cls_input)
out = self.head(out)
# alpha, beta = out[..., :3].unsqueeze(-1).unsqueeze(-1),\
# out[..., 3:].unsqueeze(-1).unsqueeze(-1)
#out = alpha * x + beta
# print(out.shape)
return out#alpha,beta#out #out[..., :3], out[..., 3:]
class PretrainedConvNext_e2e(nn.Module):
def __init__(self, model_name='convnext_base', pretrained=True):
super(PretrainedConvNext_e2e, self).__init__()
# Load the pretrained ConvNext model from timm
self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
self.head = nn.Linear(768, 6)
def forward(self, x):
with torch.no_grad():
cls_input = F.interpolate(x, size=(224, 224), mode='bilinear', align_corners=True)
# Forward pass through the ConvNext model
out = self.model(cls_input)
out = self.head(out)
alpha, beta = out[..., :3].unsqueeze(-1).unsqueeze(-1),\
out[..., 3:].unsqueeze(-1).unsqueeze(-1)
out = alpha * x + beta
#print(out.shape)
return out#alpha,beta#out #out[..., :3], out[..., 3:]
if __name__ == "__main__":
model = PretrainedConvNext('convnext_small_in22k')
print("Testing PretrainedConvNext model...")
# Assuming a dummy input tensor of size (1, 3, 224, 224) similar to an image in the ImageNet dataset
dummy_input = torch.randn(20, 3, 224, 224)
output_x, output_y = model(dummy_input)
print("Output shape:", output_x.shape)
print("Test completed successfully.")