zhouyik commited on
Commit
0d41017
·
verified ·
1 Parent(s): 8f774d4

fix modeling_sa2va_chat.py bug

Browse files
Files changed (1) hide show
  1. modeling_sa2va_chat.py +2 -2
modeling_sa2va_chat.py CHANGED
@@ -807,7 +807,7 @@ class Sa2VAChatModel(PreTrainedModel):
807
  bg_cls_token_id = torch.as_tensor([self.bg_cls_token_idx,], dtype=ids.dtype, device=ids.device)
808
  bg_cls_embedding = self.language_model.get_input_embeddings()(bg_cls_token_id).clone()
809
  output_ids = generate_output.sequences[0][:-1]
810
- cls_token_mask = ids == self.cls_token_idx
811
 
812
  # get seg tokens
813
  seg_token_mask = (output_ids >= self.the_first_seg_token_idx) & (output_ids <= self.the_last_seg_token_idx)
@@ -841,7 +841,7 @@ class Sa2VAChatModel(PreTrainedModel):
841
  **m2f_inputs
842
  )
843
 
844
- tags = re.findall(r'<p>(.*?)</p>', predict)
845
  label_id_to_text = {id: tag for id, tag in enumerate(tags)}
846
 
847
  class_queries_logits = m2f_outputs.class_queries_logits
 
807
  bg_cls_token_id = torch.as_tensor([self.bg_cls_token_idx,], dtype=ids.dtype, device=ids.device)
808
  bg_cls_embedding = self.language_model.get_input_embeddings()(bg_cls_token_id).clone()
809
  output_ids = generate_output.sequences[0][:-1]
810
+ cls_token_mask = ids[0] == self.cls_token_idx
811
 
812
  # get seg tokens
813
  seg_token_mask = (output_ids >= self.the_first_seg_token_idx) & (output_ids <= self.the_last_seg_token_idx)
 
841
  **m2f_inputs
842
  )
843
 
844
+ tags = re.findall(r'<p>(.*?)</p>', input_text)
845
  label_id_to_text = {id: tag for id, tag in enumerate(tags)}
846
 
847
  class_queries_logits = m2f_outputs.class_queries_logits