Coobiw commited on
Commit
0f06978
·
verified ·
1 Parent(s): 58240a8

support batchified-training

Browse files
Files changed (1) hide show
  1. modeling_chartmoe.py +25 -10
modeling_chartmoe.py CHANGED
@@ -16,7 +16,7 @@
16
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
  # See the License for the specific language governing permissions and
18
  # limitations under the License.
19
- """PyTorch InternLMXComposer2 model."""
20
  import copy
21
  import queue
22
  import threading
@@ -262,15 +262,30 @@ class ChartMoEForCausalLM(InternLM2PreTrainedModel):
262
  wrap_target = wrap_target[:, :self.max_length].to(self.device)
263
  wrap_im_mask = wrap_im_mask[:, :self.max_length].to(self.device)
264
 
265
- wrap_embeds_list.append(wrap_embeds)
266
- wrap_atts_list.append(wrap_atts)
267
- wrap_target_list.append(wrap_target)
268
- wrap_im_mask_list.append(wrap_im_mask)
269
-
270
- wrap_embeds = torch.cat(wrap_embeds_list)
271
- wrap_atts = torch.cat(wrap_atts_list)
272
- wrap_target = torch.cat(wrap_target_list)
273
- wrap_im_mask = torch.cat(wrap_im_mask_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
275
 
276
  def mask_human_targets(self, input_ids, pure=False):
 
16
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17
  # See the License for the specific language governing permissions and
18
  # limitations under the License.
19
+ """PyTorch ChartMoE model."""
20
  import copy
21
  import queue
22
  import threading
 
262
  wrap_target = wrap_target[:, :self.max_length].to(self.device)
263
  wrap_im_mask = wrap_im_mask[:, :self.max_length].to(self.device)
264
 
265
+ wrap_embeds_list.append(wrap_embeds.squeeze(dim=0))
266
+ wrap_atts_list.append(wrap_atts.squeeze(dim=0))
267
+ wrap_target_list.append(wrap_target.squeeze(dim=0))
268
+ wrap_im_mask_list.append(wrap_im_mask.squeeze(dim=0))
269
+
270
+ # get the embed of padding token
271
+ pad = torch.ones([1, 1]) * self.tokenizer.pad_token_id
272
+ pad = pad.long().to(self.device)
273
+ pad_emb = self.model.tok_embeddings(pad).squeeze(dim=0)
274
+
275
+ # pack and pad
276
+ longest_len = max(t.size(0) for t in wrap_embeds_list)
277
+ wrap_embeds = []
278
+ for wrap_embed in wrap_embeds_list:
279
+ pad_len = longest_len - wrap_embed.size(0)
280
+ if pad_len > 0:
281
+ pad_tensor = pad_emb.repeat(pad_len, 1)
282
+ wrap_embed = torch.cat([wrap_embed, pad_tensor], dim=0)
283
+ wrap_embeds.append(wrap_embed)
284
+ wrap_embeds = torch.stack(wrap_embeds, dim=0)
285
+
286
+ wrap_atts = nn.utils.rnn.pad_sequence(wrap_atts_list, batch_first=True, padding_value=0.)
287
+ wrap_target = nn.utils.rnn.pad_sequence(wrap_target_list, batch_first=True, padding_value=-100).long()
288
+ wrap_im_mask = nn.utils.rnn.pad_sequence(wrap_im_mask_list, batch_first=True, padding_value=0.)
289
  return wrap_embeds, wrap_atts, wrap_target, wrap_im_mask
290
 
291
  def mask_human_targets(self, input_ids, pure=False):