"""Internal runtime support for the LocateAnything-3B hybrid batch decoder. This file keeps only the model-loading, tokenization, image-encoding, stock processor, and sample-token helpers that ``engine_hybrid.py`` needs. Important env knobs: LA_FLASH_MODEL HF repo id / local path of the model (default nvidia/LocateAnything-3B) HF_HUB_OFFLINE=1 read the local HF cache only (no network); unset -> download on first use LA_FLASH_ATTN sdpa, eager, magi, or la_flash; la_flash uses FlashAttention sparse ranges LA_FLASH_STRICT_ATTN 1 -> fail if the requested backend is unavailable; default 0 falls back to sdpa LA_FLASH_VISION_ATTN auto, flash_attention_2, sdpa, or eager (default auto) LA_FLASH_HYBRID_PREFILL shared, none, per_row, or batch prompt KV prefill (default shared) MTP_BATCH_VISION 0 -> per-image vision encode (default 1: batched when flash is present) LA_FLASH_VISION_ENCODE_BATCH_SIZE max images per MoonViT encode micro-batch (default 8; <=0 disables limit) MTP_BATCH_SAN 0 -> per-row logits/sample pipeline (default 1: batched over [B,6,V]) AR_BATCH_SAN 0 -> per-row AR sample pipeline (default 1: batched over [B,1,V]) """ import inspect import os, warnings, importlib, torch from types import SimpleNamespace import numpy as np from transformers import AutoModel, AutoTokenizer, AutoProcessor # By default let transformers fetch the model on first use; set HF_HUB_OFFLINE=1 yourself # to read the local HF cache only (e.g. air-gapped / already-downloaded runs). MODEL = os.environ.get("LA_FLASH_MODEL", "nvidia/LocateAnything-3B") LLM_ATTN_MODES = ("sdpa", "eager", "magi", "la_flash") VISION_ATTN_MODES = ("auto", "flash_attention_2", "sdpa", "eager") def _normalize_attn_mode(value): mode = (value or "sdpa").strip().lower().replace("-", "_") aliases = { "": "sdpa", "manual": "eager", "torch": "eager", "torch_eager": "eager", "torch_sdpa": "sdpa", "scaled_dot_product_attention": "sdpa", "flash": "la_flash", "la_flash": "la_flash", "kernel": "la_flash", "cuda": "la_flash", "range": "la_flash", "range_attention": "la_flash", "flex_flash": "magi", "flex_flash_attention": "magi", "flex_flash_attn": "magi", } mode = aliases.get(mode, mode) if mode not in LLM_ATTN_MODES: raise ValueError( f"LA_FLASH_ATTN must be one of {', '.join(LLM_ATTN_MODES)}; got {value!r}" ) return mode def _normalize_vision_attn_mode(value): mode = (value or "auto").strip().lower().replace("-", "_") aliases = { "": "auto", "flash": "flash_attention_2", "flash_attention2": "flash_attention_2", "fa2": "flash_attention_2", "manual": "eager", } mode = aliases.get(mode, mode) if mode not in VISION_ATTN_MODES: raise ValueError( f"LA_FLASH_VISION_ATTN must be one of {', '.join(VISION_ATTN_MODES)}; got {value!r}" ) return mode ATTN_MODE = _normalize_attn_mode(os.environ.get("LA_FLASH_ATTN", "sdpa")) REMOTE_ATTN_MODE = "sdpa" if ATTN_MODE in {"la_flash", "magi"} else ATTN_MODE VISION_ATTN_MODE = _normalize_vision_attn_mode(os.environ.get("LA_FLASH_VISION_ATTN", "auto")) MAX_DIM = 1024 DEV, DT = "cuda", torch.bfloat16 N_FUTURE = 6 # = config.block_size (MTP window) _PROMPT = "Locate all the instances that matches the following description: " def _env_flag(name, default=False): val = os.environ.get(name) if val is None: return default return val.strip().lower() not in {"0", "false", "no", "off"} def _env_int(name): val = os.environ.get(name) if val is None or val.strip() == "": return None return int(val) def _strict_attn(): return _env_flag("LA_FLASH_STRICT_ATTN", False) def _fallback_to_sdpa(model, requested, reason): if requested == "sdpa": raise RuntimeError(f"LA_FLASH_ATTN=sdpa failed: {reason}") from reason message = f"LA_FLASH_ATTN={requested} is unavailable; falling back to sdpa. Reason: {reason}" if _strict_attn(): raise RuntimeError(message) from reason warnings.warn(message) _set_llm_mode(model, "sdpa") model._la_flash_requested_attn_original = requested model._la_flash_attn_fallback_reason = str(reason) return "sdpa" # Optional compile for the shared Qwen2 core. This is off by default because the # hybrid scheduler already varies query/cache shapes and first-call compile cost is high. MTP_COMPILE = os.environ.get("MTP_COMPILE", "0") == "1" # Batch the MoonViT vision encode across a micro-batch's images: pack N images into ONE # extract_feature. With flash present, MoonViT's varlen cu_seqlens path is block-diagonal per # image and equivalent to per-image encode. # Without flash, sdpa builds a dense [1,S,S] mask -> O(S^2) N^2 -> per-image fallback (auto, see # _vision_is_flash). Default ON; set MTP_BATCH_VISION=0 to force per-image. BATCH_VISION = os.environ.get("MTP_BATCH_VISION", "1") == "1" _vision_encode_batch_size = _env_int("LA_FLASH_VISION_ENCODE_BATCH_SIZE") VISION_ENCODE_BATCH_SIZE = 8 if _vision_encode_batch_size is None else max(0, _vision_encode_batch_size) # Batch the per-row box-decode (sample_tokens): run the row-independent logits pipeline # (rep-penalty / per-row temperature / top_p / top_k / softmax / sample) ONCE over the whole # [B,6,V] step instead of B times on [1,6,V]; only the variable-length box assembly stays per-row. # Greedy is BIT-IDENTICAL to the per-row san (argmax, no RNG). Default ON; MTP_BATCH_SAN=0 -> per-row. BATCH_SAN = os.environ.get("MTP_BATCH_SAN", "1") == "1" # Batch the AR repair sampler over [B,1,V]. This shares the exact filtering # helpers with MTP batching but skips box/ref decoding, so it only replaces the # repeated stock one-token sample calls. Sampling itself stays row-ordered by # default to preserve the stock RNG consumption pattern for AR repair. AR_BATCH_SAN = os.environ.get("AR_BATCH_SAN", "1") == "1" _tok = _proc = _model = None def _magi_diag(): lines = [] try: import magi_attention lines.append(f"magi_attention: OK file={getattr(magi_attention, '__file__', None)}") lines.append(f"magi_attention.__version__={getattr(magi_attention, '__version__', '')}") except Exception as e: lines.append(f"magi_attention: FAIL {type(e).__name__}: {e}") return "\n".join(lines) try: from magi_attention.functional.flex_flash_attn import flex_flash_attn_func lines.append(f"magi_attention.functional.flex_flash_attn: OK func={flex_flash_attn_func}") except Exception as e: lines.append(f"magi_attention.functional.flex_flash_attn: FAIL {type(e).__name__}: {e}") return "\n".join(lines) def _remote_magi_diag(model=None): lines = [] try: if model is not None: mod = importlib.import_module(type(model.language_model.model).__module__) else: # Best effort: if the dynamic module is not imported yet this may fail; # the post-load diagnostic below will still work. mod = importlib.import_module("transformers_modules.LocateAnything-3B.modeling_qwen2") lines.append(f"remote_qwen2_module={getattr(mod, '__file__', None)}") lines.append(f"remote_qwen2._MAGI_AVAILABLE={getattr(mod, '_MAGI_AVAILABLE', '')!r}") lines.append(f"remote_qwen2.flex_flash_attn_func={getattr(mod, 'flex_flash_attn_func', '')}") except Exception as e: lines.append(f"remote_qwen2: diagnostic failed {type(e).__name__}: {e}") return "\n".join(lines) def _attn_class_diag(model): try: llm = model.language_model.model classes = [type(layer.self_attn).__name__ for layer in llm.layers[:4]] return ( f"llm._attn_implementation={getattr(llm, '_attn_implementation', None)!r}\n" f"config._attn_implementation={getattr(llm.config, '_attn_implementation', None)!r}\n" f"first_attn_classes={classes}" ) except Exception as e: return f"attention class diagnostic failed {type(e).__name__}: {e}" def _set_vision_attention_mode(model): """Match HF's MoonViT policy: prefer flash_attention_2, then sdpa, then eager.""" vm = getattr(model, "vision_model", None) if vm is None: return None mod = importlib.import_module(type(vm).__module__) funcs = getattr(mod, "VL_VISION_ATTENTION_FUNCTIONS", {}) has_flash = getattr(mod, "flash_attn_varlen_func", None) is not None requested = VISION_ATTN_MODE if requested == "auto": candidates = ("flash_attention_2", "sdpa", "eager") else: candidates = (requested, "flash_attention_2", "sdpa", "eager") chosen = None for candidate in candidates: if candidate == "flash_attention_2" and not has_flash: continue if candidate in funcs: chosen = candidate break if chosen is None: raise RuntimeError("MoonViT has no supported attention implementation.") if requested == "flash_attention_2" and chosen != "flash_attention_2": warnings.warn("LA_FLASH_VISION_ATTN=flash_attention_2 requested but flash-attn is unavailable; " f"using {chosen}.") elif requested not in {"auto", chosen}: warnings.warn(f"LA_FLASH_VISION_ATTN={requested} is unavailable; using {chosen}.") if hasattr(model.config, "vision_config"): model.config.vision_config._attn_implementation = chosen try: vm.config._attn_implementation = chosen except Exception: pass try: for block in vm.encoder.blocks: block.attn_implementation = chosen except Exception as exc: raise RuntimeError("Failed to configure MoonViT attention implementation.") from exc model._la_flash_vision_attn = chosen return chosen def load(): """Lazy model load with HF remote-code semantics plus release backends. The text decoder is pinned to one of sdpa/eager/magi/la_flash. MoonViT is configured independently and follows the HF policy: flash_attention_2 when flash-attn is importable, otherwise sdpa, otherwise eager. """ global _tok, _proc, _model if _model is None: _tok = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True) _proc = AutoProcessor.from_pretrained(MODEL, trust_remote_code=True) attn_impl = REMOTE_ATTN_MODE if ATTN_MODE == "magi" and os.environ.get("LA_FLASH_DEBUG", "0") != "0": print("LA Flash magi pre-load diagnostic:", flush=True) print(_magi_diag(), flush=True) _model = AutoModel.from_pretrained(MODEL, torch_dtype=DT, trust_remote_code=True, attn_implementation=attn_impl).to(DEV).eval() _set_vision_attention_mode(_model) actual_attn = getattr(_model.language_model.model, "_attn_implementation", None) if ATTN_MODE == "magi" and os.environ.get("LA_FLASH_DEBUG", "0") != "0": print("LA Flash magi post-load diagnostic:", flush=True) print(_remote_magi_diag(_model), flush=True) print(_attn_class_diag(_model), flush=True) if ATTN_MODE == "magi": try: qwen2_mod = importlib.import_module(type(_model.language_model.model).__module__) if not getattr(qwen2_mod, "_MAGI_AVAILABLE", False): raise RuntimeError( "remote module reports _MAGI_AVAILABLE=False.\n" f"{_remote_magi_diag(_model)}\n{_magi_diag()}" ) first_attn = type(_model.language_model.model.layers[0].self_attn).__name__ if actual_attn != "sdpa" or first_attn != "_BatchedMagiAttention": _set_llm_mode(_model, "magi") actual_attn = getattr(_model.language_model.model, "_attn_implementation", None) first_attn = type(_model.language_model.model.layers[0].self_attn).__name__ if os.environ.get("LA_FLASH_DEBUG", "0") != "0": print("LA Flash magi post-swap diagnostic:", flush=True) print(_attn_class_diag(_model), flush=True) if actual_attn != "sdpa" or first_attn != "_BatchedMagiAttention": raise RuntimeError( "batched magi attention did not activate. " f"actual_attn={actual_attn!r}; first_attn={first_attn!r}; " f"{_remote_magi_diag(_model)}; {_attn_class_diag(_model)}" ) _model._la_flash_requested_attn = "magi" except Exception as exc: _fallback_to_sdpa(_model, "magi", exc) else: try: _set_llm_mode(_model, ATTN_MODE) # decode-safe mask plumbing for sdpa/eager/la_flash except Exception as exc: _fallback_to_sdpa(_model, ATTN_MODE, exc) if MTP_COMPILE: _maybe_compile(_model) return _tok, _proc, _model def _maybe_compile(model): """Compile the shared Qwen2Model core (base.forward). It backs BOTH prefill (called directly) and decode (language_model.forward -> self.model). lm_head + MoonViT left eager. dynamic=True so the varying decode S/kvlen don't trigger a recompile storm. No-op + warning if triton is missing (inductor needs it on GPU). First call pays the compile cost (~42s warm / ~187s cold).""" try: import triton # noqa: F401 except Exception: warnings.warn("MTP_COMPILE set but triton is unavailable; running without torch.compile.") return import torch._dynamo as _dyn _dyn.config.cache_size_limit = max(_dyn.config.cache_size_limit, 64) base = model.language_model.model if not getattr(base, "_mtp_compiled", False): base.forward = torch.compile(base.forward, dynamic=True) base._mtp_compiled = True def build_batched_magi_attention_class(mod): """Build a Qwen2 attention subclass backed by Magi's flex_flash_attn. The official LocateAnything ``Qwen2MagiAttention`` asserts ``bsz == 1`` and relies on ``Qwen2Model._attn_implementation == "magi"`` to build a single sample range plan. For release batch inference the hybrid scheduler passes a batched Magi range plan directly to this layer; a 4D-mask conversion path remains as a compatibility fallback. """ flex_flash_attn_func = getattr(mod, "flex_flash_attn_func", None) if flex_flash_attn_func is None: try: from magi_attention.functional.flex_flash_attn import flex_flash_attn_func except Exception as exc: raise RuntimeError( "LA_FLASH_ATTN=magi requires " "magi_attention.functional.flex_flash_attn.flex_flash_attn_func." ) from exc FULL, CAUSAL = 0, 1 causal_plan_cache = {} try: magi_params = set(inspect.signature(flex_flash_attn_func).parameters) except (TypeError, ValueError): magi_params = set() supports_disable_fwd_atomic = "disable_fwd_atomic_reduction" in magi_params def _disjoint_q_ranges(q_ranges): seen = set() for start, end in q_ranges: key = (int(start), int(end)) if key in seen: return False seen.add(key) return True def _plan_disjoint_q_ranges(plan): cached = plan.get("_la_flash_disjoint_q_ranges") if cached is not None: return bool(cached) q_ranges = plan["q_ranges"].detach().to(device="cpu", dtype=torch.int32).tolist() disjoint = _disjoint_q_ranges(q_ranges) try: plan["_la_flash_disjoint_q_ranges"] = disjoint except Exception: pass return disjoint def _tensor_plan(q_ranges, k_ranges, types, device): return { "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=device).contiguous(), "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=device).contiguous(), "attn_type_map": torch.tensor(types, dtype=torch.int32, device=device).contiguous(), "_la_flash_disjoint_q_ranges": _disjoint_q_ranges(q_ranges), } def _offset_plan(plan, q_offset, k_offset): return ( (plan["q_ranges"] + int(q_offset)).tolist(), (plan["k_ranges"] + int(k_offset)).tolist(), plan["attn_type_map"].tolist(), ) def _causal_plan(bsz, q_len, kv_seq_len, device): key = (int(bsz), int(q_len), int(kv_seq_len), device.type, device.index) cached = causal_plan_cache.get(key) if cached is not None: return cached q_ranges, k_ranges, types = [], [], [] for b in range(int(bsz)): q_base = b * int(q_len) k_base = b * int(kv_seq_len) q_ranges.append([q_base, q_base + int(q_len)]) k_ranges.append([k_base, k_base + int(kv_seq_len)]) types.append(CAUSAL) plan = _tensor_plan(q_ranges, k_ranges, types, device) plan.update( { "flash_cu_seqlens_q": torch.arange( 0, (int(bsz) + 1) * int(q_len), int(q_len), dtype=torch.int32, device=device, ), "flash_cu_seqlens_k": torch.arange( 0, (int(bsz) + 1) * int(kv_seq_len), int(kv_seq_len), dtype=torch.int32, device=device, ), "flash_causal": True, } ) causal_plan_cache[key] = plan return plan def _row_segments(row): idx = np.flatnonzero(row) if idx.size == 0: return ((0, 1),) split = np.flatnonzero(np.diff(idx) > 1) + 1 starts = np.concatenate((idx[:1], idx[split])) ends = np.concatenate((idx[split - 1], idx[-1:])) + 1 return tuple((int(s), int(e)) for s, e in zip(starts, ends)) def _visible_from_4d_mask(attention_mask, kv_seq_len): mask = attention_mask[:, :, :, :kv_seq_len] if mask.dtype == torch.bool: return mask[:, 0].detach().to(device="cpu", dtype=torch.bool).contiguous() mask_cpu = mask[:, 0].detach().to(device="cpu").contiguous() if getattr(attention_mask, "_la_flash_visible_mask", False): return (mask_cpu > 0).to(dtype=torch.bool) max_value = float(mask_cpu.max().item()) if mask_cpu.numel() else 0.0 min_value = float(mask_cpu.min().item()) if mask_cpu.numel() else 0.0 if max_value > 0.0 and min_value >= 0.0: return (mask_cpu > 0).to(dtype=torch.bool) return (mask_cpu >= 0).to(dtype=torch.bool) def _plan_from_visible_mask(attention_mask, bsz, q_len, kv_seq_len, device): cache_key = (int(bsz), int(q_len), int(kv_seq_len), device.type, device.index) cached = getattr(attention_mask, "_la_flash_magi_plan", None) if cached is not None and cached[0] == cache_key: return cached[1] visible = _visible_from_4d_mask(attention_mask, int(kv_seq_len)).numpy() q_ranges, k_ranges, types = [], [], [] for b in range(int(bsz)): q_base = b * int(q_len) k_base = b * int(kv_seq_len) run_start = 0 run_segments = _row_segments(visible[b, 0]) for q in range(1, int(q_len)): segments = _row_segments(visible[b, q]) if segments == run_segments: continue for start, end in run_segments: q_ranges.append([q_base + run_start, q_base + q]) k_ranges.append([k_base + start, k_base + end]) types.append(FULL) run_start = q run_segments = segments for start, end in run_segments: q_ranges.append([q_base + run_start, q_base + int(q_len)]) k_ranges.append([k_base + start, k_base + end]) types.append(FULL) plan = _tensor_plan(q_ranges, k_ranges, types, device) try: attention_mask._la_flash_magi_plan = (cache_key, plan) except Exception: pass return plan def _plan_from_magi_dict(attention_mask, bsz, q_len, kv_seq_len, device): if int(bsz) == 1: return attention_mask q_ranges, k_ranges, types = [], [], [] for b in range(int(bsz)): qs, ks, ts = _offset_plan( attention_mask, q_offset=b * int(q_len), k_offset=b * int(kv_seq_len), ) q_ranges.extend(qs) k_ranges.extend(ks) types.extend(ts) return _tensor_plan(q_ranges, k_ranges, types, device) def _magi_plan(attention_mask, bsz, q_len, kv_seq_len, device): if isinstance(attention_mask, dict): if attention_mask.get("_la_flash_batched", False): return attention_mask return _plan_from_magi_dict(attention_mask, bsz, q_len, kv_seq_len, device) if attention_mask is None: return _causal_plan(bsz, q_len, kv_seq_len, device) if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " f"but is {attention_mask.size()}" ) return _plan_from_visible_mask(attention_mask, bsz, q_len, kv_seq_len, device) class _BatchedMagiAttention(mod.Qwen2Attention): """MagiAttention path with true batch inference via packed token ranges.""" def forward( self, hidden_states: torch.Tensor, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, **kwargs, ): if output_attentions: raise NotImplementedError("MagiAttention does not support output_attentions=True") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using " f"{self.__class__.__name__} for auto-regressive decoding with k/v caching, " "please initialize the attention class with a layer index." ) kv_seq_len += past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = mod.apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs) kv_seq_len = key_states.shape[-2] plan = _magi_plan(attention_mask, bsz, q_len, kv_seq_len, query_states.device) magi_extra_kwargs = {} if supports_disable_fwd_atomic: magi_extra_kwargs["disable_fwd_atomic_reduction"] = ( (not self.training) and _plan_disjoint_q_ranges(plan) ) query_states = query_states.transpose(1, 2).reshape( bsz * q_len, self.num_heads, self.head_dim).contiguous() key_states = key_states.transpose(1, 2).reshape( bsz * kv_seq_len, self.num_key_value_heads, self.head_dim).contiguous() value_states = value_states.transpose(1, 2).reshape( bsz * kv_seq_len, self.num_key_value_heads, self.head_dim).contiguous() attn_output, _ = flex_flash_attn_func( query_states, key_states, value_states, q_ranges=plan["q_ranges"], k_ranges=plan["k_ranges"], attn_type_map=plan["attn_type_map"], softmax_scale=getattr(self, "softmax_scale", self.head_dim ** -0.5), softcap=0.0, deterministic=False, **magi_extra_kwargs, ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value return _BatchedMagiAttention def build_la_flash_attention_class(mod): """Build a Qwen2 attention subclass backed by LA Flash sparse ranges.""" try: from kernel_utils import is_available, range_attention except Exception as exc: raise RuntimeError( "LA_FLASH_ATTN=la_flash requires kernel_utils and FlashAttention." ) from exc if not is_available(): raise RuntimeError( "LA_FLASH_ATTN=la_flash requires flash_attn.flash_attn_varlen_func." ) FULL, CAUSAL = 0, 1 causal_plan_cache = {} def _tensor_plan(q_ranges, k_ranges, types, device): max_q_len = max((int(end) - int(start) for start, end in q_ranges), default=0) max_k_len = max((int(end) - int(start) for start, end in k_ranges), default=0) plan = { "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=device).contiguous(), "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=device).contiguous(), "attn_type_map": torch.tensor(types, dtype=torch.int32, device=device).contiguous(), "max_q_len": max_q_len, "max_k_len": max_k_len, } plan.update(_la_flash_group_plan_tensors(q_ranges, types, device)) return plan def _offset_plan(plan, q_offset, k_offset): return ( (plan["q_ranges"] + int(q_offset)).tolist(), (plan["k_ranges"] + int(k_offset)).tolist(), plan["attn_type_map"].tolist(), ) def _causal_plan(bsz, q_len, kv_seq_len, device): key = (int(bsz), int(q_len), int(kv_seq_len), device.type, device.index) cached = causal_plan_cache.get(key) if cached is not None: return cached q_ranges, k_ranges, types = [], [], [] for b in range(int(bsz)): q_base = b * int(q_len) k_base = b * int(kv_seq_len) q_ranges.append([q_base, q_base + int(q_len)]) k_ranges.append([k_base, k_base + int(kv_seq_len)]) types.append(CAUSAL) plan = _tensor_plan(q_ranges, k_ranges, types, device) plan.update( { "flash_cu_seqlens_q": torch.arange( 0, (int(bsz) + 1) * int(q_len), int(q_len), dtype=torch.int32, device=device, ), "flash_cu_seqlens_k": torch.arange( 0, (int(bsz) + 1) * int(kv_seq_len), int(kv_seq_len), dtype=torch.int32, device=device, ), "flash_causal": True, } ) causal_plan_cache[key] = plan return plan def _row_segments(row): idx = np.flatnonzero(row) if idx.size == 0: return ((0, 1),) split = np.flatnonzero(np.diff(idx) > 1) + 1 starts = np.concatenate((idx[:1], idx[split])) ends = np.concatenate((idx[split - 1], idx[-1:])) + 1 return tuple((int(s), int(e)) for s, e in zip(starts, ends)) def _visible_from_4d_mask(attention_mask, kv_seq_len): mask = attention_mask[:, :, :, :kv_seq_len] if mask.dtype == torch.bool: return mask[:, 0].detach().to(device="cpu", dtype=torch.bool).contiguous() mask_cpu = mask[:, 0].detach().to(device="cpu").contiguous() if getattr(attention_mask, "_la_flash_visible_mask", False): return (mask_cpu > 0).to(dtype=torch.bool) max_value = float(mask_cpu.max().item()) if mask_cpu.numel() else 0.0 min_value = float(mask_cpu.min().item()) if mask_cpu.numel() else 0.0 if max_value > 0.0 and min_value >= 0.0: return (mask_cpu > 0).to(dtype=torch.bool) return (mask_cpu >= 0).to(dtype=torch.bool) def _prefix_len(row): idx = np.flatnonzero(row) if idx.size == 0: return None end = int(idx[-1]) + 1 if not bool(row[:end].all()) or bool(row[end:].any()): return None return end def _causal_plan_from_visible(visible, bsz, q_len, kv_seq_len, device): q_ranges, k_ranges, types = [], [], [] packed_flash = True for b in range(int(bsz)): first_len = _prefix_len(visible[b, 0]) if first_len is None: return None valid_len = int(first_len) + int(q_len) - 1 if valid_len < int(q_len) or valid_len > int(kv_seq_len): return None for q in range(int(q_len)): row_len = _prefix_len(visible[b, q]) expected = valid_len - int(q_len) + q + 1 if row_len != expected: return None q_base = b * int(q_len) k_base = b * int(kv_seq_len) q_ranges.append([q_base, q_base + int(q_len)]) k_ranges.append([k_base, k_base + valid_len]) types.append(CAUSAL) packed_flash = packed_flash and valid_len == int(kv_seq_len) plan = _tensor_plan(q_ranges, k_ranges, types, device) plan["_la_flash_disjoint_q_ranges"] = True if packed_flash: plan.update( { "flash_cu_seqlens_q": torch.arange( 0, (int(bsz) + 1) * int(q_len), int(q_len), dtype=torch.int32, device=device, ), "flash_cu_seqlens_k": torch.arange( 0, (int(bsz) + 1) * int(kv_seq_len), int(kv_seq_len), dtype=torch.int32, device=device, ), "flash_causal": True, } ) return plan def _plan_from_visible_mask(attention_mask, bsz, q_len, kv_seq_len, device): cache_key = (int(bsz), int(q_len), int(kv_seq_len), device.type, device.index, "la_flash") cached = getattr(attention_mask, "_la_flash_range_plan", None) if cached is not None and cached[0] == cache_key: return cached[1] visible = _visible_from_4d_mask(attention_mask, int(kv_seq_len)).numpy() plan = _causal_plan_from_visible(visible, bsz, q_len, kv_seq_len, device) if plan is not None: try: attention_mask._la_flash_range_plan = (cache_key, plan) except Exception: pass return plan q_ranges, k_ranges, types = [], [], [] for b in range(int(bsz)): q_base = b * int(q_len) k_base = b * int(kv_seq_len) run_start = 0 run_segments = _row_segments(visible[b, 0]) for q in range(1, int(q_len)): segments = _row_segments(visible[b, q]) if segments == run_segments: continue for start, end in run_segments: q_ranges.append([q_base + run_start, q_base + q]) k_ranges.append([k_base + start, k_base + end]) types.append(FULL) run_start = q run_segments = segments for start, end in run_segments: q_ranges.append([q_base + run_start, q_base + int(q_len)]) k_ranges.append([k_base + start, k_base + end]) types.append(FULL) plan = _tensor_plan(q_ranges, k_ranges, types, device) try: attention_mask._la_flash_range_plan = (cache_key, plan) except Exception: pass return plan def _plan_from_magi_dict(attention_mask, bsz, q_len, kv_seq_len, device): if int(bsz) == 1: return attention_mask q_ranges, k_ranges, types = [], [], [] for b in range(int(bsz)): qs, ks, ts = _offset_plan( attention_mask, q_offset=b * int(q_len), k_offset=b * int(kv_seq_len), ) q_ranges.extend(qs) k_ranges.extend(ks) types.extend(ts) return _tensor_plan(q_ranges, k_ranges, types, device) def _range_plan(attention_mask, bsz, q_len, kv_seq_len, device): if isinstance(attention_mask, dict): if attention_mask.get("_la_flash_batched", False): return attention_mask return _plan_from_magi_dict(attention_mask, bsz, q_len, kv_seq_len, device) if attention_mask is None: return _causal_plan(bsz, q_len, kv_seq_len, device) if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " f"but is {attention_mask.size()}" ) return _plan_from_visible_mask(attention_mask, bsz, q_len, kv_seq_len, device) class _LaFlashAttention(mod.Qwen2Attention): """Range-plan attention path backed by FlashAttention sparse ranges.""" def forward( self, hidden_states: torch.Tensor, attention_mask=None, position_ids=None, past_key_value=None, output_attentions=False, use_cache=False, **kwargs, ): if output_attentions: raise NotImplementedError("LA Flash attention does not support output_attentions=True") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: if self.layer_idx is None: raise ValueError( f"The cache structure has changed since version v4.36. If you are using " f"{self.__class__.__name__} for auto-regressive decoding with k/v caching, " "please initialize the attention class with a layer index." ) kv_seq_len += past_key_value.get_seq_length(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = mod.apply_rotary_pos_emb( query_states, key_states, cos, sin, position_ids) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs) kv_seq_len = key_states.shape[-2] dense_backend = os.environ.get("LA_FLASH_DENSE_BACKEND", "sdpa").strip().lower() if dense_backend == "sdpa" and not isinstance(attention_mask, dict): dense_key_states = mod.repeat_kv(key_states, self.num_key_value_groups) dense_value_states = mod.repeat_kv(value_states, self.num_key_value_groups) if attention_mask is not None: if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): raise ValueError( f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, " f"but is {attention_mask.size()}" ) query_for_sdpa = query_states.contiguous() key_for_sdpa = dense_key_states.contiguous() value_for_sdpa = dense_value_states.contiguous() is_causal = False elif past_key_value is None: query_for_sdpa = query_states key_for_sdpa = dense_key_states value_for_sdpa = dense_value_states is_causal = bool(self.is_causal and q_len > 1) else: query_for_sdpa = key_for_sdpa = value_for_sdpa = None is_causal = False if query_for_sdpa is not None: attn_output = torch.nn.functional.scaled_dot_product_attention( query_for_sdpa, key_for_sdpa, value_for_sdpa, attn_mask=attention_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value plan = _range_plan(attention_mask, bsz, q_len, kv_seq_len, query_states.device) query_states = query_states.transpose(1, 2).reshape( bsz * q_len, self.num_heads, self.head_dim).contiguous() key_states = key_states.transpose(1, 2).reshape( bsz * kv_seq_len, self.num_key_value_heads, self.head_dim).contiguous() value_states = value_states.transpose(1, 2).reshape( bsz * kv_seq_len, self.num_key_value_heads, self.head_dim).contiguous() attn_output = range_attention( query_states, key_states, value_states, plan["q_ranges"], plan["k_ranges"], plan["attn_type_map"], getattr(self, "softmax_scale", self.head_dim ** -0.5), segment_offsets=plan.get("segment_offsets"), group_q_ranges=plan.get("group_q_ranges"), group_attn_type_map=plan.get("group_attn_type_map"), max_q_len=plan.get("max_q_len"), max_k_len=plan.get("max_k_len"), flash_cu_seqlens_q=plan.get("flash_cu_seqlens_q"), flash_cu_seqlens_k=plan.get("flash_cu_seqlens_k"), flash_causal=plan.get("flash_causal"), disjoint_q_ranges=plan.get("_la_flash_disjoint_q_ranges"), ) attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, None, past_key_value return _LaFlashAttention def _is_magi_plan(obj): return isinstance(obj, dict) and { "q_ranges", "k_ranges", "attn_type_map", }.issubset(obj.keys()) def _la_flash_group_plan_tensors(q_ranges, types, device): """Group consecutive Magi range entries that share the same query span. Magi-style plans may represent one query span with multiple disjoint key spans. LA Flash consumes those as one FlashAttention-backed softmax group. """ if not q_ranges: return { "group_q_ranges": torch.empty((0, 2), dtype=torch.int32, device=device), "segment_offsets": torch.zeros((1,), dtype=torch.int32, device=device), "group_attn_type_map": torch.empty((0,), dtype=torch.int32, device=device), } grouped_q, grouped_types, offsets = [], [], [0] last_q = None last_type = None for idx, (q_range, attn_type) in enumerate(zip(q_ranges, types)): key = (int(q_range[0]), int(q_range[1])) attn_type = int(attn_type) if last_q is None: grouped_q.append([key[0], key[1]]) grouped_types.append(attn_type) last_q = key last_type = attn_type continue if key == last_q and attn_type == last_type: continue offsets.append(idx) grouped_q.append([key[0], key[1]]) grouped_types.append(attn_type) last_q = key last_type = attn_type offsets.append(len(q_ranges)) return { "group_q_ranges": torch.tensor(grouped_q, dtype=torch.int32, device=device).contiguous(), "segment_offsets": torch.tensor(offsets, dtype=torch.int32, device=device).contiguous(), "group_attn_type_map": torch.tensor(grouped_types, dtype=torch.int32, device=device).contiguous(), "max_q_len": max((end - start for start, end in grouped_q), default=0), } def _record_sparse_plan_stats(model, q_ranges, k_ranges, types): if os.environ.get("LA_FLASH_PLAN_STATS", "0") != "1": return stats = getattr(model, "_la_flash_sparse_plan_stats", None) if stats is None: stats = { "calls": 0, "ranges": 0, "q_tokens": 0, "k_tokens": 0, "max_q_len": 0, "max_k_len": 0, "full_ranges": 0, "causal_ranges": 0, "other_ranges": 0, } model._la_flash_sparse_plan_stats = stats stats["calls"] += 1 stats["ranges"] += len(q_ranges) for (q_start, q_end), (k_start, k_end), attn_type in zip(q_ranges, k_ranges, types): q_len = int(q_end) - int(q_start) k_len = int(k_end) - int(k_start) stats["q_tokens"] += q_len stats["k_tokens"] += k_len stats["max_q_len"] = max(stats["max_q_len"], q_len) stats["max_k_len"] = max(stats["max_k_len"], k_len) attn_type = int(attn_type) if attn_type == 0: stats["full_ranges"] += 1 elif attn_type == 1: stats["causal_ranges"] += 1 else: stats["other_ranges"] += 1 def build_magi_scheduler_ranges(model, attention_mask_2d, input_ids, past_len, mtp_window=False): """Build batched Magi ranges directly from the hybrid scheduler mask. The official Qwen2 SDPA dispatcher may optimize an all-valid 2D mask to ``None`` before decoder layers see it. That is correct for plain causal attention but loses LocateAnything's MTP generation-window rule. Building ranges here keeps Magi batch inference exact and avoids per-layer dense mask conversion. """ requested_attn = getattr(model, "_la_flash_requested_attn", ATTN_MODE) if requested_attn not in {"magi", "la_flash"}: return None if attention_mask_2d is None or not hasattr(attention_mask_2d, "dim") or attention_mask_2d.dim() != 2: return None bsz, q_len = int(input_ids.shape[0]), int(input_ids.shape[1]) key_len = int(attention_mask_2d.shape[1]) dev = input_ids.device llm = model.language_model.model block = int(getattr(llm, "block_size", N_FUTURE)) causal_attn = bool(getattr(llm, "causal_attn", False)) use_mtp_window = bool(mtp_window and q_len >= block and key_len >= block) q0 = max(0, q_len - block) k0 = max(0, key_len - block) blocked_k = k0 - 1 past_len = int(past_len) key_valid = attention_mask_2d.detach().to(device="cpu", dtype=torch.bool).contiguous().numpy() key_idx = np.arange(key_len) q_ranges, k_ranges, types = [], [], [] if not use_mtp_window: causal_q_ranges, causal_k_ranges, causal_types = [], [], [] causal_fast_path = True packed_flash = True for b in range(bsz): valid = np.flatnonzero(key_valid[b]) if valid.size == 0: causal_fast_path = False break valid_len = int(valid[-1]) + 1 if valid_len < q_len or not bool(key_valid[b, :valid_len].all()) or bool(key_valid[b, valid_len:].any()): causal_fast_path = False break packed_flash = packed_flash and valid_len == key_len q_base = b * q_len k_base = b * key_len causal_q_ranges.append([q_base, q_base + q_len]) causal_k_ranges.append([k_base, k_base + valid_len]) causal_types.append(1) if causal_fast_path: plan = { "q_ranges": torch.tensor(causal_q_ranges, dtype=torch.int32, device=dev).contiguous(), "k_ranges": torch.tensor(causal_k_ranges, dtype=torch.int32, device=dev).contiguous(), "attn_type_map": torch.tensor(causal_types, dtype=torch.int32, device=dev).contiguous(), "max_q_len": q_len, "max_k_len": max((end - start for start, end in causal_k_ranges), default=0), "_la_flash_batched": True, "_la_flash_disjoint_q_ranges": True, } if packed_flash: plan.update( { "flash_cu_seqlens_q": torch.arange( 0, (bsz + 1) * q_len, q_len, dtype=torch.int32, device=dev, ), "flash_cu_seqlens_k": torch.arange( 0, (bsz + 1) * key_len, key_len, dtype=torch.int32, device=dev, ), "flash_causal": True, } ) plan.update(_la_flash_group_plan_tensors(causal_q_ranges, causal_types, dev)) _record_sparse_plan_stats(model, causal_q_ranges, causal_k_ranges, causal_types) return plan def row_segments(row): idx = np.flatnonzero(row) if idx.size == 0: return ((0, 1),) split = np.flatnonzero(np.diff(idx) > 1) + 1 starts = np.concatenate((idx[:1], idx[split])) ends = np.concatenate((idx[split - 1], idx[-1:])) + 1 return tuple((int(s), int(e)) for s, e in zip(starts, ends)) for b in range(bsz): q_base = b * q_len k_base = b * key_len run_start = 0 run_segments = None if use_mtp_window and not causal_attn: prefix_q_len = q0 prefix_k_end = past_len + prefix_q_len prefix_ok = ( prefix_q_len > 0 and prefix_k_end <= key_len and bool(key_valid[b, :prefix_k_end].all()) ) window_prefix_ok = blocked_k <= 0 or bool(key_valid[b, :blocked_k].all()) window_ok = bool(key_valid[b, k0:key_len].all()) if prefix_ok: q_ranges.append([q_base, q_base + prefix_q_len]) k_ranges.append([k_base, k_base + prefix_k_end]) types.append(1) run_start = prefix_q_len if run_start == prefix_q_len and prefix_q_len < q_len and window_prefix_ok and window_ok: if blocked_k > 0: q_ranges.append([q_base + prefix_q_len, q_base + q_len]) k_ranges.append([k_base, k_base + blocked_k]) types.append(0) q_ranges.append([q_base + prefix_q_len, q_base + q_len]) k_ranges.append([k_base + k0, k_base + key_len]) types.append(0) continue for q in range(run_start, q_len): visible = key_valid[b] & (key_idx <= q + past_len) if use_mtp_window and q >= q0: if not causal_attn: visible = visible.copy() visible[k0:key_len] = key_valid[b, k0:key_len] if blocked_k >= 0: if visible.base is None: visible[blocked_k] = False else: visible = visible.copy() visible[blocked_k] = False segments = row_segments(visible) if run_segments is None: run_segments = segments continue if segments == run_segments: continue for start, end in run_segments: q_ranges.append([q_base + run_start, q_base + q]) k_ranges.append([k_base + start, k_base + end]) types.append(0) run_start = q run_segments = segments for start, end in run_segments: q_ranges.append([q_base + run_start, q_base + q_len]) k_ranges.append([k_base + start, k_base + end]) types.append(0) seen_q_ranges = set() disjoint_q_ranges = True for start, end in q_ranges: key = (int(start), int(end)) if key in seen_q_ranges: disjoint_q_ranges = False break seen_q_ranges.add(key) plan = { "q_ranges": torch.tensor(q_ranges, dtype=torch.int32, device=dev).contiguous(), "k_ranges": torch.tensor(k_ranges, dtype=torch.int32, device=dev).contiguous(), "attn_type_map": torch.tensor(types, dtype=torch.int32, device=dev).contiguous(), "max_q_len": max((end - start for start, end in q_ranges), default=0), "max_k_len": max((end - start for start, end in k_ranges), default=0), "_la_flash_batched": True, "_la_flash_disjoint_q_ranges": disjoint_q_ranges, } plan.update(_la_flash_group_plan_tensors(q_ranges, types, dev)) _record_sparse_plan_stats(model, q_ranges, k_ranges, types) return plan def _direct_base_forward( base, input_ids=None, visual_features=None, image_token_index=None, attention_mask=None, position_ids=None, past_key_values=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, ): mod = importlib.import_module(type(base).__module__) output_attentions = output_attentions if output_attentions is not None else base.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else base.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else base.config.use_cache if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") if input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") past_key_values_length = 0 use_legacy_cache = False if use_cache: Cache = getattr(mod, "Cache") DynamicCache = getattr(mod, "DynamicCache") use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: if past_key_values is None: past_key_values = DynamicCache() else: past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values_length = past_key_values.get_seq_length() if position_ids is None: dev = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=dev, ).unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() if inputs_embeds is None: inputs_embeds = base.image_processing(input_ids, visual_features, image_token_index) hidden_states = inputs_embeds all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in base.layers: if output_hidden_states: all_hidden_states += (hidden_states,) layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = base.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache return SimpleNamespace( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) def language_model_forward(model, **kwargs): """Forward through the text LM, bypassing official dense-mask prep for sparse plans.""" lm = model.language_model return_logits = kwargs.pop("return_logits", True) logits_slice = kwargs.pop("logits_slice", None) attention_mask = kwargs.get("attention_mask") use_direct_sparse = ( getattr(model, "_la_flash_requested_attn", ATTN_MODE) in {"magi", "la_flash"} and _is_magi_plan(attention_mask) ) if not use_direct_sparse: return lm(**kwargs) labels = kwargs.pop("labels", None) if labels is not None: raise NotImplementedError("labels are not supported in the direct sparse-plan decode forward") output_attentions = kwargs.get("output_attentions", None) output_hidden_states = kwargs.get("output_hidden_states", None) base_out = _direct_base_forward(lm.model, **kwargs) logits = None if return_logits: hidden_states = base_out.last_hidden_state if logits_slice is not None: hidden_states = hidden_states[:, logits_slice, :] logits = lm.lm_head(hidden_states).float() return SimpleNamespace( logits=logits, past_key_values=base_out.past_key_values, hidden_states=base_out.hidden_states if output_hidden_states else None, attentions=base_out.attentions if output_attentions else None, ) _EagerCls = _SdpaCls = _LaFlashCls = _MagiCls = None def _attn_classes(mode=None): """Attention classes from the dynamic Qwen2 remote module. The official Qwen2Model mask dispatcher only implements ``sdpa`` and single-row ``magi``. Eager, LA Flash, and batched Magi inference therefore swap the layer class while keeping the model's mask dispatcher pinned to ``sdpa``. """ global _EagerCls, _SdpaCls, _LaFlashCls, _MagiCls mode = _normalize_attn_mode(mode) if mode is not None else None if _SdpaCls is None: mod = importlib.import_module(type(_model.language_model.model).__module__) _EagerCls = mod.Qwen2Attention _SdpaCls = mod.Qwen2SdpaAttention else: mod = importlib.import_module(type(_model.language_model.model).__module__) if (mode is None or mode == "la_flash") and _LaFlashCls is None: _LaFlashCls = build_la_flash_attention_class(mod) if (mode is None or mode == "magi") and _MagiCls is None: _MagiCls = build_batched_magi_attention_class(mod) if getattr(mod, "_MAGI_AVAILABLE", False) else None return _EagerCls, _SdpaCls, _LaFlashCls, _MagiCls def _set_llm_mode(model, mode): """Swap every Qwen2 decoder layer's attention class. Release backends keep ``Qwen2Model._attn_implementation='sdpa'`` so the official Qwen2 mask dispatcher stays available for dense-mask modes. The local ``la_flash`` and batched ``magi`` wrappers can also consume scheduler-built sparse plans directly, avoiding repeated per-layer dense mask conversion. """ mode = _normalize_attn_mode(mode) eager, sdpa, la_flash, magi = _attn_classes(mode) impl = "sdpa" if mode == "sdpa": cls = sdpa elif mode == "eager": cls = eager elif mode == "la_flash": cls = la_flash elif mode == "magi": if magi is None: raise RuntimeError("MagiAttention is unavailable in the current Python environment.") cls = magi else: raise ValueError(f"unknown LLM attention mode: {mode}") llm = model.language_model.model for lyr in llm.layers: lyr.self_attn.__class__ = cls if mode == "magi": lyr.self_attn.softmax_scale = lyr.self_attn.head_dim ** -0.5 llm._attn_implementation = impl llm.config._attn_implementation = llm._attn_implementation if hasattr(model.config, "text_config"): model.config.text_config._attn_implementation = llm._attn_implementation model.config._attn_implementation = llm._attn_implementation model._la_flash_requested_attn = mode _st = _hp = None def _helpers(): """The model's own sample_tokens / handle_pattern (the exact box decoders).""" global _st, _hp if _st is None: m = importlib.import_module(type(load()[2]).__module__) _st, _hp = m.sample_tokens, m.handle_pattern return _st, _hp _gu = None def _gen_utils(): """The model's generate_utils module (apply_repetition_penalty / top_p_logits / top_k_logits / decode_bbox_avg / decode_ref / dists) -- the pieces sample_tokens_batched reuses verbatim.""" global _gu if _gu is None: m = importlib.import_module(type(load()[2]).__module__) _gu = importlib.import_module(m.sample_tokens.__module__) return _gu def _env_float(name, default): val = os.environ.get(name) if val is None or val.strip() == "": return float(default) return float(val) def _coord_fallback_mode(): mode = os.environ.get("LA_FLASH_COORD_FALLBACK_MODE", "legacy").strip().lower().replace("-", "_") aliases = { "": "legacy", "official": "legacy", "range": "legacy", "spread": "legacy", "none": "off", "disable": "off", "disabled": "off", "entropy_variance": "uncertainty", "entropy_var": "uncertainty", "ent_var": "uncertainty", "entropy_std": "uncertainty", } mode = aliases.get(mode, mode) if mode not in {"legacy", "uncertainty", "off"}: raise ValueError( "LA_FLASH_COORD_FALLBACK_MODE must be one of legacy, uncertainty, off" ) return mode def _coord_uncertainty_threshold(coord_start_token_id, coord_end_token_id): """Return the coord uncertainty threshold in raw coord-token units. Backward-compatible behavior: - LA_FLASH_COORD_UNCERTAINTY_THRESH > 1 is treated as raw coord-token RMSE. - LA_FLASH_COORD_UNCERTAINTY_THRESH <= 1 is treated as normalized by coord span. - LA_FLASH_COORD_UNCERTAINTY_NORM_THRESH is an explicit normalized override. """ coord_span = max(float(coord_end_token_id - coord_start_token_id + 1), 1.0) norm_val = os.environ.get("LA_FLASH_COORD_UNCERTAINTY_NORM_THRESH") if norm_val is not None and norm_val.strip() != "": return float(norm_val) * coord_span val = os.environ.get("LA_FLASH_COORD_UNCERTAINTY_THRESH") if val is None or val.strip() == "": return 20.0 threshold = float(val) if 0.0 < threshold <= 1.0: return threshold * coord_span return threshold def _decode_bbox_with_uncertainty(logits, probs, token_ids, keep_k=4, generation_mode="hybrid"): """Decode an MTP box with configurable coord uncertainty fallback. The default mode is the official LocateAnything rule. ``uncertainty`` keeps the same frame checks and top-k coord selection, but uses one scalar criterion per coordinate: the posterior RMSE of committing to the current MAP coordinate among valid coord candidates. This is the Bayes risk under squared coordinate error, so probabilities and token distances are folded into one threshold in coordinate-token units. """ gu = _gen_utils() mode = _coord_fallback_mode() if mode == "legacy" or generation_mode != "hybrid": return gu.decode_bbox_avg(logits, probs, token_ids, keep_k=keep_k, generation_mode=generation_mode) coord_start_token_id = token_ids["coord_start_token_id"] coord_end_token_id = token_ids["coord_end_token_id"] box_start_token_id = token_ids["box_start_token_id"] box_end_token_id = token_ids["box_end_token_id"] none_token_id = token_ids["none_token_id"] null_token_id = token_ids["null_token_id"] device = logits.device box_type = gu.is_valid_box_frame( probs, token_ids, start_thresh=_env_float("LA_FLASH_COORD_BOX_START_THRESH", 0.7), end_thresh=_env_float("LA_FLASH_COORD_BOX_END_THRESH", 0.2), topk=keep_k, ) if box_type == "empty_box": return torch.tensor([ box_start_token_id, none_token_id, box_end_token_id, null_token_id, null_token_id, null_token_id, ], dtype=torch.long, device=device) if box_type == "illegal_box": return None pos_probs, pos_ids = torch.topk(probs[1:5], k=keep_k, dim=-1) valid = (pos_ids >= coord_start_token_id) & (pos_ids <= coord_end_token_id) has_valid = valid.any(dim=-1) if not has_valid.all(): return None first_valid_idx = valid.long().argmax(dim=-1, keepdim=True) first_valid_ids = pos_ids.gather(-1, first_valid_idx).squeeze(-1) if mode == "off": final_coords = first_valid_ids else: valid_counts = valid.sum(dim=-1) valid_probs = torch.where(valid, pos_probs, torch.zeros_like(pos_probs)) valid_mass = valid_probs.sum(dim=-1).clamp_min(1e-12) weights = valid_probs / valid_mass.unsqueeze(-1) coord_values = (pos_ids - coord_start_token_id).to(dtype=torch.float32) map_coord = (first_valid_ids - coord_start_token_id).to(dtype=torch.float32) uncertainty = (weights * (coord_values - map_coord.unsqueeze(-1)).pow(2)).sum(dim=-1).sqrt() is_abnormal = ( (valid_counts > 1) & (uncertainty > _coord_uncertainty_threshold(coord_start_token_id, coord_end_token_id)) ) final_coords = torch.where(is_abnormal, torch.tensor(0, device=device), first_valid_ids) start_t = torch.tensor([box_start_token_id], dtype=final_coords.dtype, device=device) end_t = torch.tensor([box_end_token_id], dtype=final_coords.dtype, device=device) return torch.cat([start_t, final_coords, end_t]) def _apply_repetition_penalty_lowmem(logits, generated, repetition_penalty): """Apply the stock repetition penalty without allocating a [B, S, V] mask.""" if repetition_penalty == 1.0: return logits _, _, vocab_size = logits.shape for row in range(logits.shape[0]): valid_tokens = generated[row].unique() valid_tokens = valid_tokens[(valid_tokens >= 0) & (valid_tokens < vocab_size)] if valid_tokens.numel() == 0: continue row_logits = logits[row, :, valid_tokens] logits[row, :, valid_tokens] = torch.where( row_logits > 0, row_logits / repetition_penalty, row_logits * repetition_penalty, ) return logits def _finite_logit_bounds(dtype): finfo = torch.finfo(dtype) return finfo.min, finfo.max def _finite_logits(logits): if not logits.dtype.is_floating_point: logits = logits.float() min_val, max_val = _finite_logit_bounds(logits.dtype) return torch.nan_to_num(logits, nan=min_val, posinf=max_val, neginf=min_val) def _finite_logits_(logits): if not logits.dtype.is_floating_point: return logits.float() min_val, max_val = _finite_logit_bounds(logits.dtype) return logits.nan_to_num_(nan=min_val, posinf=max_val, neginf=min_val) def _top_p_logits_slice_(logits, top_p): sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = False remove = torch.zeros_like(logits, dtype=torch.bool, device=logits.device) remove.scatter_(-1, sorted_indices, sorted_indices_to_remove) logits.masked_fill_(remove, torch.finfo(logits.dtype).min) return logits def _top_p_logits_(logits, top_p): """In-place nucleus filtering with bounded sort workspace. The MTP sampler uses logits shaped ``[B, 6, V]``. Top-p is independent for each row and each future position, so filtering one position at a time keeps the expensive sorted-index workspace at ``[B, V]`` instead of ``[B, 6, V]``. """ if logits.dim() == 3 and logits.shape[1] > 1: for pos in range(logits.shape[1]): _top_p_logits_slice_(logits[:, pos, :], top_p) return logits return _top_p_logits_slice_(logits, top_p) def _top_k_logits_(logits, top_k): """In-place top-k filtering mirroring generate_utils.top_k_logits.""" top_k = min(int(top_k), logits.size(-1)) threshold = torch.topk(logits, top_k)[0][..., -1, None] logits.masked_fill_(logits < threshold, torch.finfo(logits.dtype).min) return logits def _safe_probs(filtered_logits): """Softmax with CUDA-multinomial-safe cleanup and row-wise argmax fallback.""" filtered_logits = _finite_logits(filtered_logits) probs = torch.softmax(filtered_logits, dim=-1, dtype=torch.float32) probs = torch.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0).clamp_min_(0.0) row_sum = probs.sum(dim=-1, keepdim=True) bad = (~torch.isfinite(row_sum)) | (row_sum <= 0) if bool(bad.any().item()): fallback = torch.zeros_like(probs) fallback.scatter_(-1, filtered_logits.argmax(dim=-1, keepdim=True), 1.0) probs = torch.where(bad, fallback, probs) row_sum = probs.sum(dim=-1, keepdim=True) return probs / row_sum.clamp_min(1.0e-20) def _sample_top_p_sorted_tokens(logits, top_p): """Sample from top-p filtered logits without scattering back to vocab order.""" sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) remove = cumulative_probs > top_p remove[..., 1:] = remove[..., :-1].clone() remove[..., 0] = False sorted_logits.masked_fill_(remove, torch.finfo(sorted_logits.dtype).min) sorted_probs = _safe_probs(sorted_logits) sample_idx = sorted_probs.argmax(dim=-1) try: sample_idx = torch.distributions.Categorical(probs=sorted_probs).sample() except Exception: pass return sorted_indices.gather(-1, sample_idx.unsqueeze(-1)).squeeze(-1) @torch.no_grad() def sample_tokens_batched(logits, generated, token_ids, per_row_temp, repetition_penalty=1.0, top_p=None, top_k=None, keep_k_avg=4, generation_mode='fast'): """Batched fork of generate_utils.sample_tokens for the MTP window [B,6,V]. The logits pipeline (rep-penalty / per-row temperature / top_p / top_k / softmax / sample) is ROW-INDEPENDENT, so run it ONCE over the whole batch instead of B times on [1,6,V] (the per-row san defeats batching by slicing wlogits[b:b+1]). Only the variable-length box ASSEMBLY (decode_bbox_avg -> ragged shapes, where sample_tokens' final torch.stack throws) stays per-row, returned as a LIST. Equivalence to per-row san: every pipeline op reduces on dim=-1 only (never crosses the row dim), so row b's processed logits/probs are bit-identical to slicing first -> greedy (per_row_temp==0, argmax branch, no RNG) is BIT-EXACT. Under sampling, one batched Categorical changes the global RNG consumption order vs B per-row draws -> box-size jitter (blessed; greedy is the exact gate). apply_repetition_penalty already loops per-row internally, so passing the full [B,M] `generated` is row-correct. keep_k_avg/generation_mode mirror sample_tokens' decode_bbox_avg call EXACTLY (note: the per-row san passes keep_k=5 but decode_bbox_avg reads keep_k_avg, default 4 -- so 5 is a no-op there; we replicate keep_k_avg=4). Returns (x0[B,6], boxes: list of B 1-D LongTensors).""" gu = _gen_utils() B, S, V = logits.shape # S = N_FUTURE = 6 if repetition_penalty != 1.0: logits = _apply_repetition_penalty_lowmem(logits, generated, repetition_penalty) t = per_row_temp.to(dtype=logits.dtype).view(B, 1, 1) sample_rows = per_row_temp > 0 if bool(sample_rows.all().item()): logits.div_(t.clamp(min=1e-8)) elif bool(sample_rows.any().item()): idx = sample_rows.nonzero(as_tuple=True)[0] logits[idx].div_(t[idx].clamp(min=1e-8)) logits = _finite_logits_(logits) if top_p is not None and top_p < 1: logits = _top_p_logits_(logits, top_p) if top_k is not None and top_k > 0: logits = _top_k_logits_(logits, top_k) probs = _safe_probs(logits) x0 = probs.argmax(dim=-1) # [B,6]; greedy rows are final here samp = per_row_temp > 0 if bool(samp.any()): # sampling rows: ONE batched Categorical draw idx = samp.nonzero(as_tuple=True)[0] try: x0[idx] = gu.dists.Categorical(probs=probs[idx]).sample() except Exception: pass # keep argmax (matches san's except: probs.max) boxes = [] fallback = torch.zeros(1, dtype=x0.dtype, device=x0.device) for b in range(B): # variable-length box assembly (per-row, exact) db = _decode_bbox_with_uncertainty( logits[b], probs[b], token_ids, keep_k=keep_k_avg, generation_mode=generation_mode) if db is not None: boxes.append(db) else: ref = gu.decode_ref(logits[b], probs[b], token_ids) if ref is None: boxes.append(fallback) elif torch.is_tensor(ref): boxes.append(ref.to(dtype=x0.dtype, device=x0.device)) else: boxes.append(torch.tensor(ref, dtype=x0.dtype, device=x0.device)) return x0, boxes @torch.no_grad() def sample_next_tokens_batched(logits, generated, per_row_temp, repetition_penalty=1.0, top_p=None, top_k=None): """Batched one-token sampler for AR repair rows. This mirrors the row-independent part of ``sample_tokens`` for logits shaped ``[B,1,V]``. It intentionally does not run bbox/ref assembly because AR mode only needs the next token before the state machine classifies it. """ gu = _gen_utils() if logits.dim() != 3 or logits.shape[1] != 1: raise ValueError(f"AR batched sampler expects logits [B,1,V], got {tuple(logits.shape)}") B = int(logits.shape[0]) if repetition_penalty != 1.0: logits = _apply_repetition_penalty_lowmem(logits, generated, repetition_penalty) t = per_row_temp.to(dtype=logits.dtype).view(B, 1, 1) sample_rows = per_row_temp > 0 if bool(sample_rows.all().item()): logits.div_(t.clamp(min=1e-8)) elif bool(sample_rows.any().item()): idx = sample_rows.nonzero(as_tuple=True)[0] logits[idx].div_(t[idx].clamp(min=1e-8)) logits = _finite_logits_(logits) sorted_top_p = os.environ.get("AR_SORTED_TOPP", "0") == "1" default_top_p = sorted_top_p and top_p is not None and top_p < 1 and (top_k is None or top_k <= 0) if default_top_p and bool(sample_rows.all().item()): return _sample_top_p_sorted_tokens(logits, top_p) if top_p is not None and top_p < 1: logits = _top_p_logits_(logits, top_p) if top_k is not None and top_k > 0: logits = _top_k_logits_(logits, top_k) probs = _safe_probs(logits) x0 = probs.argmax(dim=-1) if bool(sample_rows.any().item()): # Keep row-ordered sampling as the release default. A single batched # Categorical is faster, but it consumes RNG differently from stock AR # repair and can alter default-temperature termination behavior. for row in sample_rows.nonzero(as_tuple=True)[0].tolist(): try: x0[row : row + 1] = gu.dists.Categorical(probs=probs[row : row + 1]).sample() except Exception: pass return x0 def load_pil(p): from PIL import Image im = Image.open(p).convert("RGB"); w, h = im.size if max(w, h) > MAX_DIM: s = MAX_DIM / max(w, h); im = im.resize((max(1, round(w*s)), max(1, round(h*s))), Image.LANCZOS) return im def _preproc_one(im): """CPU-side processor for one image -> (pixel_values[bf16], grid[int32]). Split out of _encode_image so _encode_images can batch the GPU encode while preprocessing stays per-image.""" tok, proc, model = load() msg = [{"role": "user", "content": [{"type": "image", "image": im}, {"type": "text", "text": "x"}]}] text = proc.py_apply_chat_template(msg, tokenize=False, add_generation_prompt=True) imgs, vids = proc.process_vision_info(msg) inp = proc(text=[text], images=imgs, videos=vids, return_tensors="pt").to(DEV) grid = inp.get("image_grid_hws") if isinstance(grid, np.ndarray): grid = torch.from_numpy(grid).to(DEV, dtype=torch.int32) return inp["pixel_values"].to(DT), grid def _vision_is_flash(): """True iff MoonViT will actually run flash_attn_varlen (so cross-image packing is block-diagonal = exact AND a win). If the vision blocks are on sdpa/eager, OR the flash wheel is absent (multihead_attention falls back to the dense-mask sdpa path), packing is O(S^2) N^2 -> caller must stay per-image.""" vm = load()[2].vision_model mod = importlib.import_module(type(vm).__module__) if getattr(mod, "flash_attn_varlen_func", None) is None: return False try: return vm.encoder.blocks[0].attn_implementation == "flash_attention_2" except Exception: return False @torch.no_grad() def _encode_images(ims): """N images -> list of [n_img_tokens, C] mlp1-projected visual_features, one per image (row-order). Drop-in for [_encode_image(im) for im in ims]. With flash present (_vision_is_flash) and N>1, packs images into extract_feature micro-batches: MoonViT's varlen cu_seqlens path is block-diagonal by image. Without flash, the dense SDPA fallback would scale with the packed total sequence length, so this function falls back to per-image encode. MTP_BATCH_VISION=0 also forces per-image encode.""" tok, proc, model = load() pvs, grids = [], [] for im in ims: pv, g = _preproc_one(im) pvs.append(pv); grids.append(g) if BATCH_VISION and len(ims) > 1 and _vision_is_flash(): if VISION_ENCODE_BATCH_SIZE <= 0 or VISION_ENCODE_BATCH_SIZE >= len(ims): vit_list = model.extract_feature(torch.cat(pvs, dim=0), torch.cat(grids, dim=0)) else: vit_list = [] for start in range(0, len(ims), VISION_ENCODE_BATCH_SIZE): end = min(start + VISION_ENCODE_BATCH_SIZE, len(ims)) vit_list.extend( model.extract_feature( torch.cat(pvs[start:end], dim=0), torch.cat(grids[start:end], dim=0), ) ) return [model.mlp1(v) for v in vit_list] # one [P_i, C] per image (patch_merger split) return [model.mlp1(torch.cat(model.extract_feature(pv, g), dim=0)) for pv, g in zip(pvs, grids)] # per-image (flash absent / N==1 / forced off) @torch.no_grad() def _encode_image(im): """Single-image convenience wrapper (single-image callers); = _encode_images([im])[0] (takes the per-image path inside _encode_images, so bit-identical to the original).""" return _encode_images([im])[0] @torch.no_grad() def _tokenize(im, query): """1-D prompt token ids for (image, query). Uses the model's own chat template.""" tok, proc, model = load() msg = [{"role": "user", "content": [{"type": "image", "image": im}, {"type": "text", "text": _PROMPT + query + "."}]}] text = proc.py_apply_chat_template(msg, tokenize=False, add_generation_prompt=True) imgs, vids = proc.process_vision_info(msg) return proc(text=[text], images=imgs, videos=vids, return_tensors="pt").to(DEV)["input_ids"][0] @torch.no_grad() def _tokenize_cached_image(query, image_token_count, im=None): """Tokenize a prompt when the image token count is already known. This keeps the processor's chat template, but directly expands ```` from the cached visual feature length. It avoids re-running the CPU image processor for every category prompt that shares the same image. """ tok, proc, model = load() msg = [{"role": "user", "content": [{"type": "image", "image": im}, {"type": "text", "text": _PROMPT + query + "."}]}] text = proc.py_apply_chat_template(msg, tokenize=False, add_generation_prompt=True) placeholder = f"<{getattr(proc, 'image_placeholder', 'image')}-1>" image_token = getattr(proc, "image_token", "") image_start = getattr(proc, "image_start_token", "") image_end = getattr(proc, "image_end_token", "") replacement = f"{image_start}{image_token * int(image_token_count)}{image_end}" if placeholder not in text: raise ValueError(f"cached image placeholder {placeholder!r} was not found in chat template") text = text.replace(placeholder, replacement, 1) return tok([text], return_tensors="pt").to(DEV)["input_ids"][0] def _proc_full(im, query): """Full processor dict (input_ids, attention_mask, pixel_values, image_grid_hws) — used by the bench to drive the STOCK generate for the equivalence check.""" tok, proc, model = load() msg = [{"role": "user", "content": [{"type": "image", "image": im}, {"type": "text", "text": _PROMPT + query + "."}]}] text = proc.py_apply_chat_template(msg, tokenize=False, add_generation_prompt=True) imgs, vids = proc.process_vision_info(msg) inp = proc(text=[text], images=imgs, videos=vids, return_tensors="pt").to(DEV) grid = inp.get("image_grid_hws") if isinstance(grid, np.ndarray): grid = torch.from_numpy(grid).to(DEV, dtype=torch.int32) inp["image_grid_hws"] = grid return inp def _pad_generated(prompt_ids, gen_ids, img_tok, dev): """Per-row [prompt + accepted] left-padded with the image token (already in every prompt -> .unique() unchanged -> repetition penalty identical to single-run).""" rows = [list(prompt_ids[b].tolist()) + gen_ids[b] for b in range(len(prompt_ids))] M = max(len(r) for r in rows) out = torch.full((len(rows), M), img_tok, dtype=torch.long, device=dev) for b, r in enumerate(rows): out[b, M - len(r):] = torch.tensor(r, dtype=torch.long, device=dev) return out