Spaces:
Sleeping
Sleeping
Update my_model/KBVQA.py
Browse files- my_model/KBVQA.py +11 -3
my_model/KBVQA.py
CHANGED
|
@@ -176,12 +176,16 @@ class KBVQA:
|
|
| 176 |
free_gpu_resources()
|
| 177 |
if self.kbvqa_model is not None:
|
| 178 |
del self.kbvqa_model
|
|
|
|
| 179 |
if self.captioner is not None:
|
| 180 |
del self.captioner
|
|
|
|
| 181 |
if self.detector is not None:
|
| 182 |
del self.detector
|
| 183 |
-
|
|
|
|
| 184 |
free_gpu_resources()
|
|
|
|
| 185 |
|
| 186 |
|
| 187 |
def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None, caption: str = None, objects: Optional[str] = None) -> str:
|
|
@@ -253,7 +257,7 @@ class KBVQA:
|
|
| 253 |
|
| 254 |
return output_text.capitalize()
|
| 255 |
|
| 256 |
-
def prepare_kbvqa_model(only_reload_detection_model: bool = False) -> KBVQA:
|
| 257 |
"""
|
| 258 |
Prepares the KBVQA model for use, including loading necessary sub-models.
|
| 259 |
|
|
@@ -269,7 +273,11 @@ def prepare_kbvqa_model(only_reload_detection_model: bool = False) -> KBVQA:
|
|
| 269 |
kbvqa.detection_model = st.session_state.detection_model
|
| 270 |
# Progress bar for model loading
|
| 271 |
with kbvqa.col1:
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
if not only_reload_detection_model:
|
| 274 |
progress_bar = st.progress(0)
|
| 275 |
kbvqa.load_detector(kbvqa.detection_model)
|
|
|
|
| 176 |
free_gpu_resources()
|
| 177 |
if self.kbvqa_model is not None:
|
| 178 |
del self.kbvqa_model
|
| 179 |
+
free_gpu_resources()
|
| 180 |
if self.captioner is not None:
|
| 181 |
del self.captioner
|
| 182 |
+
free_gpu_resources()
|
| 183 |
if self.detector is not None:
|
| 184 |
del self.detector
|
| 185 |
+
free_gpu_resources()
|
| 186 |
+
|
| 187 |
free_gpu_resources()
|
| 188 |
+
prepare_kbvqa_model(only_reload_detection_model=False, force_reload=True)
|
| 189 |
|
| 190 |
|
| 191 |
def format_prompt(self, current_query: str, history: Optional[str] = None, sys_prompt: Optional[str] = None, caption: str = None, objects: Optional[str] = None) -> str:
|
|
|
|
| 257 |
|
| 258 |
return output_text.capitalize()
|
| 259 |
|
| 260 |
+
def prepare_kbvqa_model(only_reload_detection_model: bool = False, force_reload: bool = True) -> KBVQA:
|
| 261 |
"""
|
| 262 |
Prepares the KBVQA model for use, including loading necessary sub-models.
|
| 263 |
|
|
|
|
| 273 |
kbvqa.detection_model = st.session_state.detection_model
|
| 274 |
# Progress bar for model loading
|
| 275 |
with kbvqa.col1:
|
| 276 |
+
if force_reload:
|
| 277 |
+
loading_message = 'Force Reloading model.. this should take no more than a few minutes!'
|
| 278 |
+
else: loading_message = 'Looading model.. this should take no more than a few minutes!'
|
| 279 |
+
|
| 280 |
+
with st.spinner(loading_message):
|
| 281 |
if not only_reload_detection_model:
|
| 282 |
progress_bar = st.progress(0)
|
| 283 |
kbvqa.load_detector(kbvqa.detection_model)
|