Spaces:
Sleeping
Sleeping
add train code
Browse files
train.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Make sure to put your kaggle token in the environment variable KAGGLE_API_TOKEN
|
| 2 |
+
|
| 3 |
+
import kagglehub
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from fastai.vision.all import *
|
| 6 |
+
|
| 7 |
+
# Download latest version
|
| 8 |
+
path = kagglehub.dataset_download("utkarshsaxenadn/fruits-classification")
|
| 9 |
+
path = Path(path) / 'Fruits Classification'
|
| 10 |
+
|
| 11 |
+
print("Path to dataset files:", path)
|
| 12 |
+
|
| 13 |
+
block = DataBlock(
|
| 14 |
+
blocks = (ImageBlock(), CategoryBlock()),
|
| 15 |
+
get_items=get_image_files,
|
| 16 |
+
splitter = GrandparentSplitter(train_name='train', valid_name='valid'),
|
| 17 |
+
get_y = parent_label,
|
| 18 |
+
item_tfms=RandomResizedCrop(224, min_scale=0.3),
|
| 19 |
+
batch_tfms = aug_transforms()
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
dls = block.dataloaders(path, bs = 64, seed=2026)
|
| 23 |
+
|
| 24 |
+
learner = vision_learner(dls, resnet34, metrics=error_rate)
|
| 25 |
+
learner.fine_tune(5)
|
| 26 |
+
|
| 27 |
+
# Show classification accuracy
|
| 28 |
+
|
| 29 |
+
interpretation = ClassificationInterpretation.from_learner(learner)
|
| 30 |
+
interpretation.plot_confusion_matrix()
|
| 31 |
+
interpretation.plot_top_losses(5, nrows=1, figsize=(20,5))
|
| 32 |
+
|
| 33 |
+
# Test the model
|
| 34 |
+
# https://forums.fast.ai/t/how-to-evaluate-model-on-test-set/97972/3
|
| 35 |
+
|
| 36 |
+
test_files = get_image_files(path / "test")
|
| 37 |
+
print(len(test_files))
|
| 38 |
+
test_dl = dls.test_dl(test_files, with_labels=True)
|
| 39 |
+
test_dl.show_batch(max_n=5)
|
| 40 |
+
preds, y = learner.get_preds(dl=test_dl)
|
| 41 |
+
acc = accuracy(preds, y)
|
| 42 |
+
print("Accuracy: ", acc)
|
| 43 |
+
acc2 = learner.validate(dl=test_dl)
|
| 44 |
+
print(acc2)
|
| 45 |
+
|
| 46 |
+
# Export the model
|
| 47 |
+
learner.export("fruits-model-v1.pkl")
|