| | |
| | |
| |
|
| | import logging |
| | from dataclasses import dataclass |
| | from typing import Any |
| | from typing import Dict |
| | from typing import List |
| | from typing import Literal |
| | from typing import Optional |
| | from typing import Tuple |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from transformers.cache_utils import DynamicCache |
| |
|
| | from .sliding_utils import drop_tokens_from_cache |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class DuplexWindowConfig: |
| | """双工滑窗配置 |
| | |
| | 滑窗模式: |
| | - "off": 禁用滑窗 |
| | - "basic": 基础滑窗(按 cache 长度触发) |
| | - "context": 带 context 的滑窗(按 unit 数量触发,保留生成文本到 previous) |
| | """ |
| |
|
| | |
| | sliding_window_mode: str = "off" |
| |
|
| | |
| | basic_window_high_tokens: int = 4000 |
| | basic_window_low_tokens: int = 3500 |
| |
|
| | |
| | context_previous_max_tokens: int = 500 |
| | context_max_units: int = 24 |
| |
|
| | |
| | verify_mode: bool = False |
| |
|
| |
|
| | def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("inf")): |
| | logits = logits.clone() |
| |
|
| | |
| | if top_k > 0: |
| | top_k = min(top_k, logits.size(-1)) |
| | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
| | logits[indices_to_remove] = filter_value |
| |
|
| | |
| | if top_p > 0.0: |
| | sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
| | probs = F.softmax(sorted_logits, dim=-1) |
| | cumulative_probs = torch.cumsum(probs, dim=-1) |
| |
|
| | sorted_indices_to_remove = cumulative_probs > top_p |
| | |
| | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
| | sorted_indices_to_remove[..., 0] = 0 |
| |
|
| | indices_to_remove = sorted_indices[sorted_indices_to_remove] |
| | logits[0, indices_to_remove] = filter_value |
| |
|
| | return logits |
| |
|
| |
|
| | class StreamDecoder: |
| | def __init__(self, llm, tokenizer, special_token_ids=None, forbidden_token_ids=None): |
| | self.m = llm |
| | self.tokenizer = tokenizer |
| | self.listen_id = self.tokenizer.eos_token_id |
| |
|
| | self.chunk_eos_id = self.tokenizer.convert_tokens_to_ids("<|chunk_eos|>") |
| | self.chunk_tts_eos_id = self.tokenizer.convert_tokens_to_ids("<|chunk_tts_eos|>") |
| | self.turn_eos_id = self.tokenizer.convert_tokens_to_ids("<|turn_eos|>") |
| | self.speak_id = self.tokenizer.convert_tokens_to_ids("<|speak|>") |
| |
|
| | self.special_token_ids = special_token_ids if special_token_ids is not None else [] |
| |
|
| | |
| | self._all_special_ids = set() |
| | self._all_special_tokens_text = set() |
| | if self.tokenizer: |
| | if hasattr(self.tokenizer, "all_special_ids"): |
| | self._all_special_ids = set(self.tokenizer.all_special_ids) |
| | if hasattr(self.tokenizer, "all_special_tokens"): |
| | self._all_special_tokens_text = set(self.tokenizer.all_special_tokens) |
| |
|
| | custom_special_tokens = [ |
| | "<unit>", |
| | "</unit>", |
| | "<image>", |
| | "</image>", |
| | "<slice>", |
| | "</slice>", |
| | "<|listen|>", |
| | "<|speak|>", |
| | "<|tts_bos|>", |
| | "<|tts_eos|>", |
| | "<|audio_start|>", |
| | "<|audio_end|>", |
| | "<|chunk_eos|>", |
| | "<|chunk_tts_eos|>", |
| | "<|turn_eos|>", |
| | "<|audio_start|>", |
| | "<|audio_end|>", |
| | ] |
| | self._all_special_tokens_text.update(custom_special_tokens) |
| | for token in custom_special_tokens: |
| | token_id = self.tokenizer.convert_tokens_to_ids(token) |
| | if token_id is not None and token_id != self.tokenizer.unk_token_id: |
| | self._all_special_ids.add(token_id) |
| |
|
| | if forbidden_token_ids is None: |
| | self.forbidden_token_ids = [] |
| | elif isinstance(forbidden_token_ids, int): |
| | self.forbidden_token_ids = [self.forbidden_token_ids] |
| | else: |
| | self.forbidden_token_ids = forbidden_token_ids |
| | self.forbidden_token_ids.append(self.chunk_eos_id) |
| |
|
| | assert isinstance(self.forbidden_token_ids, list) |
| |
|
| | self.cache = None |
| | self.context = "" |
| | self.generated_tokens = [] |
| | self.generated_special_tokens = [] |
| | self.reset() |
| | self.embeds = None |
| | self.system_embeds = None |
| |
|
| | |
| | self._unit_history: List[Dict[str, Any]] = [] |
| | self._next_unit_id: int = 0 |
| | self._pending_unit_id: Optional[int] = None |
| | self._pending_unit_start_cache_len: int = 0 |
| | self._system_preserve_length: int = 0 |
| | self._position_offset: int = 0 |
| | self._window_config = DuplexWindowConfig() |
| | self._window_enabled: bool = True |
| | self._rope_inv_freq_cache: Dict[Tuple, torch.Tensor] = {} |
| |
|
| | |
| | |
| | |
| | |
| | self._preserve_prefix_length: int = 0 |
| | self._previous_content_length: int = 0 |
| | self._suffix_token_ids: List[int] = [] |
| |
|
| | |
| | self._previous_marker: str = "\n\nprevious: " |
| | self._previous_marker_token_ids: List[int] = [] |
| | self._has_previous: bool = False |
| |
|
| | |
| | self._previous_text: str = "" |
| | self._previous_token_ids: List[int] = [] |
| |
|
| | |
| | self._sliding_event_count: int = 0 |
| | self._total_dropped_tokens: int = 0 |
| | self._total_dropped_units: int = 0 |
| |
|
| | def sliding_embeds(self): |
| | |
| | |
| | |
| | |
| | pass |
| |
|
| | def reset(self): |
| | self.context = "" |
| | self.cache = None |
| | self.generated_tokens = [] |
| | self.generated_special_tokens = [] |
| | self.embeds = None |
| | self.system_embeds = None |
| |
|
| | |
| | old_unit_count = len(self._unit_history) if hasattr(self, "_unit_history") else 0 |
| | self._unit_history = [] |
| | self._next_unit_id = 0 |
| | self._pending_unit_id = None |
| | self._pending_unit_start_cache_len = 0 |
| | self._system_preserve_length = 0 |
| | self._position_offset = 0 |
| | self._rope_inv_freq_cache = {} |
| |
|
| | |
| | self._preserve_prefix_length = 0 |
| | self._previous_content_length = 0 |
| | self._suffix_token_ids = [] |
| | self._previous_marker = "\n\nprevious: " |
| | self._previous_marker_token_ids = [] |
| | self._has_previous = False |
| | self._previous_text = "" |
| | self._previous_token_ids = [] |
| |
|
| | |
| | self._sliding_event_count = 0 |
| | self._total_dropped_tokens = 0 |
| | self._total_dropped_units = 0 |
| |
|
| | if old_unit_count > 0: |
| | logger.info("[SW] reset: cleared %d units, all sliding window state reset", old_unit_count) |
| |
|
| | def get_cache_length(self) -> int: |
| | if self.cache is None: |
| | return 0 |
| | if isinstance(self.cache, DynamicCache): |
| | if len(self.cache.key_cache) > 0 and self.cache.key_cache[0].numel() > 0: |
| | return self.cache.key_cache[0].shape[2] |
| | return 0 |
| | |
| | return self.cache[0][0].shape[2] |
| |
|
| | def get_total_generated_tokens(self) -> int: |
| | return sum(len(u.get("generated_tokens", [])) for u in self._unit_history) |
| |
|
| | def register_unit_start(self) -> int: |
| | self._pending_unit_id = self._next_unit_id |
| | self._pending_unit_start_cache_len = self.get_cache_length() |
| | logger.info( |
| | "[SW] unit_start: pending_unit_id=%d, cache_len=%d, preserve=%d, units=%d", |
| | self._pending_unit_id, |
| | self._pending_unit_start_cache_len, |
| | self._system_preserve_length, |
| | len(self._unit_history), |
| | ) |
| | return self._pending_unit_id |
| |
|
| | def register_unit_end( |
| | self, |
| | input_type: str, |
| | generated_tokens: Optional[List[int]] = None, |
| | is_listen: bool = False, |
| | generated_text: Optional[str] = None, |
| | ): |
| | """在 unit 结束时调用,记录该 unit 的信息 |
| | |
| | 应在 feed </unit> token 之后调用 |
| | |
| | Args: |
| | input_type: "audio" / "video" / "omni" / "system" |
| | generated_tokens: 该 unit 生成的 tokens(token ids) |
| | is_listen: 是否是 listen 状态 |
| | generated_text: 该 unit 生成的文本(用于 context 保留模式) |
| | """ |
| | if self._pending_unit_id is None: |
| | logger.warning("register_unit_end called without register_unit_start") |
| | return |
| |
|
| | |
| | current_cache_len = self.get_cache_length() |
| | unit_len = current_cache_len - self._pending_unit_start_cache_len |
| |
|
| | if unit_len > 0: |
| | entry = { |
| | "unit_id": self._pending_unit_id, |
| | "length": unit_len, |
| | "type": input_type, |
| | "generated_tokens": generated_tokens or [], |
| | "generated_text": generated_text or "", |
| | "is_listen": is_listen, |
| | } |
| | self._unit_history.append(entry) |
| | gen_count = len(generated_tokens) if generated_tokens else 0 |
| | gen_text_preview = ( |
| | (generated_text[:30] + "...") if generated_text and len(generated_text) > 30 else (generated_text or "") |
| | ) |
| | logger.info( |
| | "[SW] unit_end: unit_id=%d type=%s len=%d gen_tokens=%d is_listen=%s | " |
| | "cache=%d preserve=%d total_units=%d | text='%s'", |
| | self._pending_unit_id, |
| | input_type, |
| | unit_len, |
| | gen_count, |
| | is_listen, |
| | current_cache_len, |
| | self._system_preserve_length, |
| | len(self._unit_history), |
| | gen_text_preview, |
| | ) |
| | else: |
| | logger.warning( |
| | "[SW] unit_end: unit_id=%d has zero length (start=%d, current=%d), not recorded", |
| | self._pending_unit_id, |
| | self._pending_unit_start_cache_len, |
| | current_cache_len, |
| | ) |
| |
|
| | self._pending_unit_id = None |
| | self._pending_unit_start_cache_len = 0 |
| | self._next_unit_id += 1 |
| |
|
| | def register_system_prompt(self): |
| | """在 system prompt prefill 完成后调用,记录保护长度""" |
| | self._system_preserve_length = self.get_cache_length() |
| | logger.info( |
| | "[SW] system_prompt registered: preserve_length=%d (will be protected from sliding)", |
| | self._system_preserve_length, |
| | ) |
| |
|
| | |
| |
|
| | def _get_rope_theta(self) -> float: |
| | """获取模型的 rope_theta 配置""" |
| | return float(getattr(self.m.config, "rope_theta", 10000.0)) |
| |
|
| | def _drop_tokens_from_cache(self, length: int) -> bool: |
| | """从 cache 中移除指定数量的 tokens(保护 system prompt) |
| | |
| | 移除位于 [preserve, preserve + length) 区间的 tokens |
| | 支持 DynamicCache 和 tuple cache 两种格式 |
| | """ |
| | if self.cache is None or length <= 0: |
| | logger.warning("[SW] _drop_tokens_from_cache: cache is None or length<=0 (length=%d)", length) |
| | return False |
| |
|
| | cache_type = "DynamicCache" if isinstance(self.cache, DynamicCache) else "TupleCache" |
| | cache_len_before = self.get_cache_length() |
| | offset_before = self._position_offset |
| |
|
| | logger.debug( |
| | "[SW] _drop_tokens_from_cache: type=%s, drop=%d tokens from [%d, %d), cache=%d, preserve=%d", |
| | cache_type, |
| | length, |
| | self._system_preserve_length, |
| | self._system_preserve_length + length, |
| | cache_len_before, |
| | self._system_preserve_length, |
| | ) |
| |
|
| | new_cache, new_offset, success = drop_tokens_from_cache( |
| | cache=self.cache, |
| | length=length, |
| | preserve=self._system_preserve_length, |
| | position_offset=self._position_offset, |
| | rope_theta=self._get_rope_theta(), |
| | inv_freq_cache=self._rope_inv_freq_cache, |
| | ) |
| | if success: |
| | self.cache = new_cache |
| | self._position_offset = new_offset |
| |
|
| | if success: |
| | logger.debug( |
| | "[SW] _drop_tokens_from_cache: SUCCESS cache %d -> %d, offset %d -> %d (RoPE reindexed)", |
| | cache_len_before, |
| | self.get_cache_length(), |
| | offset_before, |
| | self._position_offset, |
| | ) |
| | else: |
| | logger.error( |
| | "[SW] _drop_tokens_from_cache: FAILED to drop %d tokens (cache=%d, preserve=%d)", |
| | length, |
| | cache_len_before, |
| | self._system_preserve_length, |
| | ) |
| |
|
| | return success |
| |
|
| | def _drop_unit(self, unit_id: int) -> bool: |
| | """移除指定 unit""" |
| | entries = [u for u in self._unit_history if u["unit_id"] == unit_id] |
| | if not entries: |
| | logger.warning("[SW] _drop_unit: unit_id=%d not found", unit_id) |
| | return False |
| |
|
| | total_len = sum(e["length"] for e in entries) |
| | if total_len <= 0: |
| | logger.warning("[SW] _drop_unit: unit_id=%d has zero total length, removing from history", unit_id) |
| | for e in entries: |
| | self._unit_history.remove(e) |
| | return False |
| |
|
| | cache_before = self.get_cache_length() |
| | if not self._drop_tokens_from_cache(total_len): |
| | logger.error( |
| | "[SW] _drop_unit: failed to drop %d tokens for unit_id=%d from cache (cache=%d, preserve=%d)", |
| | total_len, |
| | unit_id, |
| | cache_before, |
| | self._system_preserve_length, |
| | ) |
| | return False |
| |
|
| | cache_after = self.get_cache_length() |
| | for e in entries: |
| | gen_count = len(e.get("generated_tokens", [])) |
| | logger.info( |
| | "[SW] 🗑️ DROPPED unit_id=%d type=%s len=%d gen_tokens=%d | cache %d -> %d, offset=%d", |
| | e["unit_id"], |
| | e["type"], |
| | e["length"], |
| | gen_count, |
| | cache_before, |
| | cache_after, |
| | self._position_offset, |
| | ) |
| | self._unit_history.remove(e) |
| |
|
| | return True |
| |
|
| | def _drop_next_unit(self) -> bool: |
| | """移除最早的一个非 system unit""" |
| | for entry in self._unit_history: |
| | unit_id = entry.get("unit_id") |
| | if unit_id is None: |
| | continue |
| | |
| | if entry.get("type") == "system": |
| | logger.debug("[SW] _drop_next_unit: skipping system unit_id=%d", unit_id) |
| | continue |
| | logger.debug("[SW] _drop_next_unit: attempting to drop unit_id=%d", unit_id) |
| | if self._drop_unit(unit_id): |
| | return True |
| | logger.debug("[SW] _drop_next_unit: no droppable unit found in %d units", len(self._unit_history)) |
| | return False |
| |
|
| | def enforce_window(self) -> bool: |
| | """强制执行滑窗策略(与单工保持一致,只看 cache 长度) |
| | |
| | 当 cache 长度超过高水位线时,循环移除最早的 unit, |
| | 直到 cache 长度降到低水位线以下。 |
| | """ |
| | if not self._window_enabled: |
| | logger.info("[SW] enforce_window: window disabled, skip") |
| | return False |
| |
|
| | cfg = self._window_config |
| | cache_len_before = self.get_cache_length() |
| |
|
| | if cache_len_before <= cfg.basic_window_high_tokens: |
| | logger.debug( |
| | "[SW] enforce_window: cache=%d <= high_water=%d, no sliding needed", |
| | cache_len_before, |
| | cfg.basic_window_high_tokens, |
| | ) |
| | return False |
| |
|
| | |
| | logger.info( |
| | "[SW] ⚡ SLIDING TRIGGERED: cache=%d > high_water=%d, target=low_water=%d", |
| | cache_len_before, |
| | cfg.basic_window_high_tokens, |
| | cfg.basic_window_low_tokens, |
| | ) |
| |
|
| | dropped_count = 0 |
| | cache_len = cache_len_before |
| | while cache_len > cfg.basic_window_low_tokens: |
| | if not self._drop_next_unit(): |
| | logger.warning("[SW] enforce_window: no more units to drop, stopping") |
| | break |
| | dropped_count += 1 |
| | cache_len = self.get_cache_length() |
| |
|
| | if dropped_count > 0: |
| | |
| | self._sliding_event_count += 1 |
| | self._total_dropped_tokens += cache_len_before - cache_len |
| | self._total_dropped_units += dropped_count |
| |
|
| | |
| | expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history) |
| | is_consistent = expected == cache_len |
| | logger.info( |
| | "[SW] ✅ SLIDING DONE: cache %d -> %d, dropped %d units, remaining %d units | " |
| | "consistency: expected=%d actual=%d %s", |
| | cache_len_before, |
| | cache_len, |
| | dropped_count, |
| | len(self._unit_history), |
| | expected, |
| | cache_len, |
| | "✓" if is_consistent else "✗ MISMATCH!", |
| | ) |
| | if not is_consistent: |
| | logger.error( |
| | "[SW] ❌ CONSISTENCY ERROR! preserve=%d + sum(units)=%d != cache=%d, offset=%d", |
| | self._system_preserve_length, |
| | sum(u["length"] for u in self._unit_history), |
| | cache_len, |
| | self._position_offset, |
| | ) |
| |
|
| | return dropped_count > 0 |
| |
|
| | |
| |
|
| | def register_system_prompt_with_context( |
| | self, |
| | suffix_token_ids: Optional[List[int]] = None, |
| | context_previous_marker: str = "\n\nprevious: ", |
| | ): |
| | """注册 system prompt(带 context 保留模式) |
| | |
| | 初始化时 Cache 布局: [prefix] [suffix] [units...] |
| | 首次滑窗后布局: [prefix] [context_previous_marker + content] [suffix] [units...] |
| | |
| | 调用此方法时,cache 中应该只有 prefix(不含 previous 标志) |
| | suffix 会在后续 feed 进去 |
| | |
| | Args: |
| | suffix_token_ids: suffix 的 token ids(如 <|im_end|> 的 id) |
| | context_previous_marker: previous 标志前缀,如 "\\n\\nprevious: " |
| | """ |
| | |
| | self._preserve_prefix_length = self.get_cache_length() |
| | self._previous_content_length = 0 |
| | self._suffix_token_ids = suffix_token_ids or [] |
| | |
| | self._system_preserve_length = self._preserve_prefix_length + len(self._suffix_token_ids) |
| |
|
| | |
| | self._previous_marker = context_previous_marker |
| | self._previous_marker_token_ids = ( |
| | self.tokenizer.encode(context_previous_marker, add_special_tokens=False) if self.tokenizer else [] |
| | ) |
| | self._has_previous = False |
| | self._previous_text = "" |
| | self._previous_token_ids = [] |
| |
|
| | logger.info( |
| | "[SW-CTX] system_prompt registered: prefix_len=%d, suffix_len=%d, marker='%s' (%d tokens)", |
| | self._preserve_prefix_length, |
| | len(self._suffix_token_ids), |
| | context_previous_marker.replace("\n", "\\n"), |
| | len(self._previous_marker_token_ids), |
| | ) |
| | self.log_cache_layout("After register_system_prompt") |
| |
|
| | def _extract_generated_text(self, units: List[Dict[str, Any]]) -> Tuple[str, List[int]]: |
| | """从 units 中提取生成的文本和 token ids |
| | |
| | Args: |
| | units: 要提取的 unit 列表 |
| | |
| | Returns: |
| | (text, token_ids): 拼接后的文本和 token ids(过滤掉 special tokens) |
| | """ |
| | text_parts = [] |
| | token_ids = [] |
| |
|
| | for u in units: |
| | |
| | if u.get("is_listen", False): |
| | continue |
| | gen_text = u.get("generated_text", "") |
| | gen_tokens = u.get("generated_tokens", []) |
| |
|
| | |
| | if gen_text: |
| | clean_text = gen_text |
| | for st in self._all_special_tokens_text: |
| | clean_text = clean_text.replace(st, "") |
| | if clean_text.strip(): |
| | text_parts.append(clean_text) |
| |
|
| | |
| | if gen_tokens: |
| | filtered_tokens = [t for t in gen_tokens if t not in self._all_special_ids] |
| | token_ids.extend(filtered_tokens) |
| |
|
| | return "".join(text_parts), token_ids |
| |
|
| | def _rebuild_cache_with_previous( |
| | self, |
| | new_previous_tokens: List[int], |
| | units_to_keep_len: Optional[int] = None, |
| | ) -> bool: |
| | """重建 cache,把新的 previous 内容插入到 prefix 和 suffix 之间 |
| | |
| | Cache 布局变化: |
| | [prefix] [old_prev] [suffix] [old_units] → [prefix] [new_prev] [suffix] [remaining_units] |
| | |
| | Args: |
| | new_previous_tokens: 新的 previous token ids |
| | units_to_keep_len: 需要保留的 units 长度(从 cache 末尾往回算) |
| | 如果为 None,根据 unit_history 计算 |
| | |
| | Returns: |
| | 是否成功重建 |
| | """ |
| | if self.cache is None: |
| | logger.warning("[SW-CTX] _rebuild_cache_with_previous: cache is None") |
| | return False |
| |
|
| | old_previous_len = self._previous_content_length |
| | new_previous_len = len(new_previous_tokens) |
| | suffix_len = len(self._suffix_token_ids) |
| | total_cache_len = self.get_cache_length() |
| |
|
| | |
| | if units_to_keep_len is None: |
| | units_to_keep_len = sum(u["length"] for u in self._unit_history) |
| |
|
| | |
| | |
| | if new_previous_len == 0 and old_previous_len == 0: |
| | |
| | |
| | preserve_len = self._preserve_prefix_length + suffix_len |
| |
|
| | |
| | |
| | if units_to_keep_len > 0: |
| | |
| | prefix_suffix_cache = self._slice_cache(0, preserve_len) |
| | units_cache = self._slice_cache(total_cache_len - units_to_keep_len, None) |
| |
|
| | |
| | dropped_tokens = total_cache_len - preserve_len - units_to_keep_len |
| |
|
| | |
| | |
| | if dropped_tokens > 0: |
| | old_start = preserve_len + dropped_tokens |
| | new_start = preserve_len |
| | logger.debug( |
| | "[SW-CTX] RoPE reindex (no-op path): old_pos=[%d:%d] -> new_pos=[%d:%d], length=%d", |
| | old_start, |
| | old_start + units_to_keep_len, |
| | new_start, |
| | new_start + units_to_keep_len, |
| | units_to_keep_len, |
| | ) |
| | units_cache = self._reindex_rope_for_cache(units_cache, old_start, new_start, units_to_keep_len) |
| |
|
| | self.cache = self._concat_caches(prefix_suffix_cache, units_cache) |
| | else: |
| | self.cache = self._slice_cache(0, preserve_len) |
| |
|
| | logger.info( |
| | "[SW-CTX] _rebuild_cache_with_previous (no-op): previous unchanged (0->0), " |
| | "just removed unit from cache, cache=%d, units_kept=%d", |
| | self.get_cache_length(), |
| | units_to_keep_len, |
| | ) |
| | return True |
| |
|
| | |
| | prefix_end = self._preserve_prefix_length |
| | prefix_cache = self._slice_cache(0, prefix_end) |
| |
|
| | |
| | units_start_in_old_cache = total_cache_len - units_to_keep_len |
| | units_cache = None |
| | if units_to_keep_len > 0: |
| | units_cache = self._slice_cache(units_start_in_old_cache, None) |
| |
|
| | |
| | |
| | prev_suffix_tokens = new_previous_tokens + self._suffix_token_ids |
| | prev_suffix_len = len(prev_suffix_tokens) |
| |
|
| | new_prefix_prev_suffix_cache = prefix_cache |
| | if prev_suffix_len > 0: |
| | |
| | prev_suffix_embeds = self.embed_tokens(prev_suffix_tokens) |
| | |
| | start_pos = self._preserve_prefix_length + self._position_offset |
| |
|
| | |
| | with torch.no_grad(): |
| | device = prev_suffix_embeds.device |
| | position_ids = torch.arange( |
| | start_pos, |
| | start_pos + prev_suffix_len, |
| | device=device, |
| | ).unsqueeze(0) |
| |
|
| | |
| | outputs = self.m( |
| | inputs_embeds=( |
| | prev_suffix_embeds.unsqueeze(0) if prev_suffix_embeds.dim() == 2 else prev_suffix_embeds |
| | ), |
| | position_ids=position_ids, |
| | past_key_values=prefix_cache, |
| | use_cache=True, |
| | return_dict=True, |
| | ) |
| | |
| | new_prefix_prev_suffix_cache = outputs.past_key_values |
| |
|
| | |
| | |
| | |
| | new_system_total = prefix_end + new_previous_len + suffix_len |
| | if units_cache is not None and self._get_cache_len(units_cache) > 0: |
| | old_start = units_start_in_old_cache |
| | new_start = new_system_total |
| |
|
| | if old_start != new_start: |
| | units_cache = self._reindex_rope_for_cache(units_cache, old_start, new_start, units_to_keep_len) |
| |
|
| | |
| | if units_cache is not None and self._get_cache_len(units_cache) > 0: |
| | self.cache = self._concat_caches(new_prefix_prev_suffix_cache, units_cache) |
| | else: |
| | self.cache = new_prefix_prev_suffix_cache |
| |
|
| | |
| | self._previous_content_length = new_previous_len |
| | |
| | self._system_preserve_length = prefix_end + new_previous_len + suffix_len |
| |
|
| | |
| | prev_text_preview = self._previous_text[:50] + "..." if len(self._previous_text) > 50 else self._previous_text |
| | suffix_preview = self.tokenizer.decode(self._suffix_token_ids) if self._suffix_token_ids else "" |
| | logger.info( |
| | "[SW-CTX] _rebuild_cache_with_previous:\n" |
| | " prefix_len=%d | previous: %d tokens '%s' | suffix: %d tokens '%s'\n" |
| | " cache: %d -> %d, units_kept=%d, preserve=%d", |
| | self._preserve_prefix_length, |
| | new_previous_len, |
| | prev_text_preview, |
| | suffix_len, |
| | suffix_preview, |
| | old_previous_len + self._preserve_prefix_length + suffix_len + units_to_keep_len, |
| | self.get_cache_length(), |
| | units_to_keep_len, |
| | self._system_preserve_length, |
| | ) |
| | return True |
| |
|
| | def _slice_cache(self, start: int, end: Optional[int], clone: bool = True): |
| | """切片 cache |
| | |
| | Args: |
| | start: 起始位置 |
| | end: 结束位置(None 表示到末尾) |
| | clone: 是否克隆(默认 True,防止共享内存问题) |
| | """ |
| | if self.cache is None: |
| | return None |
| | if isinstance(self.cache, DynamicCache): |
| | |
| | new_key_cache = [ |
| | k[:, :, start:end, :].clone() if clone else k[:, :, start:end, :] for k in self.cache.key_cache |
| | ] |
| | new_value_cache = [ |
| | v[:, :, start:end, :].clone() if clone else v[:, :, start:end, :] for v in self.cache.value_cache |
| | ] |
| | new_cache = DynamicCache() |
| | new_cache.key_cache = new_key_cache |
| | new_cache.value_cache = new_value_cache |
| | return new_cache |
| | else: |
| | |
| | if clone: |
| | return tuple( |
| | (layer[0][:, :, start:end, :].clone(), layer[1][:, :, start:end, :].clone()) for layer in self.cache |
| | ) |
| | else: |
| | return tuple((layer[0][:, :, start:end, :], layer[1][:, :, start:end, :]) for layer in self.cache) |
| |
|
| | def _get_cache_len(self, cache) -> int: |
| | """获取 cache 长度""" |
| | if cache is None: |
| | return 0 |
| | if isinstance(cache, DynamicCache): |
| | if len(cache.key_cache) > 0 and cache.key_cache[0].numel() > 0: |
| | return cache.key_cache[0].shape[2] |
| | return 0 |
| | |
| | if cache and cache[0] and cache[0][0] is not None: |
| | return cache[0][0].shape[2] |
| | return 0 |
| |
|
| | def _concat_caches(self, cache1, cache2): |
| | """拼接两个 cache""" |
| | if cache1 is None: |
| | return cache2 |
| | if cache2 is None: |
| | return cache1 |
| |
|
| | if isinstance(cache1, DynamicCache): |
| | new_cache = DynamicCache() |
| | new_cache.key_cache = [torch.cat([k1, k2], dim=2) for k1, k2 in zip(cache1.key_cache, cache2.key_cache)] |
| | new_cache.value_cache = [ |
| | torch.cat([v1, v2], dim=2) for v1, v2 in zip(cache1.value_cache, cache2.value_cache) |
| | ] |
| | return new_cache |
| | else: |
| | |
| | return tuple( |
| | ( |
| | torch.cat([layer1[0], layer2[0]], dim=2), |
| | torch.cat([layer1[1], layer2[1]], dim=2), |
| | ) |
| | for layer1, layer2 in zip(cache1, cache2) |
| | ) |
| |
|
| | def _reindex_rope_for_cache(self, cache, old_start: int, new_start: int, length: int): |
| | """对 cache 进行 RoPE 位置调整""" |
| | if cache is None or length <= 0: |
| | return cache |
| |
|
| | device = None |
| | if isinstance(cache, DynamicCache): |
| | device = cache.key_cache[0].device if cache.key_cache else None |
| | else: |
| | device = cache[0][0].device if cache and cache[0] else None |
| |
|
| | if device is None: |
| | return cache |
| |
|
| | old_positions = torch.arange(old_start, old_start + length, device=device, dtype=torch.long) |
| | new_positions = torch.arange(new_start, new_start + length, device=device, dtype=torch.long) |
| |
|
| | from .sliding_utils import realign_rotary_suffix |
| |
|
| | rope_theta = self._get_rope_theta() |
| |
|
| | if isinstance(cache, DynamicCache): |
| | new_key_cache = [] |
| | for k in cache.key_cache: |
| | new_k = realign_rotary_suffix(k, old_positions, new_positions, rope_theta, self._rope_inv_freq_cache) |
| | new_key_cache.append(new_k) |
| | cache.key_cache = new_key_cache |
| | return cache |
| | else: |
| | new_cache = [] |
| | for layer in cache: |
| | new_k = realign_rotary_suffix( |
| | layer[0], old_positions, new_positions, rope_theta, self._rope_inv_freq_cache |
| | ) |
| | new_cache.append((new_k, layer[1])) |
| | return tuple(new_cache) |
| |
|
| | def _update_previous( |
| | self, |
| | new_text: str, |
| | new_tokens: List[int], |
| | max_tokens: int, |
| | ) -> None: |
| | """更新 previous 上下文(同时更新 cache) |
| | |
| | 首次滑窗时动态添加 marker + 文本,后续滑窗追加文本 |
| | 超过 max_tokens 时截断内容(保留 marker) |
| | 同时重建 cache 以保持一致 |
| | |
| | Args: |
| | new_text: 新增的文本 |
| | new_tokens: 新增的 token ids |
| | max_tokens: previous 内容的最大 token 数(不含 marker) |
| | """ |
| | marker_len = len(self._previous_marker_token_ids) |
| | tokens_to_drop = 0 |
| |
|
| | |
| | if not new_tokens and not new_text: |
| | logger.info("[SW-CTX] _update_previous: no new content, skip adding to previous") |
| | |
| | self._rebuild_cache_with_previous(self._previous_token_ids) |
| | return |
| |
|
| | if not self._has_previous: |
| | |
| | self._previous_text = new_text |
| | self._previous_token_ids = self._previous_marker_token_ids.copy() + new_tokens |
| | self._has_previous = True |
| | logger.info( |
| | "[SW-CTX] _update_previous: first slide with content, added marker + %d tokens", |
| | len(new_tokens), |
| | ) |
| | else: |
| | |
| | self._previous_text += new_text |
| | self._previous_token_ids.extend(new_tokens) |
| |
|
| | |
| | content_token_count = len(self._previous_token_ids) - marker_len |
| |
|
| | |
| | if content_token_count > max_tokens: |
| | |
| | tokens_to_drop = content_token_count - max_tokens |
| | old_text = self._previous_text |
| | |
| | content_tokens = self._previous_token_ids[marker_len + tokens_to_drop :] |
| | self._previous_token_ids = self._previous_marker_token_ids.copy() + content_tokens |
| | |
| | try: |
| | self._previous_text = self.tokenizer.decode( |
| | content_tokens, |
| | skip_special_tokens=True, |
| | ) |
| | except Exception as e: |
| | logger.warning("[SW-CTX] _update_previous: decode failed: %s", e) |
| |
|
| | |
| | logger.info( |
| | "[SW-CTX] ⚠️ LEFT TRUNCATION: previous exceeded max_tokens=%d\n" |
| | " before: %d content tokens, text='%s'\n" |
| | " after: %d content tokens, text='%s'\n" |
| | " dropped %d tokens from left", |
| | max_tokens, |
| | content_token_count, |
| | old_text[:60] + "..." if len(old_text) > 60 else old_text, |
| | len(content_tokens), |
| | self._previous_text[:60] + "..." if len(self._previous_text) > 60 else self._previous_text, |
| | tokens_to_drop, |
| | ) |
| |
|
| | |
| | self._rebuild_cache_with_previous(self._previous_token_ids) |
| |
|
| | prev_preview = self._previous_text[:80] + "..." if len(self._previous_text) > 80 else self._previous_text |
| | content_len = len(self._previous_token_ids) - marker_len |
| | if tokens_to_drop > 0: |
| | logger.info( |
| | "[SW-CTX] _update_previous: +%d tokens, -%d truncated -> %d content tokens (marker=%d) | '%s'", |
| | len(new_tokens), |
| | tokens_to_drop, |
| | content_len, |
| | marker_len, |
| | prev_preview, |
| | ) |
| | else: |
| | logger.info( |
| | "[SW-CTX] _update_previous: +%d tokens -> %d content tokens (marker=%d) | '%s'", |
| | len(new_tokens), |
| | content_len, |
| | marker_len, |
| | prev_preview, |
| | ) |
| |
|
| | def _drop_unit_with_context( |
| | self, |
| | unit_id: int, |
| | max_previous_tokens: int, |
| | ) -> Tuple[bool, str, List[int]]: |
| | """移除指定 unit 并返回其生成内容(用于 context 保留) |
| | |
| | 流程: |
| | 1. 提取 unit 的生成内容 |
| | 2. 先从 cache 移除 unit(不包括 prefix+previous) |
| | 3. 追加生成内容到 previous |
| | 4. 重建 cache(在 _update_previous 中完成) |
| | |
| | Args: |
| | unit_id: 要移除的 unit ID |
| | max_previous_tokens: previous 最大 token 数 |
| | |
| | Returns: |
| | (success, extracted_text, extracted_tokens): 是否成功,提取的文本和 tokens |
| | """ |
| | entries = [u for u in self._unit_history if u["unit_id"] == unit_id] |
| | if not entries: |
| | logger.warning("[SW-CTX] _drop_unit_with_context: unit_id=%d not found", unit_id) |
| | return False, "", [] |
| |
|
| | |
| | extracted_text, extracted_tokens = self._extract_generated_text(entries) |
| |
|
| | |
| | total_len = sum(e["length"] for e in entries) |
| | if total_len <= 0: |
| | logger.warning("[SW-CTX] _drop_unit_with_context: unit_id=%d has zero length", unit_id) |
| | for e in entries: |
| | self._unit_history.remove(e) |
| | return False, extracted_text, extracted_tokens |
| |
|
| | cache_before = self.get_cache_length() |
| |
|
| | |
| | for e in entries: |
| | self._unit_history.remove(e) |
| |
|
| | |
| | |
| |
|
| | |
| | self._update_previous(extracted_text, extracted_tokens, max_previous_tokens) |
| |
|
| | cache_after = self.get_cache_length() |
| | for e in entries: |
| | logger.info( |
| | "[SW-CTX] 🗑️ DROPPED unit_id=%d type=%s len=%d, extracted=%d chars | cache %d -> %d", |
| | e["unit_id"], |
| | e["type"], |
| | e["length"], |
| | len(extracted_text), |
| | cache_before, |
| | cache_after, |
| | ) |
| |
|
| | return True, extracted_text, extracted_tokens |
| |
|
| | def _drop_next_unit_with_context(self, max_previous_tokens: int) -> bool: |
| | """移除最早的一个非 system unit(带 context 保留)""" |
| | for entry in self._unit_history: |
| | unit_id = entry.get("unit_id") |
| | if unit_id is None: |
| | continue |
| | if entry.get("type") == "system": |
| | continue |
| | success, _, _ = self._drop_unit_with_context(unit_id, max_previous_tokens) |
| | if success: |
| | return True |
| | return False |
| |
|
| | def enforce_window_with_context(self) -> bool: |
| | """带 context 保留的滑窗执行 |
| | |
| | 当 unit 数量超过 max_units 时,移除最早的 unit, |
| | 并将其生成内容累积到 previous。 |
| | Cache 会在 _update_previous 中自动重建。 |
| | |
| | Returns: |
| | 是否执行了滑窗 |
| | """ |
| | if not self._window_enabled: |
| | logger.info("[SW-CTX] enforce_window_with_context: window disabled, skip") |
| | return False |
| |
|
| | cfg = self._window_config |
| |
|
| | if cfg.sliding_window_mode != "context": |
| | |
| | return self.enforce_window() |
| |
|
| | cache_len_before = self.get_cache_length() |
| | units_before = len(self._unit_history) |
| |
|
| | |
| | |
| | if units_before <= cfg.context_max_units: |
| | logger.debug( |
| | "[SW-CTX] enforce_window_with_context: no sliding needed (units=%d/%d)", |
| | units_before, |
| | cfg.context_max_units, |
| | ) |
| | self.log_cache_layout("No sliding (units=%d/%d)" % (units_before, cfg.context_max_units)) |
| | return False |
| |
|
| | slide_tag = "slide #%d" % (self._sliding_event_count + 1) |
| | logger.info( |
| | "[SW-CTX] ⚡ SLIDING TRIGGERED (%s): units=%d > max_units=%d, previous=%d tokens", |
| | slide_tag, |
| | units_before, |
| | cfg.context_max_units, |
| | len(self._previous_token_ids), |
| | ) |
| | self.log_cache_layout("Before %s" % slide_tag) |
| |
|
| | |
| | dropped_count = 0 |
| | while len(self._unit_history) > cfg.context_max_units: |
| | if not self._drop_next_unit_with_context(cfg.context_previous_max_tokens): |
| | logger.warning("[SW-CTX] enforce_window_with_context: no more units to drop") |
| | break |
| |
|
| | dropped_count += 1 |
| |
|
| | cache_len_after = self.get_cache_length() |
| |
|
| | if dropped_count > 0: |
| | |
| | self._sliding_event_count += 1 |
| | self._total_dropped_tokens += cache_len_before - cache_len_after |
| | self._total_dropped_units += dropped_count |
| |
|
| | |
| | expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history) |
| | is_consistent = expected == cache_len_after |
| | logger.info( |
| | "[SW-CTX] ✅ SLIDING DONE: cache %d -> %d, dropped %d units, remaining %d units, " |
| | "previous=%d tokens | consistency: %s", |
| | cache_len_before, |
| | cache_len_after, |
| | dropped_count, |
| | len(self._unit_history), |
| | len(self._previous_token_ids), |
| | "✓" if is_consistent else "✗ MISMATCH!", |
| | ) |
| | self.log_cache_layout("After slide #%d" % self._sliding_event_count) |
| |
|
| | return dropped_count > 0 |
| |
|
| | def get_previous_context(self) -> Tuple[str, List[int]]: |
| | """获取当前累积的 previous context |
| | |
| | Returns: |
| | (previous_text, previous_token_ids): 当前累积的文本和 token ids |
| | """ |
| | return self._previous_text, self._previous_token_ids.copy() |
| |
|
| | |
| |
|
| | def log_cache_layout(self, tag: str = "") -> None: |
| | """打印当前 cache 布局(调试用) |
| | |
| | 根据滑窗模式显示不同的布局信息: |
| | - context 模式:[prefix] [previous] [suffix] [units...] |
| | - 其他模式:[system] [units...] |
| | """ |
| | cache_len = self.get_cache_length() |
| | units_len = sum(u["length"] for u in self._unit_history) |
| |
|
| | if self._window_config.sliding_window_mode == "context": |
| | |
| | prefix_len = self._preserve_prefix_length |
| | prev_len = len(self._previous_token_ids) |
| | suffix_len = len(self._suffix_token_ids) |
| |
|
| | |
| | prev_full = "" |
| | if prev_len > 0 and self.tokenizer: |
| | prev_full = self.tokenizer.decode(self._previous_token_ids) |
| | suffix_text = "" |
| | if suffix_len > 0 and self.tokenizer: |
| | suffix_text = self.tokenizer.decode(self._suffix_token_ids) |
| |
|
| | logger.info( |
| | "[SW-CTX] %s Cache Layout:\n" |
| | " [prefix: %d tokens] [previous: %d tokens] [suffix: %d tokens] [units: %d tokens]\n" |
| | " preserve=%d | cache=%d | has_previous=%s\n" |
| | " previous_full: %s\n" |
| | " suffix: %s", |
| | tag, |
| | prefix_len, |
| | prev_len, |
| | suffix_len, |
| | units_len, |
| | self._system_preserve_length, |
| | cache_len, |
| | self._has_previous, |
| | repr(prev_full) if prev_full else "(empty)", |
| | repr(suffix_text) if suffix_text else "(empty)", |
| | ) |
| | else: |
| | |
| | logger.info( |
| | "[SW] %s Cache Layout: [system: %d] [units: %d] | cache=%d", |
| | tag, |
| | self._system_preserve_length, |
| | units_len, |
| | cache_len, |
| | ) |
| |
|
| | def get_window_stats(self) -> Dict[str, Any]: |
| | """获取滑窗统计信息""" |
| | unit_lengths = [u["length"] for u in self._unit_history] |
| | return { |
| | "cache_length": self.get_cache_length(), |
| | "unit_count": len(self._unit_history), |
| | "unit_lengths": unit_lengths, |
| | "unit_total_length": sum(unit_lengths), |
| | "system_preserve_length": self._system_preserve_length, |
| | "position_offset": self._position_offset, |
| | "window_enabled": self._window_enabled, |
| | "total_generated_tokens": self.get_total_generated_tokens(), |
| | "pending_unit_id": self._pending_unit_id, |
| | "next_unit_id": self._next_unit_id, |
| | "config": { |
| | "sliding_window_mode": self._window_config.sliding_window_mode, |
| | "basic_window_high_tokens": self._window_config.basic_window_high_tokens, |
| | "basic_window_low_tokens": self._window_config.basic_window_low_tokens, |
| | "context_previous_max_tokens": self._window_config.context_previous_max_tokens, |
| | "context_max_units": self._window_config.context_max_units, |
| | }, |
| | |
| | "preserve_prefix_length": self._preserve_prefix_length, |
| | "previous_content_length": self._previous_content_length, |
| | "suffix_token_count": len(self._suffix_token_ids), |
| | "previous_text_length": len(self._previous_text), |
| | "previous_token_count": len(self._previous_token_ids), |
| | "has_system_template": self._system_prompt_template is not None, |
| | } |
| |
|
| | def _verify_consistency(self) -> bool: |
| | """验证 unit 历史与 cache 长度一致""" |
| | expected = self._system_preserve_length + sum(u["length"] for u in self._unit_history) |
| | actual = self.get_cache_length() |
| | return expected == actual |
| |
|
| | def dump_unit_history(self, prefix: str = "") -> None: |
| | """打印当前 unit 历史(调试用)""" |
| | cache_len = self.get_cache_length() |
| | unit_sum = sum(u["length"] for u in self._unit_history) |
| | expected = self._system_preserve_length + unit_sum |
| |
|
| | logger.info( |
| | "[SW] %s=== UNIT HISTORY DUMP === cache=%d, preserve=%d, units=%d, offset=%d", |
| | prefix + " " if prefix else "", |
| | cache_len, |
| | self._system_preserve_length, |
| | len(self._unit_history), |
| | self._position_offset, |
| | ) |
| | logger.info( |
| | "[SW] Consistency: preserve(%d) + sum(units)(%d) = %d, actual=%d, %s", |
| | self._system_preserve_length, |
| | unit_sum, |
| | expected, |
| | cache_len, |
| | "✓ MATCH" if expected == cache_len else "✗ MISMATCH!", |
| | ) |
| | for i, u in enumerate(self._unit_history): |
| | gen_count = len(u.get("generated_tokens", [])) |
| | logger.info( |
| | "[SW] [%d] unit_id=%d type=%-6s len=%4d gen=%3d listen=%s", |
| | i, |
| | u["unit_id"], |
| | u["type"], |
| | u["length"], |
| | gen_count, |
| | u.get("is_listen", False), |
| | ) |
| |
|
| | def print_verification_summary(self) -> Dict[str, Any]: |
| | """打印验证摘要(用于对比 off/basic/context 模式) |
| | |
| | Returns: |
| | 包含关键验证数据的字典 |
| | """ |
| | cfg = self._window_config |
| |
|
| | |
| | all_generated_text = [] |
| | all_generated_tokens = [] |
| | for u in self._unit_history: |
| | if not u.get("is_listen", False): |
| | gen_text = u.get("generated_text", "") |
| | gen_tokens = u.get("generated_tokens", []) |
| | if gen_text: |
| | all_generated_text.append(gen_text) |
| | if gen_tokens: |
| | all_generated_tokens.extend(gen_tokens) |
| |
|
| | combined_text = "".join(all_generated_text) |
| |
|
| | summary = { |
| | "mode": cfg.sliding_window_mode, |
| | "final_cache_length": self.get_cache_length(), |
| | "final_unit_count": len(self._unit_history), |
| | "sliding_event_count": self._sliding_event_count, |
| | "total_dropped_tokens": self._total_dropped_tokens, |
| | "total_dropped_units": self._total_dropped_units, |
| | "total_generated_tokens": len(all_generated_tokens), |
| | "generated_text": combined_text, |
| | "previous_text": self._previous_text, |
| | "previous_token_count": len(self._previous_token_ids), |
| | "position_offset": self._position_offset, |
| | "system_preserve_length": self._system_preserve_length, |
| | } |
| |
|
| | logger.info("=" * 70) |
| | logger.info("[VERIFY] === SLIDING WINDOW VERIFICATION SUMMARY ===") |
| | logger.info("[VERIFY] Mode: %s", cfg.sliding_window_mode) |
| | logger.info("[VERIFY] Final cache length: %d", summary["final_cache_length"]) |
| | logger.info("[VERIFY] Final unit count: %d", summary["final_unit_count"]) |
| | logger.info("[VERIFY] Sliding events: %d", summary["sliding_event_count"]) |
| | logger.info( |
| | "[VERIFY] Total dropped: %d tokens, %d units", |
| | summary["total_dropped_tokens"], |
| | summary["total_dropped_units"], |
| | ) |
| | logger.info("[VERIFY] Total generated tokens: %d", summary["total_generated_tokens"]) |
| | logger.info( |
| | "[VERIFY] Generated text: '%s'", combined_text[:100] + "..." if len(combined_text) > 100 else combined_text |
| | ) |
| | if cfg.sliding_window_mode == "context": |
| | logger.info( |
| | "[VERIFY] Previous content: %d tokens, '%s'", |
| | summary["previous_token_count"], |
| | self._previous_text[:50] + "..." if len(self._previous_text) > 50 else self._previous_text, |
| | ) |
| | logger.info("[VERIFY] Position offset: %d", summary["position_offset"]) |
| | logger.info("[VERIFY] System preserve length: %d", summary["system_preserve_length"]) |
| | logger.info("=" * 70) |
| |
|
| | return summary |
| |
|
| | def set_window_config(self, config: DuplexWindowConfig) -> None: |
| | """设置滑窗配置""" |
| | self._window_config = config |
| | logger.info( |
| | "[SW] Window config set: high_water=%d, low_water=%d", |
| | config.basic_window_high_tokens, |
| | config.basic_window_low_tokens, |
| | ) |
| |
|
| | def set_window_enabled(self, enabled: bool) -> None: |
| | """启用/禁用滑窗""" |
| | old_enabled = self._window_enabled |
| | self._window_enabled = enabled |
| | if old_enabled != enabled: |
| | logger.info("[SW] Window enabled: %s -> %s", old_enabled, enabled) |
| |
|
| | def get_context(self): |
| | return self.context |
| |
|
| | def embed_token(self, tid): |
| | if isinstance(tid, int): |
| | tid = torch.tensor([tid], device=self.m.device) |
| | return self.m.model.embed_tokens(tid) |
| |
|
| | def embed_tokens(self, token_ids: List[int]) -> torch.Tensor: |
| | """批量嵌入多个 tokens |
| | |
| | Args: |
| | token_ids: token id 列表 |
| | |
| | Returns: |
| | embeddings tensor [L, H] |
| | """ |
| | if not token_ids: |
| | return torch.empty(0, self.m.config.hidden_size, device=self.m.device) |
| | tids = torch.tensor(token_ids, device=self.m.device) |
| | return self.m.model.embed_tokens(tids) |
| |
|
| | @torch.no_grad() |
| | def feed(self, embeds: torch.Tensor, return_logits: bool = False): |
| | """ |
| | embeds : [L, H] —— new embedding sequence fed into model at once |
| | """ |
| | L = embeds.size(0) |
| | device = embeds.device |
| |
|
| | past_len = self.get_cache_length() |
| | pos_ids = torch.arange(past_len, past_len + L, device=device).unsqueeze(0) |
| |
|
| | out = self.m( |
| | inputs_embeds=embeds.unsqueeze(0), |
| | position_ids=pos_ids, |
| | past_key_values=self.cache, |
| | |
| | return_dict=True, |
| | output_hidden_states=True, |
| | |
| | ) |
| | self.cache = out.past_key_values |
| |
|
| | if return_logits: |
| | logits = self.m.lm_head(out.hidden_states[-1])[:, -1] |
| | return logits, out.hidden_states[-1] |
| |
|
| | @torch.no_grad() |
| | def decode( |
| | self, |
| | logits, |
| | mode: Literal["sampling", "greedy"] = "sampling", |
| | temperature=0.7, |
| | top_k=20, |
| | top_p=0.8, |
| | listen_top_k=None, |
| | listen_prob_scale=1.0, |
| | text_repetition_penalty=1.05, |
| | text_repetition_window_size=512, |
| | debug_print_top5=False, |
| | ): |
| | """ |
| | Args: |
| | logits: |
| | mode: sampling or greedy |
| | temperature: |
| | top_k: |
| | top_p: |
| | listen_top_k: force listen_id to be in top-k to keep |
| | listen_prob_scale: multiply listen_id probability by a weight (<1 means decrease, >1 means increase) |
| | text_repetition_penalty: repetition penalty coefficient, >1.0 means decrease repetition, <1.0 means increase repetition |
| | text_repetition_window_size: repetition penalty window size |
| | debug_print_top5: whether to print debug information for top 5 tokens |
| | |
| | Sampling strategy: |
| | 1. first sample all tokens with original logits (apply temperature) |
| | 2. if sampled chunk_eos, return directly (keep the original model's decision of when to stop) |
| | 3. if not sampled chunk_eos, mask it (set logit to -inf), continue sampling text tokens |
| | 4. apply repetition penalty, top-k, top-p, etc. to the text tokens for the final sampling |
| | """ |
| |
|
| | logits = logits.clone() |
| |
|
| | |
| | eos_id = self.chunk_eos_id |
| |
|
| | with torch.no_grad(): |
| | if mode == "greedy": |
| | sampled_token = torch.argmax(logits[0]).item() |
| | else: |
| | original_probs = F.softmax(logits[0], dim=-1) |
| | sampled_token = torch.multinomial(original_probs, num_samples=1).item() |
| |
|
| | |
| | if sampled_token == eos_id: |
| | next_token_id = torch.tensor([eos_id], device=logits.device) |
| | next_token_str = self.tokenizer.decode(next_token_id) |
| |
|
| | return next_token_id |
| |
|
| | |
| | if self.forbidden_token_ids: |
| | logits[:, self.forbidden_token_ids] = float("-inf") |
| |
|
| | |
| | if debug_print_top5: |
| | print("🔵" * 30) |
| | print("【BEFORE repetition penalty】施加重复惩罚之前的 Top-k logits") |
| | logits_before_penalty = logits[0] / temperature if mode == "sampling" else logits[0] |
| | topk_logits_before, topk_indices_before = torch.topk( |
| | logits_before_penalty, k=min(5, logits_before_penalty.size(-1)) |
| | ) |
| |
|
| | for i, (token_id, logit_val) in enumerate(zip(topk_indices_before.tolist(), topk_logits_before.tolist())): |
| | token_str = self.tokenizer.decode([token_id]) |
| | |
| | if token_str == "\n": |
| | display_str = "\\n" |
| | elif token_str == " ": |
| | display_str = "[SPACE]" |
| | elif token_str == "": |
| | display_str = "[EMPTY]" |
| | elif token_str == "\t": |
| | display_str = "\\t" |
| | else: |
| | display_str = token_str |
| |
|
| | |
| | special_mark = "" |
| | if token_id == self.listen_id: |
| | special_mark = " 🎧[LISTEN]" |
| | elif token_id == self.tokenizer.eos_token_id: |
| | special_mark = " 🛑[EOS]" |
| |
|
| | print(f" {i + 1:2d}. {display_str:10s}{special_mark:15s} (id={token_id:5d}): logit={logit_val:.4f}") |
| | print("🔵" * 30) |
| |
|
| | |
| | if text_repetition_penalty != 1.0 and len(self.generated_tokens) > 0: |
| | |
| | recent_tokens = self.generated_tokens[-text_repetition_window_size:] |
| |
|
| | |
| | recent_tokens = list(set(recent_tokens)) |
| |
|
| | |
| | for token_id in recent_tokens: |
| | if token_id < logits.size(-1): |
| | if text_repetition_penalty > 1.0: |
| | |
| | logits[0, token_id] /= text_repetition_penalty |
| | else: |
| | |
| | logits[0, token_id] *= 1.0 / text_repetition_penalty |
| |
|
| | if listen_prob_scale != 1.0: |
| | logits[0, self.listen_id] *= listen_prob_scale |
| |
|
| | listen_rank = (logits[0] > logits[0, self.listen_id]).sum().item() |
| |
|
| | |
| | if debug_print_top5: |
| | |
| | logits_before_softmax = logits[0] / temperature if mode == "sampling" else logits[0] |
| | top5_logits_before, top5_indices_before = torch.topk( |
| | logits_before_softmax, k=min(5, logits_before_softmax.size(-1)) |
| | ) |
| |
|
| | print("=" * 20) |
| |
|
| | print("\n📊 Top 5 tokens BEFORE softmax (temperature={:.2f}, mode={}):".format(temperature, mode)) |
| | for i, (token_id, logit_val) in enumerate(zip(top5_indices_before.tolist(), top5_logits_before.tolist())): |
| | token_str = self.tokenizer.decode([token_id]) |
| | |
| | if token_str == "\n": |
| | display_str = "\\n" |
| | elif token_str == " ": |
| | display_str = "[SPACE]" |
| | elif token_str == "": |
| | display_str = "[EMPTY]" |
| | elif token_str == "\t": |
| | display_str = "\\t" |
| | else: |
| | display_str = token_str |
| |
|
| | |
| | special_mark = "" |
| | if token_id == self.listen_id: |
| | special_mark = " 🎧[LISTEN]" |
| | elif token_id == self.tokenizer.eos_token_id: |
| | special_mark = " 🛑[EOS]" |
| |
|
| | print(f" {i + 1}. {display_str:10s}{special_mark:15s} (id={token_id:5d}): logit={logit_val:.4f}") |
| |
|
| | |
| | probs = F.softmax(logits[0] / temperature if mode == "sampling" else logits[0], dim=-1) |
| | top5_probs, top5_indices = torch.topk(probs, k=min(5, probs.size(-1))) |
| |
|
| | print("\n📊 Top 5 tokens AFTER softmax (temperature={:.2f}, mode={}):".format(temperature, mode)) |
| | for i, (token_id, prob) in enumerate(zip(top5_indices.tolist(), top5_probs.tolist())): |
| | token_str = self.tokenizer.decode([token_id]) |
| | |
| | if token_str == "\n": |
| | display_str = "\\n" |
| | elif token_str == " ": |
| | display_str = "[SPACE]" |
| | elif token_str == "": |
| | display_str = "[EMPTY]" |
| | elif token_str == "\t": |
| | display_str = "\\t" |
| | else: |
| | display_str = token_str |
| |
|
| | |
| | special_mark = "" |
| | if token_id == self.listen_id: |
| | special_mark = " 🎧[LISTEN]" |
| | elif token_id == self.tokenizer.eos_token_id: |
| | special_mark = " 🛑[EOS]" |
| |
|
| | print( |
| | f" {i + 1}. {display_str:10s}{special_mark:15s} (id={token_id:5d}): {prob:.4f} ({prob * 100:.2f}%)" |
| | ) |
| | |
| | if self.listen_id not in top5_indices.tolist(): |
| | listen_prob = probs[self.listen_id].item() |
| | print(f" ... <|listen|> 🎧 rank={listen_rank + 1}, prob={listen_prob:.6f} ({listen_prob * 100:.4f}%)") |
| |
|
| | if listen_top_k is not None and listen_rank < listen_top_k: |
| | next_token_id = torch.tensor([self.listen_id], device=logits.device) |
| | next_token_str = self.tokenizer.decode(next_token_id) |
| |
|
| | if next_token_str == "<|listen|>": |
| | self.context += " " |
| | else: |
| | self.context += next_token_str |
| |
|
| | return next_token_id |
| |
|
| | if mode == "greedy": |
| | next_token_id = torch.argmax(logits, dim=-1) |
| | elif mode == "sampling": |
| | logits = logits / temperature |
| | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
| | probs = F.softmax(logits, dim=-1) |
| | next_token_id = torch.multinomial(probs, num_samples=1).squeeze(1) |
| | else: |
| | raise ValueError("Unsupported decode mode") |
| |
|
| | if next_token_id.item() not in self.special_token_ids: |
| | self.generated_tokens.append(next_token_id.item()) |
| | else: |
| | self.generated_special_tokens.append(next_token_id.item()) |
| |
|
| | return next_token_id |
| |
|