| from dataclasses import dataclass |
| from typing import List, Optional, Dict, Any, Tuple |
|
|
| import time |
| import torch |
| from transformers import PreTrainedModel, PreTrainedTokenizerBase |
|
|
| from Segmenter import HSPPlannerConfig, Segmenter |
| from PageEncoder import PageEncoder |
| from QueryPlanner import QueryPlanner |
|
|
|
|
| @dataclass |
| class HSPWrapperConfig: |
| page_size: int = 64 |
| flow_window: int = 4 |
| flash_top_k: int = 22 |
| anchor_pages: int = 4 |
| pad_token_id: Optional[int] = None |
| newline_token_ids: Optional[Tuple[int, ...]] = None |
| newline_token_id: int = 198 |
| sentence_boundary_ids: Optional[Tuple[int, ...]] = None |
|
|
| allow_implicit_query: bool = False |
|
|
| query_block_size: int = 64 |
| use_dynamic_token_weights: bool = True |
| min_length_for_dynamic_weights: int = 256 |
|
|
| use_query_multitoken_semantic: bool = True |
| min_query_tokens_for_multi: int = 4 |
| max_query_tokens_for_multi: int = 32 |
| identity_mean_weight: float = 0.7 |
| identity_max_weight: float = 0.3 |
| use_lexical_overlap_score: bool = True |
| semantic_score_weight: float = 0.7 |
| lexical_score_weight: float = 0.3 |
|
|
|
|
| class HSPBlackBoxWrapper: |
| def __init__( |
| self, |
| model: PreTrainedModel, |
| tokenizer: PreTrainedTokenizerBase, |
| cfg: HSPWrapperConfig, |
| idf_weights: Optional[torch.Tensor] = None, |
| device: Optional[torch.device] = None, |
| ): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.wrap_cfg = cfg |
|
|
| if device is None: |
| device = getattr(model, "device", torch.device("cuda" if torch.cuda.is_available() else "cpu")) |
| self.device = device |
| self.model.to(self.device) |
|
|
| if self.wrap_cfg.pad_token_id is None: |
| if self.tokenizer.pad_token_id is not None: |
| self.wrap_cfg.pad_token_id = int(self.tokenizer.pad_token_id) |
| else: |
| self.wrap_cfg.pad_token_id = 0 |
| else: |
| if ( |
| self.tokenizer.pad_token_id is not None |
| and int(self.wrap_cfg.pad_token_id) != int(self.tokenizer.pad_token_id) |
| ): |
| raise ValueError( |
| f"pad_token_id mismatch: wrapper cfg={self.wrap_cfg.pad_token_id}, tokenizer={self.tokenizer.pad_token_id}." |
| ) |
| if self.wrap_cfg.newline_token_ids is None: |
| auto_newline_ids: List[int] = [] |
| for s in ["\n", "\r\n", "\n\n"]: |
| ids = self.tokenizer.encode(s, add_special_tokens=False) |
| if len(ids) == 1: |
| tid = int(ids[0]) |
| if tid not in auto_newline_ids: |
| auto_newline_ids.append(tid) |
| if auto_newline_ids: |
| self.wrap_cfg.newline_token_ids = tuple(auto_newline_ids) |
| self.wrap_cfg.newline_token_id = auto_newline_ids[0] |
| if self.wrap_cfg.sentence_boundary_ids is None: |
| boundary_ids: List[int] = [] |
| if self.wrap_cfg.newline_token_ids is not None: |
| boundary_ids.extend(int(x) for x in self.wrap_cfg.newline_token_ids) |
|
|
| punct_candidates = ["。", "!", "?", ".", "!", "?"] |
| for s in punct_candidates: |
| ids = self.tokenizer.encode(s, add_special_tokens=False) |
| if len(ids) == 1: |
| tid = int(ids[0]) |
| if tid not in boundary_ids: |
| boundary_ids.append(tid) |
|
|
| if boundary_ids: |
| self.wrap_cfg.sentence_boundary_ids = tuple(boundary_ids) |
| hsp_cfg = HSPPlannerConfig( |
| page_size=cfg.page_size, |
| flow_window=cfg.flow_window, |
| flash_top_k=cfg.flash_top_k, |
| anchor_pages=cfg.anchor_pages, |
| pad_token_id=cfg.pad_token_id, |
| newline_token_id=cfg.newline_token_id, |
| newline_token_ids=cfg.newline_token_ids, |
| sentence_boundary_ids=cfg.sentence_boundary_ids, |
| identity_mean_weight=cfg.identity_mean_weight, |
| identity_max_weight=cfg.identity_max_weight, |
| lambda_semantic=cfg.semantic_score_weight, |
| lambda_lexical=cfg.lexical_score_weight, |
| min_query_tokens_for_multi=cfg.min_query_tokens_for_multi, |
| max_query_tokens_for_multi=cfg.max_query_tokens_for_multi, |
| ) |
| self.hsp_cfg = hsp_cfg |
| boundary_ids = ( |
| hsp_cfg.sentence_boundary_ids |
| if hsp_cfg.sentence_boundary_ids is not None |
| else ( |
| hsp_cfg.newline_token_ids |
| if (hsp_cfg.newline_token_ids is not None and len(hsp_cfg.newline_token_ids) > 0) |
| else ((hsp_cfg.newline_token_id,) if hsp_cfg.newline_token_id is not None else tuple()) |
| ) |
| ) |
| self.segmenter = Segmenter( |
| cfg=hsp_cfg, |
| query_block_size=cfg.query_block_size, |
| boundary_ids=boundary_ids, |
| align_explicit_query_pos=False, |
| ) |
|
|
| hidden_dim = model.config.hidden_size |
| self.page_encoder = PageEncoder(hsp_cfg, hidden_dim, idf_weights=idf_weights) |
| self.query_planner = QueryPlanner(hsp_cfg, query_dim=hidden_dim) |
|
|
| self.to(self.device) |
|
|
| def to(self, device: torch.device): |
| self.device = device |
| self.model.to(device) |
| self.segmenter.to(device) |
| self.page_encoder.to(device) |
| self.query_planner.to(device) |
| return self |
|
|
| |
| def _compute_local_token_weights( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| if not self.wrap_cfg.use_dynamic_token_weights: |
| return attention_mask.float() |
|
|
| B, L = input_ids.shape |
| pad_id = self.wrap_cfg.pad_token_id |
| device = input_ids.device |
|
|
| weights = torch.zeros_like(input_ids, dtype=torch.float, device=device) |
|
|
| for b in range(B): |
| valid = attention_mask[b].bool() if attention_mask is not None else (input_ids[b] != pad_id) |
| ids_valid = input_ids[b, valid] |
| L_b = ids_valid.numel() |
|
|
| if L_b == 0 or L_b < self.wrap_cfg.min_length_for_dynamic_weights: |
| weights[b, valid] = 1.0 |
| continue |
|
|
| uniq, counts = ids_valid.unique(return_counts=True) |
| counts = counts.float() |
| w = torch.log1p(L_b / (1.0 + counts)) |
|
|
| w_pos = torch.zeros_like(ids_valid, dtype=torch.float, device=device) |
| for uid, wv in zip(uniq.tolist(), w.tolist()): |
| mask = (ids_valid == uid) |
| if mask.any(): |
| w_pos[mask] = wv |
|
|
| maxv = float(w_pos.max().item()) |
| if maxv > 0: |
| w_pos = w_pos / maxv |
|
|
| weights[b, valid] = w_pos |
|
|
| weights = weights + 1e-6 |
| return weights |
|
|
| |
| def _build_inputs_from_texts( |
| self, |
| contexts: List[str], |
| questions: Optional[List[str]], |
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| B = len(contexts) |
| if questions is None: |
| questions = [""] * B |
| else: |
| assert len(contexts) == len(questions) |
|
|
| ctx_enc = [self.tokenizer.encode(c, add_special_tokens=False) for c in contexts] |
|
|
| implicit_mode = False |
| if self.wrap_cfg.allow_implicit_query: |
| no_query_flags = [((q is None) or (str(q).strip() == "")) for q in questions] |
| if all(no_query_flags): |
| implicit_mode = True |
|
|
| if implicit_mode: |
| q_enc = [[] for _ in range(B)] |
| else: |
| q_enc = [self.tokenizer.encode((q if q is not None else ""), add_special_tokens=False) for q in questions] |
|
|
| input_ids_list = [] |
| explicit_qp_list = [] |
|
|
| for i in range(B): |
| ids_ctx = ctx_enc[i] |
| ids_q = q_enc[i] |
| all_ids = ids_ctx if implicit_mode else (ids_ctx + ids_q) |
| if len(all_ids) == 0: |
| all_ids = [self.wrap_cfg.pad_token_id] |
| qp = 0 |
| else: |
| qp = len(ids_ctx) |
| input_ids_list.append(all_ids) |
| if not implicit_mode: |
| explicit_qp_list.append(qp) |
|
|
| max_len = max(len(x) for x in input_ids_list) |
| pad_id = self.wrap_cfg.pad_token_id |
|
|
| input_ids = torch.full((B, max_len), pad_id, dtype=torch.long, device=self.device) |
| attention_mask = torch.zeros((B, max_len), dtype=torch.long, device=self.device) |
|
|
| for i in range(B): |
| L_i = len(input_ids_list[i]) |
| input_ids[i, :L_i] = torch.tensor(input_ids_list[i], dtype=torch.long, device=self.device) |
| attention_mask[i, :L_i] = 1 |
|
|
| explicit_qp = None if implicit_mode else torch.tensor(explicit_qp_list, dtype=torch.long, device=self.device) |
| return input_ids, attention_mask, explicit_qp |
|
|
| |
| @torch.no_grad() |
| def compress_inputs_for_prefill( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| explicit_query_pos: Optional[torch.Tensor], |
| ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: |
| B, L = input_ids.shape |
| aligned_qp, split_info, layout = self.segmenter.build_layout( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| explicit_query_pos=explicit_query_pos, |
| ) |
| if hasattr(self.model, "get_input_embeddings") and self.model.get_input_embeddings() is not None: |
| embed = self.model.get_input_embeddings() |
| hidden = embed(input_ids) |
| elif hasattr(self.model, "model") and hasattr(self.model.model, "embed_tokens"): |
| hidden = self.model.model.embed_tokens(input_ids) |
| elif hasattr(self.model, "transformer") and hasattr(self.model.transformer, "wte"): |
| hidden = self.model.transformer.wte(input_ids) |
| else: |
| raise RuntimeError("Cannot locate input embedding layer for wrapped model.") |
| token_level_weights = self._compute_local_token_weights(input_ids, attention_mask) |
| block_repr = self.page_encoder( |
| hidden_states=hidden, |
| layout=layout, |
| input_ids=input_ids, |
| token_level_weights=token_level_weights, |
| ) |
| token_valid = layout.token_valid |
| query_hidden_list = [] |
| pos_idx = torch.arange(L, device=self.device) |
|
|
| w_mean = getattr(self.hsp_cfg, "identity_mean_weight", 0.7) |
| w_max = getattr(self.hsp_cfg, "identity_max_weight", 0.3) |
| w_sum = w_mean + w_max |
| if w_sum <= 0: |
| w_mean_eff, w_max_eff = 1.0, 0.0 |
| else: |
| w_mean_eff, w_max_eff = w_mean / w_sum, w_max / w_sum |
|
|
| has_idf = getattr(self.page_encoder, "idf_weights", None) is not None |
|
|
| for b in range(B): |
| split_b = split_info[b] |
| qs, qe = int(split_b.query_start), int(split_b.query_end) |
| if qe < qs: |
| qp = int(aligned_qp[b].item()) |
| query_hidden_list.append(hidden[b, qp]) |
| continue |
| span_mask = (pos_idx >= qs) & (pos_idx <= qe) & token_valid[b] |
| if not span_mask.any(): |
| qp = int(aligned_qp[b].item()) |
| query_hidden_list.append(hidden[b, qp]) |
| continue |
| h_b = hidden[b] |
| ids_b = input_ids[b] |
| if has_idf and input_ids is not None: |
| idf_vec = self.page_encoder.idf_weights |
| idf_b = idf_vec[ids_b] |
| weights = idf_b * span_mask.float() |
| w_sum_b = float(weights.sum().item()) |
| if w_sum_b > 1e-6: |
| mean_b = (h_b * weights.unsqueeze(-1)).sum(dim=0) / w_sum_b |
| else: |
| mean_b = h_b[span_mask].mean(dim=0) |
| else: |
| mean_b = h_b[span_mask].mean(dim=0) |
| max_b, _ = h_b[span_mask].max(dim=0) |
| pooled_b = w_mean_eff * mean_b + w_max_eff * max_b |
| query_hidden_list.append(pooled_b) |
|
|
| query_hidden = torch.stack(query_hidden_list, dim=0) |
| query_token_hidden_list = [] |
| query_token_weight_list = [] |
|
|
| min_q = getattr(self.hsp_cfg, "min_query_tokens_for_multi", 4) |
| max_q = getattr(self.hsp_cfg, "max_query_tokens_for_multi", 32) |
| use_multi = self.wrap_cfg.use_query_multitoken_semantic and min_q > 0 |
|
|
| for b in range(B): |
| split_b = split_info[b] |
| qs, qe = int(split_b.query_start), int(split_b.query_end) |
| if qe < qs or not use_multi: |
| query_token_hidden_list.append(None) |
| query_token_weight_list.append(None) |
| continue |
| span_mask = (pos_idx >= qs) & (pos_idx <= qe) & token_valid[b] |
| idx_span = pos_idx[span_mask] |
| Tq = idx_span.numel() |
| if Tq < min_q: |
| query_token_hidden_list.append(None) |
| query_token_weight_list.append(None) |
| continue |
| h_b = hidden[b] |
| w_b = token_level_weights[b] |
| w_span = w_b[idx_span] |
| top_k = min(Tq, max_q) |
| vals, indices = torch.topk(w_span, k=top_k, largest=True, sorted=True) |
| idx_top = idx_span[indices] |
| h_top = h_b[idx_top] |
| w_top = vals / (vals.sum() + 1e-6) |
| query_token_hidden_list.append(h_top) |
| query_token_weight_list.append(w_top) |
| keep_pages = self.query_planner( |
| block_repr=block_repr, |
| layout=layout, |
| query_hidden=query_hidden, |
| query_pos=aligned_qp, |
| input_ids=input_ids, |
| token_level_weights=token_level_weights, |
| split_results=split_info, |
| query_token_hidden_list=query_token_hidden_list, |
| query_token_weight_list=query_token_weight_list, |
| ) |
| token2page = layout.token2page |
| token_valid = layout.token_valid |
| keep_pages_bool = keep_pages.bool() |
| token2page_clamped = token2page.clamp(min=0) |
| keep_token = token_valid & keep_pages_bool.gather(1, token2page_clamped) |
| boundary_ids = ( |
| self.hsp_cfg.sentence_boundary_ids |
| if self.hsp_cfg.sentence_boundary_ids is not None |
| else ( |
| self.hsp_cfg.newline_token_ids |
| if (self.hsp_cfg.newline_token_ids is not None and len(self.hsp_cfg.newline_token_ids) > 0) |
| else ((self.hsp_cfg.newline_token_id,) if self.hsp_cfg.newline_token_id is not None else tuple()) |
| ) |
| ) |
| kept_context_token_indices: List[List[int]] = [] |
| context_lens: List[int] = [] |
| for b in range(B): |
| qp = int(aligned_qp[b].item()) |
| qp = max(0, min(qp, L)) |
| context_lens.append(qp) |
| if qp <= 0: |
| kept_context_token_indices.append([]) |
| continue |
| keep_ctx = keep_token[b, :qp].clone() |
| valid_ctx = token_valid[b, :qp] |
| if boundary_ids is not None and len(boundary_ids) > 0: |
| ids_slice = input_ids[b, :qp] |
| is_boundary = torch.zeros(qp, dtype=torch.bool, device=input_ids.device) |
| for bid in boundary_ids: |
| is_boundary |= (ids_slice == int(bid)) |
| boundary_pos = torch.nonzero(is_boundary, as_tuple=False).flatten().tolist() |
| start = 0 |
| for p in boundary_pos: |
| end = int(p) + 1 |
| if end <= start: |
| continue |
| if keep_ctx[start:end].any(): |
| keep_ctx[start:end] = valid_ctx[start:end] |
| start = end |
| if start < qp and keep_ctx[start:qp].any(): |
| keep_ctx[start:qp] = valid_ctx[start:qp] |
| kept_idx = torch.nonzero(keep_ctx & valid_ctx, as_tuple=False).flatten().detach().cpu().tolist() |
| kept_context_token_indices.append([int(x) for x in kept_idx]) |
| compressed = self.segmenter.compress( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| layout=layout, |
| keep_pages=keep_pages, |
| query_pos=aligned_qp, |
| ) |
| L_comp = compressed["input_ids"].size(1) |
| stats = { |
| "aligned_query_pos": [int(x) for x in aligned_qp.detach().cpu().tolist()], |
| "context_len": context_lens, |
| "kept_context_token_indices": kept_context_token_indices, |
| "original_len": int(L), |
| "compressed_len": int(L_comp), |
| "compression_ratio": float(L_comp / max(L, 1)), |
| } |
| return compressed, stats |
|
|
| |
| @torch.no_grad() |
| def generate_batch( |
| self, |
| contexts: List[str], |
| questions: List[str], |
| max_new_tokens: int = 128, |
| use_hsp: bool = True, |
| **gen_kwargs, |
| ) -> Dict[str, Any]: |
| assert len(contexts) == len(questions) |
| self.model.eval() |
|
|
| input_ids, attention_mask, explicit_qp = self._build_inputs_from_texts(contexts, questions) |
| prompt_lens = attention_mask.sum(dim=1).tolist() |
|
|
| if not use_hsp: |
| t0 = time.time() |
| outputs = self.model.generate( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| max_new_tokens=max_new_tokens, |
| use_cache=True, |
| **gen_kwargs, |
| ) |
| t1 = time.time() |
|
|
| texts: List[str] = [] |
| for i in range(outputs.size(0)): |
| pl = int(prompt_lens[i]) if i < len(prompt_lens) else 0 |
| answer_tokens = outputs[i, pl:] |
| texts.append(self.tokenizer.decode(answer_tokens, skip_special_tokens=True)) |
|
|
| meta = { |
| "mode": "dense", |
| "time_total": float(t1 - t0), |
| "original_len": int(input_ids.size(1)), |
| "compressed_len": int(input_ids.size(1)), |
| "compression_ratio": 1.0, |
| } |
| return {"outputs": texts, "meta": meta} |
| t0 = time.time() |
| compressed_inputs, stats = self.compress_inputs_for_prefill( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| explicit_query_pos=explicit_qp, |
| ) |
| t1 = time.time() |
|
|
| comp_attn = compressed_inputs["attention_mask"] |
| comp_prompt_lens = comp_attn.sum(dim=1).tolist() |
| gen_outputs = self.model.generate( |
| input_ids=compressed_inputs["input_ids"], |
| attention_mask=compressed_inputs["attention_mask"], |
| max_new_tokens=max_new_tokens, |
| use_cache=True, |
| **gen_kwargs, |
| ) |
| t2 = time.time() |
|
|
| texts: List[str] = [] |
| for i in range(gen_outputs.size(0)): |
| pl = int(comp_prompt_lens[i]) if i < len(comp_prompt_lens) else 0 |
| answer_tokens = gen_outputs[i, pl:] |
| texts.append(self.tokenizer.decode(answer_tokens, skip_special_tokens=True)) |
|
|
| meta = { |
| "mode": "hsp", |
| "time_prefill": float(t1 - t0), |
| "time_total": float(t2 - t0), |
| "original_len": stats["original_len"], |
| "compressed_len": stats["compressed_len"], |
| "compression_ratio": stats["compression_ratio"], |
| } |
| return {"outputs": texts, "meta": meta} |