fix modeling_sa2va_chat.py bug
Browse files- modeling_sa2va_chat.py +2 -2
modeling_sa2va_chat.py
CHANGED
|
@@ -807,7 +807,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 807 |
bg_cls_token_id = torch.as_tensor([self.bg_cls_token_idx,], dtype=ids.dtype, device=ids.device)
|
| 808 |
bg_cls_embedding = self.language_model.get_input_embeddings()(bg_cls_token_id).clone()
|
| 809 |
output_ids = generate_output.sequences[0][:-1]
|
| 810 |
-
cls_token_mask = ids == self.cls_token_idx
|
| 811 |
|
| 812 |
# get seg tokens
|
| 813 |
seg_token_mask = (output_ids >= self.the_first_seg_token_idx) & (output_ids <= self.the_last_seg_token_idx)
|
|
@@ -841,7 +841,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 841 |
**m2f_inputs
|
| 842 |
)
|
| 843 |
|
| 844 |
-
tags = re.findall(r'<p>(.*?)</p>',
|
| 845 |
label_id_to_text = {id: tag for id, tag in enumerate(tags)}
|
| 846 |
|
| 847 |
class_queries_logits = m2f_outputs.class_queries_logits
|
|
|
|
| 807 |
bg_cls_token_id = torch.as_tensor([self.bg_cls_token_idx,], dtype=ids.dtype, device=ids.device)
|
| 808 |
bg_cls_embedding = self.language_model.get_input_embeddings()(bg_cls_token_id).clone()
|
| 809 |
output_ids = generate_output.sequences[0][:-1]
|
| 810 |
+
cls_token_mask = ids[0] == self.cls_token_idx
|
| 811 |
|
| 812 |
# get seg tokens
|
| 813 |
seg_token_mask = (output_ids >= self.the_first_seg_token_idx) & (output_ids <= self.the_last_seg_token_idx)
|
|
|
|
| 841 |
**m2f_inputs
|
| 842 |
)
|
| 843 |
|
| 844 |
+
tags = re.findall(r'<p>(.*?)</p>', input_text)
|
| 845 |
label_id_to_text = {id: tag for id, tag in enumerate(tags)}
|
| 846 |
|
| 847 |
class_queries_logits = m2f_outputs.class_queries_logits
|