Upload folder using huggingface_hub
Browse files- modeling_sa2va_chat.py +4 -3
modeling_sa2va_chat.py
CHANGED
|
@@ -689,6 +689,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 689 |
input=text, round=1, bot_name=self.bot_name)
|
| 690 |
input_text = past_text + input_text
|
| 691 |
ids = self.tokenizer.encode(input_text)
|
|
|
|
| 692 |
ids = torch.tensor(ids).cuda().unsqueeze(0)
|
| 693 |
|
| 694 |
attention_mask = torch.ones_like(ids, dtype=torch.bool)
|
|
@@ -715,7 +716,8 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 715 |
)
|
| 716 |
predict = self.tokenizer.decode(
|
| 717 |
generate_output.sequences[0], skip_special_tokens=False).strip()
|
| 718 |
-
|
|
|
|
| 719 |
# if have seg result, find the seg hidden states
|
| 720 |
hidden_states = generate_output.hidden_states
|
| 721 |
last_hidden_states = [item[-1][0] for item in hidden_states]
|
|
@@ -737,8 +739,7 @@ class Sa2VAChatModel(PreTrainedModel):
|
|
| 737 |
masks = masks.sigmoid() > 0.5
|
| 738 |
masks = masks.cpu().numpy()
|
| 739 |
ret_masks.append(masks)
|
| 740 |
-
|
| 741 |
-
return {'prediction': predict, 'prediction_masks': ret_masks,}
|
| 742 |
|
| 743 |
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
| 744 |
seg_mask = output_ids == seg_id
|
|
|
|
| 689 |
input=text, round=1, bot_name=self.bot_name)
|
| 690 |
input_text = past_text + input_text
|
| 691 |
ids = self.tokenizer.encode(input_text)
|
| 692 |
+
ret_past_text = self.tokenizer.decode(ids)
|
| 693 |
ids = torch.tensor(ids).cuda().unsqueeze(0)
|
| 694 |
|
| 695 |
attention_mask = torch.ones_like(ids, dtype=torch.bool)
|
|
|
|
| 716 |
)
|
| 717 |
predict = self.tokenizer.decode(
|
| 718 |
generate_output.sequences[0], skip_special_tokens=False).strip()
|
| 719 |
+
ret_past_text = ret_past_text + self.tokenizer.decode(
|
| 720 |
+
generate_output.sequences[0], skip_special_tokens=False)
|
| 721 |
# if have seg result, find the seg hidden states
|
| 722 |
hidden_states = generate_output.hidden_states
|
| 723 |
last_hidden_states = [item[-1][0] for item in hidden_states]
|
|
|
|
| 739 |
masks = masks.sigmoid() > 0.5
|
| 740 |
masks = masks.cpu().numpy()
|
| 741 |
ret_masks.append(masks)
|
| 742 |
+
return {'prediction': predict, 'prediction_masks': ret_masks, "past_text": ret_past_text}
|
|
|
|
| 743 |
|
| 744 |
def get_seg_hidden_states(hidden_states, output_ids, seg_id):
|
| 745 |
seg_mask = output_ids == seg_id
|