support batchified-training
Browse files- modeling_chartmoe.py +25 -10
modeling_chartmoe.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
| 16 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
# See the License for the specific language governing permissions and
|
| 18 |
# limitations under the License.
|
| 19 |
-
"""PyTorch
|
| 20 |
import copy
|
| 21 |
import queue
|
| 22 |
import threading
|
|
@@ -262,15 +262,30 @@ class ChartMoEForCausalLM(InternLM2PreTrainedModel):
|
|
| 262 |
wrap_target = wrap_target[:, :self.max_length].to(self.device)
|
| 263 |
wrap_im_mask = wrap_im_mask[:, :self.max_length].to(self.device)
|
| 264 |
|
| 265 |
-
wrap_embeds_list.append(wrap_embeds)
|
| 266 |
-
wrap_atts_list.append(wrap_atts)
|
| 267 |
-
wrap_target_list.append(wrap_target)
|
| 268 |
-
wrap_im_mask_list.append(wrap_im_mask)
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
|
| 275 |
|
| 276 |
def mask_human_targets(self, input_ids, pure=False):
|
|
|
|
| 16 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 17 |
# See the License for the specific language governing permissions and
|
| 18 |
# limitations under the License.
|
| 19 |
+
"""PyTorch ChartMoE model."""
|
| 20 |
import copy
|
| 21 |
import queue
|
| 22 |
import threading
|
|
|
|
| 262 |
wrap_target = wrap_target[:, :self.max_length].to(self.device)
|
| 263 |
wrap_im_mask = wrap_im_mask[:, :self.max_length].to(self.device)
|
| 264 |
|
| 265 |
+
wrap_embeds_list.append(wrap_embeds.squeeze(dim=0))
|
| 266 |
+
wrap_atts_list.append(wrap_atts.squeeze(dim=0))
|
| 267 |
+
wrap_target_list.append(wrap_target.squeeze(dim=0))
|
| 268 |
+
wrap_im_mask_list.append(wrap_im_mask.squeeze(dim=0))
|
| 269 |
+
|
| 270 |
+
# get the embed of padding token
|
| 271 |
+
pad = torch.ones([1, 1]) * self.tokenizer.pad_token_id
|
| 272 |
+
pad = pad.long().to(self.device)
|
| 273 |
+
pad_emb = self.model.tok_embeddings(pad).squeeze(dim=0)
|
| 274 |
+
|
| 275 |
+
# pack and pad
|
| 276 |
+
longest_len = max(t.size(0) for t in wrap_embeds_list)
|
| 277 |
+
wrap_embeds = []
|
| 278 |
+
for wrap_embed in wrap_embeds_list:
|
| 279 |
+
pad_len = longest_len - wrap_embed.size(0)
|
| 280 |
+
if pad_len > 0:
|
| 281 |
+
pad_tensor = pad_emb.repeat(pad_len, 1)
|
| 282 |
+
wrap_embed = torch.cat([wrap_embed, pad_tensor], dim=0)
|
| 283 |
+
wrap_embeds.append(wrap_embed)
|
| 284 |
+
wrap_embeds = torch.stack(wrap_embeds, dim=0)
|
| 285 |
+
|
| 286 |
+
wrap_atts = nn.utils.rnn.pad_sequence(wrap_atts_list, batch_first=True, padding_value=0.)
|
| 287 |
+
wrap_target = nn.utils.rnn.pad_sequence(wrap_target_list, batch_first=True, padding_value=-100).long()
|
| 288 |
+
wrap_im_mask = nn.utils.rnn.pad_sequence(wrap_im_mask_list, batch_first=True, padding_value=0.)
|
| 289 |
return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
|
| 290 |
|
| 291 |
def mask_human_targets(self, input_ids, pure=False):
|