penguin218's picture
upload 4 files
3238d33 verified
# model.py
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models
from torch.optim import lr_scheduler
def get_mobilenet_model(num_classes=16):
"""
配置 MobileNetV3-Large 模型、优化器和学习率调度器
"""
model = models.mobilenet_v3_large(pretrained=True)
# 冻结所有层参数
for param in model.parameters():
param.requires_grad = False
# 解冻最后三个倒残差块
for name, param in model.named_parameters():
if 'features.13' in name or 'features.14' in name or 'features.15' in name:
param.requires_grad = True
# 修改分类器结构
model.classifier = nn.Sequential(
nn.Linear(960, 512),
nn.Hardswish(inplace=True),
nn.Dropout(0.5),
nn.Linear(512, 256),
nn.Dropout(0.3),
nn.Linear(256, num_classes)
)
# 设置优化器
optimizer = optim.AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=2e-4,
weight_decay=5e-5,
eps=1e-6
)
# 设置学习率调度器
scheduler = lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=50,
eta_min=1e-6
)
return model, optimizer, scheduler