Rattpon commited on
Commit
def557a
Β·
1 Parent(s): f73b6c4

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -153
app.py DELETED
@@ -1,153 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- from PIL import Image
4
- import streamlit as st
5
- import tensorflow as tf
6
- from tensorflow.keras.models import load_model
7
-
8
- # most of this code has been obtained from Datature's prediction script
9
- # https://github.com/datature/resources/blob/main/scripts/bounding_box/prediction.py
10
-
11
- st.set_option('deprecation.showfileUploaderEncoding', False)
12
-
13
- @st.cache(allow_output_mutation=True)
14
- def load_model():
15
- return tf.saved_model.load('./saved_model')
16
-
17
- def load_label_map(label_map_path):
18
- """
19
- Reads label map in the format of .pbtxt and parse into dictionary
20
- Args:
21
- label_map_path: the file path to the label_map
22
- Returns:
23
- dictionary with the format of {label_index: {'id': label_index, 'name': label_name}}
24
- """
25
- label_map = {}
26
-
27
- with open(label_map_path, "r") as label_file:
28
- for line in label_file:
29
- if "id" in line:
30
- label_index = int(line.split(":")[-1])
31
- label_name = next(label_file).split(":")[-1].strip().strip('"')
32
- label_map[label_index] = {"id": label_index, "name": label_name}
33
- return label_map
34
-
35
- def predict_class(image, model):
36
- image = tf.cast(image, tf.float32)
37
- image = tf.image.resize(image, [150, 150])
38
- image = np.expand_dims(image, axis = 0)
39
- return model.predict(image)
40
-
41
- def plot_boxes_on_img(color_map, classes, bboxes, image_origi, origi_shape):
42
- for idx, each_bbox in enumerate(bboxes):
43
- color = color_map[classes[idx]]
44
-
45
- ## Draw bounding box
46
- cv2.rectangle(
47
- image_origi,
48
- (int(each_bbox[1] * origi_shape[1]),
49
- int(each_bbox[0] * origi_shape[0]),),
50
- (int(each_bbox[3] * origi_shape[1]),
51
- int(each_bbox[2] * origi_shape[0]),),
52
- color,
53
- 2,
54
- )
55
- ## Draw label background
56
- cv2.rectangle(
57
- image_origi,
58
- (int(each_bbox[1] * origi_shape[1]),
59
- int(each_bbox[2] * origi_shape[0]),),
60
- (int(each_bbox[3] * origi_shape[1]),
61
- int(each_bbox[2] * origi_shape[0] + 15),),
62
- color,
63
- -1,
64
- )
65
- ## Insert label class & score
66
- cv2.putText(
67
- image_origi,
68
- "Class: {}, Score: {}".format(
69
- str(category_index[classes[idx]]["name"]),
70
- str(round(scores[idx], 2)),
71
- ),
72
- (int(each_bbox[1] * origi_shape[1]),
73
- int(each_bbox[2] * origi_shape[0] + 10),),
74
- cv2.FONT_HERSHEY_SIMPLEX,
75
- 0.3,
76
- (0, 0, 0),
77
- 1,
78
- cv2.LINE_AA,
79
- )
80
- return image_origi
81
-
82
-
83
- # Webpage code starts here
84
-
85
- #TODO change this
86
- st.title('YOUR PROJECT NAME')
87
- st.text('made by XXX')
88
- st.markdown('## Description about your project')
89
-
90
- with st.spinner('Model is being loaded...'):
91
- model = load_model()
92
-
93
- # ask user to upload an image
94
- file = st.file_uploader("Upload image", type=["jpg", "png"])
95
-
96
- if file is None:
97
- st.text('Waiting for upload...')
98
- else:
99
- st.text('Running inference...')
100
- # open image
101
- test_image = Image.open(file).convert("RGB")
102
- origi_shape = np.asarray(test_image).shape
103
- # resize image to default shape
104
- default_shape = 320
105
- image_resized = np.array(test_image.resize((default_shape, default_shape)))
106
-
107
- ## Load color map
108
- category_index = load_label_map("./label_map.pbtxt")
109
-
110
- # TODO Add more colors if there are more classes
111
- # color of each label. check label_map.pbtxt to check the index for each class
112
- color_map = {
113
- 1: [255, 0, 0], # bad -> red
114
- 2: [0, 255, 0] # good -> green
115
- }
116
-
117
- ## The model input needs to be a tensor
118
- input_tensor = tf.convert_to_tensor(image_resized)
119
- ## The model expects a batch of images, so add an axis with `tf.newaxis`.
120
- input_tensor = input_tensor[tf.newaxis, ...]
121
-
122
- ## Feed image into model and obtain output
123
- detections_output = model(input_tensor)
124
- num_detections = int(detections_output.pop("num_detections"))
125
- detections = {key: value[0, :num_detections].numpy() for key, value in detections_output.items()}
126
- detections["num_detections"] = num_detections
127
-
128
- ## Filter out predictions below threshold
129
- # if threshold is higher, there will be fewer predictions
130
- # TODO change this number to see how the predictions change
131
- confidence_threshold = 0.8
132
- indexes = np.where(detections["detection_scores"] > confidence_threshold)
133
-
134
- ## Extract predicted bounding boxes
135
- bboxes = detections["detection_boxes"][indexes]
136
- # there are no predicted boxes
137
- if len(bboxes) == 0:
138
- st.error('No boxes predicted')
139
- # there are predicted boxes
140
- else:
141
- st.success('Boxes predicted')
142
- classes = detections["detection_classes"][indexes].astype(np.int64)
143
- scores = detections["detection_scores"][indexes]
144
-
145
- # plot boxes and labels on image
146
- image_origi = np.array(Image.fromarray(image_resized).resize((origi_shape[1], origi_shape[0])))
147
- image_origi = plot_boxes_on_img(color_map, classes, bboxes, image_origi, origi_shape)
148
-
149
- # show image in web page
150
- st.image(Image.fromarray(image_origi), caption="Image with predictions", width=400)
151
- st.markdown("### Predicted boxes")
152
- for idx in range(len((bboxes))):
153
- st.markdown(f"* Class: {str(category_index[classes[idx]]['name'])}, confidence score: {str(round(scores[idx], 2))}")