Spaces:
Sleeping
Sleeping
Update my_model/KBVQA.py
Browse files- my_model/KBVQA.py +20 -1
my_model/KBVQA.py
CHANGED
|
@@ -21,6 +21,7 @@ class KBVQA():
|
|
| 21 |
self.kbvqa_tokenizer = None
|
| 22 |
self.captioner = None
|
| 23 |
self.detector = None
|
|
|
|
| 24 |
self.kbvqa_model = None
|
| 25 |
self.access_token = os.getenv("HUGGINGFACE_TOKEN")
|
| 26 |
# self.kbvqa_model_loaded = self.all_models_loaded()
|
|
@@ -87,6 +88,22 @@ class KBVQA():
|
|
| 87 |
def all_models_loaded(self):
|
| 88 |
return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
|
| 92 |
def format_prompt(self, current_query, history = None , sys_prompt=None, caption=None, objects=None):
|
|
@@ -144,11 +161,13 @@ class KBVQA():
|
|
| 144 |
def prepare_kbvqa_model(detection_model):
|
| 145 |
free_gpu_resources()
|
| 146 |
kbvqa = KBVQA()
|
|
|
|
| 147 |
# Progress bar for model loading
|
| 148 |
with st.spinner('Loading model...'):
|
| 149 |
|
| 150 |
progress_bar = st.progress(0)
|
| 151 |
-
|
|
|
|
| 152 |
progress_bar.progress(33)
|
| 153 |
kbvqa.load_caption_model()
|
| 154 |
free_gpu_resources()
|
|
|
|
| 21 |
self.kbvqa_tokenizer = None
|
| 22 |
self.captioner = None
|
| 23 |
self.detector = None
|
| 24 |
+
sel.detection_model = None
|
| 25 |
self.kbvqa_model = None
|
| 26 |
self.access_token = os.getenv("HUGGINGFACE_TOKEN")
|
| 27 |
# self.kbvqa_model_loaded = self.all_models_loaded()
|
|
|
|
| 88 |
def all_models_loaded(self):
|
| 89 |
return self.kbvqa_model is not None and self.captioner is not None and self.detector is not None
|
| 90 |
|
| 91 |
+
def force_reload_model(self):
|
| 92 |
+
free_gpu_resources()
|
| 93 |
+
if self.kbvqa_model is not None:
|
| 94 |
+
del self.kbvqa_model
|
| 95 |
+
if self.captioner is not None:
|
| 96 |
+
del self.captioner
|
| 97 |
+
if self.detector is not None:
|
| 98 |
+
del self.detector
|
| 99 |
+
|
| 100 |
+
free_gpu_resources()
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
|
| 107 |
|
| 108 |
|
| 109 |
def format_prompt(self, current_query, history = None , sys_prompt=None, caption=None, objects=None):
|
|
|
|
| 161 |
def prepare_kbvqa_model(detection_model):
|
| 162 |
free_gpu_resources()
|
| 163 |
kbvqa = KBVQA()
|
| 164 |
+
kbvqa.detection_model = detection_model
|
| 165 |
# Progress bar for model loading
|
| 166 |
with st.spinner('Loading model...'):
|
| 167 |
|
| 168 |
progress_bar = st.progress(0)
|
| 169 |
+
|
| 170 |
+
kbvqa.load_detector(kbvqa.detection_model)
|
| 171 |
progress_bar.progress(33)
|
| 172 |
kbvqa.load_caption_model()
|
| 173 |
free_gpu_resources()
|