Spaces:
Sleeping
Sleeping
added training sources
Browse files- get_data.py +56 -0
- train.py +72 -0
get_data.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
File: get_data.py
|
| 3 |
+
Project: PatternsRecognizer
|
| 4 |
+
Author: Milko Videv (milko.videv@thalesgroup.com)
|
| 5 |
+
-----
|
| 6 |
+
Last Modified: Friday, 1st March 2024 6:04:04 pm
|
| 7 |
+
Modified By: Milko Videv (milko.videv@thalesgroup.com>)
|
| 8 |
+
-----
|
| 9 |
+
Copyright 2017 - 2024, Thales DIS, MCS SSH
|
| 10 |
+
-----
|
| 11 |
+
HISTORY:
|
| 12 |
+
Date By Comments
|
| 13 |
+
---------- --- ---------------------------------------------------------
|
| 14 |
+
'''
|
| 15 |
+
|
| 16 |
+
from random import random
|
| 17 |
+
from fastdownload import download_url
|
| 18 |
+
from fastai.vision.all import *
|
| 19 |
+
from time import sleep
|
| 20 |
+
from fastbook import search_images_ddg
|
| 21 |
+
from fastcore.all import *
|
| 22 |
+
|
| 23 |
+
def search_images_fastbook(term, max_images=30):
|
| 24 |
+
print(f"Searching for {max_images} {term}")
|
| 25 |
+
return search_images_ddg(term, max_images=max_images)
|
| 26 |
+
|
| 27 |
+
def get_images(target_path, count, searches):
|
| 28 |
+
path = Path(target_path)
|
| 29 |
+
|
| 30 |
+
for o in searches:
|
| 31 |
+
dest = (path/o)
|
| 32 |
+
dest.mkdir(exist_ok=True, parents=True)
|
| 33 |
+
download_images(dest, urls=search_images_fastbook(f'{o} patterns images', max_images=count))
|
| 34 |
+
sleep(10)
|
| 35 |
+
resize_images(path/o, max_size=400, dest=path/o)
|
| 36 |
+
|
| 37 |
+
print(f"Checking for bad images ...")
|
| 38 |
+
failed = verify_images(get_image_files(path))
|
| 39 |
+
failed.map(Path.unlink)
|
| 40 |
+
print(f"Removed {len(failed)} bad images")
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
if len(sys.argv) < 3 or not sys.argv[2].isdigit():
|
| 44 |
+
print("Use: python get_data.py <target path> <images count> <one or more pattern kinds>\n")
|
| 45 |
+
print("Example: python get_data.py patterns 1 bulgarian indian japanese")
|
| 46 |
+
else:
|
| 47 |
+
target_path = sys.argv[1]
|
| 48 |
+
count = sys.argv[2]
|
| 49 |
+
searches = sys.argv[3:] if len(sys.argv) > 3 else "bulgarian"
|
| 50 |
+
|
| 51 |
+
if len(searches) == 0:
|
| 52 |
+
print(f"Nothing to search for. Exitting.")
|
| 53 |
+
sys.exit()
|
| 54 |
+
|
| 55 |
+
#print(target_path, int(count), searches)
|
| 56 |
+
get_images(target_path, int(count), searches)
|
train.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
File: train.py
|
| 3 |
+
Project: PatternsRecognizer
|
| 4 |
+
Author: Milko Videv (milko.videv@thalesgroup.com)
|
| 5 |
+
-----
|
| 6 |
+
Last Modified: Tuesday, 5th March 2024 10:02:08 am
|
| 7 |
+
Modified By: Milko Videv (milko.videv@thalesgroup.com>)
|
| 8 |
+
-----
|
| 9 |
+
Copyright 2017 - 2024, Thales DIS, MCS SSH
|
| 10 |
+
-----
|
| 11 |
+
HISTORY:
|
| 12 |
+
Date By Comments
|
| 13 |
+
---------- --- ---------------------------------------------------------
|
| 14 |
+
'''
|
| 15 |
+
|
| 16 |
+
from fastai.vision.all import *
|
| 17 |
+
from fastcore.all import *
|
| 18 |
+
from tools import *
|
| 19 |
+
|
| 20 |
+
def train(src, count):
|
| 21 |
+
path = Path(src)
|
| 22 |
+
dls = DataBlock(
|
| 23 |
+
blocks=(ImageBlock, CategoryBlock), # block for image and category
|
| 24 |
+
get_items=get_image_files, # recursively get files in path
|
| 25 |
+
splitter=RandomSplitter(valid_pct=0.3, seed=42), # 30% for check
|
| 26 |
+
get_y=parent_label, # use the dir as label
|
| 27 |
+
item_tfms=[Resize(192, method='squish')] # resize them a uniformly
|
| 28 |
+
).dataloaders(path, bs=32, verbose=True) # batch size 32
|
| 29 |
+
|
| 30 |
+
# using resnet Neural Network library
|
| 31 |
+
# Note: on 1st run will download it from https://download.pytorch.org/models)
|
| 32 |
+
# so set up proxy:
|
| 33 |
+
# set HTTP_PROXY=http://proxy-us-austin.gemalto.com:8080
|
| 34 |
+
# set HTTPS_PROXY=http://proxy-us-austin.gemalto.com:8080
|
| 35 |
+
|
| 36 |
+
sav_path = Path('./models')
|
| 37 |
+
sav_path.mkdir(exist_ok=True, parents=True)
|
| 38 |
+
|
| 39 |
+
learn = vision_learner(dls, resnet18, metrics=error_rate)
|
| 40 |
+
|
| 41 |
+
if os.path.exists("./models/current.pth"):
|
| 42 |
+
print(f"Loading current.pth to continue learning ...")
|
| 43 |
+
learn.load("current")
|
| 44 |
+
|
| 45 |
+
stopwatch = Stopwatch()
|
| 46 |
+
stopwatch.start()
|
| 47 |
+
print(f"Training ...")
|
| 48 |
+
learn.fine_tune(count)
|
| 49 |
+
stopwatch.stop()
|
| 50 |
+
print(f"Training took ", stopwatch.elapsed_time(), "seconds")
|
| 51 |
+
|
| 52 |
+
learn.save('current') # todo: load it if existing during next run to continue training
|
| 53 |
+
|
| 54 |
+
learn.path = sav_path
|
| 55 |
+
learn.export('trained.pkl')
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
if debugger_is_active():
|
| 59 |
+
train(3)
|
| 60 |
+
sys.exit()
|
| 61 |
+
else:
|
| 62 |
+
if len(sys.argv) < 3 or not sys.argv[2].isdigit():
|
| 63 |
+
print("Use: python train.py <images path> <epochs count>")
|
| 64 |
+
print("Example: python train.py patterns 3")
|
| 65 |
+
sys.exit()
|
| 66 |
+
elif int(sys.argv[2]) > 5:
|
| 67 |
+
if not get_yes_no("Epoch count is too big, might take time. Continue?"):
|
| 68 |
+
sys.exit()
|
| 69 |
+
|
| 70 |
+
src = sys.argv[1]
|
| 71 |
+
count = sys.argv[2]
|
| 72 |
+
train(src, int(count))
|