Spaces:
Build error
Build error
| import torch | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| from .models.model import ShapeClassifier | |
| from src.configs.model_config import ModelConfig | |
| from src.data.data_loader import train_loader, num_classes | |
| def train(): | |
| config = ModelConfig().get_config() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = ShapeClassifier(num_classes=num_classes).to(device) | |
| optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) | |
| log_interval = 20 | |
| for epoch in range(config.epochs): | |
| model.train() | |
| running_loss = 0.0 | |
| for batch_idx, (inputs, labels) in enumerate(train_loader): | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| outputs = model(inputs) | |
| loss = F.cross_entropy(outputs, labels) | |
| loss.backward() | |
| optimizer.step() | |
| running_loss += loss.item() | |
| if batch_idx % log_interval == 0: | |
| current_loss = running_loss / log_interval | |
| print( | |
| f"Epoch [{epoch + 1}/{config.epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {current_loss:.4f}") | |
| running_loss = 0.0 | |
| # calculate the accuracy on the test set | |
| with torch.no_grad(): | |
| model.eval() | |
| correct = 0 | |
| total = 0 | |
| for inputs, labels in train_loader: | |
| inputs, labels = inputs.to(device), labels.to(device) | |
| outputs = model(inputs) | |
| predicted = torch.argmax(outputs.data, 1) | |
| total += labels.size(0) | |
| correct += (predicted == labels).sum().item() | |
| print(f"Accuracy of the model on the test images: {100 * correct / total} %") | |
| # save the model | |
| torch.save(model.state_dict(), "model.pth") |