Spaces:
Build error
Build error
File size: 1,378 Bytes
06c8a6d f6c0b9e 06c8a6d f6c0b9e 06c8a6d f6c0b9e 06c8a6d f6c0b9e 06c8a6d c9e0c1d 06c8a6d f6c0b9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import torch
import torch.optim as optim
import torch.nn.functional as F
from src.models.model import ShapeClassifier
from src.configs.model_config import ModelConfig
from src.data.data_loader import train_loader, num_classes, val_loader
from src.utils.train import train
from src.utils.test import test
from src.utils.wandb import wandb
from src.utils.logs import logging
from src.utils.model import save_model
def main():
config = ModelConfig().get_config()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ShapeClassifier(num_classes=num_classes).to(device)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
# log models config to wandb
wandb.config.update(config)
for epoch in range(config.epochs):
print(f"Epoch {epoch+1}\n-------------------------------")
loss = train(train_loader, model=model, loss_fn=F.cross_entropy,
optimizer=optimizer)
test(val_loader, model=model, loss_fn=F.cross_entropy)
# 3. Log metrics over time to visualize performance
wandb.log({"loss": loss})
# save model
save_model(model, "results/models/last.pth")
# 4. Log an artifact to W&B
# wandb.log_artifact("model.pth")
# model.train()
if __name__ == "__main__":
logging.info("Training model")
main()
|