Spaces:
Sleeping
Sleeping
Update my_model/KBVQA.py
Browse files- my_model/KBVQA.py +27 -4
my_model/KBVQA.py
CHANGED
|
@@ -222,7 +222,22 @@ class KBVQA:
|
|
| 222 |
p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
|
| 223 |
|
| 224 |
return p
|
| 225 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
|
| 227 |
def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
|
| 228 |
"""
|
|
@@ -236,13 +251,21 @@ class KBVQA:
|
|
| 236 |
Returns:
|
| 237 |
str: The generated answer to the question.
|
| 238 |
"""
|
|
|
|
|
|
|
| 239 |
free_gpu_resources()
|
| 240 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
| 241 |
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
|
| 242 |
self.current_prompt_length = num_tokens
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
|
| 248 |
free_gpu_resources()
|
|
|
|
| 222 |
p = f"""{history}\n{B_SENT}{B_INST} {B_QES}{current_query}{E_QES}{E_INST}"""
|
| 223 |
|
| 224 |
return p
|
| 225 |
+
|
| 226 |
+
@staticmethod
|
| 227 |
+
def trim_objects(self, detected_objects_str):
|
| 228 |
+
"""
|
| 229 |
+
Trim the last object from the detected objects string.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
- detected_objects_str (str): String containing detected objects.
|
| 233 |
+
|
| 234 |
+
Returns:
|
| 235 |
+
- (str): The string with the last object removed.
|
| 236 |
+
"""
|
| 237 |
+
objects = detected_objects_str.strip().split("\n")
|
| 238 |
+
if len(objects) >= 1:
|
| 239 |
+
return "\n".join(objects[:-1])
|
| 240 |
+
return ""
|
| 241 |
|
| 242 |
def generate_answer(self, question: str, caption: str, detected_objects_str: str) -> str:
|
| 243 |
"""
|
|
|
|
| 251 |
Returns:
|
| 252 |
str: The generated answer to the question.
|
| 253 |
"""
|
| 254 |
+
|
| 255 |
+
|
| 256 |
free_gpu_resources()
|
| 257 |
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
| 258 |
num_tokens = len(self.kbvqa_tokenizer.tokenize(prompt))
|
| 259 |
self.current_prompt_length = num_tokens
|
| 260 |
+
|
| 261 |
+
while self.current_prompt_length > self.max_context_window:
|
| 262 |
+
detected_objects_str = self.trim_objects(detected_objects_str)
|
| 263 |
+
prompt = self.format_prompt(question, caption=caption, objects=detected_objects_str)
|
| 264 |
+
self.current_prompt_length = len(self.kbvqa_tokenizer.tokenize(prompt))
|
| 265 |
+
|
| 266 |
+
if detected_objects_str == "":
|
| 267 |
+
break # Break if no objects are left
|
| 268 |
+
|
| 269 |
|
| 270 |
model_inputs = self.kbvqa_tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to('cuda')
|
| 271 |
free_gpu_resources()
|