MilkoTv's picture
added training sources
4b26dc0
'''
File: train.py
Project: PatternsRecognizer
Author: Milko Videv (milko.videv@thalesgroup.com)
-----
Last Modified: Tuesday, 5th March 2024 10:02:08 am
Modified By: Milko Videv (milko.videv@thalesgroup.com>)
-----
Copyright 2017 - 2024, Thales DIS, MCS SSH
-----
HISTORY:
Date By Comments
---------- --- ---------------------------------------------------------
'''
from fastai.vision.all import *
from fastcore.all import *
from tools import *
def train(src, count):
path = Path(src)
dls = DataBlock(
blocks=(ImageBlock, CategoryBlock), # block for image and category
get_items=get_image_files, # recursively get files in path
splitter=RandomSplitter(valid_pct=0.3, seed=42), # 30% for check
get_y=parent_label, # use the dir as label
item_tfms=[Resize(192, method='squish')] # resize them a uniformly
).dataloaders(path, bs=32, verbose=True) # batch size 32
# using resnet Neural Network library
# Note: on 1st run will download it from https://download.pytorch.org/models)
# so set up proxy:
# set HTTP_PROXY=http://proxy-us-austin.gemalto.com:8080
# set HTTPS_PROXY=http://proxy-us-austin.gemalto.com:8080
sav_path = Path('./models')
sav_path.mkdir(exist_ok=True, parents=True)
learn = vision_learner(dls, resnet18, metrics=error_rate)
if os.path.exists("./models/current.pth"):
print(f"Loading current.pth to continue learning ...")
learn.load("current")
stopwatch = Stopwatch()
stopwatch.start()
print(f"Training ...")
learn.fine_tune(count)
stopwatch.stop()
print(f"Training took ", stopwatch.elapsed_time(), "seconds")
learn.save('current') # todo: load it if existing during next run to continue training
learn.path = sav_path
learn.export('trained.pkl')
if __name__ == "__main__":
if debugger_is_active():
train(3)
sys.exit()
else:
if len(sys.argv) < 3 or not sys.argv[2].isdigit():
print("Use: python train.py <images path> <epochs count>")
print("Example: python train.py patterns 3")
sys.exit()
elif int(sys.argv[2]) > 5:
if not get_yes_no("Epoch count is too big, might take time. Continue?"):
sys.exit()
src = sys.argv[1]
count = sys.argv[2]
train(src, int(count))