MilkoTv commited on
Commit
4b26dc0
·
1 Parent(s): 365ca58

added training sources

Browse files
Files changed (2) hide show
  1. get_data.py +56 -0
  2. train.py +72 -0
get_data.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ File: get_data.py
3
+ Project: PatternsRecognizer
4
+ Author: Milko Videv (milko.videv@thalesgroup.com)
5
+ -----
6
+ Last Modified: Friday, 1st March 2024 6:04:04 pm
7
+ Modified By: Milko Videv (milko.videv@thalesgroup.com>)
8
+ -----
9
+ Copyright 2017 - 2024, Thales DIS, MCS SSH
10
+ -----
11
+ HISTORY:
12
+ Date By Comments
13
+ ---------- --- ---------------------------------------------------------
14
+ '''
15
+
16
+ from random import random
17
+ from fastdownload import download_url
18
+ from fastai.vision.all import *
19
+ from time import sleep
20
+ from fastbook import search_images_ddg
21
+ from fastcore.all import *
22
+
23
+ def search_images_fastbook(term, max_images=30):
24
+ print(f"Searching for {max_images} {term}")
25
+ return search_images_ddg(term, max_images=max_images)
26
+
27
+ def get_images(target_path, count, searches):
28
+ path = Path(target_path)
29
+
30
+ for o in searches:
31
+ dest = (path/o)
32
+ dest.mkdir(exist_ok=True, parents=True)
33
+ download_images(dest, urls=search_images_fastbook(f'{o} patterns images', max_images=count))
34
+ sleep(10)
35
+ resize_images(path/o, max_size=400, dest=path/o)
36
+
37
+ print(f"Checking for bad images ...")
38
+ failed = verify_images(get_image_files(path))
39
+ failed.map(Path.unlink)
40
+ print(f"Removed {len(failed)} bad images")
41
+
42
+ if __name__ == "__main__":
43
+ if len(sys.argv) < 3 or not sys.argv[2].isdigit():
44
+ print("Use: python get_data.py <target path> <images count> <one or more pattern kinds>\n")
45
+ print("Example: python get_data.py patterns 1 bulgarian indian japanese")
46
+ else:
47
+ target_path = sys.argv[1]
48
+ count = sys.argv[2]
49
+ searches = sys.argv[3:] if len(sys.argv) > 3 else "bulgarian"
50
+
51
+ if len(searches) == 0:
52
+ print(f"Nothing to search for. Exitting.")
53
+ sys.exit()
54
+
55
+ #print(target_path, int(count), searches)
56
+ get_images(target_path, int(count), searches)
train.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ File: train.py
3
+ Project: PatternsRecognizer
4
+ Author: Milko Videv (milko.videv@thalesgroup.com)
5
+ -----
6
+ Last Modified: Tuesday, 5th March 2024 10:02:08 am
7
+ Modified By: Milko Videv (milko.videv@thalesgroup.com>)
8
+ -----
9
+ Copyright 2017 - 2024, Thales DIS, MCS SSH
10
+ -----
11
+ HISTORY:
12
+ Date By Comments
13
+ ---------- --- ---------------------------------------------------------
14
+ '''
15
+
16
+ from fastai.vision.all import *
17
+ from fastcore.all import *
18
+ from tools import *
19
+
20
+ def train(src, count):
21
+ path = Path(src)
22
+ dls = DataBlock(
23
+ blocks=(ImageBlock, CategoryBlock), # block for image and category
24
+ get_items=get_image_files, # recursively get files in path
25
+ splitter=RandomSplitter(valid_pct=0.3, seed=42), # 30% for check
26
+ get_y=parent_label, # use the dir as label
27
+ item_tfms=[Resize(192, method='squish')] # resize them a uniformly
28
+ ).dataloaders(path, bs=32, verbose=True) # batch size 32
29
+
30
+ # using resnet Neural Network library
31
+ # Note: on 1st run will download it from https://download.pytorch.org/models)
32
+ # so set up proxy:
33
+ # set HTTP_PROXY=http://proxy-us-austin.gemalto.com:8080
34
+ # set HTTPS_PROXY=http://proxy-us-austin.gemalto.com:8080
35
+
36
+ sav_path = Path('./models')
37
+ sav_path.mkdir(exist_ok=True, parents=True)
38
+
39
+ learn = vision_learner(dls, resnet18, metrics=error_rate)
40
+
41
+ if os.path.exists("./models/current.pth"):
42
+ print(f"Loading current.pth to continue learning ...")
43
+ learn.load("current")
44
+
45
+ stopwatch = Stopwatch()
46
+ stopwatch.start()
47
+ print(f"Training ...")
48
+ learn.fine_tune(count)
49
+ stopwatch.stop()
50
+ print(f"Training took ", stopwatch.elapsed_time(), "seconds")
51
+
52
+ learn.save('current') # todo: load it if existing during next run to continue training
53
+
54
+ learn.path = sav_path
55
+ learn.export('trained.pkl')
56
+
57
+ if __name__ == "__main__":
58
+ if debugger_is_active():
59
+ train(3)
60
+ sys.exit()
61
+ else:
62
+ if len(sys.argv) < 3 or not sys.argv[2].isdigit():
63
+ print("Use: python train.py <images path> <epochs count>")
64
+ print("Example: python train.py patterns 3")
65
+ sys.exit()
66
+ elif int(sys.argv[2]) > 5:
67
+ if not get_yes_no("Epoch count is too big, might take time. Continue?"):
68
+ sys.exit()
69
+
70
+ src = sys.argv[1]
71
+ count = sys.argv[2]
72
+ train(src, int(count))