MilkoTv commited on
Commit
27b04aa
·
1 Parent(s): 69dd959

added files

Browse files
Files changed (7) hide show
  1. app.py +53 -0
  2. models/trained.pkl +3 -0
  3. predict.py +45 -0
  4. samples/bulgarian.jpg +0 -0
  5. samples/indian.jpg +0 -0
  6. samples/japanese.jpg +0 -0
  7. tools.py +62 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ File: recognize.py
3
+ Project: PatternsRecognizer
4
+ Author: Milko Videv (milko.videv@thalesgroup.com)
5
+ -----
6
+ Last Modified: Monday, 4th March 2024 9:38:39 am
7
+ Modified By: Milko Videv (milko.videv@thalesgroup.com>)
8
+ -----
9
+ Copyright 2017 - 2024, Thales DIS, MCS SSH
10
+ -----
11
+ HISTORY:
12
+ Date By Comments
13
+ ---------- --- ---------------------------------------------------------
14
+ '''
15
+
16
+ import gradio as gr
17
+ from predict import *
18
+ from tools import *
19
+ from gradio.themes.utils.colors import slate # type: ignore
20
+
21
+ def render():
22
+
23
+ title = "Recognizer of Bulgarian, Indian and Japanese design patterns"
24
+ description = "Select image or drag some of the examples below."
25
+ examples = [
26
+ './samples/bulgarian.jpg',
27
+ './samples/indian.jpg',
28
+ './samples/japanese.jpg'
29
+ ]
30
+ rand_examples = [
31
+ pick_random_file("./patterns/bulgarian"),
32
+ pick_random_file("./patterns/indian"),
33
+ pick_random_file("./patterns/japanese"),
34
+ ]
35
+
36
+ demo = gr.Interface(fn=predict,
37
+ theme=gr.themes.Monochrome(primary_hue=slate),
38
+ inputs=gr.components.Image(shape=(512, 512), interactive=True),
39
+ outputs=gr.components.Label(num_top_classes=3),
40
+ title=title,
41
+ description=description,
42
+ examples=examples,
43
+ allow_flagging="manual",
44
+ flagging_options=["Correct", "Incorrect"],
45
+ css=
46
+ "#component-3 { background-color: rgb(192, 192, 192) !important; }"
47
+ "#component-3 H1 { margin: 1.5rem 0 1.5rem 0; color: #252873; !important;}"
48
+ "footer { display: none !important; }"
49
+ )
50
+ demo.queue().launch()
51
+
52
+ if __name__ == "__main__":
53
+ render()
models/trained.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7c54172fb330159837492830b11d1f1c15673ad85e80ab742261d3daae6cac0
3
+ size 46967388
predict.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ File: recognize.py
3
+ Project: PatternsRecognizer
4
+ Author: Milko Videv (milko.videv@thalesgroup.com)
5
+ -----
6
+ Last Modified: Tuesday, 5th March 2024 10:02:08 am
7
+ Modified By: Milko Videv (milko.videv@thalesgroup.com>)
8
+ -----
9
+ Copyright 2017 - 2024, Thales DIS, MCS SSH
10
+ -----
11
+ HISTORY:
12
+ Date By Comments
13
+ ---------- --- ---------------------------------------------------------
14
+ '''
15
+
16
+ from fastcore.all import *
17
+ from fastai.vision.all import *
18
+ from tools import *
19
+
20
+ learn = load_learner("./models/trained.pkl")
21
+ labels = learn.dls.vocab
22
+
23
+ def predict(image):
24
+ img = PILImage.create(image)
25
+ pred,pred_idx,probs = learn.predict(img)
26
+ result = {labels[i]: float(probs[i]) for i in range(len(labels))}
27
+
28
+ max_type = max(result, key=result.get)
29
+ max_probability = result[max_type] * 100
30
+ print(f"{max_type} with probability {max_probability:.2f}%")
31
+
32
+ return result
33
+
34
+ if __name__ == "__main__":
35
+ if debugger_is_active():
36
+ predict("./POC/PatternsRecognizer/samples/bulgarian.jpg")
37
+ sys.exit()
38
+ else:
39
+ if len(sys.argv) < 2:
40
+ print("Use: python predict.py <image path>")
41
+ print("Example: python predict.py samples/bulgarian.jpg")
42
+ sys.exit()
43
+
44
+ src = sys.argv[1]
45
+ predict(src)
samples/bulgarian.jpg ADDED
samples/indian.jpg ADDED
samples/japanese.jpg ADDED
tools.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ File: tools.py
3
+ Project: PatternsRecognizer
4
+ Author: Milko Videv (milko.videv@thalesgroup.com)
5
+ -----
6
+ Last Modified: Saturday, 2nd March 2024 10:25:48 am
7
+ Modified By: Milko Videv (milko.videv@thalesgroup.com>)
8
+ -----
9
+ Copyright 2017 - 2024, Thales DIS, MCS SSH
10
+ -----
11
+ HISTORY:
12
+ Date By Comments
13
+ ---------- --- ---------------------------------------------------------
14
+ '''
15
+ import sys
16
+ import time
17
+ import os
18
+ import random
19
+
20
+ def get_yes_no(prompt):
21
+ while True:
22
+ response = input(prompt + " [y]|n: ").strip().lower()
23
+ if response == "" or response in ["yes", "y"]:
24
+ return True
25
+ elif response in ["no", "n"]:
26
+ return False
27
+ else:
28
+ print("\nPlease enter 'yes|y|no|n': ")
29
+
30
+ def debugger_is_active() -> bool:
31
+ return hasattr(sys, 'gettrace') and sys.gettrace() is not None
32
+
33
+ def pick_random_file(directory):
34
+ files = [file for file in os.listdir(directory) if os.path.isfile(os.path.join(directory, file))]
35
+ if files:
36
+ return os.path.join(directory, random.choice(files))
37
+ else:
38
+ return None
39
+
40
+ class Stopwatch:
41
+ def __init__(self):
42
+ self.start_time = None
43
+ self.end_time = None
44
+
45
+ def start(self):
46
+ self.start_time = time.time()
47
+
48
+ def stop(self):
49
+ self.end_time = time.time()
50
+
51
+ def elapsed_time(self):
52
+ if self.start_time is None:
53
+ return 0
54
+ if self.end_time is None:
55
+ elapsed = time.time() - self.start_time
56
+ else:
57
+ elapsed = self.end_time - self.start_time
58
+ return elapsed
59
+
60
+ def reset(self):
61
+ self.start_time = None
62
+ self.end_time = None