ModerRAS's picture
Fix Rust encoded cache label repairs
7934324
raw
history blame
47.5 kB
use anyhow::{bail, Context, Result};
use clap::Parser;
use fancy_regex::Regex as FancyRegex;
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use rand::SeedableRng;
use rayon::prelude::*;
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::fs::{self, File};
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::{Path, PathBuf};
use std::sync::OnceLock;
use std::time::Instant;
const FALLBACK_LABELS: [&str; 37] = [
"O",
"B-TITLE_CHS",
"I-TITLE_CHS",
"B-TITLE_CHT",
"I-TITLE_CHT",
"B-TITLE_JPN",
"I-TITLE_JPN",
"B-TITLE_LATIN",
"I-TITLE_LATIN",
"B-TITLE_MIXED",
"I-TITLE_MIXED",
"B-PATH_TITLE_CHS",
"I-PATH_TITLE_CHS",
"B-PATH_TITLE_CHT",
"I-PATH_TITLE_CHT",
"B-PATH_TITLE_JPN",
"I-PATH_TITLE_JPN",
"B-PATH_TITLE_LATIN",
"I-PATH_TITLE_LATIN",
"B-PATH_TITLE_MIXED",
"I-PATH_TITLE_MIXED",
"B-PATH_SEASON",
"I-PATH_SEASON",
"B-SEASON",
"I-SEASON",
"B-EPISODE",
"I-EPISODE",
"B-SPECIAL",
"I-SPECIAL",
"B-GROUP",
"I-GROUP",
"B-RESOLUTION",
"I-RESOLUTION",
"B-SOURCE",
"I-SOURCE",
"B-TAG",
"I-TAG",
];
const SOURCE_TOKEN_PATTERN: &str = r"WEB[-_ ]?DL|WEB[-_ ]?Rip|BDRip|BluRay|BDMV|BD|DVDRip|DVD|TVRip|HDTV|Netflix|NF|AMZN|Baha|CR|ABEMA|DSNP|U[-_ ]?NEXT|Hulu|AT[-_ ]?X|x26[45]|h\.?26[45]|HEVC|AVC|AV1|AAC\d*(?:\.\d+)?|AAC|FLAC|MP3|DTS|Opus|CHS|CHT|GB|BIG5|JPN?|JPSC|JPTC|繁中|简中";
static RESOLUTION_RE: OnceLock<FancyRegex> = OnceLock::new();
static SOURCE_RE: OnceLock<Regex> = OnceLock::new();
static SOURCE_TAG_RE: OnceLock<Regex> = OnceLock::new();
static SPECIAL_TAG_RE: OnceLock<Regex> = OnceLock::new();
static SPECIAL_CODE_RE: OnceLock<Regex> = OnceLock::new();
static EPISODE_CONTEXT_RE: OnceLock<Regex> = OnceLock::new();
static EPISODE_SPAN_RE: OnceLock<FancyRegex> = OnceLock::new();
static READING_MARKER_RE: OnceLock<FancyRegex> = OnceLock::new();
static ROMAN_MARKER_RE: OnceLock<FancyRegex> = OnceLock::new();
static CJK_MARKER_RE: OnceLock<Regex> = OnceLock::new();
static SPECIAL_CONTEXT_PREFIX_RE: OnceLock<Regex> = OnceLock::new();
const SEPARATOR_CHARS: &[char] = &[' ', '\t', '-', '_', '.', '|', '~', '~'];
#[derive(Parser, Debug)]
#[command(
about = "Build split train/eval encoded AniFileBERT shard caches",
version
)]
struct Args {
#[arg(long)]
input: PathBuf,
#[arg(long)]
vocab_file: PathBuf,
#[arg(long)]
output_dir: PathBuf,
#[arg(long, default_value = "label_schema.json")]
label_schema_file: PathBuf,
#[arg(long, default_value_t = 128)]
max_length: usize,
#[arg(long, default_value_t = 25_000)]
shard_size: usize,
#[arg(long, default_value_t = 0)]
limit_rows: usize,
#[arg(long, default_value_t = 0.98)]
train_split: f64,
#[arg(long, default_value_t = 42)]
seed: u64,
#[arg(long)]
no_shuffle: bool,
#[arg(long, default_value_t = 0)]
threads: usize,
}
#[derive(Debug, Deserialize)]
struct LabelSchema {
labels: Vec<String>,
}
#[derive(Clone)]
struct SourceRow {
row_index: usize,
raw_line: String,
filename: Option<String>,
tokens: Vec<String>,
labels: Vec<String>,
tokenizer_variant: Option<String>,
}
#[derive(Clone)]
struct Vocab {
ids: HashMap<String, u16>,
pad_id: u16,
unk_id: u16,
cls_id: u16,
sep_id: u16,
}
#[derive(Clone)]
struct EncodeContext {
vocab: Vocab,
label_ids: HashMap<String, i16>,
max_length: usize,
}
#[derive(Serialize)]
struct ShardManifest {
rows: usize,
input_ids: String,
attention_mask: String,
labels: String,
}
#[derive(Serialize)]
struct SplitSummary {
split: String,
rows: usize,
shards: usize,
directory: String,
}
fn main() -> Result<()> {
let args = Args::parse();
if args.max_length < 4 {
bail!("--max-length must be at least 4");
}
if args.shard_size == 0 {
bail!("--shard-size must be positive");
}
if !(0.0..1.0).contains(&args.train_split) {
bail!("--train-split must be > 0 and < 1");
}
if args.threads > 0 {
rayon::ThreadPoolBuilder::new()
.num_threads(args.threads)
.build_global()
.context("failed to configure rayon thread pool")?;
}
let started = Instant::now();
let vocab = load_vocab(&args.vocab_file)?;
let label_ids = load_label_ids(&args.label_schema_file)?;
let mut rows = load_rows(&args.input, args.limit_rows)?;
if rows.len() < 2 {
bail!("need at least two rows to build train/eval cache");
}
if !args.no_shuffle {
let mut rng = StdRng::seed_from_u64(args.seed);
rows.shuffle(&mut rng);
}
let split_idx = ((rows.len() as f64) * args.train_split) as usize;
let split_idx = split_idx.max(1).min(rows.len() - 1);
let (train_rows, eval_rows) = rows.split_at(split_idx);
fs::create_dir_all(&args.output_dir).with_context(|| {
format!(
"failed to create output directory {}",
args.output_dir.display()
)
})?;
let context = EncodeContext {
vocab,
label_ids,
max_length: args.max_length,
};
let train_summary = write_split(
"train",
train_rows,
&args.output_dir,
&context,
args.shard_size,
)?;
let eval_summary = write_split(
"eval",
eval_rows,
&args.output_dir,
&context,
args.shard_size,
)?;
write_eval_records(eval_rows, &args.output_dir.join("eval_records.jsonl"))?;
let manifest = json!({
"format": "anifilebert.encoded_dataset_cache.v1",
"input": args.input,
"vocab_file": args.vocab_file,
"label_schema_file": args.label_schema_file,
"output_dir": args.output_dir,
"max_length": args.max_length,
"shard_size": args.shard_size,
"limit_rows": args.limit_rows,
"source_rows": train_rows.len() + eval_rows.len(),
"train_split": args.train_split,
"seed": args.seed,
"shuffle": !args.no_shuffle,
"train": train_summary,
"eval": eval_summary,
"eval_records": "eval_records.jsonl",
"elapsed_seconds": started.elapsed().as_secs_f64(),
});
let manifest_path = args.output_dir.join("manifest.json");
fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?)
.with_context(|| format!("failed to write {}", manifest_path.display()))?;
println!("{}", serde_json::to_string_pretty(&manifest)?);
Ok(())
}
fn load_vocab(path: &Path) -> Result<Vocab> {
let text = fs::read_to_string(path)
.with_context(|| format!("failed to read vocab {}", path.display()))?;
let raw: HashMap<String, u64> =
serde_json::from_str(&text).with_context(|| format!("invalid vocab {}", path.display()))?;
let mut ids = HashMap::with_capacity(raw.len());
for (token, id) in raw {
if id > u16::MAX as u64 {
bail!("vocab id for token '{token}' exceeds u16: {id}");
}
ids.insert(token, id as u16);
}
let special = |token: &str| -> Result<u16> {
ids.get(token)
.copied()
.with_context(|| format!("vocab is missing special token {token}"))
};
Ok(Vocab {
pad_id: special("[PAD]")?,
unk_id: special("[UNK]")?,
cls_id: special("[CLS]")?,
sep_id: special("[SEP]")?,
ids,
})
}
fn load_label_ids(path: &Path) -> Result<HashMap<String, i16>> {
let labels = match fs::read_to_string(path) {
Ok(text) => {
serde_json::from_str::<LabelSchema>(&text)
.with_context(|| format!("invalid label schema {}", path.display()))?
.labels
}
Err(_) => FALLBACK_LABELS
.iter()
.map(|label| (*label).to_string())
.collect(),
};
if labels.is_empty() {
bail!("label schema has no labels");
}
Ok(labels
.into_iter()
.enumerate()
.map(|(idx, label)| (label, idx as i16))
.collect())
}
fn load_rows(path: &Path, limit_rows: usize) -> Result<Vec<SourceRow>> {
let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?;
let reader = BufReader::new(file);
let mut rows = Vec::new();
for (idx, line) in reader.lines().enumerate() {
if limit_rows > 0 && rows.len() >= limit_rows {
break;
}
let raw_line = line.with_context(|| format!("failed reading line {}", idx + 1))?;
if raw_line.trim().is_empty() {
continue;
}
let value: Value = serde_json::from_str(&raw_line)
.with_context(|| format!("failed to parse JSONL line {}", idx + 1))?;
let tokens = string_array_field(&value, "tokens", idx + 1)?;
let labels = string_array_field(&value, "labels", idx + 1)?;
if tokens.len() != labels.len() {
bail!(
"line {} has mismatched token/label lengths: {} vs {}",
idx + 1,
tokens.len(),
labels.len()
);
}
rows.push(SourceRow {
row_index: idx,
raw_line,
filename: value
.get("filename")
.and_then(Value::as_str)
.map(ToOwned::to_owned),
tokens,
labels,
tokenizer_variant: value
.get("tokenizer_variant")
.and_then(Value::as_str)
.map(ToOwned::to_owned),
});
}
Ok(rows)
}
fn string_array_field(value: &Value, field: &str, line_no: usize) -> Result<Vec<String>> {
let array = value
.get(field)
.and_then(Value::as_array)
.with_context(|| format!("line {line_no} missing array field '{field}'"))?;
array
.iter()
.map(|item| match item {
Value::String(text) => Ok(text.clone()),
other => Ok(match other {
Value::Null => String::new(),
_ => other.to_string(),
}),
})
.collect()
}
fn write_split(
split: &str,
rows: &[SourceRow],
output_dir: &Path,
context: &EncodeContext,
shard_size: usize,
) -> Result<SplitSummary> {
let split_dir = output_dir.join(split);
fs::create_dir_all(&split_dir)
.with_context(|| format!("failed to create {}", split_dir.display()))?;
let chunks = rows
.chunks(shard_size)
.enumerate()
.collect::<Vec<(usize, &[SourceRow])>>();
let shards = chunks
.par_iter()
.map(|(shard_idx, chunk)| write_shard(split, *shard_idx, chunk, &split_dir, context))
.collect::<Result<Vec<_>>>()?;
let manifest = json!({
"format": "anifilebert.virtual_dataset.shards.v1",
"generated_by": "tools/encoded_dataset_cache",
"split": split,
"max_length": context.max_length,
"total_rows": rows.len(),
"shards": shards,
});
let manifest_path = split_dir.join("manifest.json");
fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?)
.with_context(|| format!("failed to write {}", manifest_path.display()))?;
Ok(SplitSummary {
split: split.to_string(),
rows: rows.len(),
shards: chunks.len(),
directory: split.to_string(),
})
}
fn write_shard(
split: &str,
shard_idx: usize,
rows: &[SourceRow],
split_dir: &Path,
context: &EncodeContext,
) -> Result<ShardManifest> {
let capacity = rows.len().saturating_mul(context.max_length);
let mut input_ids = Vec::with_capacity(capacity);
let mut attention_mask = Vec::with_capacity(capacity);
let mut labels = Vec::with_capacity(capacity);
for row in rows {
let encoded = encode_row(row, context)
.with_context(|| format!("failed to encode source line {}", row.row_index + 1))?;
input_ids.extend_from_slice(&encoded.0);
attention_mask.extend_from_slice(&encoded.1);
labels.extend_from_slice(&encoded.2);
}
let base = format!("part-{split}-s{shard_idx:06}");
let input_name = format!("{base}.input_ids.npy");
let mask_name = format!("{base}.attention_mask.npy");
let label_name = format!("{base}.labels.npy");
write_npy_u16(
&split_dir.join(&input_name),
&input_ids,
rows.len(),
context.max_length,
)?;
write_npy_u8(
&split_dir.join(&mask_name),
&attention_mask,
rows.len(),
context.max_length,
)?;
write_npy_i16(
&split_dir.join(&label_name),
&labels,
rows.len(),
context.max_length,
)?;
Ok(ShardManifest {
rows: rows.len(),
input_ids: input_name,
attention_mask: mask_name,
labels: label_name,
})
}
fn encode_row(row: &SourceRow, context: &EncodeContext) -> Result<(Vec<u16>, Vec<u8>, Vec<i16>)> {
let (tokens, labels) = labels_for_char_tokenizer(row);
let mut input_ids = vec![context.vocab.pad_id; context.max_length];
let mut attention_mask = vec![0u8; context.max_length];
let mut label_ids = vec![-100i16; context.max_length];
input_ids[0] = context.vocab.cls_id;
attention_mask[0] = 1;
let available = context.max_length.saturating_sub(2);
let token_count = tokens.len().min(labels.len()).min(available);
for idx in 0..token_count {
input_ids[idx + 1] = token_id(&context.vocab, &tokens[idx]);
attention_mask[idx + 1] = 1;
let label = canonical_bio_label(&labels[idx]);
label_ids[idx + 1] = context
.label_ids
.get(&label)
.copied()
.with_context(|| format!("unknown label '{label}'"))?;
}
let sep_pos = token_count + 1;
input_ids[sep_pos] = context.vocab.sep_id;
attention_mask[sep_pos] = 1;
Ok((input_ids, attention_mask, label_ids))
}
fn labels_for_char_tokenizer(row: &SourceRow) -> (Vec<String>, Vec<String>) {
let mut source_labels = row.labels.clone();
if let Some(filename) = row.filename.as_deref() {
repair_known_label_issues(filename, &row.tokens, &mut source_labels);
if row.tokenizer_variant.as_deref() == Some("char") {
let filename_chars = chars_as_strings(filename);
if row.tokens == filename_chars {
return (row.tokens.clone(), source_labels);
}
}
if let Some(projected) = project_labels_from_filename(filename, &row.tokens, &source_labels)
{
let (tokens, labels) = projected;
return (tokens, labels);
}
}
align_tokens_to_chars(&row.tokens, &source_labels)
}
fn project_labels_from_filename(
filename: &str,
source_tokens: &[String],
source_labels: &[String],
) -> Option<(Vec<String>, Vec<String>)> {
let offsets = token_offsets_in_text(filename, source_tokens)?;
if offsets.len() != source_labels.len() {
return None;
}
let char_len = filename.chars().count();
let mut char_entities: Vec<Option<String>> = vec![None; char_len];
for ((token, label), (mut start, mut end)) in source_tokens
.iter()
.zip(source_labels.iter())
.zip(offsets.into_iter())
{
let Some(entity) = bio_entity(label) else {
continue;
};
if is_wrapped_token(token) && end > start + 1 {
start += 1;
end -= 1;
}
for pos in start..end.min(char_entities.len()) {
char_entities[pos] = Some(entity.clone());
}
}
let tokens = chars_as_strings(filename);
let mut labels = Vec::with_capacity(tokens.len());
let mut active_entity: Option<String> = None;
for entity in char_entities {
match entity {
Some(entity) => {
let prefix = if active_entity.as_deref() == Some(entity.as_str()) {
"I"
} else {
"B"
};
labels.push(format!("{prefix}-{entity}"));
active_entity = Some(entity);
}
None => {
labels.push("O".to_string());
active_entity = None;
}
}
}
Some((tokens, labels))
}
fn token_offsets_in_text(text: &str, tokens: &[String]) -> Option<Vec<(usize, usize)>> {
let mut offsets = Vec::with_capacity(tokens.len());
let mut cursor = 0usize;
for token in tokens {
if token.is_empty() {
let char_cursor = char_index_at_byte(text, cursor);
offsets.push((char_cursor, char_cursor));
continue;
}
let relative = text.get(cursor..)?.find(token)?;
let start_byte = cursor + relative;
let end_byte = start_byte + token.len();
offsets.push((
char_index_at_byte(text, start_byte),
char_index_at_byte(text, end_byte),
));
cursor = end_byte;
}
Some(offsets)
}
fn align_tokens_to_chars(tokens: &[String], labels: &[String]) -> (Vec<String>, Vec<String>) {
let mut char_tokens = Vec::new();
let mut char_labels = Vec::new();
for (token, label) in tokens.iter().zip(labels.iter()) {
let chars = chars_as_strings(token);
if chars.is_empty() {
continue;
}
let label = label.as_str();
if label.starts_with("B-") {
let entity = label
.split_once('-')
.map(|(_, entity)| entity)
.unwrap_or("");
char_labels.push(label.to_string());
char_labels.extend((1..chars.len()).map(|_| format!("I-{entity}")));
} else if label.starts_with("I-") {
char_labels.extend((0..chars.len()).map(|_| label.to_string()));
} else {
char_labels.extend((0..chars.len()).map(|_| label.to_string()));
}
char_tokens.extend(chars);
}
(char_tokens, char_labels)
}
fn repair_structural_meta_labels(
text: &str,
_tokens: &[String],
labels: &mut [String],
offsets: &[(usize, usize)],
) {
let episode_end = first_episode_span_end(labels, offsets, text);
for (inner_start, inner_end) in bracket_inner_spans(text) {
let bracket_start = inner_start.saturating_sub(1);
if bracket_start < episode_end {
continue;
}
let inner = chars_range_to_string(text, inner_start, inner_end);
let (trim_start, trim_end) = trimmed_bounds(&inner);
if trim_start >= trim_end {
continue;
}
let clean = chars_slice_to_string(&inner, trim_start, trim_end);
if special_tag_re().is_match(&clean) || special_code_re().is_match(&clean) {
let indices = token_indices_for_span(offsets, inner_start, inner_end);
label_span_if_safe(labels, &indices, "SPECIAL");
continue;
}
if source_tag_re().is_match(&clean) {
let indices = token_indices_for_span(offsets, inner_start, inner_end);
label_span_if_safe(labels, &indices, "SOURCE");
continue;
}
for mat in resolution_re()
.find_iter(&clean)
.filter_map(|item| item.ok())
{
let start = inner_start + char_index_at_byte(&clean, mat.start());
let end = inner_start + char_index_at_byte(&clean, mat.end());
let indices = token_indices_for_span(offsets, start, end);
label_span_if_safe(labels, &indices, "RESOLUTION");
}
for mat in source_re().find_iter(&clean) {
if !has_ascii_token_boundaries(&clean, mat.start(), mat.end()) {
continue;
}
let start = inner_start + char_index_at_byte(&clean, mat.start());
let end = inner_start + char_index_at_byte(&clean, mat.end());
let indices = token_indices_for_span(offsets, start, end);
label_span_if_safe(labels, &indices, "SOURCE");
}
}
for mat in resolution_re().find_iter(text).filter_map(|item| item.ok()) {
let start = char_index_at_byte(text, mat.start());
if start < episode_end {
continue;
}
let end = char_index_at_byte(text, mat.end());
let indices = token_indices_for_span(offsets, start, end);
label_span_if_safe(labels, &indices, "RESOLUTION");
}
for mat in source_re().find_iter(text) {
if !has_ascii_token_boundaries(text, mat.start(), mat.end()) {
continue;
}
let start = char_index_at_byte(text, mat.start());
if start < episode_end {
continue;
}
let end = char_index_at_byte(text, mat.end());
let indices = token_indices_for_span(offsets, start, end);
label_span_if_safe(labels, &indices, "SOURCE");
}
}
fn repair_known_label_issues(text: &str, tokens: &[String], labels: &mut [String]) {
if tokens.len() != labels.len() {
return;
}
let Some(offsets) = token_offsets_in_text(text, tokens) else {
return;
};
let quick_text = text.to_lowercase();
let has_sequel_marker_hint = [
" II", " III", " IV", " V", " VI", " VII", " VIII", " IX", "Ⅱ", "Ⅲ", "Ⅳ", "Ⅴ", "Ⅵ", "Ⅶ",
"Ⅷ", "Ⅸ", "之章", "之期", "之季", "之部", "ノ章", "ノ期", "の章", "の期", "貳", "贰", "弐",
"弍", "參", "叁", "参", "肆", "陸", "陆", "Ni ", " ni ", " no Sara", "Gakki",
]
.iter()
.any(|needle| text.contains(needle) || quick_text.contains(&needle.to_lowercase()));
if has_sequel_marker_hint {
for (start, end) in find_sequel_season_markers(text) {
if labels_have_season_before(labels, &offsets, start) {
continue;
}
let indices = token_indices_for_span(&offsets, start, end);
if indices.is_empty() {
continue;
}
if indices.iter().any(|idx| {
matches!(
label_entity(&labels[*idx]),
Some(
"GROUP"
| "EPISODE"
| "RESOLUTION"
| "SOURCE"
| "SPECIAL"
| "TAG"
| "PATH_SEASON"
)
)
}) {
continue;
}
if !indices.iter().any(|idx| is_title_like_label(&labels[*idx])) {
continue;
}
label_span_indices(labels, &indices, "SEASON");
mark_adjacent_title_separators_o(tokens, labels, &indices);
}
}
repair_structural_meta_labels(text, tokens, labels, &offsets);
}
fn find_sequel_season_markers(text: &str) -> Vec<(usize, usize)> {
let mut repairs = Vec::new();
for mat in reading_marker_re()
.find_iter(text)
.filter_map(|item| item.ok())
{
let marker = mat.as_str();
if season_marker_number(marker).is_none() || !has_episode_context(text, mat.end()) {
continue;
}
repairs.push((
char_index_at_byte(text, mat.start()),
char_index_at_byte(text, mat.end()),
));
}
for mat in roman_marker_re()
.find_iter(text)
.filter_map(|item| item.ok())
{
let marker = mat.as_str();
if season_marker_number(marker).is_none() || !has_episode_context(text, mat.end()) {
continue;
}
repairs.push((
char_index_at_byte(text, mat.start()),
char_index_at_byte(text, mat.end()),
));
}
for mat in cjk_marker_re().find_iter(text) {
let marker = mat.as_str();
if season_marker_number(marker).is_none() || !has_episode_context(text, mat.end()) {
continue;
}
repairs.push((
char_index_at_byte(text, mat.start()),
char_index_at_byte(text, mat.end()),
));
}
for (base, value) in standalone_ni_season_bases() {
let mut search_start = 0usize;
while let Some(relative) = text[search_start..].find(base) {
let base_start = search_start + relative;
let base_end = base_start + base.len();
let Some((ni_start, ni_end)) = standalone_ni_after_base(text, base_end) else {
search_start = base_end;
continue;
};
if *value == 2
&& has_episode_context(text, ni_end)
&& has_ascii_token_boundaries(text, ni_start, ni_end)
{
repairs.push((
char_index_at_byte(text, ni_start),
char_index_at_byte(text, ni_end),
));
}
search_start = base_end;
}
}
repairs.sort_by_key(|(start, end)| (*start, *end));
let mut deduped: Vec<(usize, usize)> = Vec::new();
for repair in repairs {
if let Some(previous) = deduped.last_mut() {
if repair.0 < previous.1 {
if repair.1.saturating_sub(repair.0) > previous.1.saturating_sub(previous.0) {
*previous = repair;
}
continue;
}
}
deduped.push(repair);
}
deduped
}
fn season_marker_number(text: &str) -> Option<u8> {
let clean = clean_marker_text(text);
if clean.is_empty() {
return None;
}
if let Some(value) = roman_numeral_value(&clean) {
return Some(value);
}
let lowered = clean
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
.to_lowercase();
if let Some(value) = reading_marker_value(&lowered) {
return Some(value);
}
if lowered == "ni" {
return Some(2);
}
if clean.starts_with('第') {
if let Some(last) = clean.chars().last() {
if matches!(last, '季' | '期' | '部' | '章') {
let inner = clean
.chars()
.skip(1)
.take(clean.chars().count().saturating_sub(2))
.collect::<String>();
return cn_number_to_int(&inner);
}
}
}
let cjk_chars = clean.chars().collect::<Vec<_>>();
if let Some(first) = cjk_chars.first() {
if let Some(value) = cn_number_to_int(&first.to_string()) {
let rest = cjk_chars.iter().skip(1).collect::<String>();
if rest.trim().is_empty() || cjk_marker_suffix_remainder_ok(&rest) {
return Some(value);
}
}
}
None
}
fn clean_marker_text(text: &str) -> String {
text.trim()
.trim_matches(|ch| {
matches!(
ch,
'[' | ']' | '(' | ')' | '【' | '】' | '《' | '》' | '(' | ')'
)
})
.trim()
.to_string()
}
fn cn_number_to_int(text: &str) -> Option<u8> {
let text = text.trim();
if text.is_empty() {
return None;
}
if let Ok(value) = text.parse::<u8>() {
return Some(value);
}
if let Some(value) = cn_digit_value(text) {
return Some(value);
}
let chars = text.chars().collect::<Vec<_>>();
if chars.len() == 2 && chars[0] == '十' {
return Some(10 + cn_digit_value(&chars[1].to_string()).unwrap_or(0));
}
if chars.len() == 2 && chars[1] == '十' {
return Some(cn_digit_value(&chars[0].to_string()).unwrap_or(0) * 10);
}
if chars.len() == 3 && chars[1] == '十' {
return Some(
cn_digit_value(&chars[0].to_string()).unwrap_or(0) * 10
+ cn_digit_value(&chars[2].to_string()).unwrap_or(0),
);
}
None
}
fn cn_digit_value(text: &str) -> Option<u8> {
match text {
"一" => Some(1),
"二" | "兩" | "两" | "貳" | "贰" | "弐" | "弍" => Some(2),
"三" | "參" | "叁" | "参" => Some(3),
"四" | "肆" => Some(4),
"五" | "伍" => Some(5),
"六" | "陸" | "陆" => Some(6),
"七" | "柒" => Some(7),
"八" | "捌" => Some(8),
"九" | "玖" => Some(9),
"十" => Some(10),
_ => None,
}
}
fn roman_numeral_value(text: &str) -> Option<u8> {
match text {
"II" | "Ⅱ" => Some(2),
"III" | "Ⅲ" => Some(3),
"IV" | "Ⅳ" => Some(4),
"V" | "Ⅴ" => Some(5),
"VI" | "Ⅵ" => Some(6),
"VII" | "Ⅶ" => Some(7),
"VIII" | "Ⅷ" => Some(8),
"IX" | "Ⅸ" => Some(9),
_ => None,
}
}
fn reading_marker_value(text: &str) -> Option<u8> {
match text {
"ni no sara" | "ni no shou" | "ni no sho" | "ni no syo" | "ni no shō" | "ni gakki"
| "sono ni" => Some(2),
"san no sara" | "san no shou" | "san no sho" | "san no syo" => Some(3),
"yon no sara" | "shi no sara" | "shin no sara" => Some(4),
"go no sara" | "gou no sara" => Some(5),
_ => None,
}
}
fn cjk_marker_suffix_remainder_ok(rest: &str) -> bool {
let compact = rest.split_whitespace().collect::<String>();
matches!(
compact.as_str(),
"ノ章"
| "ノ期"
| "ノ季"
| "ノ部"
| "の章"
| "の期"
| "の季"
| "の部"
| "之章"
| "之期"
| "之季"
| "之部"
)
}
fn has_episode_context(text: &str, marker_end_byte: usize) -> bool {
let tail = &text[marker_end_byte..];
if episode_context_re().is_match(tail) {
return true;
}
let mut tail = tail.trim_start();
if let Some(ch) = tail.chars().next() {
if matches!(ch, ']' | ')' | '】' | '》') {
tail = &tail[ch.len_utf8()..];
tail = tail.trim_start();
}
}
if let Some(mat) = special_context_prefix_re().find(tail) {
tail = &tail[mat.end()..];
}
episode_context_re().is_match(tail)
}
fn first_episode_regex_end(text: &str) -> Option<usize> {
episode_span_re()
.find_iter(text)
.filter_map(|item| item.ok())
.map(|mat| char_index_at_byte(text, mat.end()))
.next()
}
fn labels_have_season_before(
labels: &[String],
offsets: &[(usize, usize)],
marker_start: usize,
) -> bool {
labels
.iter()
.zip(offsets.iter())
.any(|(label, (_start, end))| is_season_like_label(label) && *end <= marker_start)
}
fn token_indices_for_span(offsets: &[(usize, usize)], start: usize, end: usize) -> Vec<usize> {
offsets
.iter()
.enumerate()
.filter_map(|(idx, (token_start, token_end))| {
if *token_start < end && *token_end > start {
Some(idx)
} else {
None
}
})
.collect()
}
#[cfg(test)]
fn label_span(labels: &mut [String], start: usize, end: usize, entity: &str) {
let previous_same = start > 0 && label_entity(&labels[start - 1]) == Some(entity);
let mut first = !previous_same;
for label in labels.iter_mut().take(end).skip(start) {
*label = if first {
format!("B-{entity}")
} else {
format!("I-{entity}")
};
first = false;
}
}
fn label_span_indices(labels: &mut [String], indices: &[usize], entity: &str) {
if indices.is_empty() {
return;
}
let previous_same = indices[0] > 0 && label_entity(&labels[indices[0] - 1]) == Some(entity);
let mut first = !previous_same;
for idx in indices {
labels[*idx] = if first {
format!("B-{entity}")
} else {
format!("I-{entity}")
};
first = false;
}
}
fn mark_adjacent_title_separators_o(
tokens: &[String],
labels: &mut [String],
marker_indices: &[usize],
) {
if marker_indices.is_empty() {
return;
}
let mut idx = marker_indices[0];
while idx > 0 {
let prev = idx - 1;
if !tokens[prev].trim().is_empty() || !is_title_like_label(&labels[prev]) {
break;
}
labels[prev] = "O".to_string();
idx = prev;
}
let mut idx = marker_indices[marker_indices.len() - 1] + 1;
while idx < tokens.len()
&& tokens[idx].chars().all(|ch| SEPARATOR_CHARS.contains(&ch))
&& is_title_like_label(&labels[idx])
{
labels[idx] = "O".to_string();
idx += 1;
}
}
fn standalone_ni_season_bases() -> &'static [(&'static str, u8)] {
&[("Kakuriyo no Yadomeshi", 2)]
}
fn standalone_ni_after_base(text: &str, base_end: usize) -> Option<(usize, usize)> {
let mut cursor = base_end;
while let Some(ch) = text[cursor..].chars().next() {
if !ch.is_whitespace() {
break;
}
cursor += ch.len_utf8();
}
let ni_end = cursor.checked_add(2)?;
if text.get(cursor..ni_end)? == "Ni" {
Some((cursor, ni_end))
} else {
None
}
}
fn is_title_like_label(label: &str) -> bool {
matches!(
label_entity(label),
Some(
"TITLE"
| "TITLE_CHS"
| "TITLE_CHT"
| "TITLE_JPN"
| "TITLE_LATIN"
| "TITLE_MIXED"
| "PATH_TITLE_CHS"
| "PATH_TITLE_CHT"
| "PATH_TITLE_JPN"
| "PATH_TITLE_LATIN"
| "PATH_TITLE_MIXED"
)
)
}
fn is_season_like_label(label: &str) -> bool {
matches!(label_entity(label), Some("SEASON" | "PATH_SEASON"))
}
fn first_episode_span_end(labels: &[String], offsets: &[(usize, usize)], text: &str) -> usize {
let ends = labels
.iter()
.zip(offsets.iter())
.filter_map(|(label, (_start, end))| {
if label_entity(label) == Some("EPISODE") {
Some(*end)
} else {
None
}
})
.collect::<Vec<_>>();
if let Some(end) = ends.into_iter().min() {
return end;
}
first_episode_regex_end(text).unwrap_or(0)
}
fn bracket_inner_spans(text: &str) -> Vec<(usize, usize)> {
let chars = text.chars().collect::<Vec<_>>();
let mut spans = Vec::new();
let mut idx = 0usize;
while idx < chars.len() {
let close = match chars[idx] {
'[' => ']',
'(' => ')',
'【' => '】',
'《' => '》',
_ => {
idx += 1;
continue;
}
};
if let Some(relative_end) = chars[idx + 1..].iter().position(|ch| *ch == close) {
let end = idx + 1 + relative_end;
spans.push((idx + 1, end));
idx = end + 1;
} else {
idx += 1;
}
}
spans
}
fn trimmed_bounds(text: &str) -> (usize, usize) {
let chars = text.chars().collect::<Vec<_>>();
let mut start = 0usize;
let mut end = chars.len();
while start < end && chars[start].is_whitespace() {
start += 1;
}
while end > start && chars[end - 1].is_whitespace() {
end -= 1;
}
(start, end)
}
fn chars_range_to_string(text: &str, start: usize, end: usize) -> String {
text.chars()
.skip(start)
.take(end.saturating_sub(start))
.collect()
}
fn chars_slice_to_string(text: &str, start: usize, end: usize) -> String {
text.chars()
.skip(start)
.take(end.saturating_sub(start))
.collect()
}
fn label_span_if_safe(labels: &mut [String], indices: &[usize], entity: &str) {
if indices.is_empty() {
return;
}
if indices.iter().any(|idx| {
matches!(
label_entity(&labels[*idx]),
Some("GROUP" | "EPISODE" | "SEASON" | "PATH_SEASON")
)
}) {
return;
}
label_span_indices(labels, indices, entity);
}
fn has_ascii_token_boundaries(text: &str, start: usize, end: usize) -> bool {
let previous_ok = text[..start]
.chars()
.next_back()
.map(|ch| !ch.is_ascii_alphanumeric())
.unwrap_or(true);
let next_ok = text[end..]
.chars()
.next()
.map(|ch| !ch.is_ascii_alphanumeric())
.unwrap_or(true);
previous_ok && next_ok
}
fn label_entity(label: &str) -> Option<&str> {
let (prefix, entity) = label.split_once('-')?;
if prefix == "B" || prefix == "I" {
Some(entity)
} else {
None
}
}
fn resolution_re() -> &'static FancyRegex {
RESOLUTION_RE.get_or_init(|| {
FancyRegex::new(
r"(?i)(?<![A-Za-z0-9])(?:\d{3,4}[pP]|\d[Kk]|\d{3,4}[xX×]\d{3,4})(?![A-Za-z0-9])",
)
.unwrap()
})
}
fn source_re() -> &'static Regex {
SOURCE_RE.get_or_init(|| Regex::new(&format!(r"(?i)(?:{SOURCE_TOKEN_PATTERN})")).unwrap())
}
fn source_tag_re() -> &'static Regex {
SOURCE_TAG_RE.get_or_init(|| {
Regex::new(&format!(
r"(?i)^(?:{SOURCE_TOKEN_PATTERN})(?:\s*(?:[&+/,_-]|,\s*)\s*(?:{SOURCE_TOKEN_PATTERN}))*$"
))
.unwrap()
})
}
fn special_tag_re() -> &'static Regex {
SPECIAL_TAG_RE.get_or_init(|| {
Regex::new(r"(?i)^(?:檢索|检索|搜索|搜寻|搜尋|别名|別名|alias|search|keyword)\s*[::].+")
.unwrap()
})
}
fn special_code_re() -> &'static Regex {
SPECIAL_CODE_RE.get_or_init(|| {
Regex::new(r"(?i)^(?:NCOP|NCED|OP|ED|PV|CM)\d*$|^IV\d+$|^(?:OVA|OAD|SP)\d*$").unwrap()
})
}
fn episode_context_re() -> &'static Regex {
EPISODE_CONTEXT_RE.get_or_init(|| {
Regex::new(
r"(?i)^\s*(?:[-_]\s*(?:\d{1,4}|NCOP|NCED|OP|ED|OVA|OAD|SP|END)\b|#\s*\d{1,4}|[\[\(【《]\s*(?:EP?|#)?\d{1,4})",
)
.unwrap()
})
}
fn episode_span_re() -> &'static FancyRegex {
EPISODE_SPAN_RE.get_or_init(|| {
FancyRegex::new(
r"(?i)(?:[Ss]\d{1,2}[Ee]\d{1,4}(?:v\d+)?|(?:^|[\s._])[-_]\s*\d{1,4}(?:v\d+)?(?=$|[\s._\-\]\)】》\[])|[\[\(【《](?:EP?|#)?\d{1,4}(?:v\d+)?[\]\)】》]|(?:^|[\s._\-\[\(【《#])(?:EP?|第|#)\d{1,4}(?:v\d+)?(?:[话話集])?(?=$|[\s._\-\]\)】》]))",
)
.unwrap()
})
}
fn reading_marker_re() -> &'static FancyRegex {
READING_MARKER_RE.get_or_init(|| {
FancyRegex::new(
r"(?i)(?<![A-Za-z0-9])(?P<marker>Ni\s+no\s+(?:Sara|Shou|Sho|Syo|Shō)|San\s+no\s+(?:Sara|Shou|Sho|Syo)|(?:Yon|Shi|Shin)\s+no\s+Sara|(?:Go|Gou)\s+no\s+Sara|Ni\s+Gakki|Sono\s+Ni)(?![A-Za-z0-9])",
)
.unwrap()
})
}
fn roman_marker_re() -> &'static FancyRegex {
ROMAN_MARKER_RE.get_or_init(|| {
FancyRegex::new(
r"(?<![A-Za-z0-9])(?P<marker>II|III|IV|V|VI|VII|VIII|IX|[ⅡⅢⅣⅤⅥⅦⅧⅨ])(?![A-Za-z0-9])",
)
.unwrap()
})
}
fn cjk_marker_re() -> &'static Regex {
CJK_MARKER_RE.get_or_init(|| {
Regex::new(
r"(?:[一二三四五六七八九十兩两貳贰弐弍參叁参肆伍陸陆柒捌玖](?:\s*(?:ノ|の|之)\s*(?:章|期|季|部))?|第[一二三四五六七八九十兩两貳贰弐弍參叁参肆伍陸陆柒捌玖\d]+[季期部章])",
)
.unwrap()
})
}
fn special_context_prefix_re() -> &'static Regex {
SPECIAL_CONTEXT_PREFIX_RE.get_or_init(|| {
Regex::new(
r"(?i)^(?:[\[\(【《]\s*(?:menu|menus|bdmenu|ncop|nced|op|ed|ova|oad|sp)\s*[\]\)】》]\s*){0,2}",
)
.unwrap()
})
}
fn chars_as_strings(text: &str) -> Vec<String> {
text.chars().map(|ch| ch.to_string()).collect()
}
fn char_index_at_byte(text: &str, byte_index: usize) -> usize {
text[..byte_index].chars().count()
}
fn bio_entity(label: &str) -> Option<String> {
let (prefix, entity) = label.split_once('-')?;
if prefix == "B" || prefix == "I" {
Some(entity.to_string())
} else {
None
}
}
fn is_wrapped_token(token: &str) -> bool {
let mut chars = token.chars();
let Some(first) = chars.next() else {
return false;
};
let Some(last) = token.chars().last() else {
return false;
};
matches!(first, '[' | '【' | '(' | '《') && matches!(last, ']' | '】' | ')' | '》')
}
fn canonical_bio_label(label: &str) -> String {
let Some((prefix, entity)) = label.split_once('-') else {
return if label == "O" {
"O".to_string()
} else {
label.to_string()
};
};
if prefix != "B" && prefix != "I" {
return label.to_string();
}
let canonical_entity = match entity {
"TITLE" => "TITLE_MIXED",
"PATH_TITLE" => "PATH_TITLE_MIXED",
other => other,
};
format!("{prefix}-{canonical_entity}")
}
fn token_id(vocab: &Vocab, token: &str) -> u16 {
*vocab.ids.get(token).unwrap_or(&vocab.unk_id)
}
fn write_eval_records(rows: &[SourceRow], path: &Path) -> Result<()> {
let mut writer = BufWriter::new(
File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
);
for row in rows {
writer.write_all(row.raw_line.as_bytes())?;
writer.write_all(b"\n")?;
}
Ok(())
}
fn write_npy_u16(path: &Path, data: &[u16], rows: usize, cols: usize) -> Result<()> {
let mut writer = BufWriter::new(
File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
);
write_npy_header(&mut writer, "<u2", rows, cols)?;
for value in data {
writer.write_all(&value.to_le_bytes())?;
}
Ok(())
}
fn write_npy_u8(path: &Path, data: &[u8], rows: usize, cols: usize) -> Result<()> {
let mut writer = BufWriter::new(
File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
);
write_npy_header(&mut writer, "|u1", rows, cols)?;
writer.write_all(data)?;
Ok(())
}
fn write_npy_i16(path: &Path, data: &[i16], rows: usize, cols: usize) -> Result<()> {
let mut writer = BufWriter::new(
File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
);
write_npy_header(&mut writer, "<i2", rows, cols)?;
for value in data {
writer.write_all(&value.to_le_bytes())?;
}
Ok(())
}
fn write_npy_header<W: Write>(writer: &mut W, descr: &str, rows: usize, cols: usize) -> Result<()> {
let mut header = format!(
"{{'descr': '{}', 'fortran_order': False, 'shape': ({}, {}), }}",
descr, rows, cols
)
.into_bytes();
let preamble_len = 10usize;
let pad_len = (16 - ((preamble_len + header.len() + 1) % 16)) % 16;
header.extend(std::iter::repeat(b' ').take(pad_len));
header.push(b'\n');
if header.len() > u16::MAX as usize {
bail!("npy header too large");
}
writer.write_all(b"\x93NUMPY")?;
writer.write_all(&[1, 0])?;
writer.write_all(&(header.len() as u16).to_le_bytes())?;
writer.write_all(&header)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn char_row(
text: &str,
title_spans: &[(usize, usize)],
episode_spans: &[(usize, usize)],
) -> SourceRow {
let tokens = chars_as_strings(text);
let mut labels = vec!["O".to_string(); tokens.len()];
for (start, end) in title_spans {
label_span(&mut labels, *start, *end, "TITLE_LATIN");
}
for (start, end) in episode_spans {
label_span(&mut labels, *start, *end, "EPISODE");
}
SourceRow {
row_index: 0,
raw_line: String::new(),
filename: Some(text.to_string()),
tokens,
labels,
tokenizer_variant: Some("char".to_string()),
}
}
#[test]
fn repairs_cjk_sequel_marker_in_char_fast_path() {
let text = "妖怪旅館營業中 貳 - 11";
let title_end = char_index_at_byte(text, text.find(" - ").unwrap());
let episode_start = char_index_at_byte(text, text.find("11").unwrap());
let row = char_row(
text,
&[(0, title_end)],
&[(episode_start, episode_start + 2)],
);
let (_tokens, labels) = labels_for_char_tokenizer(&row);
let marker = char_index_at_byte(text, text.find('貳').unwrap());
let before_marker = marker - 1;
assert_eq!(labels[before_marker], "O");
assert_eq!(labels[marker], "B-SEASON");
assert_eq!(labels[episode_start], "B-EPISODE");
}
#[test]
fn repairs_reading_sequel_marker() {
let text = "Shokugeki no Souma Ni no Sara - 13";
let title_end = text.find(" - ").unwrap();
let episode_start = text.find("13").unwrap();
let row = char_row(
text,
&[(0, title_end)],
&[(episode_start, episode_start + 2)],
);
let (_tokens, labels) = labels_for_char_tokenizer(&row);
let marker_start = text.find("Ni").unwrap();
let marker_end = text.find(" - ").unwrap();
assert_eq!(labels[marker_start - 1], "O");
assert_eq!(labels[marker_start], "B-SEASON");
assert!(labels[marker_start + 1..marker_end]
.iter()
.all(|label| label == "I-SEASON"));
}
#[test]
fn keeps_numeric_title_suffix_out_of_sequel_repair() {
let text = "Kamisama Hajimemashita 2 - 01";
let title_end = text.find(" - ").unwrap();
let episode_start = text.find("01").unwrap();
let row = char_row(
text,
&[(0, title_end)],
&[(episode_start, episode_start + 2)],
);
let (_tokens, labels) = labels_for_char_tokenizer(&row);
let suffix = text.find('2').unwrap();
assert_eq!(labels[suffix], "I-TITLE_LATIN");
assert!(!labels
.iter()
.any(|label| label_entity(label) == Some("SEASON")));
}
#[test]
fn skips_alias_marker_when_season_already_exists() {
let text = "樱桃小丸子第二期(Chibi Maruko-chan II)[1439]";
let tokens = chars_as_strings(text);
let mut labels = vec!["O".to_string(); tokens.len()];
let title_end = char_index_at_byte(text, text.find("第二期").unwrap());
label_span(&mut labels, 0, title_end, "TITLE_CHS");
let season_start = title_end;
let season_end = season_start + "第二期".chars().count();
label_span(&mut labels, season_start, season_end, "SEASON");
let alias_start = char_index_at_byte(text, text.find("Chibi").unwrap());
let alias_end = char_index_at_byte(text, text.find(")").unwrap());
label_span(&mut labels, alias_start, alias_end, "TITLE_LATIN");
let episode_start = char_index_at_byte(text, text.find("1439").unwrap());
label_span(&mut labels, episode_start, episode_start + 4, "EPISODE");
let row = SourceRow {
row_index: 0,
raw_line: String::new(),
filename: Some(text.to_string()),
tokens,
labels,
tokenizer_variant: Some("char".to_string()),
};
let (_tokens, labels) = labels_for_char_tokenizer(&row);
let roman = char_index_at_byte(text, text.find("II").unwrap());
assert_eq!(labels[roman], "I-TITLE_LATIN");
}
}