from __future__ import annotations import argparse import re import unicodedata from pathlib import Path import numpy as np SENTENCE_PUNCT = "。!?.!?。." ELLIPSIS_PUNCT = {"…"} QUOTE_PUNCT = set("\"'“”‘’「」『』") def is_text_punctuation(ch: str) -> bool: return ch in ELLIPSIS_PUNCT or unicodedata.category(ch).startswith("P") def comma_for_context(left: str, right: str) -> str: context = left + right return "," if re.search(r"[A-Za-z0-9]", context) else "," def setup_jieba_cache(repo_dir: Path) -> None: import jieba cache_dir = repo_dir / ".work_tmp" / "jieba" cache_dir.mkdir(parents=True, exist_ok=True) jieba.dt.tmp_dir = str(cache_dir) jieba.dt.cache_file = "jieba.cache" def load_text(args: argparse.Namespace) -> str: if args.text_file: text = Path(args.text_file).read_text(encoding="utf-8") elif args.text: text = args.text else: raise ValueError("Either --text or --text-file is required") text = re.sub(r"[ \t\r\f\v]+", " ", text) text = re.sub(r"\n+", " ", text) return text.strip() def normalize_punctuation(text: str) -> str: text = re.sub(r"\s+", " ", text.strip()) text = text.replace("……", "。").replace("...", ".") chars: list[str] = [] i = 0 while i < len(text): ch = text[i] if is_text_punctuation(ch): j = i punct_run = [] while j < len(text) and is_text_punctuation(text[j]): punct_run.append(text[j]) j += 1 if all(p in QUOTE_PUNCT for p in punct_run): i = j continue sentence_marks = [p for p in punct_run if p in SENTENCE_PUNCT] left = chars[-1] if chars else "" right = text[j] if j < len(text) else "" mark = sentence_marks[-1] if sentence_marks else comma_for_context(left, right) if chars and is_text_punctuation(chars[-1]): chars[-1] = mark else: chars.append(mark) i = j continue chars.append(ch) i += 1 normalized = "".join(chars) normalized = re.sub(r"[ \t]+([,。!?,.!?.。])", r"\1", normalized) normalized = re.sub(r"([,。!?])\s+", r"\1", normalized) normalized = re.sub(r"([,.!?.。])(?=[A-Za-z0-9])", r"\1 ", normalized) normalized = re.sub(r"\s{2,}", " ", normalized) normalized = re.sub(r"([,,])\s*[,,]+", r"\1", normalized) normalized = re.sub(r"[,,]+([。!?.!?.。])", r"\1", normalized) normalized = re.sub(r"([。!?.!?.。])([,,]+)", r"\1", normalized) return normalized.strip() def join_units(left: str, right: str) -> str: if not left: return right.strip() right = right.strip() if not right: return left.strip() if re.search(r"[\u4e00-\u9fff]$", left) or re.match(r"^[\u4e00-\u9fff]", right): return left.rstrip() + right return left.rstrip() + " " + right def split_units(text: str) -> list[str]: text = text.strip() if not text: return [] pattern = r"[^。!?!?;;,,、::\n]+[。!?!?;;,,、::]?" units = [m.group(0).strip() for m in re.finditer(pattern, text) if m.group(0).strip()] return units or [text] def estimate_lengths( prompt_frames: int, prompt_tokens: int, text_tokens: int, speed: float, max_feat_len: int, ) -> dict[str, int | bool]: raw_features_len = int( np.ceil(prompt_frames / prompt_tokens * (prompt_tokens + text_tokens) / speed) ) features_len = min(raw_features_len, max_feat_len) generated_frames = features_len - prompt_frames if generated_frames <= 0: generated_frames = features_len return { "raw_features_len": raw_features_len, "features_len": features_len, "generated_frames": generated_frames, "clamped": raw_features_len > max_feat_len, } def token_count(tokenizer, text: str) -> int: return len(tokenizer.texts_to_token_ids([text])[0]) def split_long_unit(tokenizer, unit: str, max_text_tokens: int) -> list[str]: if token_count(tokenizer, unit) <= max_text_tokens: return [unit] if " " in unit: pieces = unit.split() chunks: list[str] = [] current = "" for piece in pieces: candidate = join_units(current, piece) if current and token_count(tokenizer, candidate) > max_text_tokens: chunks.append(current) current = piece else: current = candidate if current: chunks.append(current) return chunks chunks = [] current = "" for char in unit: candidate = current + char if current and token_count(tokenizer, candidate) > max_text_tokens: chunks.append(current) current = char else: current = candidate if current: chunks.append(current) return chunks def build_segments( tokenizer, text: str, prompt_frames: int, prompt_tokens_len: int, speed: float, max_feat_len: int, max_text_tokens: int, min_generated_frames: int, max_generated_frames: int, max_raw_feat_ratio: float, ) -> list[dict[str, object]]: raw_units = split_units(text) units: list[str] = [] for unit in raw_units: units.extend(split_long_unit(tokenizer, unit, max_text_tokens)) segments: list[dict[str, object]] = [] current = "" for unit in units: candidate = join_units(current, unit) cand_tokens = token_count(tokenizer, candidate) cand_est = estimate_lengths( prompt_frames, prompt_tokens_len, cand_tokens, speed, max_feat_len, ) raw_too_long = int(cand_est["raw_features_len"]) > int(max_feat_len * max_raw_feat_ratio) too_long = ( cand_tokens > max_text_tokens or int(cand_est["generated_frames"]) > max_generated_frames or raw_too_long ) if current and too_long: cur_tokens = token_count(tokenizer, current) cur_est = estimate_lengths( prompt_frames, prompt_tokens_len, cur_tokens, speed, max_feat_len, ) segments.append({"text": current, "text_tokens": cur_tokens, **cur_est}) current = unit else: current = candidate if current: cur_tokens = token_count(tokenizer, current) cur_est = estimate_lengths( prompt_frames, prompt_tokens_len, cur_tokens, speed, max_feat_len, ) segments.append({"text": current, "text_tokens": cur_tokens, **cur_est}) if len(segments) >= 2 and int(segments[-1]["generated_frames"]) < min_generated_frames: merged_text = join_units(str(segments[-2]["text"]), str(segments[-1]["text"])) merged_tokens = token_count(tokenizer, merged_text) merged_est = estimate_lengths( prompt_frames, prompt_tokens_len, merged_tokens, speed, max_feat_len, ) raw_ok = int(merged_est["raw_features_len"]) <= int(max_feat_len * max_raw_feat_ratio) if ( merged_tokens <= max_text_tokens and int(merged_est["generated_frames"]) <= max_generated_frames and raw_ok ): segments[-2] = {"text": merged_text, "text_tokens": merged_tokens, **merged_est} segments.pop() else: rebalanced = rebalance_short_tail( tokenizer, str(segments[-2]["text"]), str(segments[-1]["text"]), prompt_frames, prompt_tokens_len, speed, max_feat_len, max_text_tokens, min_generated_frames, max_generated_frames, max_raw_feat_ratio, ) if rebalanced is not None: segments[-2], segments[-1] = rebalanced return segments def split_rebalance_pieces(text: str) -> list[str]: units = split_units(text) if len(units) > 1: return units if " " in text: return text.split() return list(text) def segment_record( tokenizer, text: str, prompt_frames: int, prompt_tokens_len: int, speed: float, max_feat_len: int, ) -> dict[str, object]: text_tokens = token_count(tokenizer, text) est = estimate_lengths( prompt_frames, prompt_tokens_len, text_tokens, speed, max_feat_len, ) return {"text": text, "text_tokens": text_tokens, **est} def segment_within_limits( record: dict[str, object], max_text_tokens: int, max_generated_frames: int, max_raw_feat: int, ) -> bool: return ( int(record["text_tokens"]) <= max_text_tokens and int(record["generated_frames"]) <= max_generated_frames and int(record["raw_features_len"]) <= max_raw_feat ) def rebalance_short_tail( tokenizer, prev_text: str, tail_text: str, prompt_frames: int, prompt_tokens_len: int, speed: float, max_feat_len: int, max_text_tokens: int, min_generated_frames: int, max_generated_frames: int, max_raw_feat_ratio: float, ) -> tuple[dict[str, object], dict[str, object]] | None: max_raw_feat = int(max_feat_len * max_raw_feat_ratio) best: tuple[dict[str, object], dict[str, object]] | None = None best_tail_frames = -1 piece_sets = [split_rebalance_pieces(prev_text)] if " " in prev_text: word_pieces = prev_text.split() if len(word_pieces) > 1 and word_pieces != piece_sets[0]: piece_sets.append(word_pieces) for pieces in piece_sets: if len(pieces) < 2: continue for split_at in range(len(pieces) - 1, 0, -1): new_prev = ( join_units("", " ".join(pieces[:split_at])) if " " in prev_text else "".join(pieces[:split_at]) ) moved_tail = ( join_units("", " ".join(pieces[split_at:])) if " " in prev_text else "".join(pieces[split_at:]) ) new_tail = join_units(moved_tail, tail_text) prev_record = segment_record( tokenizer, new_prev, prompt_frames, prompt_tokens_len, speed, max_feat_len, ) tail_record = segment_record( tokenizer, new_tail, prompt_frames, prompt_tokens_len, speed, max_feat_len, ) if not segment_within_limits( tail_record, max_text_tokens, max_generated_frames, max_raw_feat, ): continue if not segment_within_limits( prev_record, max_text_tokens, max_generated_frames, max_raw_feat, ): continue prev_frames = int(prev_record["generated_frames"]) tail_frames = int(tail_record["generated_frames"]) if prev_frames < min_generated_frames: continue if tail_frames > best_tail_frames: best = (prev_record, tail_record) best_tail_frames = tail_frames if tail_frames >= min_generated_frames: return best return best def build_cat_tokens(tokenizer, prompt_tokens: list[int], text_tokens: list[int], max_tokens: int) -> np.ndarray: pad_id = tokenizer.pad_id cat = prompt_tokens + text_tokens + [pad_id] cat_tokens = np.full((1, max_tokens), pad_id, dtype=np.int64) cat_tokens[0, : len(cat)] = np.array(cat, dtype=np.int64) return cat_tokens