Spaces:
Sleeping
Sleeping
Update my_model/KBVQA.py
Browse files- my_model/KBVQA.py +7 -3
my_model/KBVQA.py
CHANGED
|
@@ -224,7 +224,7 @@ class KBVQA:
|
|
| 224 |
return p
|
| 225 |
|
| 226 |
@staticmethod
|
| 227 |
-
def trim_objects(
|
| 228 |
"""
|
| 229 |
Trim the last object from the detected objects string.
|
| 230 |
|
|
@@ -257,7 +257,9 @@ class KBVQA:
|
|
| 257 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
| 258 |
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
|
| 259 |
self.current_prompt_length = num_tokens
|
| 260 |
-
|
|
|
|
|
|
|
| 261 |
while self.current_prompt_length > self.max_context_window:
|
| 262 |
detected_objects_str = self.trim_objects(detected_objects_str)
|
| 263 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
|
@@ -265,7 +267,9 @@ class KBVQA:
|
|
| 265 |
|
| 266 |
if detected_objects_str == "":
|
| 267 |
break # Break if no objects are left
|
| 268 |
-
|
|
|
|
|
|
|
| 269 |
|
| 270 |
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
|
| 271 |
free_gpu_resources()
|
|
|
|
| 224 |
return p
|
| 225 |
|
| 226 |
@staticmethod
|
| 227 |
+
def trim_objects(detected_objects_str):
|
| 228 |
"""
|
| 229 |
Trim the last object from the detected objects string.
|
| 230 |
|
|
|
|
| 257 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
| 258 |
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
|
| 259 |
self.current_prompt_length = num_tokens
|
| 260 |
+
if self.current_prompt_length > self.max_context_window:
|
| 261 |
+
trim = True
|
| 262 |
+
st.warning(f"Prompt length is {self.current_prompt_length} which is larger than the maximum context window of LLaMA-2, objects detected with low confidence will be removed one at a time until the prompt length is within the maximum context window ...")
|
| 263 |
while self.current_prompt_length > self.max_context_window:
|
| 264 |
detected_objects_str = self.trim_objects(detected_objects_str)
|
| 265 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
|
|
|
| 267 |
|
| 268 |
if detected_objects_str == "":
|
| 269 |
break # Break if no objects are left
|
| 270 |
+
if trim:
|
| 271 |
+
st.warning(f"New prompt length is: {self.current_prompt_length}")
|
| 272 |
+
trim = False
|
| 273 |
|
| 274 |
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
|
| 275 |
free_gpu_resources()
|