Spaces:
Sleeping
Sleeping
| ''' | |
| File: train.py | |
| Project: PatternsRecognizer | |
| Author: Milko Videv (milko.videv@thalesgroup.com) | |
| ----- | |
| Last Modified: Tuesday, 5th March 2024 10:02:08 am | |
| Modified By: Milko Videv (milko.videv@thalesgroup.com>) | |
| ----- | |
| Copyright 2017 - 2024, Thales DIS, MCS SSH | |
| ----- | |
| HISTORY: | |
| Date By Comments | |
| ---------- --- --------------------------------------------------------- | |
| ''' | |
| from fastai.vision.all import * | |
| from fastcore.all import * | |
| from tools import * | |
| def train(src, count): | |
| path = Path(src) | |
| dls = DataBlock( | |
| blocks=(ImageBlock, CategoryBlock), # block for image and category | |
| get_items=get_image_files, # recursively get files in path | |
| splitter=RandomSplitter(valid_pct=0.3, seed=42), # 30% for check | |
| get_y=parent_label, # use the dir as label | |
| item_tfms=[Resize(192, method='squish')] # resize them a uniformly | |
| ).dataloaders(path, bs=32, verbose=True) # batch size 32 | |
| # using resnet Neural Network library | |
| # Note: on 1st run will download it from https://download.pytorch.org/models) | |
| # so set up proxy: | |
| # set HTTP_PROXY=http://proxy-us-austin.gemalto.com:8080 | |
| # set HTTPS_PROXY=http://proxy-us-austin.gemalto.com:8080 | |
| sav_path = Path('./models') | |
| sav_path.mkdir(exist_ok=True, parents=True) | |
| learn = vision_learner(dls, resnet18, metrics=error_rate) | |
| if os.path.exists("./models/current.pth"): | |
| print(f"Loading current.pth to continue learning ...") | |
| learn.load("current") | |
| stopwatch = Stopwatch() | |
| stopwatch.start() | |
| print(f"Training ...") | |
| learn.fine_tune(count) | |
| stopwatch.stop() | |
| print(f"Training took ", stopwatch.elapsed_time(), "seconds") | |
| learn.save('current') # todo: load it if existing during next run to continue training | |
| learn.path = sav_path | |
| learn.export('trained.pkl') | |
| if __name__ == "__main__": | |
| if debugger_is_active(): | |
| train(3) | |
| sys.exit() | |
| else: | |
| if len(sys.argv) < 3 or not sys.argv[2].isdigit(): | |
| print("Use: python train.py <images path> <epochs count>") | |
| print("Example: python train.py patterns 3") | |
| sys.exit() | |
| elif int(sys.argv[2]) > 5: | |
| if not get_yes_no("Epoch count is too big, might take time. Continue?"): | |
| sys.exit() | |
| src = sys.argv[1] | |
| count = sys.argv[2] | |
| train(src, int(count)) |