Spaces:
Running
Running
| import torch | |
| from torch.utils.data import DataLoader | |
| from src.dataset import RelationshipDataset | |
| from src.model import RelationshipNet | |
| device = "mps" if torch.backends.mps.is_available() else "cpu" | |
| dataset = RelationshipDataset( | |
| image_dir="data/relationship_dataset/images", | |
| label_path="data/relationship_dataset/labels_encoded.json" | |
| ) | |
| loader = DataLoader(dataset, batch_size=16, shuffle=True) | |
| num_classes = len(set([item["label"] for item in dataset.data])) | |
| model = RelationshipNet(num_classes).to(device) | |
| criterion = torch.nn.CrossEntropyLoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=3e-4) | |
| epochs = 10 | |
| for epoch in range(epochs): | |
| total_loss = 0 | |
| for images, labels in loader: | |
| images = images.to(device) | |
| labels = labels.to(device) | |
| outputs = model(images) | |
| loss = criterion(outputs, labels) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}") | |
| torch.save(model.state_dict(), "models/relationship_model.pth") |