Spaces:
Sleeping
Sleeping
Update my_model/state_manager.py
Browse files- my_model/state_manager.py +15 -1
my_model/state_manager.py
CHANGED
|
@@ -24,7 +24,7 @@ class StateManager:
|
|
| 24 |
def set_up_widgets(self):
|
| 25 |
|
| 26 |
# Create two columns with different widths
|
| 27 |
-
col1, col2 = st.columns([0.
|
| 28 |
with col1:
|
| 29 |
st.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
|
| 30 |
detection_model = st.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
|
|
@@ -37,15 +37,19 @@ class StateManager:
|
|
| 37 |
if show_model_settings:
|
| 38 |
self.display_model_settings()
|
| 39 |
|
|
|
|
|
|
|
| 40 |
|
| 41 |
def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name):
|
| 42 |
|
| 43 |
return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
|
|
|
|
| 44 |
|
| 45 |
@property
|
| 46 |
def settings_changed(self):
|
| 47 |
return self.has_state_changed()
|
| 48 |
|
|
|
|
| 49 |
def display_model_settings(self):
|
| 50 |
st.write("#### Current Model Settings:")
|
| 51 |
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
|
|
@@ -53,11 +57,13 @@ class StateManager:
|
|
| 53 |
styled_df = df.style.set_properties(**{'background-color': 'black', 'color': 'white', 'border-color': 'white'}).set_table_styles([{'selector': 'th','props': [('background-color', 'black'), ('font-weight', 'bold')]}])
|
| 54 |
st.table(styled_df)
|
| 55 |
|
|
|
|
| 56 |
def display_session_state(self):
|
| 57 |
st.write("Current Model:")
|
| 58 |
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
|
| 59 |
df = pd.DataFrame(data)
|
| 60 |
st.table(df)
|
|
|
|
| 61 |
|
| 62 |
def load_model(self):
|
| 63 |
"""Load the KBVQA model with specified settings."""
|
|
@@ -76,6 +82,7 @@ class StateManager:
|
|
| 76 |
except Exception as e:
|
| 77 |
st.error(f"Error loading model: {e}")
|
| 78 |
|
|
|
|
| 79 |
# Function to check if any session state values have changed
|
| 80 |
def has_state_changed(self):
|
| 81 |
for key in st.session_state['previous_state']:
|
|
@@ -83,13 +90,16 @@ class StateManager:
|
|
| 83 |
return True # Found a change
|
| 84 |
else: return False # No changes found
|
| 85 |
|
|
|
|
| 86 |
def get_model(self):
|
| 87 |
"""Retrieve the KBVQA model from the session state."""
|
| 88 |
return st.session_state.get('kbvqa', None)
|
| 89 |
|
|
|
|
| 90 |
def is_model_loaded(self):
|
| 91 |
return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
|
| 92 |
|
|
|
|
| 93 |
def reload_detection_model(self):
|
| 94 |
try:
|
| 95 |
free_gpu_resources()
|
|
@@ -112,6 +122,7 @@ class StateManager:
|
|
| 112 |
'analysis_done': False
|
| 113 |
}
|
| 114 |
|
|
|
|
| 115 |
def analyze_image(self, image, kbvqa):
|
| 116 |
img = copy.deepcopy(image)
|
| 117 |
st.text("Analyzing the image .. ")
|
|
@@ -119,13 +130,16 @@ class StateManager:
|
|
| 119 |
image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
|
| 120 |
return caption, detected_objects_str, image_with_boxes
|
| 121 |
|
|
|
|
| 122 |
def add_to_qa_history(self, image_key, question, answer):
|
| 123 |
if image_key in st.session_state['images_data']:
|
| 124 |
st.session_state['images_data'][image_key]['qa_history'].append((question, answer))
|
| 125 |
|
|
|
|
| 126 |
def get_images_data(self):
|
| 127 |
return st.session_state['images_data']
|
| 128 |
|
|
|
|
| 129 |
def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
|
| 130 |
if image_key in st.session_state['images_data']:
|
| 131 |
st.session_state['images_data'][image_key].update({
|
|
|
|
| 24 |
def set_up_widgets(self):
|
| 25 |
|
| 26 |
# Create two columns with different widths
|
| 27 |
+
col1, col2, col3 = st.columns([0.2, 0.6, 0.2]) # Adjust the ratio as needed
|
| 28 |
with col1:
|
| 29 |
st.selectbox("Choose a method:", ["Fine-Tuned Model", "In-Context Learning (n-shots)"], index=0, key='method')
|
| 30 |
detection_model = st.selectbox("Choose a model for objects detection:", ["yolov5", "detic"], index=1, key='detection_model')
|
|
|
|
| 37 |
if show_model_settings:
|
| 38 |
self.display_model_settings()
|
| 39 |
|
| 40 |
+
col3.header("COL3")
|
| 41 |
+
|
| 42 |
|
| 43 |
def set_slider_value(self, text, min_value, max_value, value, step, slider_key_name):
|
| 44 |
|
| 45 |
return st.slider(text, min_value, max_value, value, step, key=slider_key_name)
|
| 46 |
+
|
| 47 |
|
| 48 |
@property
|
| 49 |
def settings_changed(self):
|
| 50 |
return self.has_state_changed()
|
| 51 |
|
| 52 |
+
|
| 53 |
def display_model_settings(self):
|
| 54 |
st.write("#### Current Model Settings:")
|
| 55 |
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items() if key in ["confidence_level", 'detection_model', 'method', 'kbvqa', 'previous_state', 'settings_changed', ]]
|
|
|
|
| 57 |
styled_df = df.style.set_properties(**{'background-color': 'black', 'color': 'white', 'border-color': 'white'}).set_table_styles([{'selector': 'th','props': [('background-color', 'black'), ('font-weight', 'bold')]}])
|
| 58 |
st.table(styled_df)
|
| 59 |
|
| 60 |
+
|
| 61 |
def display_session_state(self):
|
| 62 |
st.write("Current Model:")
|
| 63 |
data = [{'Key': key, 'Value': str(value)} for key, value in st.session_state.items()]
|
| 64 |
df = pd.DataFrame(data)
|
| 65 |
st.table(df)
|
| 66 |
+
|
| 67 |
|
| 68 |
def load_model(self):
|
| 69 |
"""Load the KBVQA model with specified settings."""
|
|
|
|
| 82 |
except Exception as e:
|
| 83 |
st.error(f"Error loading model: {e}")
|
| 84 |
|
| 85 |
+
|
| 86 |
# Function to check if any session state values have changed
|
| 87 |
def has_state_changed(self):
|
| 88 |
for key in st.session_state['previous_state']:
|
|
|
|
| 90 |
return True # Found a change
|
| 91 |
else: return False # No changes found
|
| 92 |
|
| 93 |
+
|
| 94 |
def get_model(self):
|
| 95 |
"""Retrieve the KBVQA model from the session state."""
|
| 96 |
return st.session_state.get('kbvqa', None)
|
| 97 |
|
| 98 |
+
|
| 99 |
def is_model_loaded(self):
|
| 100 |
return 'kbvqa' in st.session_state and st.session_state['kbvqa'] is not None
|
| 101 |
|
| 102 |
+
|
| 103 |
def reload_detection_model(self):
|
| 104 |
try:
|
| 105 |
free_gpu_resources()
|
|
|
|
| 122 |
'analysis_done': False
|
| 123 |
}
|
| 124 |
|
| 125 |
+
|
| 126 |
def analyze_image(self, image, kbvqa):
|
| 127 |
img = copy.deepcopy(image)
|
| 128 |
st.text("Analyzing the image .. ")
|
|
|
|
| 130 |
image_with_boxes, detected_objects_str = kbvqa.detect_objects(img)
|
| 131 |
return caption, detected_objects_str, image_with_boxes
|
| 132 |
|
| 133 |
+
|
| 134 |
def add_to_qa_history(self, image_key, question, answer):
|
| 135 |
if image_key in st.session_state['images_data']:
|
| 136 |
st.session_state['images_data'][image_key]['qa_history'].append((question, answer))
|
| 137 |
|
| 138 |
+
|
| 139 |
def get_images_data(self):
|
| 140 |
return st.session_state['images_data']
|
| 141 |
|
| 142 |
+
|
| 143 |
def update_image_data(self, image_key, caption, detected_objects_str, analysis_done):
|
| 144 |
if image_key in st.session_state['images_data']:
|
| 145 |
st.session_state['images_data'][image_key].update({
|