| """ | |
| Shared constants, utilities, and base class for blend test scripts. | |
| """ | |
| import time | |
| from pathlib import Path | |
| from typing import List, Tuple | |
| from transformers import AutoTokenizer, AutoConfig | |
| import sglang as sgl | |
| from sglang.srt.utils.triton_attention_score import warmup_triton_kernels | |
| from qcfuse_config import ( | |
| DEFAULT_CRITICAL_LAYERS, | |
| MODEL_TOP10_CRITICAL_LAYERS, | |
| SUPPORTED_BASELINES, | |
| ) | |
| # ==================== Constants ==================== | |
| DEFAULT_DATA_DIR = Path(__file__).parent | |
| # Frontend-only delimiter. The tokenizer path splits on this string before | |
| # tokenization, so it must not rely on whitespace to avoid token merges. | |
| BLEND_SEP = "<|blendsep|>" | |
| # Model template configurations | |
| TEMPLATES = { | |
| "llama": ( | |
| "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n", | |
| "<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n", | |
| "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", | |
| ), | |
| "mistral": ("<s>[INST]", "", "[/INST]"), | |
| "qwen": ( | |
| "<|im_start|>system\n", | |
| "<|im_end|>\n<|im_start|>user\n", | |
| "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n", | |
| ), | |
| } | |
| # Blend baseline configurations: (style, start, method) | |
| BLEND_CONFIG = { | |
| "fullcomp": ("FULLCOMPUTE", 0, "none"), | |
| "ours": ("KVCOMPUTE", 0, "attn"), | |
| } | |
| def _critical_model_key(model_name: str) -> str: | |
| name_lower = model_name.lower() | |
| if name_lower.startswith("qwen3-8b"): | |
| return "qwen3-8b" | |
| elif name_lower.startswith("qwen3-14b"): | |
| return "qwen3-14b" | |
| elif name_lower.startswith("llama"): | |
| return "llama3.1-8b" | |
| elif name_lower.startswith("mistral"): | |
| return "mistral-7b" | |
| raise ValueError(f"critical layers are not configured for model {model_name}") | |
| def get_critical_layers( | |
| model_name: str, num_layers: int, critical_layers: int = DEFAULT_CRITICAL_LAYERS | |
| ) -> List[int]: | |
| """Return Top-K critical layers as 0-based indices. | |
| critical_layers=-1 selects every layer. | |
| """ | |
| critical_layers = int(critical_layers) | |
| if critical_layers == -1: | |
| return list(range(num_layers)) | |
| model_key = _critical_model_key(model_name) | |
| top_layers = MODEL_TOP10_CRITICAL_LAYERS[model_key] | |
| if critical_layers < 1 or critical_layers > len(top_layers): | |
| raise ValueError( | |
| f"critical_layers must be -1 or an integer in [1, {len(top_layers)}], " | |
| f"got {critical_layers}" | |
| ) | |
| return _validate_explicit_layers( | |
| model_name, | |
| num_layers, | |
| top_layers[:critical_layers], | |
| ) | |
| def _validate_explicit_layers( | |
| model_name: str, num_layers: int, layers: List[int] | |
| ) -> List[int]: | |
| invalid_layers = [layer for layer in layers if layer < 0 or layer >= num_layers] | |
| if invalid_layers: | |
| raise ValueError( | |
| f"critical layers {invalid_layers} are out of range for " | |
| f"model {model_name} with {num_layers} layers" | |
| ) | |
| return layers | |
| def _set_critical_layers(engine, model_name: str, layers: List[int]) -> None: | |
| layers = _validate_explicit_layers( | |
| model_name, | |
| engine._get_model_config()["num_layers"], | |
| [int(layer) for layer in layers], | |
| ) | |
| engine.critical_layers = layers | |
| engine.attn_start, engine.attn_end = 0, max(layers) + 1 | |
| def set_ours_layers( | |
| engine, model_name: str, critical_layers: int = DEFAULT_CRITICAL_LAYERS | |
| ): | |
| """Set the critical layer Top-K used by the ours baseline.""" | |
| num_layers = engine._get_model_config()["num_layers"] | |
| layers = get_critical_layers( | |
| model_name, num_layers, critical_layers=critical_layers | |
| ) | |
| _set_critical_layers(engine, model_name, layers) | |
| # ==================== Base Engine ==================== | |
| class BlendEngineBase: | |
| """Base class with shared blend engine functionality.""" | |
| def __init__(self, model_path: str, baseline: str = "ours"): | |
| self.model_name = Path(model_path).name.lower() | |
| self.model_path = model_path | |
| self.context_length = 32000 | |
| self.attn_start = 0 | |
| self.attn_end = -1 | |
| self.critical_layers = None | |
| self._model_config = None | |
| self.llm = sgl.Engine( | |
| model_path=model_path, | |
| mem_fraction_static=0.8, | |
| context_length=self.context_length, | |
| tp_size=1, | |
| disable_cuda_graph=True, | |
| trust_remote_code=True, | |
| disable_radix_cache=True, | |
| chunked_prefill_size=-1, | |
| dtype="bfloat16", | |
| attention_backend="triton", | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_path, trust_remote_code=True | |
| ) | |
| self.set_baseline(baseline) | |
| def set_baseline(self, baseline: str): | |
| """Switch baseline configuration.""" | |
| if baseline not in SUPPORTED_BASELINES: | |
| raise ValueError( | |
| f"Unsupported baseline={baseline!r}; expected one of " | |
| f"{SUPPORTED_BASELINES}" | |
| ) | |
| self.baseline = baseline | |
| self.critical_layers = None | |
| cfg = BLEND_CONFIG[baseline] | |
| self.first_style, self.start, self.method = cfg | |
| def _get_model_config(self) -> dict: | |
| """Get model architecture parameters (cached).""" | |
| if self._model_config is not None: | |
| return self._model_config | |
| config = AutoConfig.from_pretrained(self.model_path, trust_remote_code=True) | |
| head_dim = getattr(config, "head_dim", None) | |
| if head_dim is None: | |
| head_dim = config.hidden_size // config.num_attention_heads | |
| if getattr(config, "multi_query_attention", False): | |
| num_kv_heads = getattr(config, "multi_query_group_num", 1) | |
| else: | |
| num_kv_heads = getattr( | |
| config, | |
| "num_key_value_heads", | |
| getattr( | |
| config, | |
| "multi_query_group_num", | |
| config.num_attention_heads, | |
| ), | |
| ) | |
| self._model_config = { | |
| "head_dim": head_dim, | |
| "num_layers": getattr(config, "num_hidden_layers", 32), | |
| "num_heads": getattr(config, "num_attention_heads", 32), | |
| "num_kv_heads": num_kv_heads, | |
| } | |
| return self._model_config | |
| def _get_template(self) -> Tuple[str, str, str]: | |
| """Get model template based on model name.""" | |
| for prefix, template in TEMPLATES.items(): | |
| if self.model_name.startswith(prefix): | |
| return template | |
| return ("", "", "") | |
| def _build_prompt( | |
| self, system_prompt: str, docs: List[str], q_prompt: List[str], use_sep: bool | |
| ) -> Tuple[str, str]: | |
| """Build complete prompt from components.""" | |
| sys_h, sys_e, asst_h = self._get_template() | |
| prefix = sys_h + system_prompt + sys_e | |
| suffix = "".join(q_prompt) + "\n\n## Answer\n" + asst_h | |
| if use_sep: | |
| query_sep = BLEND_SEP.join(q_prompt) | |
| return BLEND_SEP.join([prefix] + docs + [suffix]), query_sep | |
| return prefix + "".join(docs) + suffix, suffix | |
| def check_prompt_length( | |
| self, system_prompt: str, docs: List[str], q_prompt: List[str], | |
| max_new_tokens: int, | |
| ) -> Tuple[bool, int]: | |
| """Check if prompt length exceeds context_length.""" | |
| prompt, _ = self._build_prompt(system_prompt, docs, q_prompt, use_sep=False) | |
| token_count = len(self.tokenizer.encode(prompt)) | |
| max_allowed = self.context_length - max_new_tokens | |
| return token_count <= max_allowed, token_count | |
| def _timed_generate(self, prompt: str, params: dict, **kwargs) -> dict: | |
| """Run streaming generate and return {text, ttft, decode_time}.""" | |
| start = time.time() | |
| ttft, text = None, "" | |
| for out in self.llm.generate(prompt, params, stream=True, **kwargs): | |
| if ttft is None and out.get("text"): | |
| ttft = time.time() - start | |
| text = out.get("text", "") | |
| ttft = ttft or (time.time() - start) | |
| return {"text": text, "ttft": ttft, "decode_time": time.time() - start - ttft} | |
| def warmup(self, num_warmup: int = 3): | |
| cfg = self._get_model_config() | |
| warmup_triton_kernels( | |
| head_dims=[cfg["head_dim"]], | |
| num_warmup_iters=3, | |
| num_layers=cfg["num_layers"], | |
| num_heads=cfg["num_heads"], | |
| num_kv_heads=cfg["num_kv_heads"], | |
| ) | |
| sys_h, sys_e, asst_h = self._get_template() | |
| warmup_prompt = ( | |
| sys_h + "You are a helpful assistant." + sys_e | |
| + "Hello, how are you?" + asst_h | |
| ) | |
| for _ in range(num_warmup): | |
| for _ in self.llm.generate( | |
| warmup_prompt, | |
| {"temperature": 0, "max_new_tokens": 1}, | |
| stream=True, | |
| blend_style=None, | |
| ): | |
| pass | |
| def shutdown(self): | |
| if hasattr(self, "llm"): | |
| self.llm.shutdown() | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc, tb): | |
| self.shutdown() | |
| return False | |
Xet Storage Details
- Size:
- 9.15 kB
- Xet hash:
- cc5b1d66290dfbb1ea3d1aced6808d21c800b50b8f4e3492182a2b39cf749a04
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.