Buckets:
| """ | |
| LlavaEuroBertAudioForEmbedding: Qwen3VL vision + Qwen2.5-Omni audio + EuroBERT text. | |
| Architecture: | |
| - Vision: Qwen3VLVisionModel (with RoPE, 3D Conv3d patch embed, all layers) | |
| - Merger: PretrainedMerger (top-level, NOT inside vision_tower) | |
| - Audio: Qwen2_5OmniAudioEncoder (Qwen2.5-Omni) + Linear projector | |
| - Text: LlamaModel (EuroBERT, bidirectional) | |
| - LM head: Identity (embedding model, no vocab projection) | |
| Modality loading: | |
| model = AutoModel.from_pretrained(path, trust_remote_code=True, modality="omni") # all components (default) | |
| model = AutoModel.from_pretrained(path, trust_remote_code=True, modality="vision") # no audio tower/projector | |
| model = AutoModel.from_pretrained(path, trust_remote_code=True, modality="audio") # no vision tower/merger | |
| """ | |
| from typing import List, Optional, Union | |
| import torch | |
| import torch.nn as nn | |
| from transformers import LlamaConfig, PreTrainedModel, PretrainedConfig | |
| from transformers.modeling_outputs import BaseModelOutputWithPast | |
| from transformers.models.llama.modeling_llama import LlamaModel | |
| from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLVisionConfig | |
| from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel | |
| from transformers.models.qwen2_5_omni.configuration_qwen2_5_omni import Qwen2_5OmniAudioEncoderConfig | |
| from transformers.models.qwen2_5_omni.modeling_qwen2_5_omni import Qwen2_5OmniAudioEncoder | |
| _VALID_MODALITIES = ("omni", "vision", "audio", "text") | |
| class PretrainedMerger(nn.Module): | |
| def __init__(self, hidden_size, out_hidden_size, spatial_merge_size=2): | |
| super().__init__() | |
| self.hidden_size = hidden_size * (spatial_merge_size**2) | |
| self.norm = nn.LayerNorm(hidden_size, eps=1e-6) | |
| self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size) | |
| self.act = nn.GELU() | |
| self.linear_fc2 = nn.Linear(self.hidden_size, out_hidden_size) | |
| def forward(self, x): | |
| x = self.norm(x) | |
| x = x.view(-1, self.hidden_size) | |
| x = self.linear_fc2(self.act(self.linear_fc1(x))) | |
| return x | |
| class LlavaEuroBertAudioConfig(PretrainedConfig): | |
| model_type = "llava_eurobert_audio" | |
| def __init__( | |
| self, | |
| vision_config=None, | |
| text_config=None, | |
| audio_config=None, | |
| image_token_index=None, | |
| audio_token_id=None, | |
| audio_start_token_id=None, | |
| audio_end_token_id=None, | |
| projector_hidden_act="gelu", | |
| tie_word_embeddings=False, | |
| modality="omni", | |
| **kwargs, | |
| ): | |
| if isinstance(vision_config, dict): | |
| vision_config = PretrainedConfig(**vision_config) | |
| self.vision_config = vision_config or PretrainedConfig() | |
| if isinstance(text_config, dict): | |
| text_config = PretrainedConfig(**text_config) | |
| self.text_config = text_config or PretrainedConfig() | |
| if isinstance(audio_config, dict): | |
| audio_config = PretrainedConfig(**audio_config) | |
| self.audio_config = audio_config or PretrainedConfig() | |
| self.image_token_index = image_token_index | |
| self.audio_token_id = audio_token_id | |
| self.audio_start_token_id = audio_start_token_id | |
| self.audio_end_token_id = audio_end_token_id | |
| self.projector_hidden_act = projector_hidden_act | |
| if modality not in _VALID_MODALITIES: | |
| raise ValueError(f"modality must be one of {_VALID_MODALITIES}, got '{modality}'") | |
| self.modality = modality | |
| super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) | |
| def get_text_config(self, **kwargs): | |
| return self.text_config | |
| class LlavaEuroBertAudioForEmbedding(PreTrainedModel): | |
| config_class = LlavaEuroBertAudioConfig | |
| supports_gradient_checkpointing = True | |
| _supports_sdpa = True | |
| _supports_flash_attn_2 = True | |
| _supports_attention_backend = True | |
| _tied_weights_keys = [] | |
| _keys_to_ignore_on_load_missing = ["lm_head.weight"] | |
| _keys_to_ignore_on_load_unexpected = [] | |
| def __init__(self, config: LlavaEuroBertAudioConfig): | |
| super().__init__(config) | |
| modality = getattr(config, "modality", "omni") | |
| if modality not in _VALID_MODALITIES: | |
| raise ValueError(f"modality must be one of {_VALID_MODALITIES}, got '{modality}'") | |
| self._modality = modality | |
| vision_cfg = config.vision_config | |
| if not isinstance(vision_cfg, Qwen3VLVisionConfig): | |
| if hasattr(vision_cfg, "to_dict"): | |
| d = vision_cfg.to_dict() | |
| else: | |
| d = dict(vision_cfg) | |
| d.pop("model_type", None) | |
| d.pop("transformers_version", None) | |
| vision_cfg = Qwen3VLVisionConfig(**d) | |
| vision_cfg.deepstack_visual_indexes = [] | |
| spatial_merge_size = getattr(vision_cfg, "spatial_merge_size", 2) | |
| text_cfg = config.text_config | |
| if not isinstance(text_cfg, LlamaConfig): | |
| txt_dict = text_cfg.to_dict() if hasattr(text_cfg, 'to_dict') else dict(text_cfg) | |
| _saved_attn_impl = getattr(text_cfg, "_attn_implementation", None) | |
| text_cfg = LlamaConfig(**txt_dict) | |
| if _saved_attn_impl is not None: | |
| text_cfg._attn_implementation = _saved_attn_impl | |
| text_hidden = text_cfg.hidden_size | |
| self._spatial_merge_size = spatial_merge_size | |
| self._vision_hidden_size = getattr(vision_cfg, "hidden_size", 768) | |
| if modality not in ("audio", "text"): | |
| self.vision_tower = Qwen3VLVisionModel(vision_cfg) | |
| self.vision_tower.merger = nn.Identity() | |
| self.vision_tower.deepstack_merger_list = nn.ModuleList() | |
| self.vision_tower.deepstack_visual_indexes = [] | |
| self.merger = PretrainedMerger( | |
| vision_cfg.hidden_size, text_hidden, spatial_merge_size | |
| ) | |
| self.multi_modal_projector = nn.Identity() | |
| self.language_model = LlamaModel(text_cfg) | |
| self.lm_head = nn.Identity() | |
| for layer in self.language_model.layers: | |
| layer.self_attn.is_causal = False | |
| if modality not in ("vision", "text"): | |
| aud_cfg = config.audio_config | |
| aud_dict = aud_cfg.to_dict() if hasattr(aud_cfg, 'to_dict') else aud_cfg | |
| audio_encoder_config = Qwen2_5OmniAudioEncoderConfig(**aud_dict) | |
| self.audio_tower = Qwen2_5OmniAudioEncoder(audio_encoder_config) | |
| output_dim = aud_dict.get('output_dim', 3584) | |
| self.audio_projector = nn.Linear(output_dim, text_hidden) | |
| ignore = [] | |
| if modality in ("audio", "text"): | |
| ignore.extend([r"^vision_tower\.", r"^merger\."]) | |
| if modality in ("vision", "text"): | |
| ignore.extend([r"^audio_tower\.", r"^audio_projector\."]) | |
| if ignore: | |
| self._keys_to_ignore_on_load_unexpected = ignore | |
| self.post_init() | |
| def modality(self) -> str: | |
| return self._modality | |
| def get_input_embeddings(self): | |
| return self.language_model.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.language_model.embed_tokens = value | |
| def get_output_embeddings(self): | |
| return None | |
| def get_image_features( | |
| self, | |
| pixel_values: torch.FloatTensor, | |
| image_grid_thw: torch.LongTensor, | |
| num_image_tokens: Optional[int] = None, | |
| ) -> List[torch.Tensor]: | |
| if self._modality in ("audio", "text"): | |
| raise ValueError( | |
| f"Vision inputs are not available in {self._modality}-only mode. " | |
| "Load with modality='omni' or modality='vision'." | |
| ) | |
| vision_output = self.vision_tower( | |
| hidden_states=pixel_values, grid_thw=image_grid_thw | |
| ) | |
| if isinstance(vision_output, tuple): | |
| raw_hidden = vision_output[0] | |
| elif hasattr(vision_output, "pooler_output") and vision_output.pooler_output is not None: | |
| raw_hidden = vision_output.pooler_output | |
| else: | |
| raw_hidden = vision_output[0] | |
| image_features = self.merger(raw_hidden) | |
| merge_sq = self._spatial_merge_size ** 2 | |
| split_sizes = (image_grid_thw.prod(-1) // merge_sq).tolist() | |
| return list(torch.split(image_features, split_sizes)) | |
| def get_audio_features( | |
| self, | |
| input_features: torch.FloatTensor, | |
| feature_attention_mask: Optional[torch.LongTensor] = None, | |
| ) -> torch.Tensor: | |
| if self._modality in ("vision", "text"): | |
| raise ValueError( | |
| f"Audio inputs are not available in {self._modality}-only mode. " | |
| "Load with modality='omni' or modality='audio'." | |
| ) | |
| batch_size = input_features.shape[0] | |
| if batch_size > 1: | |
| # Serialize per-sample so the packed-frames GEMM shape stays invariant | |
| # across batch sizes. Makes batched audio bit-exact to B=1 in bf16, | |
| # and is substantially faster for B>=16 because B=1 hits a | |
| # well-optimized kernel while the packed-B=N path thrashes on a | |
| # (total_frames)^2 sdpa matrix. | |
| outs = [ | |
| self.get_audio_features( | |
| input_features[i : i + 1], | |
| feature_attention_mask[i : i + 1] if feature_attention_mask is not None else None, | |
| ) | |
| for i in range(batch_size) | |
| ] | |
| return torch.cat(outs, dim=0) | |
| if feature_attention_mask is not None: | |
| feature_lens = feature_attention_mask.sum(-1).long() | |
| packed = input_features.permute(0, 2, 1)[feature_attention_mask.bool()].permute(1, 0) | |
| else: | |
| feature_lens = torch.full( | |
| (batch_size,), input_features.shape[2], | |
| device=input_features.device, dtype=torch.long, | |
| ) | |
| packed = input_features.transpose(1, 2).reshape(-1, input_features.shape[1]).T | |
| aftercnn_lens, _ = self.audio_tower._get_feat_extract_output_lengths(feature_lens) | |
| audio_output = self.audio_tower( | |
| packed, feature_lens=feature_lens, aftercnn_lens=aftercnn_lens, | |
| ) | |
| return self.audio_projector(audio_output.last_hidden_state) | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values=None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| input_features: Optional[torch.FloatTensor] = None, | |
| feature_attention_mask: Optional[torch.LongTensor] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| **kwargs, | |
| ): | |
| image_grid_thw = kwargs.pop("image_grid_thw", None) | |
| num_image_tokens = kwargs.pop("num_image_tokens", None) | |
| kwargs.pop("spatial_shapes", None) | |
| kwargs.pop("pixel_attention_mask", None) | |
| if pixel_values is not None and self._modality in ("audio", "text"): | |
| raise ValueError( | |
| f"Vision inputs are not available in {self._modality}-only mode. " | |
| "Load with modality='omni' or modality='vision'." | |
| ) | |
| if input_features is not None and self._modality in ("vision", "text"): | |
| raise ValueError( | |
| f"Audio inputs are not available in {self._modality}-only mode. " | |
| "Load with modality='omni' or modality='audio'." | |
| ) | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError( | |
| "You must specify exactly one of input_ids or inputs_embeds" | |
| ) | |
| if inputs_embeds is None: | |
| inputs_embeds = self.get_input_embeddings()(input_ids) | |
| if pixel_values is not None and image_grid_thw is not None: | |
| image_features = self.get_image_features( | |
| pixel_values=pixel_values, | |
| image_grid_thw=image_grid_thw, | |
| num_image_tokens=num_image_tokens, | |
| ) | |
| image_features = torch.cat(image_features, dim=0).to( | |
| inputs_embeds.device, inputs_embeds.dtype | |
| ) | |
| special_image_mask = ( | |
| (input_ids == self.config.image_token_index) | |
| .unsqueeze(-1) | |
| .expand_as(inputs_embeds) | |
| ) | |
| inputs_embeds = inputs_embeds.masked_scatter( | |
| special_image_mask, image_features | |
| ) | |
| if input_features is not None: | |
| audio_embeds = self.get_audio_features( | |
| input_features, feature_attention_mask | |
| ) | |
| audio_embeds_flat = audio_embeds.reshape( | |
| -1, audio_embeds.shape[-1] | |
| ).to(inputs_embeds.device, inputs_embeds.dtype) | |
| audio_mask = ( | |
| (input_ids == self.config.audio_token_id) | |
| .unsqueeze(-1) | |
| .expand_as(inputs_embeds) | |
| ) | |
| inputs_embeds = inputs_embeds.masked_scatter( | |
| audio_mask, audio_embeds_flat | |
| ) | |
| if attention_mask is not None and attention_mask.dim() == 2: | |
| dtype = inputs_embeds.dtype | |
| seq_len = inputs_embeds.shape[1] | |
| bidi_mask = attention_mask[:, None, None, :].to(dtype=dtype) | |
| bidi_mask = (1.0 - bidi_mask) * torch.finfo(dtype).min | |
| attention_mask = bidi_mask.expand(-1, -1, seq_len, -1) | |
| # vLLM's transformers backend passes `return_dict=False` + `attention_instances`. | |
| # Force dict-style output internally, and forward remaining kwargs so the | |
| # vllm attention hook receives its `attention_instances` dict. | |
| kwargs.pop("return_dict", None) | |
| outputs = self.language_model( | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_values=past_key_values, | |
| inputs_embeds=inputs_embeds, | |
| cache_position=cache_position, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=True, | |
| **kwargs, | |
| ) | |
| hidden_states = outputs[0] | |
| logits = self.lm_head(hidden_states) | |
| return BaseModelOutputWithPast( | |
| last_hidden_state=logits, | |
| past_key_values=outputs.past_key_values, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) | |
| def _register_vllm() -> None: | |
| import importlib.util as _iu | |
| if _iu.find_spec("vllm") is None: | |
| return | |
| try: | |
| import os, sys, importlib, shutil | |
| pkg = __package__ or "" | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| sibling_name = "vllm_llava_eurobert_audio" | |
| sibling_path = os.path.join(current_dir, sibling_name + ".py") | |
| if not os.path.exists(sibling_path): | |
| parts = pkg.split(".") | |
| if len(parts) >= 4 and parts[0] == "transformers_modules": | |
| from huggingface_hub import hf_hub_download | |
| repo_name = parts[2].replace("_hyphen_", "-").replace("_dot_", ".") | |
| repo_id = f"{parts[1]}/{repo_name}" | |
| downloaded = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=sibling_name + ".py", | |
| revision=parts[3], | |
| ) | |
| shutil.copy(downloaded, sibling_path) | |
| if current_dir not in sys.path: | |
| sys.path.insert(0, current_dir) | |
| existing = os.environ.get("PYTHONPATH", "") | |
| if current_dir not in existing.split(os.pathsep): | |
| os.environ["PYTHONPATH"] = ( | |
| current_dir if not existing else current_dir + os.pathsep + existing | |
| ) | |
| if pkg: | |
| _lla = importlib.import_module("." + sibling_name, package=pkg) | |
| else: | |
| _lla = importlib.import_module(sibling_name) | |
| from vllm import ModelRegistry | |
| ModelRegistry.register_model( | |
| "LlavaEuroBertAudioForEmbedding", | |
| _lla.LlavaEuroBertAudioForVLLMEmbedding, | |
| ) | |
| except Exception as e: | |
| import warnings | |
| warnings.warn( | |
| f"jina-embeddings-v5-omni nano: vLLM registration failed " | |
| f"({type(e).__name__}: {e}); falling back to Transformers backend.", | |
| stacklevel=2, | |
| ) | |
| _register_vllm() | |
Xet Storage Details
- Size:
- 16.5 kB
- Xet hash:
- 53497c1532da6bf9bf977310202aec9a9f177a804e5745f356ea8a7100bb7a12
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.