Spaces:
Sleeping
Sleeping
rmm
commited on
Commit
·
71dfd99
1
Parent(s):
8c4b1f7
feat: extended InputObservation to contain species/prediction info
Browse files- when manual validation is performed (dropdown selection among
species), it is written to the observations (And not the
dynamically-created dicts).
- TODO: decide if we need to retain public_observations in
session_state, or just generate the dict each time it is needed.
- src/classifier/classifier_image.py +15 -16
- src/input/input_observation.py +27 -1
- src/main.py +4 -2
- src/utils/metadata_handler.py +4 -1
src/classifier/classifier_image.py
CHANGED
|
@@ -10,6 +10,7 @@ import whale_viewer as viewer
|
|
| 10 |
from hf_push_observations import push_observations
|
| 11 |
from utils.grid_maker import gridder
|
| 12 |
from utils.metadata_handler import metadata2md
|
|
|
|
| 13 |
|
| 14 |
def add_header_text() -> None:
|
| 15 |
"""
|
|
@@ -24,12 +25,11 @@ def add_header_text() -> None:
|
|
| 24 |
def cetacean_just_classify(cetacean_classifier):
|
| 25 |
|
| 26 |
images = st.session_state.images
|
| 27 |
-
observations = st.session_state.observations
|
| 28 |
hashes = st.session_state.image_hashes
|
| 29 |
|
| 30 |
for hash in hashes:
|
| 31 |
image = images[hash]
|
| 32 |
-
observation = observations[hash].to_dict()
|
| 33 |
# run classifier model on `image`, and persistently store the output
|
| 34 |
out = cetacean_classifier(image) # get top 3 matches
|
| 35 |
st.session_state.whale_prediction1[hash] = out['predictions'][0]
|
|
@@ -39,8 +39,6 @@ def cetacean_just_classify(cetacean_classifier):
|
|
| 39 |
msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
|
| 40 |
g_logger.info(msg)
|
| 41 |
|
| 42 |
-
# store the elements of the observation that will be transmitted (not image)
|
| 43 |
-
st.session_state.public_observations[hash] = observation
|
| 44 |
if st.session_state.MODE_DEV_STATEFUL:
|
| 45 |
st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
|
| 46 |
|
|
@@ -58,7 +56,8 @@ def cetacean_show_results_and_review():
|
|
| 58 |
|
| 59 |
for hash in hashes:
|
| 60 |
image = images[hash]
|
| 61 |
-
observation = observations[hash].to_dict()
|
|
|
|
| 62 |
|
| 63 |
with grid[col]:
|
| 64 |
st.image(image, use_column_width=True)
|
|
@@ -75,14 +74,19 @@ def cetacean_show_results_and_review():
|
|
| 75 |
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
|
| 76 |
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
| 81 |
|
|
|
|
|
|
|
| 82 |
st.session_state.public_observations[hash] = observation
|
|
|
|
| 83 |
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
|
| 84 |
# TODO: the metadata only fills properly if `validate` was clicked.
|
| 85 |
-
st.markdown(metadata2md(hash))
|
| 86 |
|
| 87 |
msg = f"[D] full observation after inference: {observation}"
|
| 88 |
g_logger.debug(msg)
|
|
@@ -138,12 +142,7 @@ def cetacean_show_results():
|
|
| 138 |
|
| 139 |
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
|
| 140 |
#
|
| 141 |
-
st.markdown(metadata2md(hash))
|
| 142 |
-
# TODO: FIXME: this is the data taht will get pushed -- it DOESN'T reflect any adjustments
|
| 143 |
-
# # made via the dropdown on the last step!!!!
|
| 144 |
-
#st.markdown(f"- **selected species**: {observation['predicted_class']}")
|
| 145 |
-
st.markdown(f"- **selected species**: {st.session_state.whale_prediction1[hash]}")
|
| 146 |
-
st.markdown(f"- **hash**: {hash}")
|
| 147 |
|
| 148 |
msg = f"[D] full observation after inference: {observation}"
|
| 149 |
g_logger.debug(msg)
|
|
@@ -223,4 +222,4 @@ def cetacean_classify_show_and_review(cetacean_classifier):
|
|
| 223 |
for i in range(len(whale_classes)):
|
| 224 |
viewer.display_whale(whale_classes, i)
|
| 225 |
o += 1
|
| 226 |
-
col = (col + 1) % row_size
|
|
|
|
| 10 |
from hf_push_observations import push_observations
|
| 11 |
from utils.grid_maker import gridder
|
| 12 |
from utils.metadata_handler import metadata2md
|
| 13 |
+
from input.input_observation import InputObservation
|
| 14 |
|
| 15 |
def add_header_text() -> None:
|
| 16 |
"""
|
|
|
|
| 25 |
def cetacean_just_classify(cetacean_classifier):
|
| 26 |
|
| 27 |
images = st.session_state.images
|
| 28 |
+
#observations = st.session_state.observations
|
| 29 |
hashes = st.session_state.image_hashes
|
| 30 |
|
| 31 |
for hash in hashes:
|
| 32 |
image = images[hash]
|
|
|
|
| 33 |
# run classifier model on `image`, and persistently store the output
|
| 34 |
out = cetacean_classifier(image) # get top 3 matches
|
| 35 |
st.session_state.whale_prediction1[hash] = out['predictions'][0]
|
|
|
|
| 39 |
msg = f"[D]2 classify_whale_done for {hash}: {st.session_state.classify_whale_done[hash]}, whale_prediction1: {st.session_state.whale_prediction1[hash]}"
|
| 40 |
g_logger.info(msg)
|
| 41 |
|
|
|
|
|
|
|
| 42 |
if st.session_state.MODE_DEV_STATEFUL:
|
| 43 |
st.write(f"*[D] Observation {hash} classified as {st.session_state.whale_prediction1[hash]}*")
|
| 44 |
|
|
|
|
| 56 |
|
| 57 |
for hash in hashes:
|
| 58 |
image = images[hash]
|
| 59 |
+
#observation = observations[hash].to_dict()
|
| 60 |
+
_observation:InputObservation = observations[hash]
|
| 61 |
|
| 62 |
with grid[col]:
|
| 63 |
st.image(image, use_column_width=True)
|
|
|
|
| 74 |
ix = viewer.WHALE_CLASSES.index(pred1) if pred1 in viewer.WHALE_CLASSES else None
|
| 75 |
selected_class = st.selectbox(f"Species for observation {str(o)}", viewer.WHALE_CLASSES, index=ix)
|
| 76 |
|
| 77 |
+
_observation.set_selected_class(selected_class)
|
| 78 |
+
#observation['predicted_class'] = selected_class
|
| 79 |
+
# this logic is now in the InputObservation class automatially
|
| 80 |
+
#if selected_class != st.session_state.whale_prediction1[hash]:
|
| 81 |
+
# observation['class_overriden'] = selected_class # TODO: this should be boolean!
|
| 82 |
|
| 83 |
+
# store the elements of the observation that will be transmitted (not image)
|
| 84 |
+
observation = _observation.to_dict()
|
| 85 |
st.session_state.public_observations[hash] = observation
|
| 86 |
+
|
| 87 |
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
|
| 88 |
# TODO: the metadata only fills properly if `validate` was clicked.
|
| 89 |
+
st.markdown(metadata2md(hash, debug=True))
|
| 90 |
|
| 91 |
msg = f"[D] full observation after inference: {observation}"
|
| 92 |
g_logger.debug(msg)
|
|
|
|
| 142 |
|
| 143 |
#st.button(f"Upload observation {str(o)} to THE INTERNET!", on_click=push_observations)
|
| 144 |
#
|
| 145 |
+
st.markdown(metadata2md(hash, debug=True))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
|
| 147 |
msg = f"[D] full observation after inference: {observation}"
|
| 148 |
g_logger.debug(msg)
|
|
|
|
| 222 |
for i in range(len(whale_classes)):
|
| 223 |
viewer.display_whale(whale_classes, i)
|
| 224 |
o += 1
|
| 225 |
+
col = (col + 1) % row_size
|
src/input/input_observation.py
CHANGED
|
@@ -68,7 +68,10 @@ class InputObservation:
|
|
| 68 |
self.time = time
|
| 69 |
self.uploaded_file = uploaded_file
|
| 70 |
self.image_md5 = image_md5
|
|
|
|
| 71 |
self._top_predictions = []
|
|
|
|
|
|
|
| 72 |
|
| 73 |
InputObservation._inst_count += 1
|
| 74 |
self._inst_id = InputObservation._inst_count
|
|
@@ -81,11 +84,30 @@ class InputObservation:
|
|
| 81 |
|
| 82 |
def set_top_predictions(self, top_predictions:list):
|
| 83 |
self._top_predictions = top_predictions
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
@property
|
| 87 |
def top_predictions(self):
|
| 88 |
return self._top_predictions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
# add a method to assign the image_md5 only once
|
| 91 |
def assign_image_md5(self):
|
|
@@ -194,6 +216,10 @@ class InputObservation:
|
|
| 194 |
"image_datetime_raw": self.image_datetime_raw,
|
| 195 |
"date": str(self.date),
|
| 196 |
"time": str(self.time),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
#"uploaded_file": self.uploaded_file # can't serialize this in json, not sent to dataset anyway.
|
| 198 |
}
|
| 199 |
|
|
|
|
| 68 |
self.time = time
|
| 69 |
self.uploaded_file = uploaded_file
|
| 70 |
self.image_md5 = image_md5
|
| 71 |
+
# attributes that get set after predictions/processing
|
| 72 |
self._top_predictions = []
|
| 73 |
+
self._selected_class = None
|
| 74 |
+
self._class_overriden = False
|
| 75 |
|
| 76 |
InputObservation._inst_count += 1
|
| 77 |
self._inst_id = InputObservation._inst_count
|
|
|
|
| 84 |
|
| 85 |
def set_top_predictions(self, top_predictions:list):
|
| 86 |
self._top_predictions = top_predictions
|
| 87 |
+
if len(top_predictions) > 0:
|
| 88 |
+
self.set_selected_class(top_predictions[0])
|
| 89 |
|
| 90 |
+
def set_selected_class(self, selected_class:str):
|
| 91 |
+
self._selected_class = selected_class
|
| 92 |
+
if selected_class != self._top_predictions[0]:
|
| 93 |
+
self.set_class_overriden(True)
|
| 94 |
+
|
| 95 |
+
def set_class_overriden(self, class_overriden:bool):
|
| 96 |
+
self._class_overriden = class_overriden
|
| 97 |
+
|
| 98 |
+
# add getters for the top_predictions, selected_class and class_overriden
|
| 99 |
@property
|
| 100 |
def top_predictions(self):
|
| 101 |
return self._top_predictions
|
| 102 |
+
|
| 103 |
+
@property
|
| 104 |
+
def selected_class(self):
|
| 105 |
+
return self._selected_class
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def class_overriden(self):
|
| 109 |
+
return self._class_overriden
|
| 110 |
+
|
| 111 |
|
| 112 |
# add a method to assign the image_md5 only once
|
| 113 |
def assign_image_md5(self):
|
|
|
|
| 216 |
"image_datetime_raw": self.image_datetime_raw,
|
| 217 |
"date": str(self.date),
|
| 218 |
"time": str(self.time),
|
| 219 |
+
"selected_class": self._selected_class,
|
| 220 |
+
"top_prediction": self._top_predictions[0] if len(self._top_predictions) else None,
|
| 221 |
+
"class_overriden": self._class_overriden,
|
| 222 |
+
|
| 223 |
#"uploaded_file": self.uploaded_file # can't serialize this in json, not sent to dataset anyway.
|
| 224 |
}
|
| 225 |
|
src/main.py
CHANGED
|
@@ -237,7 +237,8 @@ def main() -> None:
|
|
| 237 |
if st.sidebar.button(":white_check_mark:[**Validate**]"):
|
| 238 |
# create a dictionary with the submitted observation
|
| 239 |
tab_log.info(f"{st.session_state.observations}")
|
| 240 |
-
df = pd.DataFrame(st.session_state.observations
|
|
|
|
| 241 |
with tab_coords:
|
| 242 |
st.table(df)
|
| 243 |
# there doesn't seem to be any actual validation here?? TODO: find validator function (each element is validated by the input box, but is there something at the whole image level?)
|
|
@@ -320,7 +321,8 @@ def main() -> None:
|
|
| 320 |
cetacean_show_results()
|
| 321 |
|
| 322 |
st.divider()
|
| 323 |
-
df = pd.DataFrame(st.session_state.observations, index=[0])
|
|
|
|
| 324 |
st.table(df)
|
| 325 |
|
| 326 |
# didn't decide what the next state is here - I think we are in the terminal state.
|
|
|
|
| 237 |
if st.sidebar.button(":white_check_mark:[**Validate**]"):
|
| 238 |
# create a dictionary with the submitted observation
|
| 239 |
tab_log.info(f"{st.session_state.observations}")
|
| 240 |
+
df = pd.DataFrame([obs.to_dict() for obs in st.session_state.observations.values()])
|
| 241 |
+
#df = pd.DataFrame(st.session_state.observations, index=[0])
|
| 242 |
with tab_coords:
|
| 243 |
st.table(df)
|
| 244 |
# there doesn't seem to be any actual validation here?? TODO: find validator function (each element is validated by the input box, but is there something at the whole image level?)
|
|
|
|
| 321 |
cetacean_show_results()
|
| 322 |
|
| 323 |
st.divider()
|
| 324 |
+
#df = pd.DataFrame(st.session_state.observations, index=[0])
|
| 325 |
+
df = pd.DataFrame([obs.to_dict() for obs in st.session_state.observations.values()])
|
| 326 |
st.table(df)
|
| 327 |
|
| 328 |
# didn't decide what the next state is here - I think we are in the terminal state.
|
src/utils/metadata_handler.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
|
| 3 |
-
def metadata2md(image_hash:str) -> str:
|
| 4 |
"""Get metadata from cache and return as markdown-formatted key-value list
|
| 5 |
|
| 6 |
Args:
|
| 7 |
image_hash (str): The hash of the image to get metadata for
|
|
|
|
| 8 |
|
| 9 |
Returns:
|
| 10 |
str: Markdown-formatted key-value list of metadata
|
|
@@ -12,6 +13,8 @@ def metadata2md(image_hash:str) -> str:
|
|
| 12 |
"""
|
| 13 |
markdown_str = "\n"
|
| 14 |
keys_to_print = ["author_email", "latitude", "longitude", "date", "time"]
|
|
|
|
|
|
|
| 15 |
|
| 16 |
observation = st.session_state.public_observations.get(image_hash, {})
|
| 17 |
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
|
| 3 |
+
def metadata2md(image_hash:str, debug:bool=False) -> str:
|
| 4 |
"""Get metadata from cache and return as markdown-formatted key-value list
|
| 5 |
|
| 6 |
Args:
|
| 7 |
image_hash (str): The hash of the image to get metadata for
|
| 8 |
+
debug (bool, optional): Whether to print additional fields.
|
| 9 |
|
| 10 |
Returns:
|
| 11 |
str: Markdown-formatted key-value list of metadata
|
|
|
|
| 13 |
"""
|
| 14 |
markdown_str = "\n"
|
| 15 |
keys_to_print = ["author_email", "latitude", "longitude", "date", "time"]
|
| 16 |
+
if debug:
|
| 17 |
+
keys_to_print += ["iamge_md5", "selected_class", "top_prediction", "class_overriden"]
|
| 18 |
|
| 19 |
observation = st.session_state.public_observations.get(image_hash, {})
|
| 20 |
|