Shrikrishna commited on
Commit
3bbd8ef
·
1 Parent(s): be82a88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -1
app.py CHANGED
@@ -5,4 +5,90 @@ import json
5
  import numpy as np
6
  import cv2
7
 
8
- st.title("Welcome!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import numpy as np
6
  import cv2
7
 
8
+ st.title("Welcome!")
9
+
10
+ __class_name_to_number = {}
11
+ __class_number_to_name = {}
12
+ __model = None
13
+
14
+ def classify_image(image_base64_data, file_path=None):
15
+
16
+ imgs = get_cropped_image_if_2_eyes_new(file_path, image_base64_data)
17
+
18
+ result = []
19
+ for img in imgs:
20
+ scalled_raw_img = cv2.resize(img, (32, 32))
21
+ img_har = w2d(img, 'db1', 5)
22
+ scalled_img_har = cv2.resize(img_har, (32, 32))
23
+ combined_img = np.vstack((scalled_raw_img.reshape(32 * 32 * 3, 1), scalled_img_har.reshape(32 * 32, 1)))
24
+
25
+ len_image_array = 32*32*3 + 32*32
26
+
27
+ final = combined_img.reshape(1,len_image_array).astype(float)
28
+ result.append({
29
+ 'class': class_number_to_name(__model.predict(final)[0]),
30
+ 'class_probability': np.around(__model.predict_proba(final)*100,2).tolist()[0],
31
+ 'class_dictionary': __class_name_to_number
32
+ })
33
+
34
+ return result
35
+
36
+
37
+ def get_cropped_image_if_2_eyes_new(image_path, image_base64_data):
38
+ face_cascade = cv2.CascadeClassifier('haarcascade_frontalface_default.xml')
39
+ eye_cascade = cv2.CascadeClassifier('haarcascade_eye.xml')
40
+
41
+ if image_path:
42
+ img = cv2.imread(image_path)
43
+ else:
44
+ img = get_cv2_image_from_base64_string(image_base64_data)
45
+
46
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
47
+ faces = face_cascade.detectMultiScale(gray, 1.3, 5)
48
+
49
+ cropped_faces = []
50
+ for (x,y,w,h) in faces:
51
+ roi_gray = gray[y:y+h, x:x+w]
52
+ roi_color = img[y:y+h, x:x+w]
53
+ eyes = eye_cascade.detectMultiScale(roi_gray)
54
+ if len(eyes) >= 2:
55
+ cropped_faces.append(roi_color)
56
+ return cropped_faces
57
+
58
+ def get_cv2_image_from_base64_string(b64str):
59
+ '''
60
+ credit: https://stackoverflow.com/questions/33754935/read-a-base-64-encoded-image-from-memory-using-opencv-python-library
61
+ :param uri:
62
+ :return:
63
+ '''
64
+ encoded_data = b64str.split(',')[1]
65
+ nparr = np.frombuffer(base64.b64decode(encoded_data), np.uint8)
66
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
67
+ return img
68
+
69
+ def load_saved_artifacts():
70
+ print("loading saved artifacts...start")
71
+ global __class_name_to_number
72
+ global __class_number_to_name
73
+
74
+ with open("class_dictionary.json", "r") as f:
75
+ __class_name_to_number = json.load(f)
76
+ __class_number_to_name = {v:k for k,v in __class_name_to_number.items()}
77
+
78
+ global __model
79
+ if __model is None:
80
+ __model = pickle.load(open('saved_model.pkl','rb'))
81
+ st.text("loading saved artifacts...done")
82
+
83
+ def class_number_to_name(class_num):
84
+ return __class_number_to_name[class_num]
85
+
86
+ def get_b64_test_image_for_virat():
87
+ with open("b64.txt") as f:
88
+ return f.read()
89
+
90
+ uploaded_image = st.file_uploader('Choose an image')
91
+ load_saved_artifacts()
92
+
93
+ st.text(classify_image(get_b64_test_image_for_virat(), "sharapova1.jpg"))
94
+