File size: 2,485 Bytes
4b26dc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
'''

File: train.py

Project: PatternsRecognizer

Author: Milko Videv (milko.videv@thalesgroup.com)

-----

Last Modified: Tuesday, 5th March 2024 10:02:08 am

Modified By: Milko Videv (milko.videv@thalesgroup.com>)

-----

Copyright 2017 - 2024, Thales DIS, MCS SSH

-----

HISTORY:

Date      	By	Comments

----------	---	---------------------------------------------------------

'''

from fastai.vision.all import *
from fastcore.all import *
from tools import *

def train(src, count):
    path = Path(src)
    dls = DataBlock(
        blocks=(ImageBlock, CategoryBlock), # block for image and category
        get_items=get_image_files, # recursively get files in path
        splitter=RandomSplitter(valid_pct=0.3, seed=42), # 30% for check
        get_y=parent_label, # use the dir as label
        item_tfms=[Resize(192, method='squish')] # resize them a uniformly
    ).dataloaders(path, bs=32, verbose=True)    # batch size 32 

    # using resnet Neural Network library 
    # Note: on 1st run will download it from https://download.pytorch.org/models)
    # so set up proxy:     
    # set HTTP_PROXY=http://proxy-us-austin.gemalto.com:8080
    # set HTTPS_PROXY=http://proxy-us-austin.gemalto.com:8080

    sav_path = Path('./models')  
    sav_path.mkdir(exist_ok=True, parents=True)

    learn = vision_learner(dls, resnet18, metrics=error_rate)

    if os.path.exists("./models/current.pth"):
        print(f"Loading current.pth to continue learning ...")
        learn.load("current")

    stopwatch = Stopwatch()
    stopwatch.start()
    print(f"Training ...")
    learn.fine_tune(count)      
    stopwatch.stop()
    print(f"Training took ", stopwatch.elapsed_time(), "seconds")  

    learn.save('current') # todo: load it if existing during next run to continue training 

    learn.path = sav_path
    learn.export('trained.pkl')

if __name__ == "__main__":
    if debugger_is_active():
        train(3)
        sys.exit()
    else:
        if len(sys.argv) < 3 or not sys.argv[2].isdigit():
            print("Use: python train.py <images path> <epochs count>")   
            print("Example: python train.py patterns 3")
            sys.exit()
        elif int(sys.argv[2]) > 5:
            if not get_yes_no("Epoch count is too big, might take time. Continue?"):
                sys.exit()
                
        src = sys.argv[1]    
        count = sys.argv[2]    
        train(src, int(count))