fruits / models.py
ivanm151's picture
mobilesam v1.4
fbb0759
import torch
import torchvision.models as models
import torch.nn as nn
from mobile_sam import sam_model_registry, SamPredictor
DEVICE = torch.device('cpu')
sam_predictor = None # MobileSAM + Predictor
model2 = None # сорт фрукта
model3 = None # свежесть
def load_sam(weights_path='weights/mobile_sam.pt'):
global sam_predictor
if sam_predictor is None:
model_type = "vit_t"
sam = sam_model_registry[model_type](checkpoint=weights_path)
sam.to(DEVICE)
sam.eval()
sam_predictor = SamPredictor(sam)
return sam_predictor
def load_model2(weights_path='weights/class.pth'):
global model2
if model2 is None:
model2 = models.mobilenet_v2(pretrained=False)
for param in model2.features.parameters():
param.requires_grad = False
model2.classifier[1] = nn.Linear(model2.classifier[1].in_features, 10)
state_dict = torch.load(weights_path, map_location=DEVICE)
model2.load_state_dict(state_dict)
model2.eval()
return model2
def load_model3(weights_path='weights/class2.pth'):
global model3
if model3 is None:
model3 = models.mobilenet_v2(pretrained=False)
for param in model3.features.parameters():
param.requires_grad = False
model3.classifier[1] = nn.Linear(model3.classifier[1].in_features, 6)
state_dict = torch.load(weights_path, map_location=DEVICE)
model3.load_state_dict(state_dict)
model3.eval()
return model3