File size: 1,094 Bytes
c858478
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
from torch.utils.data import DataLoader
from src.dataset import RelationshipDataset
from src.model import RelationshipNet

device = "mps" if torch.backends.mps.is_available() else "cpu"

dataset = RelationshipDataset(
    image_dir="data/relationship_dataset/images",
    label_path="data/relationship_dataset/labels_encoded.json"
)

loader = DataLoader(dataset, batch_size=16, shuffle=True)

num_classes = len(set([item["label"] for item in dataset.data]))

model = RelationshipNet(num_classes).to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

epochs = 10

for epoch in range(epochs):
    total_loss = 0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

torch.save(model.state_dict(), "models/relationship_model.pth")