Voidljc
Your commit message
aa24fe8
Raw
History Blame Contribute Delete
5.12 kB
import warnings
import torch
from torch import nn
from torchvision.ops import RoIPool
warnings.filterwarnings("ignore")
class VGG16RoIHead(nn.Module):
def __init__(self, n_class, roi_size, spatial_scale, classifier):
super(VGG16RoIHead, self).__init__()
self.classifier = classifier
#--------------------------------------#
# 对ROIPooling后的的结果进行回归预测
#--------------------------------------#
self.cls_loc = nn.Linear(4096, n_class * 4)
#-----------------------------------#
# 对ROIPooling后的的结果进行分类
#-----------------------------------#
self.score = nn.Linear(4096, n_class)
#-----------------------------------#
# 权值初始化
#-----------------------------------#
normal_init(self.cls_loc, 0, 0.001)
normal_init(self.score, 0, 0.01)
self.roi = RoIPool((roi_size, roi_size), spatial_scale)
def forward(self, x, rois, roi_indices, img_size):
n, _, _, _ = x.shape
if x.is_cuda:
roi_indices = roi_indices.cuda()
rois = rois.cuda()
rois = torch.flatten(rois, 0, 1)
roi_indices = torch.flatten(roi_indices, 0, 1)
rois_feature_map = torch.zeros_like(rois)
rois_feature_map[:, [0,2]] = rois[:, [0,2]] / img_size[1] * x.size()[3]
rois_feature_map[:, [1,3]] = rois[:, [1,3]] / img_size[0] * x.size()[2]
indices_and_rois = torch.cat([roi_indices[:, None], rois_feature_map], dim=1)
#-----------------------------------#
# 利用建议框对公用特征层进行截取
#-----------------------------------#
pool = self.roi(x, indices_and_rois)
#-----------------------------------#
# 利用classifier网络进行特征提取
#-----------------------------------#
pool = pool.view(pool.size(0), -1)
#--------------------------------------------------------------#
# 当输入为一张图片的时候,这里获得的f7的shape为[300, 4096]
#--------------------------------------------------------------#
fc7 = self.classifier(pool)
roi_cls_locs = self.cls_loc(fc7)
roi_scores = self.score(fc7)
roi_cls_locs = roi_cls_locs.view(n, -1, roi_cls_locs.size(1))
roi_scores = roi_scores.view(n, -1, roi_scores.size(1))
return roi_cls_locs, roi_scores
class Resnet50RoIHead(nn.Module):
def __init__(self, n_class, roi_size, spatial_scale, classifier):
super(Resnet50RoIHead, self).__init__()
self.classifier = classifier
#--------------------------------------#
# 对ROIPooling后的的结果进行回归预测
#--------------------------------------#
self.cls_loc = nn.Linear(2048, n_class * 4)
#-----------------------------------#
# 对ROIPooling后的的结果进行分类
#-----------------------------------#
self.score = nn.Linear(2048, n_class)
#-----------------------------------#
# 权值初始化
#-----------------------------------#
normal_init(self.cls_loc, 0, 0.001)
normal_init(self.score, 0, 0.01)
self.roi = RoIPool((roi_size, roi_size), spatial_scale)
def forward(self, x, rois, roi_indices, img_size):
n, _, _, _ = x.shape
if x.is_cuda:
roi_indices = roi_indices.cuda()
rois = rois.cuda()
rois = torch.flatten(rois, 0, 1)
roi_indices = torch.flatten(roi_indices, 0, 1)
rois_feature_map = torch.zeros_like(rois)
rois_feature_map[:, [0,2]] = rois[:, [0,2]] / img_size[1] * x.size()[3]
rois_feature_map[:, [1,3]] = rois[:, [1,3]] / img_size[0] * x.size()[2]
indices_and_rois = torch.cat([roi_indices[:, None], rois_feature_map], dim=1)
#-----------------------------------#
# 利用建议框对公用特征层进行截取
#-----------------------------------#
pool = self.roi(x, indices_and_rois)
#-----------------------------------#
# 利用classifier网络进行特征提取
#-----------------------------------#
fc7 = self.classifier(pool)
#--------------------------------------------------------------#
# 当输入为一张图片的时候,这里获得的f7的shape为[300, 2048]
#--------------------------------------------------------------#
fc7 = fc7.view(fc7.size(0), -1)
roi_cls_locs = self.cls_loc(fc7)
roi_scores = self.score(fc7)
roi_cls_locs = roi_cls_locs.view(n, -1, roi_cls_locs.size(1))
roi_scores = roi_scores.view(n, -1, roi_scores.size(1))
return roi_cls_locs, roi_scores
def normal_init(m, mean, stddev, truncated=False):
if truncated:
m.weight.data.normal_().fmod_(2).mul_(stddev).add_(mean) # not a perfect approximation
else:
m.weight.data.normal_(mean, stddev)
m.bias.data.zero_()