Spaces:
Running
Running
File size: 1,094 Bytes
c858478 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 | import torch
from torch.utils.data import DataLoader
from src.dataset import RelationshipDataset
from src.model import RelationshipNet
device = "mps" if torch.backends.mps.is_available() else "cpu"
dataset = RelationshipDataset(
image_dir="data/relationship_dataset/images",
label_path="data/relationship_dataset/labels_encoded.json"
)
loader = DataLoader(dataset, batch_size=16, shuffle=True)
num_classes = len(set([item["label"] for item in dataset.data]))
model = RelationshipNet(num_classes).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
epochs = 10
for epoch in range(epochs):
total_loss = 0
for images, labels in loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")
torch.save(model.state_dict(), "models/relationship_model.pth") |