Fix BitMarModel class and tensor shapes for main
Browse files- modeling_bitmar.py +16 -0
modeling_bitmar.py
CHANGED
|
@@ -1537,12 +1537,28 @@ class BitMarModel(PreTrainedModel):
|
|
| 1537 |
has_vision: Boolean tensor [batch_size] indicating which samples have real vision features
|
| 1538 |
"""
|
| 1539 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1540 |
|
| 1541 |
if input_ids is None:
|
| 1542 |
raise ValueError("input_ids must be provided")
|
| 1543 |
|
| 1544 |
batch_size, seq_len = input_ids.shape
|
| 1545 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1546 |
# Handle missing vision features
|
| 1547 |
if vision_features is None:
|
| 1548 |
vision_features = torch.zeros(batch_size, self.config.vision_encoder_dim,
|
|
|
|
| 1537 |
has_vision: Boolean tensor [batch_size] indicating which samples have real vision features
|
| 1538 |
"""
|
| 1539 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1540 |
+
|
| 1541 |
+
# CRITICAL FIX: Ensure input_ids are integers
|
| 1542 |
+
if input_ids.dtype != torch.long:
|
| 1543 |
+
input_ids = input_ids.long()
|
| 1544 |
+
|
| 1545 |
+
# CRITICAL FIX: Ensure labels are integers if provided
|
| 1546 |
+
if labels is not None and labels.dtype != torch.long:
|
| 1547 |
+
labels = labels.long()
|
| 1548 |
|
| 1549 |
if input_ids is None:
|
| 1550 |
raise ValueError("input_ids must be provided")
|
| 1551 |
|
| 1552 |
batch_size, seq_len = input_ids.shape
|
| 1553 |
|
| 1554 |
+
# Handle missing attention mask
|
| 1555 |
+
if attention_mask is None:
|
| 1556 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.float)
|
| 1557 |
+
|
| 1558 |
+
# Ensure attention_mask is float
|
| 1559 |
+
if attention_mask.dtype != torch.float:
|
| 1560 |
+
attention_mask = attention_mask.float()
|
| 1561 |
+
|
| 1562 |
# Handle missing vision features
|
| 1563 |
if vision_features is None:
|
| 1564 |
vision_features = torch.zeros(batch_size, self.config.vision_encoder_dim,
|