Skin-Cancer / models /convnext.py
umergohar's picture
1st Commit
f94b780 verified
raw
history blame contribute delete
841 Bytes
import torch
from torchvision.models import convnext_base, ConvNeXt_Base_Weights
from PIL import Image
import torch.nn.functional as F
class ConvNextBase:
def __init__(self, weights_path: str):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = convnext_base()
self.model.classifier[2] = torch.nn.Linear(self.model.classifier[2].in_features, 7)
state_dict = torch.load(weights_path, map_location=self.device)
self.model.load_state_dict(state_dict)
self.model.eval()
self.transform = ConvNeXt_Base_Weights.IMAGENET1K_V1.transforms()
def make_prediction(self, image: Image):
image = self.transform(image).unsqueeze(0)
with torch.no_grad():
pred = F.softmax(self.model(image), dim=1)
return pred