NewBreaker
commited on
Commit
·
b83d9ec
1
Parent(s):
47d7bda
auto git
Browse files- tools/ResNet_MNIST.py +206 -0
- tools/ResNet_cat_vs_dog_Ram.py +15 -0
- tools/data_test.py +80 -0
- tools/data_train.py +47 -0
- tools/将数据集按照比例进行拆分.py +84 -0
tools/ResNet_MNIST.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# In[1] 导入所需工具包
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torchvision
|
| 5 |
+
from torchvision import datasets, transforms
|
| 6 |
+
import time
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
from math import floor, ceil
|
| 9 |
+
from torch.utils.data import DataLoader,TensorDataset
|
| 10 |
+
# import torchvision.transforms as transforms
|
| 11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 12 |
+
print(device)
|
| 13 |
+
# In[1] 设置超参数
|
| 14 |
+
num_epochs = 60
|
| 15 |
+
batch_size = 1000
|
| 16 |
+
learning_rate = 0.001
|
| 17 |
+
|
| 18 |
+
# In[2] 获取数据包括训练数据和测试数据
|
| 19 |
+
|
| 20 |
+
transform = torchvision.transforms.Compose([
|
| 21 |
+
torchvision.transforms.ToTensor(),
|
| 22 |
+
torchvision.transforms.Normalize(
|
| 23 |
+
(0.1307,), (0.3081,))
|
| 24 |
+
])
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
train_set = torchvision.datasets.MNIST(root='MNIST', train=True, download=True)
|
| 28 |
+
train_data = train_set.data.float().unsqueeze(1) / 255.0
|
| 29 |
+
train_labels = train_set.targets
|
| 30 |
+
train_dataset = TensorDataset(train_data,train_labels)
|
| 31 |
+
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
test_set = torchvision.datasets.MNIST(root='MNIST', train=False, download=True)
|
| 35 |
+
test_data = test_set.data.float().unsqueeze(1) / 255.0
|
| 36 |
+
test_labels = test_set.targets
|
| 37 |
+
test_dataset = TensorDataset(test_data,test_labels)
|
| 38 |
+
test_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=True)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# In[1] 定义卷积核
|
| 43 |
+
def conv3x3(in_channels, out_channels, stride=1):
|
| 44 |
+
return nn.Conv2d(in_channels, out_channels, kernel_size=3,
|
| 45 |
+
stride=stride, padding=1, bias=True)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# In[1] 定义残差块
|
| 49 |
+
class ResidualBlock(nn.Module):
|
| 50 |
+
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
|
| 51 |
+
super(ResidualBlock, self).__init__()
|
| 52 |
+
self.conv1 = conv3x3(in_channels, out_channels, stride)
|
| 53 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
| 54 |
+
self.relu = nn.ReLU(inplace=True)
|
| 55 |
+
self.conv2 = conv3x3(out_channels, out_channels)
|
| 56 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
| 57 |
+
self.downsample = downsample
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
residual = x
|
| 61 |
+
out = self.conv1(x)
|
| 62 |
+
out = self.bn1(out)
|
| 63 |
+
out = self.relu(out)
|
| 64 |
+
out = self.conv2(out)
|
| 65 |
+
out = self.bn2(out)
|
| 66 |
+
# 下采样
|
| 67 |
+
if self.downsample:
|
| 68 |
+
residual = self.downsample(x)
|
| 69 |
+
out += residual
|
| 70 |
+
out = self.relu(out)
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# In[1] 搭建残差神经网络
|
| 75 |
+
class ResNet(nn.Module):
|
| 76 |
+
def __init__(self, block, layers, num_classes=10):
|
| 77 |
+
super(ResNet, self).__init__()
|
| 78 |
+
self.in_channels = 16
|
| 79 |
+
self.conv = conv3x3(1, 16)
|
| 80 |
+
self.bn = nn.BatchNorm2d(16)
|
| 81 |
+
self.relu = nn.ReLU(inplace=True)
|
| 82 |
+
# 构建残差块,恒等映射
|
| 83 |
+
# in_channels == out_channels and stride = 1 所以这里我们构建残差块,没有下采样
|
| 84 |
+
self.layer1 = self.make_layer(block, 16, layers[0], stride=1)
|
| 85 |
+
# 不构建残差块,进行了下采样
|
| 86 |
+
# layers中记录的是数字,表示对应位置的残差块数目
|
| 87 |
+
self.layer2 = self.make_layer(block, 32, layers[1], 2)
|
| 88 |
+
# 不构建残差块,进行了下采样
|
| 89 |
+
self.layer3 = self.make_layer(block, 64, layers[2], 2)
|
| 90 |
+
self.avg_pool = nn.AvgPool2d(8)
|
| 91 |
+
self.fc1 = nn.Linear(3136, 128)
|
| 92 |
+
self.normfc12 = nn.LayerNorm((128), eps=1e-5)
|
| 93 |
+
self.fc2 = nn.Linear(128, num_classes)
|
| 94 |
+
|
| 95 |
+
def make_layer(self, block, out_channels, blocks, stride=1):
|
| 96 |
+
downsample = None
|
| 97 |
+
if (stride != 1) or (self.in_channels != out_channels):
|
| 98 |
+
downsample = nn.Sequential(
|
| 99 |
+
conv3x3(self.in_channels, out_channels, stride=stride),
|
| 100 |
+
nn.BatchNorm2d(out_channels))
|
| 101 |
+
layers = []
|
| 102 |
+
layers.append(block(self.in_channels, out_channels, stride, downsample))
|
| 103 |
+
# 当out_channels = 32时,in_channels也变成32了
|
| 104 |
+
self.in_channels = out_channels
|
| 105 |
+
# blocks是残差块的数目
|
| 106 |
+
# 残差块之后的网络结构,是out_channels->out_channels的
|
| 107 |
+
# 可以说,make_layer做的是输出尺寸相同的所有网络结构
|
| 108 |
+
# 由于输出尺寸会改变,我们用make_layers去生成一大块对应尺寸完整网络结构
|
| 109 |
+
for i in range(1, blocks):
|
| 110 |
+
layers.append(block(out_channels, out_channels))
|
| 111 |
+
return nn.Sequential(*layers)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
out = self.conv(x)
|
| 115 |
+
out = self.bn(out)
|
| 116 |
+
out = self.relu(out)
|
| 117 |
+
# layer1是三块in_channels等于16的网络结构,包括三个恒等映射
|
| 118 |
+
out = self.layer1(out)
|
| 119 |
+
# layer2包括了16->32下采样,然后是32的三个恒等映射
|
| 120 |
+
out = self.layer2(out)
|
| 121 |
+
# layer3包括了32->64的下采样,然后是64的三个恒等映射
|
| 122 |
+
out = self.layer3(out)
|
| 123 |
+
# out = self.avg_pool(out)
|
| 124 |
+
# 全连接压缩
|
| 125 |
+
# out.size(0)可以看作是batch_size
|
| 126 |
+
out = out.view(out.size(0), -1)
|
| 127 |
+
out = self.fc1(out)
|
| 128 |
+
out = self.normfc12(out)
|
| 129 |
+
out = self.relu(out)
|
| 130 |
+
out = self.fc2(out)
|
| 131 |
+
return out
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# In[1] 定义模型和损失函数
|
| 135 |
+
# [2,2,2]表示的是不同in_channels下的恒等映射数目
|
| 136 |
+
model = ResNet(ResidualBlock, [2, 2, 2]).to(device)
|
| 137 |
+
criterion = nn.CrossEntropyLoss()
|
| 138 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# In[1] 设置一个通过优化器更新学习率的函数
|
| 142 |
+
def update_lr(optimizer, lr):
|
| 143 |
+
for param_group in optimizer.param_groups:
|
| 144 |
+
param_group['lr'] = lr
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# In[1] 定义测试函数
|
| 148 |
+
def test(model, test_loader):
|
| 149 |
+
model.eval()
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
correct = 0
|
| 152 |
+
total = 0
|
| 153 |
+
for images, labels in test_loader:
|
| 154 |
+
images = images.to(device)
|
| 155 |
+
labels = labels.to(device)
|
| 156 |
+
outputs = model(images)
|
| 157 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 158 |
+
total += labels.size(0)
|
| 159 |
+
correct += (predicted == labels).sum().item()
|
| 160 |
+
|
| 161 |
+
print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# In[1] 训练模型更新学习率
|
| 165 |
+
total_step = len(train_loader)
|
| 166 |
+
curr_lr = learning_rate
|
| 167 |
+
for epoch in range(num_epochs):
|
| 168 |
+
in_epoch = time.time()
|
| 169 |
+
for i, (images, labels) in enumerate(train_loader):
|
| 170 |
+
images = images.to(device)
|
| 171 |
+
labels = labels.to(device)
|
| 172 |
+
|
| 173 |
+
# Forward pass
|
| 174 |
+
outputs = model(images)
|
| 175 |
+
loss = criterion(outputs, labels)
|
| 176 |
+
|
| 177 |
+
# Backward and optimize
|
| 178 |
+
optimizer.zero_grad()
|
| 179 |
+
loss.backward()
|
| 180 |
+
optimizer.step()
|
| 181 |
+
|
| 182 |
+
if (i + 1) % 100 == 0:
|
| 183 |
+
print("Epoch [{}/{}], Step [{}/{}] Loss: {:.4f}"
|
| 184 |
+
.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
|
| 185 |
+
test(model, test_loader)
|
| 186 |
+
out_epoch = time.time()
|
| 187 |
+
print(f"use {(out_epoch - in_epoch) // 60}min{(out_epoch - in_epoch) % 60}s")
|
| 188 |
+
if (epoch + 1) % 20 == 0:
|
| 189 |
+
curr_lr /= 3
|
| 190 |
+
update_lr(optimizer, curr_lr)
|
| 191 |
+
# In[1] 测试模型并保存
|
| 192 |
+
model.eval()
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
correct = 0
|
| 195 |
+
total = 0
|
| 196 |
+
for images, labels in test_loader:
|
| 197 |
+
images = images.to(device)
|
| 198 |
+
labels = labels.to(device)
|
| 199 |
+
outputs = model(images)
|
| 200 |
+
_, predicted = torch.max(outputs.data, 1)
|
| 201 |
+
total += labels.size(0)
|
| 202 |
+
correct += (predicted == labels).sum().item()
|
| 203 |
+
|
| 204 |
+
print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))
|
| 205 |
+
|
| 206 |
+
torch.save(model.state_dict(), '../resnet.ckpt')
|
tools/ResNet_cat_vs_dog_Ram.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torchvision import datasets, transforms
|
| 3 |
+
from torch.utils.data import DataLoader,TensorDataset
|
| 4 |
+
|
| 5 |
+
transform = transforms.Compose([
|
| 6 |
+
transforms.Resize((512, 512)),
|
| 7 |
+
transforms.ToTensor(),
|
| 8 |
+
])
|
| 9 |
+
|
| 10 |
+
# 加载训练集和测试集
|
| 11 |
+
train_set = datasets.ImageFolder(root='data/cat_vs_dog/train', transform=transform)
|
| 12 |
+
test_set = datasets.ImageFolder(root='data/cat_vs_dog/test', transform=transform)
|
| 13 |
+
|
| 14 |
+
train_data = train_set.imgs
|
| 15 |
+
print("train_data:", train_data)
|
tools/data_test.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
import cv2
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
|
| 9 |
+
import torchvision
|
| 10 |
+
from torchvision import models,transforms,datasets
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch import optim
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import shutil
|
| 21 |
+
import random
|
| 22 |
+
def make_dir(path):
|
| 23 |
+
import os
|
| 24 |
+
dir = os.path.exists(path)
|
| 25 |
+
if not dir:
|
| 26 |
+
os.makedirs(path)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_filename_and_houzhui(full_path):
|
| 30 |
+
import os
|
| 31 |
+
path, file_full_name = os.path.split(full_path)
|
| 32 |
+
file_name, 后缀名 = os.path.splitext(file_full_name)
|
| 33 |
+
return path,file_name,后缀名
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
dataset_root_path = '../data/cat_vs_dog'
|
| 37 |
+
train_path_cat_new = os.path.join(dataset_root_path, 'new/train/cat')
|
| 38 |
+
train_path_dog_new = os.path.join(dataset_root_path, 'new/train/dog')
|
| 39 |
+
|
| 40 |
+
test_path_cat_new = os.path.join(dataset_root_path, 'new/test/cat')
|
| 41 |
+
test_path_dog_new = os.path.join(dataset_root_path, 'new/test/dog')
|
| 42 |
+
|
| 43 |
+
make_dir(train_path_cat_new)
|
| 44 |
+
make_dir(train_path_dog_new)
|
| 45 |
+
make_dir(test_path_cat_new)
|
| 46 |
+
make_dir(test_path_dog_new)
|
| 47 |
+
|
| 48 |
+
image_dir_path = os.path.join(dataset_root_path,'train')
|
| 49 |
+
image_name_list = os.listdir(image_dir_path)
|
| 50 |
+
for image_name in tqdm(image_name_list):
|
| 51 |
+
image_path = os.path.join(image_dir_path,image_name)
|
| 52 |
+
path, file_name, 后缀名 = get_filename_and_houzhui(full_path=image_path)
|
| 53 |
+
# print("file_name:", file_name)
|
| 54 |
+
# 定义随机数的范围和对应的概率
|
| 55 |
+
nums = [1, 2]
|
| 56 |
+
probs = [0.9, 0.1] #设定训练集和测试集的比率
|
| 57 |
+
|
| 58 |
+
random_nums = random.choices(nums, weights=probs)[0]
|
| 59 |
+
|
| 60 |
+
if(random_nums == 1): #摇筛子如果摇到了1,那么就是训练集
|
| 61 |
+
|
| 62 |
+
if('cat' in file_name):
|
| 63 |
+
shutil.copy(image_path, train_path_cat_new)
|
| 64 |
+
elif('dog' in file_name):
|
| 65 |
+
shutil.copy(image_path, train_path_dog_new)
|
| 66 |
+
elif(random_nums == 2): #摇骰子如果摇到了2,那么就是测试集
|
| 67 |
+
if('cat' in file_name):
|
| 68 |
+
shutil.copy(image_path, test_path_cat_new)
|
| 69 |
+
elif('dog' in file_name):
|
| 70 |
+
shutil.copy(image_path, test_path_dog_new)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
tools/data_train.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import shutil
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def make_dir(path):
|
| 9 |
+
import os
|
| 10 |
+
dir = os.path.exists(path)
|
| 11 |
+
if not dir:
|
| 12 |
+
os.makedirs(path)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_filename_and_houzhui(full_path):
|
| 16 |
+
import os
|
| 17 |
+
path, file_full_name = os.path.split(full_path)
|
| 18 |
+
file_name, 后缀名 = os.path.splitext(file_full_name)
|
| 19 |
+
return path,file_name,后缀名
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
dataset_root_path = '../data/cat_vs_dog'
|
| 23 |
+
train_path_cat_new = os.path.join(dataset_root_path, 'new/train/cat')
|
| 24 |
+
train_path_dog_new = os.path.join(dataset_root_path, 'new/train/dog')
|
| 25 |
+
make_dir(train_path_cat_new)
|
| 26 |
+
make_dir(train_path_dog_new)
|
| 27 |
+
|
| 28 |
+
image_dir_path = os.path.join(dataset_root_path,'train')
|
| 29 |
+
image_name_list = os.listdir(image_dir_path)
|
| 30 |
+
for image_name in image_name_list:
|
| 31 |
+
image_path = os.path.join(image_dir_path,image_name)
|
| 32 |
+
path, file_name, 后缀名 = get_filename_and_houzhui(full_path=image_path)
|
| 33 |
+
print("file_name:", file_name)
|
| 34 |
+
|
| 35 |
+
if('cat' in file_name):
|
| 36 |
+
shutil.copy(image_path,train_path_cat_new)
|
| 37 |
+
elif('dog' in file_name):
|
| 38 |
+
shutil.copy(image_path, train_path_dog_new)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
tools/将数据集按照比例进行拆分.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import cv2
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
|
| 10 |
+
import torchvision
|
| 11 |
+
from torchvision import models,transforms,datasets
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
from torch import optim
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import shutil
|
| 22 |
+
import random
|
| 23 |
+
def make_dir(path):
|
| 24 |
+
import os
|
| 25 |
+
dir = os.path.exists(path)
|
| 26 |
+
if not dir:
|
| 27 |
+
os.makedirs(path)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_filename_and_houzhui(full_path):
|
| 31 |
+
import os
|
| 32 |
+
path, file_full_name = os.path.split(full_path)
|
| 33 |
+
file_name, 后缀名 = os.path.splitext(file_full_name)
|
| 34 |
+
return path,file_name,后缀名
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
dataset_root_path = '../data/cat_vs_dog'
|
| 38 |
+
train_path_cat_new = os.path.join(dataset_root_path, 'new/train/cat')
|
| 39 |
+
train_path_dog_new = os.path.join(dataset_root_path, 'new/train/dog')
|
| 40 |
+
|
| 41 |
+
test_path_cat_new = os.path.join(dataset_root_path, 'new/test/cat')
|
| 42 |
+
test_path_dog_new = os.path.join(dataset_root_path, 'new/test/dog')
|
| 43 |
+
|
| 44 |
+
make_dir(train_path_cat_new)
|
| 45 |
+
make_dir(train_path_dog_new)
|
| 46 |
+
make_dir(test_path_cat_new)
|
| 47 |
+
make_dir(test_path_dog_new)
|
| 48 |
+
|
| 49 |
+
image_dir_path = os.path.join(dataset_root_path,'train')
|
| 50 |
+
image_name_list = os.listdir(image_dir_path)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
for image_name in tqdm(image_name_list):
|
| 55 |
+
image_path = os.path.join(image_dir_path,image_name)
|
| 56 |
+
path, file_name, 后缀名 = get_filename_and_houzhui(full_path=image_path)
|
| 57 |
+
# print("file_name:", file_name)
|
| 58 |
+
# 定义随机数的范围和对应的概率
|
| 59 |
+
nums = [1, 2]
|
| 60 |
+
probs = [0.9, 0.1] #设定训练集和测试集的比率
|
| 61 |
+
|
| 62 |
+
random_nums = random.choices(nums, weights=probs)[0]
|
| 63 |
+
|
| 64 |
+
if(random_nums == 1): #摇筛子如果摇到了1,那么就是训练集
|
| 65 |
+
|
| 66 |
+
if('cat' in file_name):
|
| 67 |
+
shutil.copy(image_path, train_path_cat_new)
|
| 68 |
+
elif('dog' in file_name):
|
| 69 |
+
shutil.copy(image_path, train_path_dog_new)
|
| 70 |
+
elif(random_nums == 2): #摇骰子如果摇到了2,那么就是测试集
|
| 71 |
+
if('cat' in file_name):
|
| 72 |
+
shutil.copy(image_path, test_path_cat_new)
|
| 73 |
+
elif('dog' in file_name):
|
| 74 |
+
shutil.copy(image_path, test_path_dog_new)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|