fruits-classifier / train.py
ananthvk's picture
add train code
c91bc75
# Make sure to put your kaggle token in the environment variable KAGGLE_API_TOKEN
import kagglehub
from pathlib import Path
from fastai.vision.all import *
# Download latest version
path = kagglehub.dataset_download("utkarshsaxenadn/fruits-classification")
path = Path(path) / 'Fruits Classification'
print("Path to dataset files:", path)
block = DataBlock(
blocks = (ImageBlock(), CategoryBlock()),
get_items=get_image_files,
splitter = GrandparentSplitter(train_name='train', valid_name='valid'),
get_y = parent_label,
item_tfms=RandomResizedCrop(224, min_scale=0.3),
batch_tfms = aug_transforms()
)
dls = block.dataloaders(path, bs = 64, seed=2026)
learner = vision_learner(dls, resnet34, metrics=error_rate)
learner.fine_tune(5)
# Show classification accuracy
interpretation = ClassificationInterpretation.from_learner(learner)
interpretation.plot_confusion_matrix()
interpretation.plot_top_losses(5, nrows=1, figsize=(20,5))
# Test the model
# https://forums.fast.ai/t/how-to-evaluate-model-on-test-set/97972/3
test_files = get_image_files(path / "test")
print(len(test_files))
test_dl = dls.test_dl(test_files, with_labels=True)
test_dl.show_batch(max_n=5)
preds, y = learner.get_preds(dl=test_dl)
acc = accuracy(preds, y)
print("Accuracy: ", acc)
acc2 = learner.validate(dl=test_dl)
print(acc2)
# Export the model
learner.export("fruits-model-v1.pkl")