| from __future__ import annotations |
|
|
| import argparse |
| import re |
| import unicodedata |
| from pathlib import Path |
|
|
| import numpy as np |
|
|
|
|
| SENTENCE_PUNCT = "。!?.!?。." |
| ELLIPSIS_PUNCT = {"…"} |
| QUOTE_PUNCT = set("\"'“”‘’「」『』") |
|
|
|
|
| def is_text_punctuation(ch: str) -> bool: |
| return ch in ELLIPSIS_PUNCT or unicodedata.category(ch).startswith("P") |
|
|
|
|
| def comma_for_context(left: str, right: str) -> str: |
| context = left + right |
| return "," if re.search(r"[A-Za-z0-9]", context) else "," |
|
|
|
|
| def setup_jieba_cache(repo_dir: Path) -> None: |
| import jieba |
|
|
| cache_dir = repo_dir / ".work_tmp" / "jieba" |
| cache_dir.mkdir(parents=True, exist_ok=True) |
| jieba.dt.tmp_dir = str(cache_dir) |
| jieba.dt.cache_file = "jieba.cache" |
|
|
|
|
| def load_text(args: argparse.Namespace) -> str: |
| if args.text_file: |
| text = Path(args.text_file).read_text(encoding="utf-8") |
| elif args.text: |
| text = args.text |
| else: |
| raise ValueError("Either --text or --text-file is required") |
| text = re.sub(r"[ \t\r\f\v]+", " ", text) |
| text = re.sub(r"\n+", " ", text) |
| return text.strip() |
|
|
|
|
| def normalize_punctuation(text: str) -> str: |
| text = re.sub(r"\s+", " ", text.strip()) |
| text = text.replace("……", "。").replace("...", ".") |
| chars: list[str] = [] |
| i = 0 |
| while i < len(text): |
| ch = text[i] |
| if is_text_punctuation(ch): |
| j = i |
| punct_run = [] |
| while j < len(text) and is_text_punctuation(text[j]): |
| punct_run.append(text[j]) |
| j += 1 |
| if all(p in QUOTE_PUNCT for p in punct_run): |
| i = j |
| continue |
| sentence_marks = [p for p in punct_run if p in SENTENCE_PUNCT] |
| left = chars[-1] if chars else "" |
| right = text[j] if j < len(text) else "" |
| mark = sentence_marks[-1] if sentence_marks else comma_for_context(left, right) |
| if chars and is_text_punctuation(chars[-1]): |
| chars[-1] = mark |
| else: |
| chars.append(mark) |
| i = j |
| continue |
| chars.append(ch) |
| i += 1 |
|
|
| normalized = "".join(chars) |
| normalized = re.sub(r"[ \t]+([,。!?,.!?.。])", r"\1", normalized) |
| normalized = re.sub(r"([,。!?])\s+", r"\1", normalized) |
| normalized = re.sub(r"([,.!?.。])(?=[A-Za-z0-9])", r"\1 ", normalized) |
| normalized = re.sub(r"\s{2,}", " ", normalized) |
| normalized = re.sub(r"([,,])\s*[,,]+", r"\1", normalized) |
| normalized = re.sub(r"[,,]+([。!?.!?.。])", r"\1", normalized) |
| normalized = re.sub(r"([。!?.!?.。])([,,]+)", r"\1", normalized) |
| return normalized.strip() |
|
|
|
|
| def join_units(left: str, right: str) -> str: |
| if not left: |
| return right.strip() |
| right = right.strip() |
| if not right: |
| return left.strip() |
| if re.search(r"[\u4e00-\u9fff]$", left) or re.match(r"^[\u4e00-\u9fff]", right): |
| return left.rstrip() + right |
| return left.rstrip() + " " + right |
|
|
|
|
| def split_units(text: str) -> list[str]: |
| text = text.strip() |
| if not text: |
| return [] |
| pattern = r"[^。!?!?;;,,、::\n]+[。!?!?;;,,、::]?" |
| units = [m.group(0).strip() for m in re.finditer(pattern, text) if m.group(0).strip()] |
| return units or [text] |
|
|
|
|
| def estimate_lengths( |
| prompt_frames: int, |
| prompt_tokens: int, |
| text_tokens: int, |
| speed: float, |
| max_feat_len: int, |
| ) -> dict[str, int | bool]: |
| raw_features_len = int( |
| np.ceil(prompt_frames / prompt_tokens * (prompt_tokens + text_tokens) / speed) |
| ) |
| features_len = min(raw_features_len, max_feat_len) |
| generated_frames = features_len - prompt_frames |
| if generated_frames <= 0: |
| generated_frames = features_len |
| return { |
| "raw_features_len": raw_features_len, |
| "features_len": features_len, |
| "generated_frames": generated_frames, |
| "clamped": raw_features_len > max_feat_len, |
| } |
|
|
|
|
| def token_count(tokenizer, text: str) -> int: |
| return len(tokenizer.texts_to_token_ids([text])[0]) |
|
|
|
|
| def split_long_unit(tokenizer, unit: str, max_text_tokens: int) -> list[str]: |
| if token_count(tokenizer, unit) <= max_text_tokens: |
| return [unit] |
|
|
| if " " in unit: |
| pieces = unit.split() |
| chunks: list[str] = [] |
| current = "" |
| for piece in pieces: |
| candidate = join_units(current, piece) |
| if current and token_count(tokenizer, candidate) > max_text_tokens: |
| chunks.append(current) |
| current = piece |
| else: |
| current = candidate |
| if current: |
| chunks.append(current) |
| return chunks |
|
|
| chunks = [] |
| current = "" |
| for char in unit: |
| candidate = current + char |
| if current and token_count(tokenizer, candidate) > max_text_tokens: |
| chunks.append(current) |
| current = char |
| else: |
| current = candidate |
| if current: |
| chunks.append(current) |
| return chunks |
|
|
|
|
| def build_segments( |
| tokenizer, |
| text: str, |
| prompt_frames: int, |
| prompt_tokens_len: int, |
| speed: float, |
| max_feat_len: int, |
| max_text_tokens: int, |
| min_generated_frames: int, |
| max_generated_frames: int, |
| max_raw_feat_ratio: float, |
| ) -> list[dict[str, object]]: |
| raw_units = split_units(text) |
| units: list[str] = [] |
| for unit in raw_units: |
| units.extend(split_long_unit(tokenizer, unit, max_text_tokens)) |
|
|
| segments: list[dict[str, object]] = [] |
| current = "" |
| for unit in units: |
| candidate = join_units(current, unit) |
| cand_tokens = token_count(tokenizer, candidate) |
| cand_est = estimate_lengths( |
| prompt_frames, |
| prompt_tokens_len, |
| cand_tokens, |
| speed, |
| max_feat_len, |
| ) |
| raw_too_long = int(cand_est["raw_features_len"]) > int(max_feat_len * max_raw_feat_ratio) |
| too_long = ( |
| cand_tokens > max_text_tokens |
| or int(cand_est["generated_frames"]) > max_generated_frames |
| or raw_too_long |
| ) |
| if current and too_long: |
| cur_tokens = token_count(tokenizer, current) |
| cur_est = estimate_lengths( |
| prompt_frames, |
| prompt_tokens_len, |
| cur_tokens, |
| speed, |
| max_feat_len, |
| ) |
| segments.append({"text": current, "text_tokens": cur_tokens, **cur_est}) |
| current = unit |
| else: |
| current = candidate |
|
|
| if current: |
| cur_tokens = token_count(tokenizer, current) |
| cur_est = estimate_lengths( |
| prompt_frames, |
| prompt_tokens_len, |
| cur_tokens, |
| speed, |
| max_feat_len, |
| ) |
| segments.append({"text": current, "text_tokens": cur_tokens, **cur_est}) |
|
|
| if len(segments) >= 2 and int(segments[-1]["generated_frames"]) < min_generated_frames: |
| merged_text = join_units(str(segments[-2]["text"]), str(segments[-1]["text"])) |
| merged_tokens = token_count(tokenizer, merged_text) |
| merged_est = estimate_lengths( |
| prompt_frames, |
| prompt_tokens_len, |
| merged_tokens, |
| speed, |
| max_feat_len, |
| ) |
| raw_ok = int(merged_est["raw_features_len"]) <= int(max_feat_len * max_raw_feat_ratio) |
| if ( |
| merged_tokens <= max_text_tokens |
| and int(merged_est["generated_frames"]) <= max_generated_frames |
| and raw_ok |
| ): |
| segments[-2] = {"text": merged_text, "text_tokens": merged_tokens, **merged_est} |
| segments.pop() |
| else: |
| rebalanced = rebalance_short_tail( |
| tokenizer, |
| str(segments[-2]["text"]), |
| str(segments[-1]["text"]), |
| prompt_frames, |
| prompt_tokens_len, |
| speed, |
| max_feat_len, |
| max_text_tokens, |
| min_generated_frames, |
| max_generated_frames, |
| max_raw_feat_ratio, |
| ) |
| if rebalanced is not None: |
| segments[-2], segments[-1] = rebalanced |
|
|
| return segments |
|
|
|
|
| def split_rebalance_pieces(text: str) -> list[str]: |
| units = split_units(text) |
| if len(units) > 1: |
| return units |
| if " " in text: |
| return text.split() |
| return list(text) |
|
|
|
|
| def segment_record( |
| tokenizer, |
| text: str, |
| prompt_frames: int, |
| prompt_tokens_len: int, |
| speed: float, |
| max_feat_len: int, |
| ) -> dict[str, object]: |
| text_tokens = token_count(tokenizer, text) |
| est = estimate_lengths( |
| prompt_frames, |
| prompt_tokens_len, |
| text_tokens, |
| speed, |
| max_feat_len, |
| ) |
| return {"text": text, "text_tokens": text_tokens, **est} |
|
|
|
|
| def segment_within_limits( |
| record: dict[str, object], |
| max_text_tokens: int, |
| max_generated_frames: int, |
| max_raw_feat: int, |
| ) -> bool: |
| return ( |
| int(record["text_tokens"]) <= max_text_tokens |
| and int(record["generated_frames"]) <= max_generated_frames |
| and int(record["raw_features_len"]) <= max_raw_feat |
| ) |
|
|
|
|
| def rebalance_short_tail( |
| tokenizer, |
| prev_text: str, |
| tail_text: str, |
| prompt_frames: int, |
| prompt_tokens_len: int, |
| speed: float, |
| max_feat_len: int, |
| max_text_tokens: int, |
| min_generated_frames: int, |
| max_generated_frames: int, |
| max_raw_feat_ratio: float, |
| ) -> tuple[dict[str, object], dict[str, object]] | None: |
| max_raw_feat = int(max_feat_len * max_raw_feat_ratio) |
| best: tuple[dict[str, object], dict[str, object]] | None = None |
| best_tail_frames = -1 |
|
|
| piece_sets = [split_rebalance_pieces(prev_text)] |
| if " " in prev_text: |
| word_pieces = prev_text.split() |
| if len(word_pieces) > 1 and word_pieces != piece_sets[0]: |
| piece_sets.append(word_pieces) |
|
|
| for pieces in piece_sets: |
| if len(pieces) < 2: |
| continue |
| for split_at in range(len(pieces) - 1, 0, -1): |
| new_prev = ( |
| join_units("", " ".join(pieces[:split_at])) |
| if " " in prev_text |
| else "".join(pieces[:split_at]) |
| ) |
| moved_tail = ( |
| join_units("", " ".join(pieces[split_at:])) |
| if " " in prev_text |
| else "".join(pieces[split_at:]) |
| ) |
| new_tail = join_units(moved_tail, tail_text) |
| prev_record = segment_record( |
| tokenizer, |
| new_prev, |
| prompt_frames, |
| prompt_tokens_len, |
| speed, |
| max_feat_len, |
| ) |
| tail_record = segment_record( |
| tokenizer, |
| new_tail, |
| prompt_frames, |
| prompt_tokens_len, |
| speed, |
| max_feat_len, |
| ) |
| if not segment_within_limits( |
| tail_record, |
| max_text_tokens, |
| max_generated_frames, |
| max_raw_feat, |
| ): |
| continue |
| if not segment_within_limits( |
| prev_record, |
| max_text_tokens, |
| max_generated_frames, |
| max_raw_feat, |
| ): |
| continue |
| prev_frames = int(prev_record["generated_frames"]) |
| tail_frames = int(tail_record["generated_frames"]) |
| if prev_frames < min_generated_frames: |
| continue |
| if tail_frames > best_tail_frames: |
| best = (prev_record, tail_record) |
| best_tail_frames = tail_frames |
| if tail_frames >= min_generated_frames: |
| return best |
|
|
| return best |
|
|
|
|
| def build_cat_tokens(tokenizer, prompt_tokens: list[int], text_tokens: list[int], max_tokens: int) -> np.ndarray: |
| pad_id = tokenizer.pad_id |
| cat = prompt_tokens + text_tokens + [pad_id] |
| cat_tokens = np.full((1, max_tokens), pad_id, dtype=np.int64) |
| cat_tokens[0, : len(cat)] = np.array(cat, dtype=np.int64) |
| return cat_tokens |
|
|