ResNet18 on Tiny-ImageNet(64x64)模型说明

概述

  • 该仓库包含基于 ResNet18 的训练实现,目标数据集为 zh-plus/tiny-imagenet(Tiny ImageNet,200 类,64x64)。
  • 训练脚本使用 PyTorch、huggingface datasets 与 torchvision transforms,支持 TensorBoard 日志和 ONNX 导出。

模型结构

  • 使用自定义 ResNet 实现,ResNet18 由 BasicBlock 组成(layers = [2,2,2,2])。
  • 输入尺寸:3x64x64。去掉原始 ResNet 的大卷积与最大池化层以适配小尺寸输入;增加 Dropout(0.5)。

数据与增强

  • 数据来源:load_dataset('zh-plus/tiny-imagenet'),使用 cache_dir 指定数据缓存路径(示例代码中为 ./data)。
  • 训练增强(示例):
    • RandomCrop(64, padding=4)
    • RandomHorizontalFlip()
    • RandomRotation(15)
    • ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1)
    • ToTensor() + Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
  • 验证集仅做 ToTensor() + Normalize。

训练细节(示例超参数)

  • 优化器:SGD(lr=0.1, momentum=0.9, weight_decay=5e-4)
  • 损失:CrossEntropyLoss(label_smoothing=0.1)
  • 学习率调度:MultiStepLR(milestones=[30,60,90], gamma=0.1)
  • Batch size:128(可根据显存调整)
  • Epochs:150(示例值)
  • 保存:每个 epoch 保存 checkpoint(./model/checkpoint_epoch.pth),在验证准确率提升时保存最佳模型,训练结束导出 ONNX(./model/resnet18_tiny_imagenet.onnx)。

日志与可视化

  • 使用 TensorBoard 记录:
    • Training/Iter_Loss、Training/Epoch_Loss、Training/Epoch_Accuracy、Training/Learning_Rate
    • Validation/Accuracy
  • 日志目录示例:./logs/resnet18_tiny_imagenet
  • 启动 TensorBoard:
    • tensorboard --logdir ./logs/resnet18_tiny_imagenet

运行示例

  • 假设训练代码文件名为 train_resnet18_tiny_imagenet.py:
    • python train_resnet18_tiny_imagenet.py
  • 若使用其它文件名,请替换为对应脚本名。

输出说明

  • Checkpoint:./model/checkpoint_epoch.pth(包含 epoch、model_state_dict、optimizer_state_dict、loss、val_acc)
  • ONNX:./model/resnet18_tiny_imagenet.onnx(导出时使用输入尺寸 1x3x64x64)

注意事项

  • 请核实环境依赖(PyTorch、torchvision、datasets、tensorboard)。
  • 根据显存调整 batch_size 与是否使用混合精度(本示例未启用 AMP)。
  • 若数据集 split 名称不同,请修改脚本中对 dataset['train']/dataset['valid'] 的访问。

联系方式

  • 仓库维护者信息或问题请在 Issues 中反馈。
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train zhouxzh/resnet18_tiny_imagenet