ZipVoice.AXERA / scripts /text_processing.py
HY-2012's picture
First commit
ea47387 verified
Raw
History Blame Contribute Delete
12.2 kB
from __future__ import annotations
import argparse
import re
import unicodedata
from pathlib import Path
import numpy as np
SENTENCE_PUNCT = "。!?.!?。."
ELLIPSIS_PUNCT = {"…"}
QUOTE_PUNCT = set("\"'“”‘’「」『』")
def is_text_punctuation(ch: str) -> bool:
return ch in ELLIPSIS_PUNCT or unicodedata.category(ch).startswith("P")
def comma_for_context(left: str, right: str) -> str:
context = left + right
return "," if re.search(r"[A-Za-z0-9]", context) else ","
def setup_jieba_cache(repo_dir: Path) -> None:
import jieba
cache_dir = repo_dir / ".work_tmp" / "jieba"
cache_dir.mkdir(parents=True, exist_ok=True)
jieba.dt.tmp_dir = str(cache_dir)
jieba.dt.cache_file = "jieba.cache"
def load_text(args: argparse.Namespace) -> str:
if args.text_file:
text = Path(args.text_file).read_text(encoding="utf-8")
elif args.text:
text = args.text
else:
raise ValueError("Either --text or --text-file is required")
text = re.sub(r"[ \t\r\f\v]+", " ", text)
text = re.sub(r"\n+", " ", text)
return text.strip()
def normalize_punctuation(text: str) -> str:
text = re.sub(r"\s+", " ", text.strip())
text = text.replace("……", "。").replace("...", ".")
chars: list[str] = []
i = 0
while i < len(text):
ch = text[i]
if is_text_punctuation(ch):
j = i
punct_run = []
while j < len(text) and is_text_punctuation(text[j]):
punct_run.append(text[j])
j += 1
if all(p in QUOTE_PUNCT for p in punct_run):
i = j
continue
sentence_marks = [p for p in punct_run if p in SENTENCE_PUNCT]
left = chars[-1] if chars else ""
right = text[j] if j < len(text) else ""
mark = sentence_marks[-1] if sentence_marks else comma_for_context(left, right)
if chars and is_text_punctuation(chars[-1]):
chars[-1] = mark
else:
chars.append(mark)
i = j
continue
chars.append(ch)
i += 1
normalized = "".join(chars)
normalized = re.sub(r"[ \t]+([,。!?,.!?.。])", r"\1", normalized)
normalized = re.sub(r"([,。!?])\s+", r"\1", normalized)
normalized = re.sub(r"([,.!?.。])(?=[A-Za-z0-9])", r"\1 ", normalized)
normalized = re.sub(r"\s{2,}", " ", normalized)
normalized = re.sub(r"([,,])\s*[,,]+", r"\1", normalized)
normalized = re.sub(r"[,,]+([。!?.!?.。])", r"\1", normalized)
normalized = re.sub(r"([。!?.!?.。])([,,]+)", r"\1", normalized)
return normalized.strip()
def join_units(left: str, right: str) -> str:
if not left:
return right.strip()
right = right.strip()
if not right:
return left.strip()
if re.search(r"[\u4e00-\u9fff]$", left) or re.match(r"^[\u4e00-\u9fff]", right):
return left.rstrip() + right
return left.rstrip() + " " + right
def split_units(text: str) -> list[str]:
text = text.strip()
if not text:
return []
pattern = r"[^。!?!?;;,,、::\n]+[。!?!?;;,,、::]?"
units = [m.group(0).strip() for m in re.finditer(pattern, text) if m.group(0).strip()]
return units or [text]
def estimate_lengths(
prompt_frames: int,
prompt_tokens: int,
text_tokens: int,
speed: float,
max_feat_len: int,
) -> dict[str, int | bool]:
raw_features_len = int(
np.ceil(prompt_frames / prompt_tokens * (prompt_tokens + text_tokens) / speed)
)
features_len = min(raw_features_len, max_feat_len)
generated_frames = features_len - prompt_frames
if generated_frames <= 0:
generated_frames = features_len
return {
"raw_features_len": raw_features_len,
"features_len": features_len,
"generated_frames": generated_frames,
"clamped": raw_features_len > max_feat_len,
}
def token_count(tokenizer, text: str) -> int:
return len(tokenizer.texts_to_token_ids([text])[0])
def split_long_unit(tokenizer, unit: str, max_text_tokens: int) -> list[str]:
if token_count(tokenizer, unit) <= max_text_tokens:
return [unit]
if " " in unit:
pieces = unit.split()
chunks: list[str] = []
current = ""
for piece in pieces:
candidate = join_units(current, piece)
if current and token_count(tokenizer, candidate) > max_text_tokens:
chunks.append(current)
current = piece
else:
current = candidate
if current:
chunks.append(current)
return chunks
chunks = []
current = ""
for char in unit:
candidate = current + char
if current and token_count(tokenizer, candidate) > max_text_tokens:
chunks.append(current)
current = char
else:
current = candidate
if current:
chunks.append(current)
return chunks
def build_segments(
tokenizer,
text: str,
prompt_frames: int,
prompt_tokens_len: int,
speed: float,
max_feat_len: int,
max_text_tokens: int,
min_generated_frames: int,
max_generated_frames: int,
max_raw_feat_ratio: float,
) -> list[dict[str, object]]:
raw_units = split_units(text)
units: list[str] = []
for unit in raw_units:
units.extend(split_long_unit(tokenizer, unit, max_text_tokens))
segments: list[dict[str, object]] = []
current = ""
for unit in units:
candidate = join_units(current, unit)
cand_tokens = token_count(tokenizer, candidate)
cand_est = estimate_lengths(
prompt_frames,
prompt_tokens_len,
cand_tokens,
speed,
max_feat_len,
)
raw_too_long = int(cand_est["raw_features_len"]) > int(max_feat_len * max_raw_feat_ratio)
too_long = (
cand_tokens > max_text_tokens
or int(cand_est["generated_frames"]) > max_generated_frames
or raw_too_long
)
if current and too_long:
cur_tokens = token_count(tokenizer, current)
cur_est = estimate_lengths(
prompt_frames,
prompt_tokens_len,
cur_tokens,
speed,
max_feat_len,
)
segments.append({"text": current, "text_tokens": cur_tokens, **cur_est})
current = unit
else:
current = candidate
if current:
cur_tokens = token_count(tokenizer, current)
cur_est = estimate_lengths(
prompt_frames,
prompt_tokens_len,
cur_tokens,
speed,
max_feat_len,
)
segments.append({"text": current, "text_tokens": cur_tokens, **cur_est})
if len(segments) >= 2 and int(segments[-1]["generated_frames"]) < min_generated_frames:
merged_text = join_units(str(segments[-2]["text"]), str(segments[-1]["text"]))
merged_tokens = token_count(tokenizer, merged_text)
merged_est = estimate_lengths(
prompt_frames,
prompt_tokens_len,
merged_tokens,
speed,
max_feat_len,
)
raw_ok = int(merged_est["raw_features_len"]) <= int(max_feat_len * max_raw_feat_ratio)
if (
merged_tokens <= max_text_tokens
and int(merged_est["generated_frames"]) <= max_generated_frames
and raw_ok
):
segments[-2] = {"text": merged_text, "text_tokens": merged_tokens, **merged_est}
segments.pop()
else:
rebalanced = rebalance_short_tail(
tokenizer,
str(segments[-2]["text"]),
str(segments[-1]["text"]),
prompt_frames,
prompt_tokens_len,
speed,
max_feat_len,
max_text_tokens,
min_generated_frames,
max_generated_frames,
max_raw_feat_ratio,
)
if rebalanced is not None:
segments[-2], segments[-1] = rebalanced
return segments
def split_rebalance_pieces(text: str) -> list[str]:
units = split_units(text)
if len(units) > 1:
return units
if " " in text:
return text.split()
return list(text)
def segment_record(
tokenizer,
text: str,
prompt_frames: int,
prompt_tokens_len: int,
speed: float,
max_feat_len: int,
) -> dict[str, object]:
text_tokens = token_count(tokenizer, text)
est = estimate_lengths(
prompt_frames,
prompt_tokens_len,
text_tokens,
speed,
max_feat_len,
)
return {"text": text, "text_tokens": text_tokens, **est}
def segment_within_limits(
record: dict[str, object],
max_text_tokens: int,
max_generated_frames: int,
max_raw_feat: int,
) -> bool:
return (
int(record["text_tokens"]) <= max_text_tokens
and int(record["generated_frames"]) <= max_generated_frames
and int(record["raw_features_len"]) <= max_raw_feat
)
def rebalance_short_tail(
tokenizer,
prev_text: str,
tail_text: str,
prompt_frames: int,
prompt_tokens_len: int,
speed: float,
max_feat_len: int,
max_text_tokens: int,
min_generated_frames: int,
max_generated_frames: int,
max_raw_feat_ratio: float,
) -> tuple[dict[str, object], dict[str, object]] | None:
max_raw_feat = int(max_feat_len * max_raw_feat_ratio)
best: tuple[dict[str, object], dict[str, object]] | None = None
best_tail_frames = -1
piece_sets = [split_rebalance_pieces(prev_text)]
if " " in prev_text:
word_pieces = prev_text.split()
if len(word_pieces) > 1 and word_pieces != piece_sets[0]:
piece_sets.append(word_pieces)
for pieces in piece_sets:
if len(pieces) < 2:
continue
for split_at in range(len(pieces) - 1, 0, -1):
new_prev = (
join_units("", " ".join(pieces[:split_at]))
if " " in prev_text
else "".join(pieces[:split_at])
)
moved_tail = (
join_units("", " ".join(pieces[split_at:]))
if " " in prev_text
else "".join(pieces[split_at:])
)
new_tail = join_units(moved_tail, tail_text)
prev_record = segment_record(
tokenizer,
new_prev,
prompt_frames,
prompt_tokens_len,
speed,
max_feat_len,
)
tail_record = segment_record(
tokenizer,
new_tail,
prompt_frames,
prompt_tokens_len,
speed,
max_feat_len,
)
if not segment_within_limits(
tail_record,
max_text_tokens,
max_generated_frames,
max_raw_feat,
):
continue
if not segment_within_limits(
prev_record,
max_text_tokens,
max_generated_frames,
max_raw_feat,
):
continue
prev_frames = int(prev_record["generated_frames"])
tail_frames = int(tail_record["generated_frames"])
if prev_frames < min_generated_frames:
continue
if tail_frames > best_tail_frames:
best = (prev_record, tail_record)
best_tail_frames = tail_frames
if tail_frames >= min_generated_frames:
return best
return best
def build_cat_tokens(tokenizer, prompt_tokens: list[int], text_tokens: list[int], max_tokens: int) -> np.ndarray:
pad_id = tokenizer.pad_id
cat = prompt_tokens + text_tokens + [pad_id]
cat_tokens = np.full((1, max_tokens), pad_id, dtype=np.int64)
cat_tokens[0, : len(cat)] = np.array(cat, dtype=np.int64)
return cat_tokens