Buckets:
| """Sentence-transformers integration for jina-embeddings-v5-omni-nano (base + LoRA). | |
| Supports text, image, video, and audio with per-task adapter routing: | |
| from sentence_transformers import SentenceTransformer | |
| model = SentenceTransformer( | |
| "jinaai/jina-embeddings-v5-omni-nano", | |
| trust_remote_code=True, | |
| model_kwargs={"default_task": "retrieval"}, | |
| ) | |
| q = model.encode("What is ML?", prompt_name="query") | |
| d = model.encode("ML is ...", prompt_name="document") | |
| img = model.encode(Image.open("photo.jpg")) | |
| vid = model.encode("clip.mp4") | |
| aud = model.encode("speech.wav") | |
| """ | |
| import json | |
| import os | |
| from typing import Any, Dict, List, Optional, Union | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from transformers import AutoConfig, AutoModel, AutoTokenizer | |
| MAX_SEQ_LENGTH = 8192 | |
| IMAGE_PROMPT = "<image>" | |
| VIDEO_PROMPT = "<image>" | |
| AUDIO_EXTENSIONS = {".wav", ".mp3", ".flac", ".ogg", ".m4a", ".opus", ".webm"} | |
| VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".flv", ".wmv"} | |
| PDF_EXTENSIONS = {".pdf"} | |
| SVG_EXTENSIONS = {".svg"} | |
| PDF_DPI = 150 | |
| TASK_NAMES = ["retrieval", "text-matching", "clustering", "classification"] | |
| EVAL_IMAGE_MIN_PIXELS = 262144 | |
| EVAL_IMAGE_MAX_PIXELS = 1310720 | |
| EVAL_VIDEO_MAX_PIXELS = 12845056 | |
| EVAL_VIDEO_NUM_FRAMES = 32 | |
| def _pil_image(): | |
| """Return the PIL.Image module, with a clean ImportError if pillow is not | |
| installed. Wrapped in `try` so transformers' AST-based `check_imports` | |
| does not list PIL as a top-level required dependency: text-only and | |
| audio-only users should not need pillow installed. | |
| """ | |
| try: | |
| from PIL import Image as _PILImage | |
| except ImportError as e: | |
| raise ImportError( | |
| "Encoding images or rasterising PDFs needs `pip install pillow`." | |
| ) from e | |
| return _PILImage | |
| def _is_image(x) -> bool: | |
| try: | |
| from PIL import Image as PILImage | |
| return isinstance(x, PILImage.Image) | |
| except ImportError: | |
| return False | |
| def _is_video_path(x) -> bool: | |
| if not isinstance(x, str): | |
| return False | |
| return any(x.lower().endswith(ext) for ext in VIDEO_EXTENSIONS) | |
| def _is_audio_path(x) -> bool: | |
| if not isinstance(x, str): | |
| return False | |
| return any(x.lower().endswith(ext) for ext in AUDIO_EXTENSIONS) | |
| def _is_pdf_path(x) -> bool: | |
| if not isinstance(x, str): | |
| return False | |
| return any(x.lower().endswith(ext) for ext in PDF_EXTENSIONS) | |
| def _is_svg_path(x) -> bool: | |
| if not isinstance(x, str): | |
| return False | |
| return any(x.lower().split("?", 1)[0].endswith(ext) for ext in SVG_EXTENSIONS) | |
| def _is_audio_array(x) -> bool: | |
| try: | |
| import numpy as np | |
| except ImportError: | |
| return False | |
| return isinstance(x, np.ndarray) and x.ndim == 1 and np.issubdtype(x.dtype, np.floating) | |
| class _AudioWrapper: | |
| def __init__(self, array, sampling_rate: int = 16000): | |
| self.array = array | |
| self.sampling_rate = sampling_rate | |
| def _download_if_url(x): | |
| """If x is an http(s) URL, download to a hashed local cache and return the | |
| local path. Otherwise return x unchanged. | |
| """ | |
| if not isinstance(x, str): | |
| return x | |
| if not (x.startswith("http://") or x.startswith("https://")): | |
| return x | |
| import hashlib, os, tempfile, urllib.request | |
| from urllib.parse import urlparse | |
| cache = os.path.join(tempfile.gettempdir(), "jina_omni_media_cache") | |
| os.makedirs(cache, exist_ok=True) | |
| h = hashlib.sha256(x.encode("utf-8")).hexdigest()[:16] | |
| url_path = urlparse(x).path | |
| _, ext = os.path.splitext(url_path) | |
| local = os.path.join(cache, f"{h}{ext}" if ext else h) | |
| if not os.path.isfile(local) or os.path.getsize(local) == 0: | |
| urllib.request.urlretrieve(x, local) | |
| return local | |
| def _looks_like_svg(data): | |
| if not data: | |
| return False | |
| head = data[:4096].lstrip().lower() | |
| return b"<svg" in head | |
| def _svg_to_image(svg): | |
| try: | |
| import cairosvg | |
| except ImportError as e: | |
| raise ImportError("Encoding SVG images needs `pip install cairosvg pillow`.") from e | |
| import io | |
| png = cairosvg.svg2png(bytestring=svg if isinstance(svg, (bytes, bytearray)) else None, | |
| url=svg if isinstance(svg, str) else None) | |
| _PILImage = _pil_image() | |
| return _PILImage.open(io.BytesIO(png)).convert("RGB") | |
| def _sniff_media_type_bytes(head): | |
| """Return 'image'/'svg'/'video'/'audio'/'pdf'/None from content headers.""" | |
| if _looks_like_svg(head): | |
| return "svg" | |
| if not head or len(head) < 8: | |
| return None | |
| if head[:3] == b"\xff\xd8\xff": return "image" | |
| if head[:8] == b"\x89PNG\r\n\x1a\n": return "image" | |
| if head[:6] in (b"GIF87a", b"GIF89a"): return "image" | |
| if head[:4] == b"RIFF" and head[8:12] == b"WEBP": return "image" | |
| if head[:2] == b"BM": return "image" | |
| if head[:4] in (b"II*\x00", b"MM\x00*"): return "image" | |
| if head[4:12] in (b"ftypavif", b"ftypavis"): return "image" | |
| if head[4:12] in (b"ftypheic", b"ftypheix", b"ftypmif1", b"ftypmsf1"): | |
| return "image" | |
| if head[:3] == b"ID3": return "audio" | |
| if head[:2] in (b"\xff\xfb", b"\xff\xf3", b"\xff\xf2"): return "audio" | |
| if head[:4] == b"fLaC": return "audio" | |
| if head[:4] == b"OggS": return "audio" | |
| if head[:4] == b"RIFF" and head[8:12] == b"WAVE": return "audio" | |
| if head[4:12] in (b"M4A ", b"M4B ", b"M4P "): return "audio" | |
| if head[:4] == b"\x1a\x45\xdf\xa3": return "video" | |
| if head[4:8] == b"ftyp": return "video" | |
| if head[:4] == b"RIFF" and head[8:12] == b"AVI ": return "video" | |
| if head[:3] == b"FLV": return "video" | |
| if head[:4] == b"0&\xb2u": return "video" | |
| if head[:5] == b"%PDF-": return "pdf" | |
| return None | |
| def _sniff_media_type(path): | |
| try: | |
| with open(path, "rb") as f: | |
| data = f.read(4096) | |
| kind = _sniff_media_type_bytes(data) | |
| if kind is None and _is_svg_path(path): | |
| return "svg" | |
| return kind | |
| except OSError: | |
| return None | |
| def _resolve_input(x): | |
| """Normalize any input to (kind, value). Accepts: | |
| - PIL.Image -> image | |
| - np.ndarray HxWx3 uint8 -> image (via PIL.fromarray) | |
| - np.ndarray TxHxWx3 uint8 -> video (saved to /tmp via imageio) | |
| - np.ndarray 1-D float -> audio | |
| - np.ndarray 2-D float (C,N) or (N,C) -> audio (mono mixdown) | |
| - torch.Tensor -> converted to numpy, recurse | |
| - bytes / io.IOBase -> sniff + route | |
| - str URL -> downloaded + routed | |
| - str path -> content-sniffed + routed | |
| - str -> text | |
| """ | |
| import os as _os | |
| import io | |
| if _is_image(x): | |
| return ("image", x) | |
| if _is_audio_array(x): | |
| return ("audio", x) | |
| try: | |
| import numpy as _np | |
| except ImportError: | |
| _np = None | |
| if _np is not None and isinstance(x, _np.ndarray): | |
| # Image (H,W,3|4) uint8 | |
| if x.ndim == 3 and x.shape[-1] in (3, 4) and x.dtype == _np.uint8: | |
| _PILImage = _pil_image() | |
| mode = "RGBA" if x.shape[-1] == 4 else "RGB" | |
| return ("image", _PILImage.fromarray(x, mode).convert("RGB")) | |
| # Video (T,H,W,3|4) uint8 | |
| if x.ndim == 4 and x.shape[-1] in (3, 4) and x.dtype == _np.uint8: | |
| # Pass frames straight to the processor — no mp4 round-trip, no | |
| # av/imageio needed. Drop alpha if present. | |
| return ("video", x if x.shape[-1] == 3 else x[..., :3]) | |
| # Audio multichannel 2D float -> mono mixdown | |
| if x.ndim == 2 and _np.issubdtype(x.dtype, _np.floating): | |
| audio = x.mean(axis=0 if x.shape[0] <= 8 else 1).astype(_np.float32) | |
| return ("audio", audio) | |
| # torch.Tensor -> numpy and recurse | |
| try: | |
| import torch as _torch | |
| except ImportError: | |
| _torch = None | |
| if _torch is not None and isinstance(x, _torch.Tensor): | |
| return _resolve_input(x.detach().cpu().numpy()) | |
| # bytes / BytesIO / file-like | |
| if isinstance(x, (bytes, bytearray)): | |
| data = bytes(x) | |
| elif isinstance(x, io.IOBase): | |
| data = x.read() | |
| else: | |
| data = None | |
| if data is not None: | |
| kind = _sniff_media_type_bytes(data[:4096]) | |
| if kind == "image": | |
| _PILImage = _pil_image() | |
| return ("image", _PILImage.open(io.BytesIO(data)).convert("RGB")) | |
| if kind == "svg": | |
| return ("image", _svg_to_image(bytes(data))) | |
| if kind in ("video", "audio"): | |
| import tempfile as _tf | |
| ext = ".mp4" if kind == "video" else ".wav" | |
| tf = _tf.NamedTemporaryFile(suffix=ext, delete=False) | |
| tf.write(data); tf.close() | |
| return (kind, tf.name) | |
| if kind == "pdf": | |
| # pypdfium2 reads bytes directly — no temp file needed. | |
| return ("pdf", bytes(data)) | |
| if isinstance(x, str): | |
| local = _download_if_url(x) | |
| if _os.path.isfile(local): | |
| kind = _sniff_media_type(local) | |
| if kind == "image": | |
| _PILImage = _pil_image() | |
| return ("image", _PILImage.open(local).convert("RGB")) | |
| if kind == "svg": | |
| return ("image", _svg_to_image(local)) | |
| if kind in ("video", "audio", "pdf"): | |
| return (kind, local) | |
| return ("text", x) | |
| return ("text", str(x)) | |
| def _is_media_string(x) -> bool: | |
| if not isinstance(x, str): | |
| return False | |
| return _resolve_input(x)[0] in ("image", "video", "audio", "pdf") | |
| def _prompt_from_kwargs(st_model, kwargs): | |
| prompt = kwargs.get("prompt") | |
| if prompt is None: | |
| prompt_name = kwargs.get("prompt_name") or getattr(st_model, "default_prompt_name", None) | |
| prompt = (getattr(st_model, "prompts", {}) or {}).get(prompt_name, "") if prompt_name else "" | |
| return prompt or "" | |
| def _raw_media_parts(st_model, value, kwargs): | |
| prompt = _prompt_from_kwargs(st_model, kwargs) | |
| return (prompt, value) if prompt else (value,) | |
| def _prompted_parts(st_model, value, kwargs): | |
| parts = value if isinstance(value, tuple) else (value,) | |
| prompt = _prompt_from_kwargs(st_model, kwargs) | |
| return (prompt, *parts) if prompt else parts | |
| def _align_eval_processor(processor): | |
| video_processor = getattr(processor, "video_processor", None) | |
| if video_processor is None: | |
| return | |
| if hasattr(video_processor, "do_sample_frames"): | |
| video_processor.do_sample_frames = False | |
| for attr in ("max_frames", "num_frames"): | |
| if hasattr(video_processor, attr): | |
| setattr(video_processor, attr, EVAL_VIDEO_NUM_FRAMES) | |
| if hasattr(video_processor, "size") and isinstance(video_processor.size, dict): | |
| video_processor.size = { | |
| **video_processor.size, | |
| "longest_edge": EVAL_VIDEO_MAX_PIXELS, | |
| "shortest_edge": EVAL_IMAGE_MIN_PIXELS, | |
| } | |
| if hasattr(video_processor, "max_pixels"): | |
| video_processor.max_pixels = EVAL_VIDEO_MAX_PIXELS | |
| if hasattr(video_processor, "min_pixels"): | |
| video_processor.min_pixels = EVAL_IMAGE_MIN_PIXELS | |
| def _build_eval_image_prompt(processor, prefix: str = ""): | |
| image_token = getattr(processor, "image_token", IMAGE_PROMPT) | |
| text = f"{prefix or ''}<|vision_start|>{image_token}<|vision_end|>" | |
| try: | |
| return processor.apply_chat_template( | |
| [{"role": "user", "content": text}], | |
| tokenize=False, | |
| add_generation_prompt=False, | |
| ) | |
| except (ValueError, AttributeError): | |
| return f"{prefix or ''}{IMAGE_PROMPT}" | |
| def _audio_output_length(feature_attention_mask): | |
| real_frames = feature_attention_mask.sum(-1) | |
| aftercnn = (real_frames - 1) // 2 + 1 | |
| return int(((aftercnn - 2) // 2 + 1).item()) | |
| def _load_audio_array(audio_input): | |
| import numpy as np | |
| if isinstance(audio_input, _AudioWrapper): | |
| return audio_input.array.astype(np.float32), audio_input.sampling_rate | |
| if isinstance(audio_input, str): | |
| try: | |
| import librosa | |
| except ImportError as e: | |
| raise ImportError( | |
| "Loading audio from a file path needs `pip install librosa`" | |
| " (or pass a 1-D numpy float32 waveform at 16 kHz)." | |
| ) from e | |
| audio, sr = librosa.load(audio_input, sr=16000) | |
| return audio.astype(np.float32), sr | |
| if isinstance(audio_input, np.ndarray): | |
| return audio_input.astype(np.float32), 16000 | |
| raise TypeError(f"Unsupported audio input type: {type(audio_input)}") | |
| def _build_audio_model_inputs(owner, audio_input, device, prefix: str = ""): | |
| import numpy as np | |
| from transformers import WhisperFeatureExtractor | |
| audio, sr = _load_audio_array(audio_input) | |
| if not np.isfinite(audio).all(): | |
| audio = np.nan_to_num(audio, nan=0.0, posinf=0.0, neginf=0.0) | |
| peak = float(np.max(np.abs(audio))) if audio.size else 0.0 | |
| if peak > 1.0: | |
| audio = audio / peak | |
| feat_ext = WhisperFeatureExtractor(feature_size=128) | |
| audio_inputs = feat_ext( | |
| audio, | |
| sampling_rate=sr, | |
| return_tensors="pt", | |
| padding="max_length", | |
| return_attention_mask=True, | |
| ) | |
| input_features = audio_inputs["input_features"] | |
| feature_attention_mask = audio_inputs["attention_mask"] | |
| n_tokens = _audio_output_length(feature_attention_mask) | |
| start = owner.tokenizer.convert_ids_to_tokens(owner.config.audio_start_token_id) | |
| token = owner.tokenizer.convert_ids_to_tokens(owner.config.audio_token_id) | |
| end = owner.tokenizer.convert_ids_to_tokens(owner.config.audio_end_token_id) | |
| audio_seq = start + token * n_tokens + end | |
| text = f"{prefix or ''}{audio_seq}" | |
| try: | |
| prompt = owner.processor.apply_chat_template( | |
| [{"role": "user", "content": text}], | |
| tokenize=False, | |
| add_generation_prompt=False, | |
| ) | |
| except (ValueError, AttributeError): | |
| prompt = text | |
| out = owner.processor(text=[prompt], return_tensors="pt", padding=False, truncation=False) | |
| model_dtype = next(owner.model.parameters()).dtype | |
| inputs = {k: v.to(device) for k, v in out.items() if torch.is_tensor(v)} | |
| inputs["input_features"] = input_features.to(device=device, dtype=model_dtype) | |
| inputs["feature_attention_mask"] = feature_attention_mask.to(device) | |
| pos_builder = globals().get("_get_1d_position_ids") | |
| if pos_builder is not None: | |
| inputs["position_ids"] = pos_builder(inputs["attention_mask"]) | |
| return inputs | |
| def _extract_audio_from_video(video_path): | |
| """Return mono float32 audio @ 16 kHz decoded from the video's audio track, or | |
| None if no audio stream is present. PyAV is already a dep for video decoding.""" | |
| try: | |
| import av | |
| import numpy as np | |
| from av.audio.resampler import AudioResampler | |
| except ImportError: | |
| return None | |
| container = av.open(video_path) | |
| try: | |
| audio_stream = next((s for s in container.streams if s.type == "audio"), None) | |
| if audio_stream is None: | |
| return None | |
| resampler = AudioResampler(format="flt", layout="mono", rate=16000) | |
| samples = [] | |
| for frame in container.decode(audio=0): | |
| for rf in resampler.resample(frame): | |
| samples.append(rf.to_ndarray().flatten()) | |
| for rf in resampler.resample(None): | |
| samples.append(rf.to_ndarray().flatten()) | |
| if not samples: | |
| return None | |
| return np.concatenate(samples).astype(np.float32) | |
| finally: | |
| container.close() | |
| def _eval_video_frames(video_path): | |
| if not isinstance(video_path, str): | |
| return video_path | |
| try: | |
| import av | |
| import numpy as np | |
| except ImportError: | |
| return video_path | |
| container = av.open(video_path) | |
| try: | |
| frames = [frame.to_image().convert("RGB") for frame in container.decode(video=0)] | |
| finally: | |
| container.close() | |
| if not frames: | |
| return video_path | |
| if len(frames) <= EVAL_VIDEO_NUM_FRAMES: | |
| return frames | |
| indices = np.linspace(0, len(frames) - 1, EVAL_VIDEO_NUM_FRAMES, dtype=int).tolist() | |
| return [frames[i] for i in indices] | |
| def _pdf_to_images(pdf, dpi: int = PDF_DPI): | |
| """Rasterise every page of a PDF to a list of PIL.Image (RGB). | |
| `pdf` may be a path, raw bytes, BytesIO, or an existing list of PIL.Images | |
| (returned as-is). Lazy-imports `pypdfium2` so users who never touch PDFs | |
| are not forced to install it. | |
| """ | |
| _PILImage = _pil_image() # PIL is a hard dep of the image path | |
| if isinstance(pdf, list) and pdf and all(isinstance(p, _PILImage.Image) for p in pdf): | |
| return pdf | |
| try: | |
| import pypdfium2 as pdfium | |
| except ImportError as e: | |
| raise ImportError( | |
| "Decoding PDF pages needs `pip install pypdfium2`." | |
| ) from e | |
| import io as _io | |
| if isinstance(pdf, (bytes, bytearray)): | |
| doc = pdfium.PdfDocument(bytes(pdf)) | |
| elif isinstance(pdf, _io.IOBase): | |
| doc = pdfium.PdfDocument(pdf.read()) | |
| else: | |
| doc = pdfium.PdfDocument(pdf) | |
| scale = dpi / 72.0 | |
| pages = [] | |
| try: | |
| for page in doc: | |
| pil = page.render(scale=scale).to_pil().convert("RGB") | |
| pages.append(pil) | |
| finally: | |
| doc.close() | |
| return pages | |
| def _patch_st_encode_multipart(): | |
| """Intercept ST.encode for multipart tuple inputs so PIL.Image and | |
| np.ndarray media parts bypass ST's length-sort.""" | |
| import importlib | |
| import torch | |
| try: | |
| st_mod = importlib.import_module("sentence_transformers.SentenceTransformer") | |
| except ImportError: | |
| return | |
| _ST = st_mod.SentenceTransformer | |
| if getattr(_ST.encode, "_omni_multipart_patched", False): | |
| return | |
| _orig = _ST.encode | |
| def _encode(self, sentences, *args, **kwargs): | |
| def _is_nonstring_input(x): | |
| # anything other than a pure string becomes a 1-part multipart item | |
| return not isinstance(x, str) | |
| single_bare = _is_nonstring_input(sentences) and not isinstance(sentences, list) | |
| list_with_nonstr = (isinstance(sentences, list) and sentences | |
| and any(_is_nonstring_input(s) for s in sentences)) | |
| single_media_string = isinstance(sentences, str) and _is_media_string(sentences) | |
| list_with_media_string = (isinstance(sentences, list) and sentences | |
| and any(isinstance(s, str) and _is_media_string(s) for s in sentences)) | |
| fwd_keys = getattr(self[0], "forward_kwargs", set()) | |
| forward_kwargs = {k: kwargs[k] for k in fwd_keys if k in kwargs} | |
| if single_media_string or list_with_media_string: | |
| if single_media_string: | |
| batch = [_raw_media_parts(self, sentences, kwargs)] | |
| else: | |
| batch = [_raw_media_parts(self, s, kwargs) for s in sentences] | |
| features = {"_multipart_batch": batch, "_is_multipart_batch": True} | |
| with torch.no_grad(): | |
| out = self[0](features, **forward_kwargs) | |
| emb = out["sentence_embedding"] | |
| if kwargs.get("convert_to_numpy", True): | |
| emb = emb.detach().cpu().float().numpy() | |
| if single_media_string: | |
| emb = emb[0] if hasattr(emb, "__getitem__") else emb | |
| return emb | |
| if single_bare or list_with_nonstr: | |
| if single_bare: | |
| batch = [_prompted_parts(self, sentences, kwargs)] | |
| else: | |
| batch = [_prompted_parts(self, s, kwargs) for s in sentences] | |
| features = {"_multipart_batch": batch, "_is_multipart_batch": True} | |
| with torch.no_grad(): | |
| out = self[0](features, **forward_kwargs) | |
| emb = out["sentence_embedding"] | |
| if kwargs.get("convert_to_numpy", True): | |
| emb = emb.detach().cpu().float().numpy() | |
| if single_bare: | |
| emb = emb[0] if hasattr(emb, "__getitem__") else emb | |
| return emb | |
| result = _orig(self, sentences, *args, **kwargs) | |
| # ST 5.x applies truncate_dim without L2 renormalization; the README | |
| # promises unit-norm truncated embeddings, so restore that here. | |
| if kwargs.get("truncate_dim") is not None and not kwargs.get("normalize_embeddings", False): | |
| import numpy as _np | |
| if torch.is_tensor(result): | |
| result = torch.nn.functional.normalize(result, p=2, dim=-1) | |
| elif isinstance(result, _np.ndarray): | |
| n = _np.linalg.norm(result, axis=-1, keepdims=True) + 1e-12 | |
| result = result / n | |
| return result | |
| _encode._omni_multipart_patched = True | |
| _ST.encode = _encode | |
| def encode_query(self, sentences, *args, **kwargs): | |
| kwargs.setdefault("prompt_name", "query") | |
| return self.encode(sentences, *args, **kwargs) | |
| def encode_document(self, sentences, *args, **kwargs): | |
| kwargs.setdefault("prompt_name", "document") | |
| return self.encode(sentences, *args, **kwargs) | |
| _ST.encode_query = encode_query | |
| _ST.encode_document = encode_document | |
| _patch_st_encode_multipart() | |
| class Transformer(nn.Module): | |
| save_in_root: bool = True | |
| # Tells sentence-transformers to thread these kwargs from encode() through | |
| # to our forward() — otherwise ST filters unknown kwargs out. | |
| forward_kwargs = {"task", "truncate_dim"} | |
| def __init__( | |
| self, | |
| model_name_or_path: str = "jinaai/jina-embeddings-v5-omni-nano", | |
| max_seq_length: Optional[int] = None, | |
| config_args: Optional[Dict[str, Any]] = None, | |
| model_args: Optional[Dict[str, Any]] = None, | |
| tokenizer_args: Optional[Dict[str, Any]] = None, | |
| cache_dir: Optional[str] = None, | |
| backend: str = "torch", | |
| task: Optional[str] = None, | |
| default_task: Optional[str] = None, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__() | |
| if backend != "torch": | |
| raise ValueError( | |
| f"Backend '{backend}' is not supported, please use 'torch' instead" | |
| ) | |
| config_kwargs = dict(config_args or {}) | |
| model_kwargs = dict(model_args or {}) | |
| tokenizer_kwargs = dict(tokenizer_args or {}) | |
| # Default-task resolution precedence (highest to lowest): | |
| # 1. `task` / `default_task` kwarg to this __init__ | |
| # 2. `model_args={'default_task': ...}` (legacy path) | |
| # 3. JINA_V5_TASK env var | |
| # 4. unset -> encode() must pass task= | |
| self.default_task = ( | |
| task | |
| or default_task | |
| or model_kwargs.pop("default_task", None) | |
| or os.environ.get("JINA_V5_TASK") | |
| ) | |
| if self.default_task and self.default_task not in TASK_NAMES: | |
| raise ValueError( | |
| f"Invalid task: {self.default_task}. Must be one of {TASK_NAMES}." | |
| ) | |
| # setdefault so caller-provided trust_remote_code isn't duplicated | |
| config_kwargs.setdefault("trust_remote_code", True) | |
| model_kwargs.setdefault("trust_remote_code", True) | |
| tokenizer_kwargs.setdefault("trust_remote_code", True) | |
| # Dedupe cache_dir: we pass it explicitly below, so strip any copy | |
| # that sentence-transformers may have also threaded through *_args. | |
| for _kw in (config_kwargs, model_kwargs, tokenizer_kwargs): | |
| _kw.pop("cache_dir", None) | |
| self.config = AutoConfig.from_pretrained( | |
| model_name_or_path, cache_dir=cache_dir, **config_kwargs | |
| ) | |
| self.model = AutoModel.from_pretrained( | |
| model_name_or_path, cache_dir=cache_dir, **model_kwargs, | |
| ) | |
| self.tokenizer = self.model.tokenizer | |
| # AutoProcessor pulls in PIL transitively; lazy-import so users on | |
| # text-only setups (no pillow installed) can still load the model. | |
| try: | |
| from transformers import AutoProcessor as _AutoProcessor | |
| processor_kwargs = dict(tokenizer_kwargs) | |
| processor_kwargs.setdefault("min_pixels", EVAL_IMAGE_MIN_PIXELS) | |
| processor_kwargs.setdefault("max_pixels", EVAL_IMAGE_MAX_PIXELS) | |
| self.processor = _AutoProcessor.from_pretrained( | |
| model_name_or_path, cache_dir=cache_dir, **processor_kwargs, | |
| ) | |
| _align_eval_processor(self.processor) | |
| except Exception: | |
| self.processor = None | |
| tc = getattr(self.config, "text_config", self.config) | |
| max_pos = getattr(tc, "max_position_embeddings", MAX_SEQ_LENGTH) | |
| self.max_seq_length = max_seq_length or min(max_pos, MAX_SEQ_LENGTH) | |
| def tokenize( | |
| self, | |
| texts: Union[List[str], List[Dict], list], | |
| padding: Union[str, bool] = True, | |
| **kwargs, | |
| ) -> Dict[str, torch.Tensor]: | |
| if texts and any(isinstance(t, tuple) for t in texts): | |
| # Wrap non-tuple entries as 1-tuples so every batch slot goes | |
| # through _encode_parts. Lets users mix: [(t,img), "plain text"]. | |
| wrapped = [t if isinstance(t, tuple) else (t,) for t in texts] | |
| return {"_multipart_batch": wrapped, "_is_multipart_batch": True} | |
| resolved = [_resolve_input(t) for t in texts] | |
| # Heterogeneous batch (e.g. ["speech.wav", "plain text"]) — route through | |
| # the multipart path where each element is dispatched on its own kind. | |
| if len({k for k, _ in resolved}) > 1: | |
| wrapped = [t if isinstance(t, tuple) else (t,) for t in texts] | |
| return {"_multipart_batch": wrapped, "_is_multipart_batch": True} | |
| first_kind = resolved[0][0] | |
| values = [v for _, v in resolved] | |
| if first_kind == "image": | |
| return {"_images": values, "_is_image_batch": True} | |
| if first_kind == "video": | |
| return {"_video_paths": values, "_is_video_batch": True} | |
| if first_kind == "audio": | |
| return {"_audio_paths": values, "_is_audio_batch": True} | |
| if first_kind == "pdf": | |
| return {"_pdfs": values, "_is_pdf_batch": True} | |
| if isinstance(texts[0], dict): | |
| texts = [next(iter(t.values())) for t in texts] | |
| elif isinstance(texts[0], (list, tuple)): | |
| texts = [t[0] for t in texts] | |
| return self.tokenizer( | |
| [str(s) for s in texts], | |
| max_length=self.max_seq_length, | |
| truncation=True, | |
| padding=padding, | |
| return_tensors="pt", | |
| ) | |
| def _resolve_task(self, task: Optional[str]) -> str: | |
| if task is None: | |
| if self.default_task is None: | |
| raise ValueError( | |
| "Task must be specified. Set it during loading " | |
| "(model_kwargs={'default_task': 'retrieval'}) or pass " | |
| "task='retrieval' to encode()." | |
| ) | |
| task = self.default_task | |
| if task not in TASK_NAMES: | |
| raise ValueError(f"Invalid task: {task}. Must be one of {TASK_NAMES}.") | |
| return task | |
| def _last_token_pool(self, hidden, attention_mask): | |
| seq_lens = attention_mask.sum(dim=1) - 1 | |
| pooled = hidden[torch.arange(hidden.shape[0], device=hidden.device), seq_lens] | |
| return F.normalize(pooled, p=2, dim=-1).float() | |
| def _encode_single_image(self, image, device, prefix: str = "") -> torch.Tensor: | |
| prompt = _build_eval_image_prompt(self.processor, prefix=prefix) | |
| inputs = self.processor(images=image, text=prompt, return_tensors="pt", truncation=False) | |
| inputs = {k: v.to(device) for k, v in inputs.items() if torch.is_tensor(v)} | |
| with torch.no_grad(): | |
| hidden = self.model(**inputs).last_hidden_state | |
| return self._last_token_pool(hidden, inputs["attention_mask"]).squeeze(0) | |
| def _encode_single_video(self, video_path, device) -> torch.Tensor: | |
| video = _eval_video_frames(video_path) | |
| inputs = self.processor(videos=video, text=VIDEO_PROMPT, return_tensors="pt", truncation=False) | |
| inputs = {k: v.to(device) for k, v in inputs.items() if torch.is_tensor(v)} | |
| with torch.no_grad(): | |
| hidden = self.model(**inputs).last_hidden_state | |
| return self._last_token_pool(hidden, inputs["attention_mask"]).squeeze(0) | |
| def _encode_single_audio(self, audio_input, device, prefix: str = "") -> torch.Tensor: | |
| inputs = _build_audio_model_inputs(self, audio_input, device, prefix=prefix) | |
| with torch.no_grad(): | |
| hidden = self.model(**inputs).last_hidden_state | |
| return self._last_token_pool(hidden, inputs["attention_mask"]).squeeze(0) | |
| def _encode_single_pdf(self, pdf, device) -> torch.Tensor: | |
| """Encode a PDF as a fused sequence of page images (single embedding). | |
| Pages are rasterised with pypdfium2 then fed through the same | |
| multipart fusion path used for tuples — so a 3-page PDF produces | |
| a single embedding spanning all three rendered pages. | |
| """ | |
| pages = _pdf_to_images(pdf) | |
| if not pages: | |
| raise ValueError("PDF has 0 pages — nothing to encode.") | |
| return self._encode_parts(tuple(pages), device) | |
| def _encode_composite_parts(self, expanded, device) -> torch.Tensor: | |
| import numpy as np | |
| from transformers import WhisperFeatureExtractor | |
| content = [] | |
| images, videos = [], [] | |
| audio_features, feature_masks = [], [] | |
| feat_ext = None | |
| for kind, p in expanded: | |
| if kind == "text": | |
| content.append({"type": "text", "text": str(p)}) | |
| elif kind == "image": | |
| content.append({"type": "image"}) | |
| images.append(p) | |
| elif kind == "video": | |
| content.append({"type": "video"}) | |
| videos.append(_eval_video_frames(p) if isinstance(p, str) else p) | |
| elif kind == "audio": | |
| if feat_ext is None: | |
| feat_ext = WhisperFeatureExtractor(feature_size=128) | |
| audio_arr, sr = _load_audio_array(p) | |
| if not np.isfinite(audio_arr).all(): | |
| audio_arr = np.nan_to_num(audio_arr, nan=0.0, posinf=0.0, neginf=0.0) | |
| peak = float(np.max(np.abs(audio_arr))) if audio_arr.size else 0.0 | |
| if peak > 1.0: | |
| audio_arr = audio_arr / peak | |
| audio_inputs = feat_ext( | |
| audio_arr, | |
| sampling_rate=sr, | |
| return_tensors="pt", | |
| padding="max_length", | |
| return_attention_mask=True, | |
| ) | |
| feat_mask = audio_inputs["attention_mask"] | |
| n_tokens = _audio_output_length(feat_mask) | |
| start = self.tokenizer.convert_ids_to_tokens(self.config.audio_start_token_id) | |
| token = self.tokenizer.convert_ids_to_tokens(self.config.audio_token_id) | |
| end = self.tokenizer.convert_ids_to_tokens(self.config.audio_end_token_id) | |
| content.append({"type": "text", "text": start + token * n_tokens + end}) | |
| audio_features.append(audio_inputs["input_features"]) | |
| feature_masks.append(feat_mask) | |
| has_chat_template = getattr(self.processor, "chat_template", None) is not None | |
| if has_chat_template: | |
| prompt = self.processor.apply_chat_template( | |
| [{"role": "user", "content": content}], | |
| tokenize=False, | |
| add_generation_prompt=False, | |
| ) | |
| if images or videos: | |
| image_token = getattr(self.processor, "image_token", "<|image_pad|>") | |
| video_token = getattr(self.processor, "video_token", "<|video_pad|>") | |
| flat = [] | |
| for c in content: | |
| if c.get("type") == "text": | |
| flat.append(c["text"]) | |
| elif c.get("type") == "image": | |
| flat.append(f"<|vision_start|>{image_token}<|vision_end|>") | |
| elif c.get("type") == "video": | |
| flat.append(f"<|vision_start|>{video_token}<|vision_end|>") | |
| prompt_flat = self.processor.apply_chat_template( | |
| [{"role": "user", "content": "".join(flat)}], | |
| tokenize=False, | |
| add_generation_prompt=False, | |
| ) | |
| if "<|vision_start|>" in prompt_flat: | |
| prompt = prompt_flat | |
| else: | |
| pieces = [] | |
| for c in content: | |
| if c.get("type") == "text": | |
| pieces.append(c["text"]) | |
| elif c.get("type") == "image": | |
| pieces.append(IMAGE_PROMPT) | |
| elif c.get("type") == "video": | |
| pieces.append(VIDEO_PROMPT) | |
| prompt = "".join(pieces) | |
| proc_kwargs = {"text": [prompt], "return_tensors": "pt", "padding": False, "truncation": False} | |
| if images: | |
| proc_kwargs["images"] = images | |
| if videos: | |
| proc_kwargs["videos"] = videos | |
| out = self.processor(**proc_kwargs) | |
| model_dtype = next(self.model.parameters()).dtype | |
| inputs = {k: v.to(device) if torch.is_tensor(v) else v for k, v in out.items()} | |
| if audio_features: | |
| inputs["input_features"] = torch.cat(audio_features, dim=0).to(device=device, dtype=model_dtype) | |
| inputs["feature_attention_mask"] = torch.cat(feature_masks, dim=0).to(device) | |
| if "Qwen" in type(self.processor).__name__: | |
| ids = inputs["input_ids"].squeeze(0) | |
| mm_ids = torch.zeros_like(ids, dtype=torch.int32) | |
| image_token_id = self.processor.tokenizer.convert_tokens_to_ids(getattr(self.processor, "image_token", "<image>")) | |
| video_token_id = self.processor.tokenizer.convert_tokens_to_ids(getattr(self.processor, "video_token", "<video>")) | |
| audio_token_id = self.processor.tokenizer.convert_tokens_to_ids(self.tokenizer.convert_ids_to_tokens(self.config.audio_token_id)) | |
| mm_ids += (ids == image_token_id).to(torch.int32) | |
| mm_ids += 2 * (ids == video_token_id).to(torch.int32) | |
| mm_ids += 3 * (ids == audio_token_id).to(torch.int32) | |
| inputs["mm_token_type_ids"] = mm_ids.unsqueeze(0) | |
| mask = inputs["attention_mask"] | |
| pos = mask.long().cumsum(-1) - 1 | |
| pos = pos.masked_fill(mask == 0, 0) | |
| inputs["position_ids"] = pos.unsqueeze(0).expand(3, -1, -1).contiguous() | |
| else: | |
| pos_builder = globals().get("_get_1d_position_ids") | |
| if pos_builder is not None: | |
| inputs["position_ids"] = pos_builder(inputs["attention_mask"]) | |
| with torch.no_grad(): | |
| hidden = self.model(**inputs).last_hidden_state | |
| return self._last_token_pool(hidden, inputs["attention_mask"]).squeeze(0) | |
| def _encode_parts(self, parts, device) -> torch.Tensor: | |
| """Fuse a tuple of parts into one embedding in a single forward pass. | |
| Each part may be a URL, a local path (sniffed by magic bytes if no | |
| extension), a PIL.Image, a 1-D numpy audio array, a PDF (rasterised | |
| to one image per page), or plain text. A video with an audio track | |
| is auto-expanded to [extracted_audio, video] so the audio tokens | |
| precede the video tokens. | |
| """ | |
| import numpy as np | |
| from transformers import WhisperFeatureExtractor | |
| # Normalize every part first (URL -> path, content-sniff if needed). | |
| resolved = [_resolve_input(p) for p in parts] | |
| # Expand videos-with-audio: prepend extracted audio. | |
| # Expand PDFs: rasterise into one image-part per page. | |
| expanded = [] | |
| for kind, value in resolved: | |
| if kind == "video": | |
| if isinstance(value, str): | |
| aud = _extract_audio_from_video(value) | |
| if aud is not None and aud.size > 0: | |
| expanded.append(("audio", aud)) | |
| expanded.append(("video", value)) | |
| elif kind == "pdf": | |
| for page in _pdf_to_images(value): | |
| expanded.append(("image", page)) | |
| else: | |
| expanded.append((kind, value)) | |
| ids_chunks, mask_chunks = [], [] | |
| pix_images, img_grid = [], [] | |
| pix_videos, vid_grid = [], [] | |
| audio_features = [] | |
| feat_ext = None | |
| if len(expanded) == 1 and expanded[0][0] == "image": | |
| return self._encode_single_image(expanded[0][1], device) | |
| if len(expanded) == 1 and expanded[0][0] == "audio": | |
| return self._encode_single_audio(expanded[0][1], device) | |
| if len(expanded) == 2 and expanded[0][0] == "text" and expanded[1][0] == "image": | |
| return self._encode_single_image(expanded[1][1], device, prefix=str(expanded[0][1])) | |
| if len(expanded) == 2 and expanded[0][0] == "text" and expanded[1][0] == "audio": | |
| return self._encode_single_audio(expanded[1][1], device, prefix=str(expanded[0][1])) | |
| return self._encode_composite_parts(expanded, device) | |
| def forward( | |
| self, | |
| features: Dict[str, torch.Tensor], | |
| task: Optional[str] = None, | |
| truncate_dim: Optional[int] = None, | |
| **kwargs, | |
| ) -> Dict[str, torch.Tensor]: | |
| self.model.eval() | |
| device = next(self.model.parameters()).device | |
| task = self._resolve_task(task) | |
| self.model.set_adapter([task]) | |
| if features.get("_is_multipart_batch"): | |
| embs = [self._encode_parts(parts, device) for parts in features["_multipart_batch"]] | |
| features["sentence_embedding"] = torch.stack(embs) | |
| return self._maybe_truncate(features, truncate_dim) | |
| if features.get("_is_image_batch"): | |
| embs = [self._encode_single_image(img, device) for img in features["_images"]] | |
| features["sentence_embedding"] = torch.stack(embs) | |
| return self._maybe_truncate(features, truncate_dim) | |
| if features.get("_is_video_batch"): | |
| embs = [self._encode_single_video(p, device) for p in features["_video_paths"]] | |
| features["sentence_embedding"] = torch.stack(embs) | |
| return self._maybe_truncate(features, truncate_dim) | |
| if features.get("_is_audio_batch"): | |
| embs = [self._encode_single_audio(p, device) for p in features["_audio_paths"]] | |
| features["sentence_embedding"] = torch.stack(embs) | |
| return self._maybe_truncate(features, truncate_dim) | |
| if features.get("_is_pdf_batch"): | |
| embs = [self._encode_single_pdf(p, device) for p in features["_pdfs"]] | |
| features["sentence_embedding"] = torch.stack(embs) | |
| return self._maybe_truncate(features, truncate_dim) | |
| batch = {k: v.to(device) for k, v in features.items() if torch.is_tensor(v)} | |
| with torch.no_grad(): | |
| hidden = self.model(**batch).last_hidden_state | |
| features["sentence_embedding"] = self._last_token_pool(hidden, batch["attention_mask"]) | |
| return self._maybe_truncate(features, truncate_dim) | |
| def _maybe_truncate(features, truncate_dim): | |
| # Slicing an L2-normalized vector and renormalizing is equivalent to | |
| # truncate-then-normalize on the raw pooled vector — so this produces a | |
| # unit-norm matryoshka embedding. | |
| if truncate_dim is not None: | |
| emb = features["sentence_embedding"][..., :truncate_dim] | |
| features["sentence_embedding"] = F.normalize(emb, p=2, dim=-1) | |
| return features | |
| def get_word_embedding_dimension(self) -> int: | |
| tc = getattr(self.config, "text_config", self.config) | |
| return getattr(tc, "hidden_size", 768) | |
| def get_sentence_embedding_dimension(self) -> int: | |
| return self.get_word_embedding_dimension() | |
| def get_max_seq_length(self) -> int: | |
| return self.max_seq_length | |
| def save(self, output_path: str, safe_serialization: bool = True, **kwargs) -> None: | |
| self.model.save_pretrained(output_path, safe_serialization=safe_serialization) | |
| self.tokenizer.save_pretrained(output_path) | |
| config = {"max_seq_length": self.max_seq_length} | |
| with open(os.path.join(output_path, "sentence_bert_config.json"), "w") as f: | |
| json.dump(config, f, indent=2) | |
| def load(cls, input_path: str) -> "Transformer": | |
| # Signature must have exactly 1 param so ST routes through the direct | |
| # constructor path (which maps model_kwargs -> model_args correctly). | |
| config_path = os.path.join(input_path, "sentence_bert_config.json") | |
| extra = {} | |
| if os.path.exists(config_path): | |
| with open(config_path) as f: | |
| extra = json.load(f) | |
| return cls(model_name_or_path=input_path, **extra) | |
Xet Storage Details
- Size:
- 41.7 kB
- Xet hash:
- 4aa0b1a738c53a3178b93897645e4e813625c08a8d4fe8c91a8a0ff3495c2892
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.