lajota13 commited on
Commit
0ef999d
·
verified ·
1 Parent(s): ed943ba

Update seasonal_color_analysis/fe.py

Browse files
Files changed (1) hide show
  1. seasonal_color_analysis/fe.py +21 -17
seasonal_color_analysis/fe.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  from io import BytesIO
3
- from PIL import Image, ImageDraw
4
  import uuid
5
  import json
6
  import datetime
@@ -58,19 +58,18 @@ def get_season_description(season: str) -> tuple[str, str]:
58
 
59
 
60
  @st.cache_data
61
- def predict(img_bytes: bytes) -> tuple[np.ndarray | None, dict[str, float], np.ndarray, np.ndarray]:
62
- with Image.open(BytesIO(img_bytes)) as img:
63
- batch_boxes, proba_dicts, np_season_embeddings, np_facenet_embeddings = st.session_state["classifier"].predict([img.convert("RGB")])
64
- return batch_boxes[0], proba_dicts[0], np_season_embeddings[0], np_facenet_embeddings[0]
65
 
66
 
67
  @st.cache_data
68
- def draw_bbox(img_bytes: bytes, bbox: np.ndarray) -> Image:
69
- with Image.open(BytesIO(img_bytes)) as img:
70
- _img = img.copy()
71
- draw = ImageDraw.Draw(_img)
72
- draw.rectangle(bbox.tolist(), outline="green", width=img.size[0] // 100)
73
- return _img
74
 
75
 
76
  @st.cache_data
@@ -179,12 +178,17 @@ img_stream = st.file_uploader(
179
  )
180
 
181
  if img_stream is not None:
182
- img_bytes = img_stream.getvalue()
183
- bbox, proba_dict, np_season_embedding, np_facenet_embedding = predict(img_bytes)
 
 
 
 
 
184
  if bbox is None:
185
  col1, col2 = st.columns(2)
186
  with col1:
187
- st.image(img_bytes, caption="Your image")
188
  with col2:
189
  st.write("⚠️\n\nIt was not possibile to detect any face in your image, try uploading another one\n\n⚠️")
190
  else:
@@ -196,10 +200,10 @@ if img_stream is not None:
196
  second_most_likely_prob = np.sort(probs)[-2]
197
  col1, col2 = st.columns(2)
198
  with col1:
199
- st.image(img_bytes, caption="Your image")
200
  with col2:
201
- img_w_bbox = draw_bbox(img_bytes, bbox)
202
- st.image(np.array(img_w_bbox), caption="Detected face")
203
 
204
  st.header("Your result")
205
 
 
1
  import os
2
  from io import BytesIO
3
+ from PIL import Image, ImageDraw, ImageOps
4
  import uuid
5
  import json
6
  import datetime
 
58
 
59
 
60
  @st.cache_data
61
+ def predict(np_img: np.ndarray) -> tuple[np.ndarray | None, dict[str, float], np.ndarray, np.ndarray]:
62
+ img = Image.fromarray(np_img)
63
+ batch_boxes, proba_dicts, np_season_embeddings, np_facenet_embeddings = st.session_state["classifier"].predict([img])
64
+ return batch_boxes[0], proba_dicts[0], np_season_embeddings[0], np_facenet_embeddings[0]
65
 
66
 
67
  @st.cache_data
68
+ def draw_bbox(np_img: np.ndarray, bbox: np.ndarray) -> np.ndarray:
69
+ img = Image.fromarray(np_img)
70
+ draw = ImageDraw.Draw(img)
71
+ draw.rectangle(bbox.tolist(), outline="green", width=img.size[0] // 100)
72
+ return np.array(img)
 
73
 
74
 
75
  @st.cache_data
 
178
  )
179
 
180
  if img_stream is not None:
181
+ with Image.open(img_stream) as img:
182
+ np_img = np.array(
183
+ ImageOps.exif_transpose(
184
+ img.convert("RGB")
185
+ )
186
+ )
187
+ bbox, proba_dict, np_season_embedding, np_facenet_embedding = predict(np_img)
188
  if bbox is None:
189
  col1, col2 = st.columns(2)
190
  with col1:
191
+ st.image(np_img, caption="Your image")
192
  with col2:
193
  st.write("⚠️\n\nIt was not possibile to detect any face in your image, try uploading another one\n\n⚠️")
194
  else:
 
200
  second_most_likely_prob = np.sort(probs)[-2]
201
  col1, col2 = st.columns(2)
202
  with col1:
203
+ st.image(np_img, caption="Your image")
204
  with col2:
205
+ np_img_w_bbox = draw_bbox(np_img, bbox)
206
+ st.image(np_img_w_bbox, caption="Detected face")
207
 
208
  st.header("Your result")
209