| --- | |
| language: zh | |
| tags: | |
| - pytorch | |
| - image-classification | |
| - cifar10 | |
| datasets: | |
| - cifar10 | |
| metrics: | |
| - accuracy | |
| --- | |
| # CIFAR10 图像分类模型 | |
| 这个模型是在CIFAR10数据集上训练的CNN分类器。 | |
| ## 模型描述 | |
| - 输入: 3x32x32 的RGB图像 | |
| - 输出: 10个类别的概率分布 | |
| - 架构: 3层CNN + 全连接层 | |
| ## 使用方式 | |
| ```python | |
| from torchvision import transforms | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| # 预处理图像 | |
| image = transform(image) | |
| # 进行预测 | |
| outputs = model(image.unsqueeze(0)) | |
| predicted_class = outputs.argmax(1).item() | |
| ``` | |