FasterRCNN / detection /roi_heads.py
Mjolnir65's picture
uploaded the weights
0e51835 verified
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
import torchvision
from torch import nn, Tensor
from torchvision.ops import boxes as box_ops, roi_align
from . import _utils as det_utils
def compute_theta_loss(preds, targets):
# print(f"{preds.shape=} {targets.shape=}")
# print(f"{preds.device=}, {targets.device=}")
num_bins = preds.shape[1]
if num_bins == 1:
# regression
return F.mse_loss(preds[:,0], targets)
else:
# classification
bin_size = torch.pi / num_bins
targets_bins = (targets / bin_size).long()
targets_bins = torch.clamp(targets_bins, 0, num_bins - 1)
return F.cross_entropy(preds, targets_bins)
def fastrcnn_loss(class_logits, box_regression, theta_preds, labels, regression_targets, theta_targets):
# type: (Tensor, Tensor, Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor, Tensor]
"""
Computes the loss for Faster R-CNN.
Args:
class_logits (Tensor)
box_regression (Tensor)
labels (list[BoxList])
regression_targets (Tensor)
theta_targets (Tensor)
Returns:
classification_loss (Tensor)
box_loss (Tensor)
theta_loss (Tensor)
"""
# print(f"{class_logits.shape=} {box_regression.shape=} {theta_preds.shape=}")
# print(f"{labels[0].shape=} {regression_targets[0].shape=} {theta_targets[0].shape=}")
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
theta_targets = torch.cat(theta_targets, dim=0)
# print(f"{labels.shape=} {regression_targets.shape=} {theta_targets.shape=}")
classification_loss = F.cross_entropy(class_logits, labels)
# get indices that correspond to the regression targets for
# the corresponding ground truth labels, to be used with
# advanced indexing
sampled_pos_inds_subset = torch.where(labels > 0)[0]
labels_pos = labels[sampled_pos_inds_subset]
N, num_classes = class_logits.shape
box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)
box_loss = F.smooth_l1_loss(
box_regression[sampled_pos_inds_subset, labels_pos],
regression_targets[sampled_pos_inds_subset],
beta=1 / 9,
reduction="sum",
)
box_loss = box_loss / labels.numel()
# theta loss
preds = theta_preds[sampled_pos_inds_subset.to(theta_preds.device)][:,1:]
targets = theta_targets[sampled_pos_inds_subset.to(theta_targets.device)].to(preds.device)
theta_loss = compute_theta_loss(preds, targets)
return classification_loss, box_loss, theta_loss
def fastrcnn_theta_loss(theta_preds, theta_targets):
# print(f"{len(theta_preds)=}")
# print(f"{len(theta_targets)=}")
# print(f"{theta_preds[0].shape=}")
# print(f"{theta_targets[0].shape=}")
return 0
def convert_xyxytheta_to_xywha(boxes):
xc = (boxes[:, 0] + boxes[:, 2]) / 2
yc = (boxes[:, 1] + boxes[:, 3]) / 2
w = boxes[:, 2] - boxes[:, 0]
h = boxes[:, 3] - boxes[:, 1]
theta = torch.deg2rad(boxes[:, 4])
return torch.stack([xc, yc, w, h, theta], dim=1)
class RoIHeads(nn.Module):
__annotations__ = {
"box_coder": det_utils.BoxCoder,
"proposal_matcher": det_utils.Matcher,
"fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
}
def __init__(
self,
box_roi_pool,
box_head,
box_predictor,
# Faster R-CNN training
fg_iou_thresh,
bg_iou_thresh,
batch_size_per_image,
positive_fraction,
bbox_reg_weights,
# Faster R-CNN inference
score_thresh,
nms_thresh,
detections_per_img,
):
super().__init__()
self.box_similarity = box_ops.box_iou
# assign ground-truth boxes for each proposal
self.proposal_matcher = det_utils.Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)
self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)
if bbox_reg_weights is None:
bbox_reg_weights = (10.0, 10.0, 5.0, 5.0)
self.box_coder = det_utils.BoxCoder(bbox_reg_weights)
self.box_roi_pool = box_roi_pool
self.box_head = box_head
self.box_predictor = box_predictor
self.score_thresh = score_thresh
self.nms_thresh = nms_thresh
self.detections_per_img = detections_per_img
def assign_targets_to_proposals(self, proposals, gt_boxes, gt_labels):
# type: (List[Tensor], List[Tensor], List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
matched_idxs = []
labels = []
for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
if gt_boxes_in_image.numel() == 0:
# Background image
device = proposals_in_image.device
clamped_matched_idxs_in_image = torch.zeros(
(proposals_in_image.shape[0],), dtype=torch.int64, device=device
)
labels_in_image = torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
else:
# set to self.box_similarity when https://github.com/pytorch/pytorch/issues/27495 lands
# print(f"{gt_boxes_in_image.shape=}")
# print(f"{proposals_in_image.shape=}")
match_quality_matrix = box_ops.box_iou(gt_boxes_in_image, proposals_in_image)
# match_quality_matrix = box_iou_rotated(convert_xyxytheta_to_xywha(gt_boxes_in_image), convert_xyxytheta_to_xywha(proposals_in_image))
matched_idxs_in_image = self.proposal_matcher(match_quality_matrix)
clamped_matched_idxs_in_image = matched_idxs_in_image.clamp(min=0)
labels_in_image = gt_labels_in_image[clamped_matched_idxs_in_image]
labels_in_image = labels_in_image.to(dtype=torch.int64)
# Label background (below the low threshold)
bg_inds = matched_idxs_in_image == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_in_image[bg_inds] = 0
# Label ignore proposals (between low and high thresholds)
ignore_inds = matched_idxs_in_image == self.proposal_matcher.BETWEEN_THRESHOLDS
labels_in_image[ignore_inds] = -1 # -1 is ignored by sampler
matched_idxs.append(clamped_matched_idxs_in_image)
labels.append(labels_in_image)
return matched_idxs, labels
def subsample(self, labels):
# type: (List[Tensor]) -> List[Tensor]
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
sampled_inds = []
for img_idx, (pos_inds_img, neg_inds_img) in enumerate(zip(sampled_pos_inds, sampled_neg_inds)):
img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
sampled_inds.append(img_sampled_inds)
return sampled_inds
def add_gt_proposals(self, proposals, gt_boxes):
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
# print(f"{len(proposals)=}")
# print(f"{len(gt_boxes)=}")
# print(f"{proposals[0].shape=}")
# print(f"{gt_boxes[0].shape=}")
# print(f"{proposals[0]=}")
# proposals_with_theta = []
# for proposal in proposals:
# proposal_with_theta = torch.cat((proposal, torch.zeros((proposal.shape[0], 1), dtype=proposal.dtype, device=proposal.device)), dim=1)
# proposals_with_theta.append(proposal_with_theta)
# gt_boxes_without_theta = [gt_box[:, :-1] for gt_box in gt_boxes]
# print(f"{len(proposal_with_theta)=}")
# print(f"{proposals_with_theta[0].shape=}")
# print(f"{proposals_with_theta[0]=}")
proposals = [torch.cat((proposal, gt_box)) for proposal, gt_box in zip(proposals, gt_boxes)]
return proposals
def check_targets(self, targets):
# type: (Optional[List[Dict[str, Tensor]]]) -> None
if targets is None:
raise ValueError("targets should not be None")
if not all(["boxes" in t for t in targets]):
raise ValueError("Every element of targets should have a boxes key")
if not all(["labels" in t for t in targets]):
raise ValueError("Every element of targets should have a labels key")
def select_training_samples(
self,
proposals, # type: List[Tensor]
targets, # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]]
self.check_targets(targets)
if targets is None:
raise ValueError("targets should not be None")
dtype = proposals[0].dtype
device = proposals[0].device
gt_boxes = [t["boxes"].to(dtype) for t in targets]
gt_labels = [t["labels"] for t in targets]
gt_thetas = [t["thetas"] for t in targets]
# append ground-truth bboxes to propos
proposals = self.add_gt_proposals(proposals, gt_boxes)
# get matching gt indices for each proposal
matched_idxs, labels = self.assign_targets_to_proposals(proposals, gt_boxes, gt_labels)
# sample a fixed proportion of positive-negative proposals
sampled_inds = self.subsample(labels)
matched_gt_boxes = []
matched_gt_thetas = []
num_images = len(proposals)
for img_id in range(num_images):
img_sampled_inds = sampled_inds[img_id]
proposals[img_id] = proposals[img_id][img_sampled_inds]
labels[img_id] = labels[img_id][img_sampled_inds]
matched_idxs[img_id] = matched_idxs[img_id][img_sampled_inds]
gt_boxes_in_image = gt_boxes[img_id]
gt_thetas_in_image = gt_thetas[img_id]
if gt_boxes_in_image.numel() == 0:
gt_boxes_in_image = torch.zeros((1, 4), dtype=dtype, device=device)
matched_gt_boxes.append(gt_boxes_in_image[matched_idxs[img_id]])
matched_gt_thetas.append(gt_thetas_in_image[matched_idxs[img_id].to(gt_thetas_in_image.device)])
regression_targets = self.box_coder.encode(matched_gt_boxes, proposals)
theta_targets = matched_gt_thetas
return proposals, matched_idxs, labels, regression_targets, theta_targets
def postprocess_detections(
self,
class_logits, # type: Tensor
box_regression, # type: Tensor
theta_preds, # type: Tensor
proposals, # type: List[Tensor]
image_shapes, # type: List[Tuple[int, int]]
):
# type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor]]
device = class_logits.device
num_classes = class_logits.shape[-1]
boxes_per_image = [boxes_in_image.shape[0] for boxes_in_image in proposals]
pred_boxes = self.box_coder.decode(box_regression, proposals)
pred_scores = F.softmax(class_logits, -1)
pred_boxes_list = pred_boxes.split(boxes_per_image, 0)
pred_scores_list = pred_scores.split(boxes_per_image, 0)
pred_theta_list = theta_preds.split(boxes_per_image, 0)
all_boxes = []
all_scores = []
all_labels = []
all_thetas = []
for boxes, scores, thetas, image_shape in zip(pred_boxes_list, pred_scores_list, pred_theta_list, image_shapes):
boxes = box_ops.clip_boxes_to_image(boxes, image_shape)
# create labels for each prediction
labels = torch.arange(num_classes, device=device)
labels = labels.view(1, -1).expand_as(scores)
# remove predictions with the background label
boxes = boxes[:, 1:]
scores = scores[:, 1:]
labels = labels[:, 1:]
thetas = thetas[:, 1:]
if thetas.shape[1] != 1:
nbins = thetas.shape[1]
angle_per_bin = torch.pi / nbins
max_val_idx = torch.argmax(thetas, dim=1)
max_val_theta = angle_per_bin * max_val_idx.float()
thetas = max_val_theta
# batch everything, by making every class prediction be a separate instance
boxes = boxes.reshape(-1, 4)
scores = scores.reshape(-1)
labels = labels.reshape(-1)
thetas = thetas.reshape(-1)
# remove low scoring boxes
inds = torch.where(scores > self.score_thresh)[0]
boxes, scores, labels, thetas = boxes[inds], scores[inds], labels[inds], thetas[inds]
# remove empty boxes
keep = box_ops.remove_small_boxes(boxes, min_size=1e-2)
boxes, scores, labels, thetas = boxes[keep], scores[keep], labels[keep], thetas[keep]
# non-maximum suppression, independently done per class
keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh)
# keep only topk scoring predictions
keep = keep[: self.detections_per_img]
boxes, scores, labels, thetas = boxes[keep], scores[keep], labels[keep], thetas[keep]
all_boxes.append(boxes)
all_scores.append(scores)
all_labels.append(labels)
all_thetas.append(thetas)
return all_boxes, all_scores, all_labels, all_thetas
def forward(
self,
features, # type: Dict[str, Tensor]
proposals, # type: List[Tensor]
image_shapes, # type: List[Tuple[int, int]]
targets=None, # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Dict[str, Tensor]], Dict[str, Tensor]]
"""
Args:
features (List[Tensor])
proposals (List[Tensor[N, 4]])
image_shapes (List[Tuple[H, W]])
targets (List[Dict])
"""
if targets is not None:
for t in targets:
# TODO: https://github.com/pytorch/pytorch/issues/26731
floating_point_types = (torch.float, torch.double, torch.half)
if not t["boxes"].dtype in floating_point_types:
raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}")
if not t["labels"].dtype == torch.int64:
raise TypeError(f"target labels must of int64 type, instead got {t['labels'].dtype}")
# targets_without_theta = []
# if targets is not None:
# for target in targets:
# target_without_theta = {"boxes": target["boxes"][:, :-1], "labels": target["labels"]}
# targets_without_theta.append(target_without_theta)
if self.training:
proposals, matched_idxs, labels, regression_targets, theta_targets = self.select_training_samples(proposals, targets)
# print("---------")
# print(f"{theta_targets.shape=}")
else:
labels = None
regression_targets = None
theta_targets = None
matched_idxs = None
box_features = self.box_roi_pool(features, proposals, image_shapes)
box_features = self.box_head(box_features)
class_logits, box_regression, theta_preds = self.box_predictor(box_features)
# print(f"{class_logits.shape=}")
# print(f"{box_regression.shape=}")
# print(f"{theta_preds.shape=}")
result: List[Dict[str, torch.Tensor]] = []
losses = {}
if self.training:
if labels is None:
raise ValueError("labels cannot be None")
if regression_targets is None:
raise ValueError("regression_targets cannot be None")
loss_classifier, loss_box_reg, loss_theta = fastrcnn_loss(class_logits, box_regression, theta_preds, labels, regression_targets, theta_targets)
losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg, "loss_theta": loss_theta}
else:
boxes, scores, labels, thetas = self.postprocess_detections(class_logits, box_regression, theta_preds, proposals, image_shapes)
# print(f"{scores[0]=}")
# print(f"{labels[0]=}")
# print(f"{thetas[0]=}")
num_images = len(boxes)
for i in range(num_images):
result.append(
{
"boxes": boxes[i],
"labels": labels[i],
"scores": scores[i],
"thetas": thetas[i]
}
)
# print(f"{result}")
return result, losses