estebancarlin commited on
Commit
1cbb7a1
·
verified ·
1 Parent(s): 4e6d38e

Fix BitMarModel class and tensor shapes for main

Browse files
Files changed (1) hide show
  1. modeling_bitmar.py +16 -0
modeling_bitmar.py CHANGED
@@ -1537,12 +1537,28 @@ class BitMarModel(PreTrainedModel):
1537
  has_vision: Boolean tensor [batch_size] indicating which samples have real vision features
1538
  """
1539
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
 
 
 
 
 
 
1540
 
1541
  if input_ids is None:
1542
  raise ValueError("input_ids must be provided")
1543
 
1544
  batch_size, seq_len = input_ids.shape
1545
 
 
 
 
 
 
 
 
 
1546
  # Handle missing vision features
1547
  if vision_features is None:
1548
  vision_features = torch.zeros(batch_size, self.config.vision_encoder_dim,
 
1537
  has_vision: Boolean tensor [batch_size] indicating which samples have real vision features
1538
  """
1539
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1540
+
1541
+ # CRITICAL FIX: Ensure input_ids are integers
1542
+ if input_ids.dtype != torch.long:
1543
+ input_ids = input_ids.long()
1544
+
1545
+ # CRITICAL FIX: Ensure labels are integers if provided
1546
+ if labels is not None and labels.dtype != torch.long:
1547
+ labels = labels.long()
1548
 
1549
  if input_ids is None:
1550
  raise ValueError("input_ids must be provided")
1551
 
1552
  batch_size, seq_len = input_ids.shape
1553
 
1554
+ # Handle missing attention mask
1555
+ if attention_mask is None:
1556
+ attention_mask = torch.ones_like(input_ids, dtype=torch.float)
1557
+
1558
+ # Ensure attention_mask is float
1559
+ if attention_mask.dtype != torch.float:
1560
+ attention_mask = attention_mask.float()
1561
+
1562
  # Handle missing vision features
1563
  if vision_features is None:
1564
  vision_features = torch.zeros(batch_size, self.config.vision_encoder_dim,