orrzohar commited on
Commit
18bdd7b
·
1 Parent(s): 24580dc

Fix forward signature and video key filtering

Browse files
Files changed (1) hide show
  1. modeling_blip3o_qwen.py +6 -4
modeling_blip3o_qwen.py CHANGED
@@ -354,10 +354,6 @@ class blip3oMetaForCausalLM(ABC):
354
  text_embeds = text_embeds.clone()
355
  text_embeds[gen_mask] = latent_queries[:num_gen_tokens]
356
 
357
- if labels is not None:
358
- labels = labels.clone()
359
- labels[image_idx] = IGNORE_INDEX
360
-
361
  return None, position_ids, attention_mask, past_key_values, text_embeds, labels, target_image_embeds
362
 
363
  def initialize_vision_tokenizer(self, model_args, tokenizer):
@@ -453,6 +449,8 @@ class blip3oQwenForCausalLM(Qwen2_5_VLForConditionalGeneration, blip3oMetaForCau
453
  image_grid_thw: Optional[torch.Tensor] = None,
454
  return_dict: Optional[bool] = None,
455
  cache_position: Optional[torch.LongTensor] = None
 
 
456
  ) -> Union[Tuple, CausalLMOutputWithPast]:
457
  gen_image=gen_images
458
 
@@ -805,6 +803,10 @@ class blip3oQwenForCausalLM(Qwen2_5_VLForConditionalGeneration, blip3oMetaForCau
805
  inputs['images'] = images
806
  if image_sizes is not None:
807
  inputs['image_sizes'] = image_sizes
 
 
 
 
808
  return inputs
809
 
810
  AutoConfig.register("blip3o_qwen", blip3oQwenConfig)
 
354
  text_embeds = text_embeds.clone()
355
  text_embeds[gen_mask] = latent_queries[:num_gen_tokens]
356
 
 
 
 
 
357
  return None, position_ids, attention_mask, past_key_values, text_embeds, labels, target_image_embeds
358
 
359
  def initialize_vision_tokenizer(self, model_args, tokenizer):
 
449
  image_grid_thw: Optional[torch.Tensor] = None,
450
  return_dict: Optional[bool] = None,
451
  cache_position: Optional[torch.LongTensor] = None
452
+ ,
453
+ **kwargs
454
  ) -> Union[Tuple, CausalLMOutputWithPast]:
455
  gen_image=gen_images
456
 
 
803
  inputs['images'] = images
804
  if image_sizes is not None:
805
  inputs['image_sizes'] = image_sizes
806
+ # Filter out video-related keys from parent Qwen class
807
+ for key in list(inputs.keys()):
808
+ if "video" in key.lower():
809
+ inputs.pop(key, None)
810
  return inputs
811
 
812
  AutoConfig.register("blip3o_qwen", blip3oQwenConfig)