Uprm-i1 / test_v_01.py
GQFth's picture
Rename test3.py to test_v_01.py
5c861f2 verified
# 5k CIFAR-10训练集 2层cnn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertModel
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10 # 示例图像数据集
import swanlab # 新增:实验跟踪
# 1. 多模态模型框架(同之前)
class MultimodalFramework(nn.Module):
def __init__(self, text_hidden=768, image_hidden=512, fusion_hidden=256, num_classes=10):
super(MultimodalFramework, self).__init__()
self.text_encoder = BertModel.from_pretrained('bert-base-uncased')
self.image_conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) # CIFAR-10 3通道
self.image_conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.image_pool = nn.MaxPool2d(2, 2)
self.image_fc = nn.Linear(128 * 8 * 8, image_hidden)
self.fusion_fc = nn.Linear(text_hidden + image_hidden, fusion_hidden) # 修复打字: fusion_hidden
self.dropout = nn.Dropout(0.3)
self.classifier = nn.Linear(fusion_hidden, num_classes)
def forward(self, text_input, attention_mask, image_input):
text_outputs = self.text_encoder(
input_ids = text_input,
attention_mask = attention_mask
).last_hidden_state # [batch, seq_len, 768]
text_features = text_outputs.mean(dim=1) # 修复: 平均seq_len维, 保持 [batch, 768]
x = self.image_pool(F.relu(self.image_conv1(image_input))) # [batch, 64, 16, 16]
x = self.image_pool(F.relu(self.image_conv2(x))) # [batch, 128, 8, 8]
x = x.view(x.size(0), -1) # [batch, 8192]
image_features = F.relu(self.image_fc(x)) # [batch, 512]
fused = torch.cat([text_features, image_features], dim=1) # [batch, 1280] - 现在2维, 正常拼接
fused = self.dropout(F.relu(self.fusion_fc(fused))) # [batch, 256]
return self.classifier(fused) # [batch, 10]
class MultimodalTrainDataset(torch.utils.data.Dataset):
def __init__(self, root='E:/temp/data', num_samples=50000, transform=None):
self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
self.dataset = CIFAR10(root=root, train=True, download=True, transform=transform)
self.num_samples = min(num_samples, len(self.dataset))
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.max_length = 16
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
image, label = self.dataset[idx]
assert isinstance(image, torch.Tensor), f"Expected Tensor, got {type(image)}"
if torch.rand(1).item() < 0.5:
text_label = label
else:
text_label = torch.randint(0, 10, (1,)).item()
text = f"a photo of a {self.classes[text_label]}"
encoded = self.tokenizer(
text,
padding='max_length',
max_length=self.max_length,
truncation=True,
return_tensors='pt'
)
return image, encoded['input_ids'].squeeze(0), encoded['attention_mask'].squeeze(0), label
# 测试用 Test
class MultimodalTestDataset(torch.utils.data.Dataset):
def __init__(self, root='E:/temp/data', transform=None):
self.classes = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
self.dataset = CIFAR10(root=root, train=False, download=True, transform=transform)
self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
self.max_length = 16
# 预生成所有文本
self.text_inputs = []
for i in range(len(self.dataset)):
label = self.dataset.targets[i]
if torch.rand(1).item() < 0.5:
text_label = label
else:
text_label = torch.randint(0, 10, (1,)).item()
text = f"a photo of a {self.classes[text_label]}"
enc = self.tokenizer(
text,
padding='max_length',
max_length=self.max_length,
truncation=True,
return_tensors='pt'
)
self.text_inputs.append({
'input_ids': enc['input_ids'].squeeze(0),
'attention_mask': enc['attention_mask'].squeeze(0)
})
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
image, label = self.dataset[idx]
text = f"a photo of a {self.classes[label]}"
assert isinstance(image, torch.Tensor), f"Expected Tensor, got {type(image)}"
encoded = self.tokenizer(
text,
padding='max_length',
max_length=self.max_length,
truncation=True,
return_tensors='pt'
)
return image, encoded['input_ids'].squeeze(0), encoded['attention_mask'].squeeze(0), label
def train_epoch(model, dataloader, optimizer, criterion, epoch):
model.train()
total_loss = 0
correct = 0
total = 0
for batch_idx, (images, text_ids, attention_mask, labels) in enumerate(dataloader):
optimizer.zero_grad()
outputs = model(text_ids, attention_mask, images) # [batch, 10]
# targets = torch.randint(0, 10, (images.size(0),)) # 模拟标签
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
pred = outputs.argmax(dim=1)
correct += (pred == labels).sum().item()
total += labels.size(0)
# SwanLab日志(每50步,删掉图像/文本log,只留指标)
if batch_idx % 50 == 0:
swanlab.log({
"epoch": epoch,
"batch": batch_idx,
"loss": loss.item(),
"accuracy": 100. * correct / total,
"learning_rate": optimizer.param_groups[0]['lr']
})
avg_loss = total_loss / len(dataloader)
avg_acc = 100. * correct / total
print(f'Epoch {epoch+2}: Loss: {avg_loss:.4f}, Acc: {avg_acc:.2f}%')
# Epoch末日志
swanlab.log({
"epoch_end_loss": avg_loss,
"epoch_end_acc": avg_acc
})
if __name__ == '__main__':
# 初始化SwanLab
swanlab.init(project="multimodal-object-detection", anonymous=True) # 改项目名,填token如果需要
# 2. 示例数据集(CIFAR-10图像 + 模拟文本标签)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 设置训练集
dataset = MultimodalTrainDataset(num_samples=50000, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
# 训练用 Train
# 3. 训练循环(加SwanLab日志)
model = MultimodalFramework()
criterion = nn.CrossEntropyLoss()
# 先结冻 BERT 训练6轮
for param in model.text_encoder.parameters():
param.requires_grad = False
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
for epoch in range(2):
train_epoch(model, dataloader, optimizer, criterion, epoch)
# 再解冻 BERT 训练
for param in model.text_encoder.parameters():
param.requires_grad = True
# 4. 重建 optimizer 跑训练(30 epochs测试)
optimizer = optim.Adam(model.parameters(), lr = 1e-5)
for epoch in range(2, 5):
train_epoch(model, dataloader, optimizer, criterion, epoch)
# 设置测试集 CIFAR-10 Terst
# 测试集
test_dataset = MultimodalTestDataset(root='E:/temp/data', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
model.eval()
test_correct = 0
test_total = 0
with torch.no_grad():
for images, text_ids, attention_mask, labels in test_loader:
outputs = model(text_ids, attention_mask, images)
pred = outputs.argmax(dim = 1)
test_correct += (pred == labels).sum().item()
test_total += labels.size(0)
test_acc = 100. * test_correct / test_total # 乘法在前
print(f"Test Accuracy: {test_acc:.2f}%")
swanlab.log({"test_accuracy": test_acc})
# 保存模型本地(网页已记录指标)
torch.save(model.state_dict(), 'multimodal_model_epuch50_1.pth')
print('模型保存完成!')
swanlab.finish() # 结束实验,网页曲线可用