''' File: train.py Project: PatternsRecognizer Author: Milko Videv (milko.videv@thalesgroup.com) ----- Last Modified: Tuesday, 5th March 2024 10:02:08 am Modified By: Milko Videv (milko.videv@thalesgroup.com>) ----- Copyright 2017 - 2024, Thales DIS, MCS SSH ----- HISTORY: Date By Comments ---------- --- --------------------------------------------------------- ''' from fastai.vision.all import * from fastcore.all import * from tools import * def train(src, count): path = Path(src) dls = DataBlock( blocks=(ImageBlock, CategoryBlock), # block for image and category get_items=get_image_files, # recursively get files in path splitter=RandomSplitter(valid_pct=0.3, seed=42), # 30% for check get_y=parent_label, # use the dir as label item_tfms=[Resize(192, method='squish')] # resize them a uniformly ).dataloaders(path, bs=32, verbose=True) # batch size 32 # using resnet Neural Network library # Note: on 1st run will download it from https://download.pytorch.org/models) # so set up proxy: # set HTTP_PROXY=http://proxy-us-austin.gemalto.com:8080 # set HTTPS_PROXY=http://proxy-us-austin.gemalto.com:8080 sav_path = Path('./models') sav_path.mkdir(exist_ok=True, parents=True) learn = vision_learner(dls, resnet18, metrics=error_rate) if os.path.exists("./models/current.pth"): print(f"Loading current.pth to continue learning ...") learn.load("current") stopwatch = Stopwatch() stopwatch.start() print(f"Training ...") learn.fine_tune(count) stopwatch.stop() print(f"Training took ", stopwatch.elapsed_time(), "seconds") learn.save('current') # todo: load it if existing during next run to continue training learn.path = sav_path learn.export('trained.pkl') if __name__ == "__main__": if debugger_is_active(): train(3) sys.exit() else: if len(sys.argv) < 3 or not sys.argv[2].isdigit(): print("Use: python train.py ") print("Example: python train.py patterns 3") sys.exit() elif int(sys.argv[2]) > 5: if not get_yes_no("Epoch count is too big, might take time. Continue?"): sys.exit() src = sys.argv[1] count = sys.argv[2] train(src, int(count))