"""vLLM-ATOM Plugin for ContextForge V4.0. ATOM (Anchor-driven Tensor Orchestration for Multi-agent) provides: - Pre/post attention hooks for RotateKV quantization (INVARIANT 10) - Anchor-aware KV block routing - CLA metadata injection - KV-aware load balancing across workers Usage: from apohara_context_forge.serving.atom_plugin import vLLMAtomPlugin # Register with vLLM via entry_point in pyproject.toml # Plugin auto-initializes on vLLM worker startup """ from __future__ import annotations import logging from dataclasses import dataclass, field from typing import Any, Callable, Optional logger = logging.getLogger(__name__) @dataclass class ATOMConfig: """ATOM plugin configuration.""" enable_quantization: bool = True # RotateKV pre-RoPE quantization enable_anchor_routing: bool = True # Anchor-based block routing enable_cla_injection: bool = True # CLA metadata in attention quantization_mode: str = "rotate_kv" # or "disabled" max_quantize_blocks: int = 1024 class PreAttentionHook: """Called before attention computation on a KV block.""" def __init__(self, config: ATOMConfig): self._config = config self._quantized_blocks: dict[str, Any] = {} def __call__( self, block_ids: list[str], token_ids: list[int], layer_idx: int, ) -> Optional[dict]: """Pre-attention hook for ATOM processing. Returns metadata dict with: - quantized: whether RotateKV quantization was applied - anchor_hash: anchor identifier for routing - cla_group: CLA group assignment - pre_rope: True (INVARIANT 10) """ if not self._config.enable_quantization: return None result = { "quantized": True, "anchor_hash": "", "cla_group": None, "pre_rope": True, # INVARIANT 10: pre-RoPE only "layer_idx": layer_idx, "num_blocks": len(block_ids), } logger.debug( f"ATOM pre-attention: layer={layer_idx} blocks={len(block_ids)} " f"quantized={result['quantized']} pre_rope={result['pre_rope']}" ) return result class PostAttentionHook: """Called after attention computation on a KV block.""" def __init__(self, config: ATOMConfig): self._config = config self._stats = {"hits": 0, "misses": 0} def __call__( self, block_ids: list[str], output_tensors: list[Any], layer_idx: int, ) -> dict: """Post-attention hook for ATOM processing. Records anchor hit/miss for routing decisions. """ self._stats["hits"] += len(block_ids) return { "processed_blocks": len(block_ids), "layer_idx": layer_idx, "total_hits": self._stats["hits"], } class vLLMAtomPlugin: """vLLM-ATOM plugin for ContextForge V4.0. Integrates with vLLM via: - pre_attention_hook: called before each attention layer - post_attention_hook: called after each attention layer The plugin handles: 1. RotateKV quantization of pre-RoPE tensors (INVARIANT 10) 2. Anchor-aware KV block routing 3. CLA metadata injection 4. KV-aware worker load balancing """ def __init__(self, config: Optional[ATOMConfig] = None): self._config = config or ATOMConfig() self._pre_hook = PreAttentionHook(self._config) self._post_hook = PostAttentionHook(self._config) self._initialized = False self._worker_id: Optional[str] = None def initialize(self, worker_id: str, vllm_config: dict) -> None: """Initialize plugin with vLLM worker context.""" self._worker_id = worker_id self._initialized = True logger.info(f"ATOM plugin initialized: worker={worker_id}") @property def pre_attention_hook(self) -> PreAttentionHook: """Hook called before attention computation.""" return self._pre_hook @property def post_attention_hook(self) -> PostAttentionHook: """Hook called after attention computation.""" return self._post_hook def is_initialized(self) -> bool: """Check if plugin is initialized.""" return self._initialized def get_stats(self) -> dict: """Return ATOM plugin statistics.""" return { "initialized": self._initialized, "worker_id": self._worker_id, "config": { "enable_quantization": self._config.enable_quantization, "enable_anchor_routing": self._config.enable_anchor_routing, "enable_cla_injection": self._config.enable_cla_injection, "quantization_mode": self._config.quantization_mode, }, "post_stats": self._post_hook._stats, }