ananthvk commited on
Commit
c91bc75
·
1 Parent(s): 1f7ece1

add train code

Browse files
Files changed (1) hide show
  1. train.py +47 -0
train.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Make sure to put your kaggle token in the environment variable KAGGLE_API_TOKEN
2
+
3
+ import kagglehub
4
+ from pathlib import Path
5
+ from fastai.vision.all import *
6
+
7
+ # Download latest version
8
+ path = kagglehub.dataset_download("utkarshsaxenadn/fruits-classification")
9
+ path = Path(path) / 'Fruits Classification'
10
+
11
+ print("Path to dataset files:", path)
12
+
13
+ block = DataBlock(
14
+ blocks = (ImageBlock(), CategoryBlock()),
15
+ get_items=get_image_files,
16
+ splitter = GrandparentSplitter(train_name='train', valid_name='valid'),
17
+ get_y = parent_label,
18
+ item_tfms=RandomResizedCrop(224, min_scale=0.3),
19
+ batch_tfms = aug_transforms()
20
+ )
21
+
22
+ dls = block.dataloaders(path, bs = 64, seed=2026)
23
+
24
+ learner = vision_learner(dls, resnet34, metrics=error_rate)
25
+ learner.fine_tune(5)
26
+
27
+ # Show classification accuracy
28
+
29
+ interpretation = ClassificationInterpretation.from_learner(learner)
30
+ interpretation.plot_confusion_matrix()
31
+ interpretation.plot_top_losses(5, nrows=1, figsize=(20,5))
32
+
33
+ # Test the model
34
+ # https://forums.fast.ai/t/how-to-evaluate-model-on-test-set/97972/3
35
+
36
+ test_files = get_image_files(path / "test")
37
+ print(len(test_files))
38
+ test_dl = dls.test_dl(test_files, with_labels=True)
39
+ test_dl.show_batch(max_n=5)
40
+ preds, y = learner.get_preds(dl=test_dl)
41
+ acc = accuracy(preds, y)
42
+ print("Accuracy: ", acc)
43
+ acc2 = learner.validate(dl=test_dl)
44
+ print(acc2)
45
+
46
+ # Export the model
47
+ learner.export("fruits-model-v1.pkl")