File size: 3,870 Bytes
29a4de2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.models import resnet18


class SSLModel(nn.Module):
    def __init__(self, backbone, projection_dim=128):
        super(SSLModel, self).__init__()
        self.backbone = backbone
        self.projection_head = nn.Sequential(
            nn.Linear(backbone.fc.in_features, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )
        self.backbone.fc = nn.Identity()  

    def forward(self, x):
        features = self.backbone(x)
        projections = self.projection_head(features)
        return projections


def contrastive_loss(z_i, z_j, temperature=0.5):
    batch_size = z_i.shape[0]

    # Concatenate both views
    z = torch.cat([z_i, z_j], dim=0)  # (2 * batch_size, projection_dim)

    # Similarity matrix computation (dot product normalized by temperature)
    sim_matrix = torch.mm(z, z.T) / temperature  # (2 * batch_size, 2 * batch_size)

    sim_matrix -= torch.max(sim_matrix, dim=1, keepdim=True)[0]

    # Mask out self-similarity
    mask = torch.eye(sim_matrix.size(0), device=sim_matrix.device).bool()
    sim_matrix = sim_matrix.masked_fill(mask, -float("inf"))

    # Extract positive similarities (z_i, z_j) and (z_j, z_i)
    pos_sim = torch.cat([
        torch.diag(sim_matrix, sim_matrix.size(0) // 2),
        torch.diag(sim_matrix, -sim_matrix.size(0) // 2)
    ])

    loss = -torch.log(torch.exp(pos_sim) / torch.sum(torch.exp(sim_matrix), dim=1))
    return loss.mean()


if __name__ == "__main__":
    transform = T.Compose([
        T.RandomResizedCrop(32),
        T.RandomHorizontalFlip(),
        T.ColorJitter(0.4, 0.4, 0.4, 0.1),
        T.RandomGrayscale(p=0.2),
        T.GaussianBlur(kernel_size=3),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
    ])

    # Load CIFAR-10 dataset
    train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform, download=True)
    train_loader = DataLoader(
        train_dataset,
        batch_size=256,
        shuffle=True,
        pin_memory=True,
        num_workers=4  
    )

    model = SSLModel(resnet18(pretrained=False)).to(device := torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

    # Training loop
    model.train()
    for epoch in range(10):
        epoch_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/10", unit="batch")

        for batch in progress_bar:
            imgs, _ = batch
            imgs = imgs.to(device, non_blocking=True)

            # Create two augmented views
            z_i = model(imgs)
            z_j = model(imgs)

            # Compute contrastive loss
            try:
                loss = contrastive_loss(z_i, z_j)
            except Exception as e:
                print(f"Loss computation failed: {e}")
                continue

            optimizer.zero_grad()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}")

        scheduler.step()
        print(f"Epoch {epoch + 1}, Average Loss: {epoch_loss / len(train_loader):.4f}")

        # Save checkpoint
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
        }, f"ssl_checkpoint_epoch_{epoch + 1}.pth")
        print(f"Model saved to ssl_checkpoint_epoch_{epoch + 1}.pth")