Thompson001 commited on
Commit
5ab20fd
·
verified ·
1 Parent(s): 7fcb8df

Upload deepcrack_model.py

Browse files
Files changed (1) hide show
  1. models/deepcrack_model.py +117 -0
models/deepcrack_model.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Author: Yahui Liu <yahui.liu@uintn.it>
2
+
3
+ import torch
4
+ import numpy as np
5
+ import itertools
6
+ from .base_model import BaseModel
7
+ from .deepcrack_networks import define_deepcrack, BinaryFocalLoss
8
+
9
+ class DeepCrackModel(BaseModel):
10
+ """
11
+ This class implements the DeepCrack model.
12
+ DeepCrack paper: https://www.sciencedirect.com/science/article/pii/S0925231219300566
13
+ """
14
+ @staticmethod
15
+ def modify_commandline_options(parser, is_train=True):
16
+ """Add new dataset-specific options, and rewrite default values for existing options."""
17
+ parser.add_argument('--lambda_side', type=float, default=1.0, help='weight for side output loss')
18
+ parser.add_argument('--lambda_fused', type=float, default=1.0, help='weight for fused loss')
19
+ return parser
20
+
21
+ def __init__(self, opt):
22
+ """Initialize the DeepCrack class.
23
+
24
+ Parameters:
25
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
26
+ """
27
+ BaseModel.__init__(self, opt)
28
+ # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
29
+ self.loss_names = ['side', 'fused', 'total']
30
+ # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
31
+ self.display_sides = opt.display_sides
32
+ self.visual_names = ['image', 'label_viz', 'fused']
33
+ if self.display_sides:
34
+ self.visual_names += ['side1', 'side2', 'side3', 'side4', 'side5']
35
+ # specify the models you want to save to the disk.
36
+ self.model_names = ['G']
37
+
38
+ # define networks
39
+ self.netG = define_deepcrack(opt.input_nc,
40
+ opt.num_classes,
41
+ opt.ngf,
42
+ opt.norm,
43
+ opt.init_type,
44
+ opt.init_gain,
45
+ self.gpu_ids)
46
+
47
+ self.softmax = torch.nn.Softmax(dim=1)
48
+
49
+ if self.isTrain:
50
+ # define loss functions
51
+ #self.weight = torch.from_numpy(np.array([0.0300, 1.0000], dtype='float32')).float().to(self.device)
52
+ #self.criterionSeg = torch.nn.CrossEntropyLoss(weight=self.weight)
53
+ if self.opt.loss_mode == 'focal':
54
+ self.criterionSeg = BinaryFocalLoss()
55
+ else:
56
+ self.criterionSeg = nn.BCEWithLogitsLoss(size_average=True, reduce=True,
57
+ pos_weight=torch.tensor(1.0/3e-2).to(self.device))
58
+ self.weight_side = [0.5, 0.75, 1.0, 0.75, 0.5]
59
+
60
+ # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
61
+ self.optimizer = torch.optim.SGD(self.netG.parameters(), lr=opt.lr, momentum=0.9, weight_decay=2e-4)
62
+ self.optimizers.append(self.optimizer)
63
+
64
+ def set_input(self, input):
65
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
66
+ Parameters:
67
+ input (dict): include the data itself and its metadata information.
68
+ """
69
+ self.image = input['image'].to(self.device)
70
+ self.label = input['label'].to(self.device)
71
+ #self.label3d = self.label.squeeze(1)
72
+ self.image_paths = input['A_paths']
73
+
74
+ def forward(self):
75
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
76
+ self.outputs = self.netG(self.image)
77
+
78
+ # for visualization
79
+ self.label_viz = (self.label.float()-0.5)/0.5
80
+ #self.fused = (self.softmax(self.outputs[-1])[:,1].detach().unsqueeze(1)-0.5)/0.5
81
+ #if self.display_sides:
82
+ # self.side1 = (self.softmax(self.outputs[0])[:,1].detach().unsqueeze(1)-0.5)/0.5
83
+ # self.side2 = (self.softmax(self.outputs[1])[:,1].detach().unsqueeze(1)-0.5)/0.5
84
+ # self.side3 = (self.softmax(self.outputs[2])[:,1].detach().unsqueeze(1)-0.5)/0.5
85
+ # self.side4 = (self.softmax(self.outputs[3])[:,1].detach().unsqueeze(1)-0.5)/0.5
86
+ # self.side5 = (self.softmax(self.outputs[4])[:,1].detach().unsqueeze(1)-0.5)/0.5
87
+ self.fused = (torch.sigmoid(self.outputs[-1])-0.5)/0.5
88
+
89
+ if self.display_sides:
90
+ self.side1 = (torch.sigmoid(self.outputs[0])-0.5)/0.5
91
+ self.side2 = (torch.sigmoid(self.outputs[1])-0.5)/0.5
92
+ self.side3 = (torch.sigmoid(self.outputs[2])-0.5)/0.5
93
+ self.side4 = (torch.sigmoid(self.outputs[3])-0.5)/0.5
94
+ self.side5 = (torch.sigmoid(self.outputs[4])-0.5)/0.5
95
+
96
+ def backward(self):
97
+ """Calculate the loss"""
98
+ lambda_side = self.opt.lambda_side
99
+ lambda_fused = self.opt.lambda_fused
100
+
101
+ self.loss_side = 0.0
102
+ for out, w in zip(self.outputs[:-1], self.weight_side):
103
+ #self.loss_side += self.criterionSeg(out, self.label3d) * w
104
+ self.loss_side += self.criterionSeg(out, self.label) * w
105
+
106
+ #self.loss_fused = self.criterionSeg(self.outputs[-1], self.label3d)
107
+ self.loss_fused = self.criterionSeg(self.outputs[-1], self.label)
108
+ self.loss_total = self.loss_side * lambda_side + self.loss_fused * lambda_fused
109
+ self.loss_total.backward()
110
+
111
+ def optimize_parameters(self, epoch=None):
112
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
113
+ # forward
114
+ self.forward() # compute predictions.
115
+ self.optimizer.zero_grad() # set G's gradients to zero
116
+ self.backward() # calculate gradients for G
117
+ self.optimizer.step() # update G's weights