Thompson001 commited on
Commit
7fcb8df
·
verified ·
1 Parent(s): 7e2323f

Delete model

Browse files
Files changed (1) hide show
  1. model/deepcrack_model.py +0 -117
model/deepcrack_model.py DELETED
@@ -1,117 +0,0 @@
1
- # Author: Yahui Liu <yahui.liu@uintn.it>
2
-
3
- import torch
4
- import numpy as np
5
- import itertools
6
- from .base_model import BaseModel
7
- from .deepcrack_networks import define_deepcrack, BinaryFocalLoss
8
-
9
- class DeepCrackModel(BaseModel):
10
- """
11
- This class implements the DeepCrack model.
12
- DeepCrack paper: https://www.sciencedirect.com/science/article/pii/S0925231219300566
13
- """
14
- @staticmethod
15
- def modify_commandline_options(parser, is_train=True):
16
- """Add new dataset-specific options, and rewrite default values for existing options."""
17
- parser.add_argument('--lambda_side', type=float, default=1.0, help='weight for side output loss')
18
- parser.add_argument('--lambda_fused', type=float, default=1.0, help='weight for fused loss')
19
- return parser
20
-
21
- def __init__(self, opt):
22
- """Initialize the DeepCrack class.
23
-
24
- Parameters:
25
- opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
26
- """
27
- BaseModel.__init__(self, opt)
28
- # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
29
- self.loss_names = ['side', 'fused', 'total']
30
- # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
31
- self.display_sides = opt.display_sides
32
- self.visual_names = ['image', 'label_viz', 'fused']
33
- if self.display_sides:
34
- self.visual_names += ['side1', 'side2', 'side3', 'side4', 'side5']
35
- # specify the models you want to save to the disk.
36
- self.model_names = ['G']
37
-
38
- # define networks
39
- self.netG = define_deepcrack(opt.input_nc,
40
- opt.num_classes,
41
- opt.ngf,
42
- opt.norm,
43
- opt.init_type,
44
- opt.init_gain,
45
- self.gpu_ids)
46
-
47
- self.softmax = torch.nn.Softmax(dim=1)
48
-
49
- if self.isTrain:
50
- # define loss functions
51
- #self.weight = torch.from_numpy(np.array([0.0300, 1.0000], dtype='float32')).float().to(self.device)
52
- #self.criterionSeg = torch.nn.CrossEntropyLoss(weight=self.weight)
53
- if self.opt.loss_mode == 'focal':
54
- self.criterionSeg = BinaryFocalLoss()
55
- else:
56
- self.criterionSeg = nn.BCEWithLogitsLoss(size_average=True, reduce=True,
57
- pos_weight=torch.tensor(1.0/3e-2).to(self.device))
58
- self.weight_side = [0.5, 0.75, 1.0, 0.75, 0.5]
59
-
60
- # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
61
- self.optimizer = torch.optim.SGD(self.netG.parameters(), lr=opt.lr, momentum=0.9, weight_decay=2e-4)
62
- self.optimizers.append(self.optimizer)
63
-
64
- def set_input(self, input):
65
- """Unpack input data from the dataloader and perform necessary pre-processing steps.
66
- Parameters:
67
- input (dict): include the data itself and its metadata information.
68
- """
69
- self.image = input['image'].to(self.device)
70
- self.label = input['label'].to(self.device)
71
- #self.label3d = self.label.squeeze(1)
72
- self.image_paths = input['A_paths']
73
-
74
- def forward(self):
75
- """Run forward pass; called by both functions <optimize_parameters> and <test>."""
76
- self.outputs = self.netG(self.image)
77
-
78
- # for visualization
79
- self.label_viz = (self.label.float()-0.5)/0.5
80
- #self.fused = (self.softmax(self.outputs[-1])[:,1].detach().unsqueeze(1)-0.5)/0.5
81
- #if self.display_sides:
82
- # self.side1 = (self.softmax(self.outputs[0])[:,1].detach().unsqueeze(1)-0.5)/0.5
83
- # self.side2 = (self.softmax(self.outputs[1])[:,1].detach().unsqueeze(1)-0.5)/0.5
84
- # self.side3 = (self.softmax(self.outputs[2])[:,1].detach().unsqueeze(1)-0.5)/0.5
85
- # self.side4 = (self.softmax(self.outputs[3])[:,1].detach().unsqueeze(1)-0.5)/0.5
86
- # self.side5 = (self.softmax(self.outputs[4])[:,1].detach().unsqueeze(1)-0.5)/0.5
87
- self.fused = (torch.sigmoid(self.outputs[-1])-0.5)/0.5
88
-
89
- if self.display_sides:
90
- self.side1 = (torch.sigmoid(self.outputs[0])-0.5)/0.5
91
- self.side2 = (torch.sigmoid(self.outputs[1])-0.5)/0.5
92
- self.side3 = (torch.sigmoid(self.outputs[2])-0.5)/0.5
93
- self.side4 = (torch.sigmoid(self.outputs[3])-0.5)/0.5
94
- self.side5 = (torch.sigmoid(self.outputs[4])-0.5)/0.5
95
-
96
- def backward(self):
97
- """Calculate the loss"""
98
- lambda_side = self.opt.lambda_side
99
- lambda_fused = self.opt.lambda_fused
100
-
101
- self.loss_side = 0.0
102
- for out, w in zip(self.outputs[:-1], self.weight_side):
103
- #self.loss_side += self.criterionSeg(out, self.label3d) * w
104
- self.loss_side += self.criterionSeg(out, self.label) * w
105
-
106
- #self.loss_fused = self.criterionSeg(self.outputs[-1], self.label3d)
107
- self.loss_fused = self.criterionSeg(self.outputs[-1], self.label)
108
- self.loss_total = self.loss_side * lambda_side + self.loss_fused * lambda_fused
109
- self.loss_total.backward()
110
-
111
- def optimize_parameters(self, epoch=None):
112
- """Calculate losses, gradients, and update network weights; called in every training iteration"""
113
- # forward
114
- self.forward() # compute predictions.
115
- self.optimizer.zero_grad() # set G's gradients to zero
116
- self.backward() # calculate gradients for G
117
- self.optimizer.step() # update G's weights