Aircraft_detection / model.py
mansh
new model
0ecc700
raw
history blame contribute delete
575 Bytes
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def load_model(model_path="fasterrcnn_gtr_sar_new.pth", num_classes=8): # 7 aircraft classes + 1 background
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
state_dict = torch.load(model_path, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
return model