更强的模型
Browse files- cyclegan.py +19 -3
- model_data/G_model_B2A_last_epoch_weights.pth +1 -1
cyclegan.py
CHANGED
|
@@ -19,6 +19,10 @@ class CYCLEGAN(object):
|
|
| 19 |
#-----------------------------------------------#
|
| 20 |
"input_shape" : [112, 112],
|
| 21 |
#-------------------------------#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
# 是否使用Cuda
|
| 23 |
# 没有GPU可以设置成False
|
| 24 |
#-------------------------------#
|
|
@@ -64,9 +68,14 @@ class CYCLEGAN(object):
|
|
| 64 |
#---------------------------------------------------------#
|
| 65 |
image = cvtColor(image)
|
| 66 |
#---------------------------------------------------------#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
# 添加上batch_size维度
|
| 68 |
#---------------------------------------------------------#
|
| 69 |
-
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(
|
| 70 |
|
| 71 |
with torch.no_grad():
|
| 72 |
images = torch.from_numpy(image_data)
|
|
@@ -80,10 +89,17 @@ class CYCLEGAN(object):
|
|
| 80 |
#---------------------------------------------------#
|
| 81 |
# 转为numpy
|
| 82 |
#---------------------------------------------------#
|
| 83 |
-
pr = pr.permute(1, 2, 0).cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
image = postprocess_output(pr)
|
| 86 |
-
image = np.clip(image, 0, 255)
|
| 87 |
image = Image.fromarray(np.uint8(image))
|
| 88 |
|
| 89 |
return image
|
|
|
|
| 19 |
#-----------------------------------------------#
|
| 20 |
"input_shape" : [112, 112],
|
| 21 |
#-------------------------------#
|
| 22 |
+
# 是否进行不失真的resize
|
| 23 |
+
#-------------------------------#
|
| 24 |
+
"letterbox_image" : True,
|
| 25 |
+
#-------------------------------#
|
| 26 |
# 是否使用Cuda
|
| 27 |
# 没有GPU可以设置成False
|
| 28 |
#-------------------------------#
|
|
|
|
| 68 |
#---------------------------------------------------------#
|
| 69 |
image = cvtColor(image)
|
| 70 |
#---------------------------------------------------------#
|
| 71 |
+
# 给图像增加灰条,实现不失真的resize
|
| 72 |
+
# 也可以直接resize进行识别
|
| 73 |
+
#---------------------------------------------------------#
|
| 74 |
+
image_data, nw, nh = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
|
| 75 |
+
#---------------------------------------------------------#
|
| 76 |
# 添加上batch_size维度
|
| 77 |
#---------------------------------------------------------#
|
| 78 |
+
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
| 79 |
|
| 80 |
with torch.no_grad():
|
| 81 |
images = torch.from_numpy(image_data)
|
|
|
|
| 89 |
#---------------------------------------------------#
|
| 90 |
# 转为numpy
|
| 91 |
#---------------------------------------------------#
|
| 92 |
+
pr = pr.permute(1, 2, 0).cpu().numpy()
|
| 93 |
+
|
| 94 |
+
#--------------------------------------#
|
| 95 |
+
# 将灰条部分截取掉
|
| 96 |
+
#--------------------------------------#
|
| 97 |
+
if nw is not None:
|
| 98 |
+
pr = pr[int((self.input_shape[0] - nh) // 2) : int((self.input_shape[0] - nh) // 2 + nh), \
|
| 99 |
+
int((self.input_shape[1] - nw) // 2) : int((self.input_shape[1] - nw) // 2 + nw)]
|
| 100 |
+
|
| 101 |
|
| 102 |
image = postprocess_output(pr)
|
|
|
|
| 103 |
image = Image.fromarray(np.uint8(image))
|
| 104 |
|
| 105 |
return image
|
model_data/G_model_B2A_last_epoch_weights.pth
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 11888773
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1815cd8f77471a8712b9a80b20da4cd7afe7aad2b32ad48cd205d1c370a65dc2
|
| 3 |
size 11888773
|