Spaces:
Build error
Build error
| from collections import Counter | |
| from tqdm import tqdm | |
| from time import time | |
| import torch | |
| from torch import nn, Tensor | |
| from S3_intermediateDataset import build_intermediate_dataset_if_not_exists, intermediate_dataset | |
| from S2_TimberDataset import build_dataloader | |
| from S1_CNN_Model import build_model | |
| if __name__ == '__main__': | |
| img_size = (320,320) | |
| train_loader, val_loader = build_dataloader( | |
| # train_ratio= 0.005, | |
| img_size=img_size, | |
| batch_size=16, | |
| ) | |
| build_intermediate_dataset_if_not_exists(lambda x:x, "train", train_loader) | |
| build_intermediate_dataset_if_not_exists(lambda x:x, "val", val_loader) | |
| train_loader = intermediate_dataset("train") | |
| val_loader = intermediate_dataset("val") | |
| model = build_model(img_size=img_size) | |
| model.train() | |
| criterion = nn.CrossEntropyLoss() | |
| optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) | |
| accuracies = [] | |
| n_epoch = 40 | |
| timer = time() | |
| pbar_0 = tqdm(range(n_epoch), position=0, ncols=100) | |
| pbar_0.set_description(f"epoch 1/{n_epoch}") | |
| img = None | |
| for epoch in pbar_0: | |
| pbar_1 = tqdm(enumerate(train_loader), total=len(train_loader), position=1, ncols=100, leave=False) | |
| for i, (images, labels) in pbar_1: | |
| # # Reshape | |
| images = images.reshape(images.shape[1:]) | |
| labels = labels.reshape(labels.shape[1:]) | |
| # Forward | |
| out = model.forward(images) | |
| loss = criterion.forward(out, labels) | |
| # Backward | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| if i % 10 == 0: | |
| pbar_1.set_description(f" loss = {loss.item():.4f} ({(time() - timer)*1000:.4f} ms)") | |
| timer = time() | |
| n_correct = 0 | |
| n_samples = 0 | |
| pbar_0.set_description(f"epoch {epoch+1}/{n_epoch}, Validating . . .") | |
| with torch.no_grad(): | |
| with tqdm(bar_format='{desc}{postfix}', position=1, leave=False) as val_desc: | |
| tally = Counter() | |
| for images, labels in tqdm(val_loader, position=2, ncols=100, leave=False): | |
| # # Reshape | |
| images = images.reshape(images.shape[1:]) | |
| labels = labels.reshape(labels.shape[1:]) | |
| x = model.forward(images) | |
| _, predictions = torch.max(x,1) | |
| tally += Counter(predictions.tolist()) | |
| n_samples += labels.shape[0] | |
| n_correct += (predictions == labels).sum().item() | |
| tally_desc = ' '.join([f"{n}:{c}" for n,c in tally.most_common()])[:80] + "..." | |
| val_desc.set_description(f"{n_correct}/{n_samples} correct") | |
| val_desc.set_postfix_str(tally_desc) | |
| accuracy = f"{n_correct/n_samples * 100:.2f}%" | |
| pbar_0.set_description(f"epoch {epoch+2}/{n_epoch}, accuracy: {accuracy}") | |
| if len(accuracies) >= 3 and accuracy > max(accuracies): | |
| torch.save(model,f"model_{epoch}.pt") | |
| accuracies.append(accuracy) | |
| torch.save(model,"model.pt") | |
| print(accuracies) |