Uprm-i1 / test_v_02.py
GQFth's picture
Rename test4.py to test_v_02.py
b438656 verified
# 5k CIFAR-10训练集 3层cnn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertModel
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10 # 示例图像数据集
import swanlab # 新增:实验跟踪
import os
# 1. 增强自定义cnn 和 Droput
class MultimodalFramework(nn.Module):
def __init__(self, text_hidden=768, image_hidden=512, fusion_hidden=256, num_classes=10):
# 调用父类的__init__()方法,先把父类初始化,再继续运行
# 注册 _praameters 字典 登记所有"可以训练的张量"
# 注册 _modules 字典 登记所有 子模块
# 注册 _buffers 字典 随模型移动/保存的张量
super(MultimodalFramework, self).__init__()
# BERT 文本编码器
# BERT 基于双向tf
self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
# == 改进为3层cnn ==
# 64 * 16 * 16
self.image_conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) # CIFAR-10 3通道
self.image_bn1 = nn.BatchNorm2d(64) # bn
self.image_pool = nn.MaxPool2d(2, 2)
# 128 * 8 * 8
self.image_conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.image_bn2 = nn.BatchNorm2d(128)
# 256 * 4 * 4
self.image_conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.image_bn3 = nn.BatchNorm2d(256)
# 256 * 4 * 4 = 4096 展平输入尺寸
self.image_fc = nn.Linear(256 * 4 * 4, image_hidden) # 4096 → 512
# 融合与分类层?
self.fusion_fc = nn.Linear(text_hidden + image_hidden, fusion_hidden) # 修复打字: fusion_hidden
self.dropout = nn.Dropout(0.5) # 增强正则化 0.3 -> 0.5
self.classifier = nn.Linear(fusion_hidden, num_classes)
def forward(self, text_input, attention_mask, image_input):
# 文本编码
text_outputs = self.text_encoder(
input_ids = text_input,
attention_mask = attention_mask
).last_hidden_state # [batch, seq_len, 768]
text_features = text_outputs.mean(dim=1) # 修复: 平均seq_len维, 保持 [batch, 768]
x = self.image_pool(F.relu(self.image_bn1(self.image_conv1(image_input)))) # [batch, 64, 16, 16]
x = self.image_pool(F.relu(self.image_bn2(self.image_conv2(x)))) # [batch, 128, 8, 8]
x = self.image_pool(F.relu(self.image_bn3(self.image_conv3(x))))
x = x.view(x.size(0), -1) # [batch, 4096]
image_features = F.relu(self.image_fc(x)) # [batch, 512]
# 融合 分类
fused = torch.cat([text_features, image_features], dim=1) # [batch, 1280] - 现在2维, 正常拼接
fused = self.dropout(F.relu(self.fusion_fc(fused))) # [batch, 256]
return self.classifier(fused) # [batch, 10]
# 数据集修正
class MultimodalTrainDataset(torch.utils.data.Dataset):
def __init__(self, root='E:/temp/data', num_samples=50000, transform=None):
self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
self.dataset = CIFAR10(root=root, train=True, download=True, transform=transform)
self.num_samples = min(num_samples, len(self.dataset))
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.max_length = 16
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
image, label = self.dataset[idx]
# 去掉了噪声
text_label = label
text = f"a photo of a {self.classes[text_label]}"
encoded = self.tokenizer(
text,
padding='max_length',
max_length=self.max_length,
truncation=True,
return_tensors='pt'
)
return image, encoded['input_ids'].squeeze(0), encoded['attention_mask'].squeeze(0), label
# 测试用 Test
class MultimodalTestDataset(torch.utils.data.Dataset):
def __init__(self, root='E:/temp/data', transform=None):
self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
self.dataset = CIFAR10(root=root, train=False, download=True, transform=transform)
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.max_length = 16
'''
# 预生成所有文本
self.text_inputs = []
for i in range(len(self.dataset)):
label = self.dataset.targets[i]
if torch.rand(1).item() < 0.5:
text_label = label
else:
text_label = torch.randint(0, 10, (1,)).item()
text = f"a photo of a {self.classes[text_label]}"
enc = self.tokenizer(
text,
padding='max_length',
max_length=self.max_length,
truncation=True,
return_tensors='pt'
)
self.text_inputs.append({
'input_ids': enc['input_ids'].squeeze(0),
'attention_mask': enc['attention_mask'].squeeze(0)
})
'''
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image, label = self.dataset[idx]
text = f"a photo of a {self.classes[label]}"
assert isinstance(image, torch.Tensor), f"Expected Tensor, got {type(image)}"
encoded = self.tokenizer(
text,
padding='max_length',
max_length=self.max_length,
truncation=True,
return_tensors='pt'
)
return image, encoded['input_ids'].squeeze(0), encoded['attention_mask'].squeeze(0), label
def train_epoch(model, dataloader, optimizer, criterion, epoch, device,log_step=50):
model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (images, text_ids, attention_mask, labels) in enumerate(dataloader):
# 移动到显卡上
images = images.to(device)
text_ids = text_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(text_ids, attention_mask, images) # [batch, 10]
# targets = torch.randint(0, 10, (images.size(0),)) # 模拟标签
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
pred = outputs.argmax(dim=1)
correct += (pred == labels).sum().item()
total += labels.size(0)
# SwanLab日志(每50步,删掉图像/文本log,只留指标)
if batch_idx % log_step == 0:
current_acc = 100. * correct / total
swanlab.log({
"epoch": epoch,
"batch": batch_idx,
"loss": loss.item(),
"accuracy": current_acc,
"learning_rate": optimizer.param_groups[0]['lr']
})
avg_loss = total_loss / len(dataloader)
avg_acc = 100. * correct / total
print(f'Epoch {epoch+1}: Loss: {avg_loss:.4f}, Acc: {avg_acc:.2f}%')
# Epoch末日志
swanlab.log({
"epoch_end_loss": avg_loss,
"epoch_end_acc": avg_acc
})
# 测试模型 test_model
def test_model(model, test_loader, device):
model.eval()
test_correct = 0
test_total = 0
with torch.no_grad():
for images, text_ids, attention_mask, labels in test_loader:
# 移动到显卡上
images = images.to(device)
text_ids = text_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)
outputs = model(text_ids, attention_mask, images)
pred = outputs.argmax(dim=1)
test_correct += (pred == labels).sum().item()
test_total += labels.size(0)
test_acc = 100. * test_correct / test_total
print(f"Test Accuracy: {test_acc:.2f}%")
swanlab.log({"test_accuracy": test_acc})
return test_acc
if __name__ == '__main__':
# 初始化SwanLab
swanlab.init(project="multimodal-object-detection", anonymous=True)
# 2. 准备数据
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 设置训练集
dataset = MultimodalTrainDataset(num_samples=50000, transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
test_dataset = MultimodalTestDataset(root='E:/temp/data', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
# 训练用 Train
# 3. 训练循环(加SwanLab日志)
model = MultimodalFramework()
criterion = nn.CrossEntropyLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device) # 将模型参数和缓冲区移到 GPU(如果可用)
# 训练策略更改
TOTAL_EPOCHS = 10
FROZEN_EPOCHS = 4
# 冻结 BERT 只训练 CNN 和 融合层
print("-" * 20 + " Stage 1: Frozen BERT Training " + "-" * 20)
for param in model.text_encoder.parameters():
param.requires_grad = False
optimizer1 = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
for epoch in range(FROZEN_EPOCHS):
train_epoch(model, dataloader, optimizer1, criterion, epoch, device)
# 再解冻 BERT 训练 全局微调
print("-" * 20 + " Stage 2: Global Fine-Tuning " + "-" * 20)
for param in model.text_encoder.parameters():
param.requires_grad = True
# 降低学习率进行微调
optimizer2 = optim.Adam(model.parameters(), lr = 2e-5)
for epoch in range(FROZEN_EPOCHS, TOTAL_EPOCHS):
train_epoch(model, dataloader, optimizer2, criterion, epoch, device)
# 设置测试集 CIFAR-10 Terst
test_model(model, test_loader, device)
# 保存模型本地(网页已记录指标)
torch.save(model.state_dict(), 'multimodal_cifar10_epoch10.pth')
print('模型保存完成!')
swanlab.finish() # 结束实验,网页曲线可用