npv2k1's picture
feat: ShapeClassifier
c9e0c1d verified
raw
history blame
2 kB
import torch
import torch.optim as optim
import torch.nn.functional as F
from .models.model import ShapeClassifier
from src.configs.model_config import ModelConfig
from src.data.data_loader import train_loader, num_classes
def train():
config = ModelConfig().get_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ShapeClassifier(num_classes=num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
log_interval = 20
for epoch in range(config.epochs):
model.train()
running_loss = 0.0
for batch_idx, (inputs, labels) in enumerate(train_loader):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = F.cross_entropy(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % log_interval == 0:
current_loss = running_loss / log_interval
print(
f"Epoch [{epoch + 1}/{config.epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {current_loss:.4f}")
running_loss = 0.0
# calculate the accuracy on the test set
with torch.no_grad():
model.eval()
correct = 0
total = 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
predicted = torch.argmax(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Accuracy of the model on the test images: {100 * correct / total} %")
# save the model
torch.save(model.state_dict(), "model.pth")