Spaces:
Sleeping
Sleeping
Update my_model/tabs/run_inference.py
Browse files
my_model/tabs/run_inference.py
CHANGED
|
@@ -17,7 +17,6 @@ from my_model.state_manager import StateManager
|
|
| 17 |
from my_model.config import inference_config as config
|
| 18 |
|
| 19 |
|
| 20 |
-
|
| 21 |
class InferenceRunner(StateManager):
|
| 22 |
"""
|
| 23 |
Manages the user interface and interactions for running inference using the Streamlit-based Knowledge-Based Visual
|
|
@@ -244,16 +243,15 @@ class InferenceRunner(StateManager):
|
|
| 244 |
reload_kbvqa = False
|
| 245 |
reload_detection_model = False
|
| 246 |
force_reload_full_model = False
|
| 247 |
-
|
| 248 |
|
| 249 |
if self.is_model_loaded and self.settings_changed:
|
| 250 |
self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
|
| 251 |
-
|
| 252 |
st.session_state.button_label = (
|
| 253 |
"Reload Model" if (self.is_model_loaded and
|
| 254 |
st.session_state.kbvqa.detection_model != st.session_state['detection_model']) or
|
| 255 |
-
|
| 256 |
-
|
| 257 |
else "Load Model"
|
| 258 |
)
|
| 259 |
|
|
@@ -269,10 +267,11 @@ class InferenceRunner(StateManager):
|
|
| 269 |
fine_tuned_model_already_loaded = True
|
| 270 |
else:
|
| 271 |
load_fine_tuned_model = True
|
| 272 |
-
elif st.session_state.button_label == "Reload Model"
|
| 273 |
-
|
| 274 |
force_reload_full_model = True
|
| 275 |
-
elif (self.is_model_loaded and st.session_state.kbvqa.detection_model !=
|
|
|
|
| 276 |
reload_detection_model = True
|
| 277 |
if nested_col12.button("Force Reload", on_click=self.disable_widgets,
|
| 278 |
disabled=self.is_widget_disabled):
|
|
@@ -298,7 +297,7 @@ class InferenceRunner(StateManager):
|
|
| 298 |
st.session_state['time_taken_to_load_model'] = int(time.time() - t1)
|
| 299 |
st.session_state['loading_in_progress'] = False
|
| 300 |
st.session_state['model_loaded'] = True
|
| 301 |
-
|
| 302 |
elif st.session_state.method == "Vision-Language Embeddings Alignment":
|
| 303 |
self.col1.warning(
|
| 304 |
f'Model using {st.session_state.method} is desgined but requires large scale data and multiple '
|
|
@@ -308,7 +307,7 @@ class InferenceRunner(StateManager):
|
|
| 308 |
st.write(st.session_state['previous_state']['method'])
|
| 309 |
if st.session_state['kbvqa'] is not None:
|
| 310 |
st.write(st.session_state['kbvqa'].kbvqa_model_name)
|
| 311 |
-
|
| 312 |
if self.is_model_loaded:
|
| 313 |
free_gpu_resources()
|
| 314 |
st.session_state['loading_in_progress'] = False
|
|
|
|
| 17 |
from my_model.config import inference_config as config
|
| 18 |
|
| 19 |
|
|
|
|
| 20 |
class InferenceRunner(StateManager):
|
| 21 |
"""
|
| 22 |
Manages the user interface and interactions for running inference using the Streamlit-based Knowledge-Based Visual
|
|
|
|
| 243 |
reload_kbvqa = False
|
| 244 |
reload_detection_model = False
|
| 245 |
force_reload_full_model = False
|
|
|
|
| 246 |
|
| 247 |
if self.is_model_loaded and self.settings_changed:
|
| 248 |
self.col1.warning("Model settings have changed, please reload the model, this will take a second .. ")
|
| 249 |
+
# self.update_prev_state()
|
| 250 |
st.session_state.button_label = (
|
| 251 |
"Reload Model" if (self.is_model_loaded and
|
| 252 |
st.session_state.kbvqa.detection_model != st.session_state['detection_model']) or
|
| 253 |
+
(st.session_state['previous_state']['method'] is not None and
|
| 254 |
+
st.session_state['method'] != st.session_state['previous_state']['method'])
|
| 255 |
else "Load Model"
|
| 256 |
)
|
| 257 |
|
|
|
|
| 267 |
fine_tuned_model_already_loaded = True
|
| 268 |
else:
|
| 269 |
load_fine_tuned_model = True
|
| 270 |
+
elif st.session_state.button_label == "Reload Model" and st.session_state['method'] != \
|
| 271 |
+
st.session_state['previous_state']['method']: # check if the model size have changed
|
| 272 |
force_reload_full_model = True
|
| 273 |
+
elif (self.is_model_loaded and st.session_state.kbvqa.detection_model !=
|
| 274 |
+
st.session_state['detection_model']):
|
| 275 |
reload_detection_model = True
|
| 276 |
if nested_col12.button("Force Reload", on_click=self.disable_widgets,
|
| 277 |
disabled=self.is_widget_disabled):
|
|
|
|
| 297 |
st.session_state['time_taken_to_load_model'] = int(time.time() - t1)
|
| 298 |
st.session_state['loading_in_progress'] = False
|
| 299 |
st.session_state['model_loaded'] = True
|
| 300 |
+
|
| 301 |
elif st.session_state.method == "Vision-Language Embeddings Alignment":
|
| 302 |
self.col1.warning(
|
| 303 |
f'Model using {st.session_state.method} is desgined but requires large scale data and multiple '
|
|
|
|
| 307 |
st.write(st.session_state['previous_state']['method'])
|
| 308 |
if st.session_state['kbvqa'] is not None:
|
| 309 |
st.write(st.session_state['kbvqa'].kbvqa_model_name)
|
| 310 |
+
|
| 311 |
if self.is_model_loaded:
|
| 312 |
free_gpu_resources()
|
| 313 |
st.session_state['loading_in_progress'] = False
|