WeekendZhou's picture
add cpu channel
711e925
import torch
from torch.utils.data import Subset
from PyQt5.QtGui import QImage
import numpy as np
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(123456)
torch.manual_seed(123456)
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=20, kernel_size=5, stride=2)
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.fc1 = torch.nn.Linear(in_features=720, out_features=10)
def forward(self, x):
x = self.conv1(x)
x = torch.relu(x)
x = self.pool1(x)
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
x = torch.nn.functional.softmax(x, dim=1)
return x
def train_and_save(save_path='mnist_cnn.pth'):
# 数据加载
mnist = datasets.MNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
subset = Subset(mnist, indices=range(60000))
loader = DataLoader(subset, batch_size=60000, shuffle=True)
x, y = next(iter(loader))
nepoch = 30
batch_size = 200
lr = 0.001
np.random.seed(123)
torch.manual_seed(123)
model = MyModel()
losses = []
opt = torch.optim.Adam(model.parameters(), lr=lr)
n = x.shape[0]
obs_id = np.arange(n) # [0, 1, ..., n-1]
# Run the whole data set `nepoch` times
for i in range(nepoch):
# Shuffle observation IDs
np.random.shuffle(obs_id)
# Update on mini-batches
for j in range(0, n, batch_size):
# Create mini-batch
x_mini_batch = x[obs_id[j:(j + batch_size)]]
y_mini_batch = y[obs_id[j:(j + batch_size)]]
# Compute loss
pred = model(x_mini_batch)
lossfn = torch.nn.NLLLoss()
loss = lossfn(torch.log(pred), y_mini_batch)
# Compute gradient and update parameters
opt.zero_grad()
loss.backward()
opt.step()
losses.append(loss.item())
if (j // batch_size) % 20 == 0:
print(f"epoch {i}, batch {j // batch_size}, loss = {loss.item()}")
torch.save({
'model_state': model.state_dict(),
'input_size': (1, 28, 28),
'output_size': 10
}, save_path)
# 函数:加载已训练模型
def load_trained_model(model_path='mnist_cnn.pth'):
model = MyModel()
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint['model_state'])
model.eval()
return model
def predict_user_image(grid_matrix,model,device):
# 输入验证
if grid_matrix.shape != (28, 28):
raise ValueError(f"输入形状应为(28,28),实际获取{grid_matrix.shape}")
# 预处理
input_tensor = torch.from_numpy(grid_matrix).float().unsqueeze(0).unsqueeze(0) / 255.0
# 推理:确保将张量移动到正确的设备(GPU 或 CPU)
input_tensor = input_tensor.to(device)
# 推理
with torch.no_grad():
outputs = model(input_tensor.to(device))
# 获取预测结果
probabilities = outputs.cpu().numpy().flatten()
pred_class = int(np.argmax(probabilities))
# 关键修复:必须返回两个值
return pred_class, probabilities
if __name__ == '__main__':
# 训练并保存模型
train_and_save()
model = load_trained_model()