mod regex_benchmark; use anyhow::{bail, Context, Result}; use clap::Parser; use fancy_regex::Regex as FancyRegex; use rand::rngs::StdRng; use rand::seq::SliceRandom; use rand::SeedableRng; use rayon::prelude::*; use regex::Regex; use serde::{Deserialize, Serialize}; use serde_json::{json, Value}; use std::collections::HashMap; use std::fs::{self, File}; use std::io::{BufRead, BufReader, BufWriter, Write}; use std::path::{Path, PathBuf}; use std::sync::OnceLock; use std::time::Instant; const FALLBACK_LABELS: [&str; 37] = [ "O", "B-TITLE_CHS", "I-TITLE_CHS", "B-TITLE_CHT", "I-TITLE_CHT", "B-TITLE_JPN", "I-TITLE_JPN", "B-TITLE_LATIN", "I-TITLE_LATIN", "B-TITLE_MIXED", "I-TITLE_MIXED", "B-PATH_TITLE_CHS", "I-PATH_TITLE_CHS", "B-PATH_TITLE_CHT", "I-PATH_TITLE_CHT", "B-PATH_TITLE_JPN", "I-PATH_TITLE_JPN", "B-PATH_TITLE_LATIN", "I-PATH_TITLE_LATIN", "B-PATH_TITLE_MIXED", "I-PATH_TITLE_MIXED", "B-PATH_SEASON", "I-PATH_SEASON", "B-SEASON", "I-SEASON", "B-EPISODE", "I-EPISODE", "B-SPECIAL", "I-SPECIAL", "B-GROUP", "I-GROUP", "B-RESOLUTION", "I-RESOLUTION", "B-SOURCE", "I-SOURCE", "B-TAG", "I-TAG", ]; const SOURCE_TOKEN_PATTERN: &str = r"WEB[-_ ]?DL|WEB[-_ ]?Rip|BDRip|BluRay|BDMV|BD|DVDRip|DVD|TVRip|HDTV|Netflix|NF|AMZN|Baha|CR|ABEMA|DSNP|U[-_ ]?NEXT|Hulu|AT[-_ ]?X|x26[45]|h\.?26[45]|HEVC|AVC|AV1|AAC\d*(?:\.\d+)?|AAC|FLAC|MP3|DTS|Opus|CHS|CHT|GB|BIG5|JPN?|JPSC|JPTC|繁中|简中"; static RESOLUTION_RE: OnceLock = OnceLock::new(); static SOURCE_RE: OnceLock = OnceLock::new(); static SOURCE_TAG_RE: OnceLock = OnceLock::new(); static SPECIAL_TAG_RE: OnceLock = OnceLock::new(); static SPECIAL_CODE_RE: OnceLock = OnceLock::new(); static EPISODE_CONTEXT_RE: OnceLock = OnceLock::new(); static EPISODE_SPAN_RE: OnceLock = OnceLock::new(); static READING_MARKER_RE: OnceLock = OnceLock::new(); static ROMAN_MARKER_RE: OnceLock = OnceLock::new(); static CJK_MARKER_RE: OnceLock = OnceLock::new(); static SPECIAL_CONTEXT_PREFIX_RE: OnceLock = OnceLock::new(); const SEPARATOR_CHARS: &[char] = &[' ', '\t', '-', '_', '.', '|', '~', '~']; #[derive(Parser, Debug)] #[command( about = "Build split train/eval encoded AniFileBERT shard caches", version )] struct Args { #[arg(long)] input: Vec, #[arg(long, value_name = "N")] input_repeat: Vec, #[arg(long)] vocab_file: Option, #[arg(long)] output_dir: Option, #[arg(long, default_value = "label_schema.json")] label_schema_file: PathBuf, #[arg(long, default_value_t = 128)] max_length: usize, #[arg(long, default_value_t = 25_000)] shard_size: usize, #[arg(long, default_value_t = 0)] limit_rows: usize, #[arg(long, default_value_t = 0.98)] train_split: f64, #[arg(long, default_value_t = 42)] seed: u64, #[arg(long)] no_shuffle: bool, #[arg(long, default_value_t = 0)] threads: usize, #[arg(long)] regex_benchmark_input: Option, #[arg(long, default_value_t = 0)] regex_benchmark_limit_rows: usize, #[arg(long, default_value_t = 3)] regex_benchmark_repeat: usize, } #[derive(Debug, Deserialize)] struct LabelSchema { labels: Vec, } #[derive(Clone)] struct SourceRow { row_index: usize, raw_line: String, filename: Option, tokens: Vec, labels: Vec, tokenizer_variant: Option, } #[derive(Clone)] struct Vocab { ids: HashMap, pad_id: u16, unk_id: u16, cls_id: u16, sep_id: u16, } #[derive(Clone)] struct EncodeContext { vocab: Vocab, label_ids: HashMap, max_length: usize, } #[derive(Serialize)] struct ShardManifest { rows: usize, input_ids: String, attention_mask: String, labels: String, } #[derive(Serialize)] struct SplitSummary { split: String, rows: usize, shards: usize, directory: String, } fn main() -> Result<()> { let args = Args::parse(); if let Some(input) = &args.regex_benchmark_input { return regex_benchmark::run( input, args.regex_benchmark_limit_rows, args.regex_benchmark_repeat, ); } if args.input.is_empty() { bail!("at least one --input is required"); } let vocab_file = args .vocab_file .as_ref() .context("--vocab-file is required when building an encoded cache")?; let output_dir = args .output_dir .as_ref() .context("--output-dir is required when building an encoded cache")?; if args.max_length < 4 { bail!("--max-length must be at least 4"); } if args.shard_size == 0 { bail!("--shard-size must be positive"); } if !(0.0..1.0).contains(&args.train_split) { bail!("--train-split must be > 0 and < 1"); } if args.threads > 0 { rayon::ThreadPoolBuilder::new() .num_threads(args.threads) .build_global() .context("failed to configure rayon thread pool")?; } let started = Instant::now(); let vocab = load_vocab(vocab_file)?; let label_ids = load_label_ids(&args.label_schema_file)?; let input_repeats = resolve_input_repeats(&args.input, &args.input_repeat)?; let (mut rows, input_summaries) = load_input_rows(&args.input, &input_repeats, args.limit_rows)?; if rows.len() < 2 { bail!("need at least two rows to build train/eval cache"); } if !args.no_shuffle { let mut rng = StdRng::seed_from_u64(args.seed); rows.shuffle(&mut rng); } let split_idx = ((rows.len() as f64) * args.train_split) as usize; let split_idx = split_idx.max(1).min(rows.len() - 1); let (train_rows, eval_rows) = rows.split_at(split_idx); fs::create_dir_all(output_dir).with_context(|| { format!( "failed to create output directory {}", output_dir.display() ) })?; let context = EncodeContext { vocab, label_ids, max_length: args.max_length, }; let train_summary = write_split( "train", train_rows, output_dir, &context, args.shard_size, )?; let eval_summary = write_split( "eval", eval_rows, output_dir, &context, args.shard_size, )?; write_eval_records(eval_rows, &output_dir.join("eval_records.jsonl"))?; let manifest = json!({ "format": "anifilebert.encoded_dataset_cache.v1", "input": args.input.first(), "inputs": input_summaries, "vocab_file": vocab_file, "label_schema_file": args.label_schema_file, "output_dir": output_dir, "max_length": args.max_length, "shard_size": args.shard_size, "limit_rows": args.limit_rows, "source_rows": train_rows.len() + eval_rows.len(), "train_split": args.train_split, "seed": args.seed, "shuffle": !args.no_shuffle, "train": train_summary, "eval": eval_summary, "eval_records": "eval_records.jsonl", "elapsed_seconds": started.elapsed().as_secs_f64(), }); let manifest_path = output_dir.join("manifest.json"); fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?) .with_context(|| format!("failed to write {}", manifest_path.display()))?; println!("{}", serde_json::to_string_pretty(&manifest)?); Ok(()) } fn load_vocab(path: &Path) -> Result { let text = fs::read_to_string(path) .with_context(|| format!("failed to read vocab {}", path.display()))?; let raw: HashMap = serde_json::from_str(&text).with_context(|| format!("invalid vocab {}", path.display()))?; let mut ids = HashMap::with_capacity(raw.len()); for (token, id) in raw { if id > u16::MAX as u64 { bail!("vocab id for token '{token}' exceeds u16: {id}"); } ids.insert(token, id as u16); } let special = |token: &str| -> Result { ids.get(token) .copied() .with_context(|| format!("vocab is missing special token {token}")) }; Ok(Vocab { pad_id: special("[PAD]")?, unk_id: special("[UNK]")?, cls_id: special("[CLS]")?, sep_id: special("[SEP]")?, ids, }) } fn load_label_ids(path: &Path) -> Result> { let labels = match fs::read_to_string(path) { Ok(text) => { serde_json::from_str::(&text) .with_context(|| format!("invalid label schema {}", path.display()))? .labels } Err(_) => FALLBACK_LABELS .iter() .map(|label| (*label).to_string()) .collect(), }; if labels.is_empty() { bail!("label schema has no labels"); } Ok(labels .into_iter() .enumerate() .map(|(idx, label)| (label, idx as i16)) .collect()) } fn resolve_input_repeats(inputs: &[PathBuf], repeats: &[usize]) -> Result> { if repeats.is_empty() { return Ok(vec![1; inputs.len()]); } if repeats.len() == 1 { return Ok(vec![repeats[0].max(1); inputs.len()]); } if repeats.len() != inputs.len() { bail!( "--input-repeat must be omitted, passed once for all inputs, or passed once per --input ({} inputs, {} repeats)", inputs.len(), repeats.len() ); } Ok(repeats.iter().map(|repeat| (*repeat).max(1)).collect()) } fn load_input_rows( inputs: &[PathBuf], repeats: &[usize], limit_rows: usize, ) -> Result<(Vec, Vec)> { let mut combined = Vec::new(); let mut summaries = Vec::new(); for (path, repeat) in inputs.iter().zip(repeats.iter()) { let rows = load_rows(path)?; let samples = rows.len(); let mut written = 0usize; for _ in 0..*repeat { for row in &rows { if limit_rows > 0 && combined.len() >= limit_rows { break; } let mut row = row.clone(); row.row_index = combined.len(); combined.push(row); written += 1; } if limit_rows > 0 && combined.len() >= limit_rows { break; } } summaries.push(json!({ "path": path, "samples": samples, "repeat": repeat, "effective_samples": samples * repeat, "written_rows": written, })); if limit_rows > 0 && combined.len() >= limit_rows { break; } } Ok((combined, summaries)) } fn load_rows(path: &Path) -> Result> { let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?; let reader = BufReader::new(file); let mut rows = Vec::new(); for (idx, line) in reader.lines().enumerate() { let raw_line = line.with_context(|| format!("failed reading line {}", idx + 1))?; if raw_line.trim().is_empty() { continue; } let value: Value = serde_json::from_str(&raw_line) .with_context(|| format!("failed to parse JSONL line {}", idx + 1))?; let tokens = string_array_field(&value, "tokens", idx + 1)?; let labels = string_array_field(&value, "labels", idx + 1)?; if tokens.len() != labels.len() { bail!( "line {} has mismatched token/label lengths: {} vs {}", idx + 1, tokens.len(), labels.len() ); } rows.push(SourceRow { row_index: idx, raw_line, filename: value .get("filename") .and_then(Value::as_str) .map(ToOwned::to_owned), tokens, labels, tokenizer_variant: value .get("tokenizer_variant") .and_then(Value::as_str) .map(ToOwned::to_owned), }); } Ok(rows) } fn string_array_field(value: &Value, field: &str, line_no: usize) -> Result> { let array = value .get(field) .and_then(Value::as_array) .with_context(|| format!("line {line_no} missing array field '{field}'"))?; array .iter() .map(|item| match item { Value::String(text) => Ok(text.clone()), other => Ok(match other { Value::Null => String::new(), _ => other.to_string(), }), }) .collect() } fn write_split( split: &str, rows: &[SourceRow], output_dir: &Path, context: &EncodeContext, shard_size: usize, ) -> Result { let split_dir = output_dir.join(split); fs::create_dir_all(&split_dir) .with_context(|| format!("failed to create {}", split_dir.display()))?; let chunks = rows .chunks(shard_size) .enumerate() .collect::>(); let shards = chunks .par_iter() .map(|(shard_idx, chunk)| write_shard(split, *shard_idx, chunk, &split_dir, context)) .collect::>>()?; let manifest = json!({ "format": "anifilebert.virtual_dataset.shards.v1", "generated_by": "tools/encoded_dataset_cache", "split": split, "max_length": context.max_length, "total_rows": rows.len(), "shards": shards, }); let manifest_path = split_dir.join("manifest.json"); fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?) .with_context(|| format!("failed to write {}", manifest_path.display()))?; Ok(SplitSummary { split: split.to_string(), rows: rows.len(), shards: chunks.len(), directory: split.to_string(), }) } fn write_shard( split: &str, shard_idx: usize, rows: &[SourceRow], split_dir: &Path, context: &EncodeContext, ) -> Result { let capacity = rows.len().saturating_mul(context.max_length); let mut input_ids = Vec::with_capacity(capacity); let mut attention_mask = Vec::with_capacity(capacity); let mut labels = Vec::with_capacity(capacity); for row in rows { let encoded = encode_row(row, context) .with_context(|| format!("failed to encode source line {}", row.row_index + 1))?; input_ids.extend_from_slice(&encoded.0); attention_mask.extend_from_slice(&encoded.1); labels.extend_from_slice(&encoded.2); } let base = format!("part-{split}-s{shard_idx:06}"); let input_name = format!("{base}.input_ids.npy"); let mask_name = format!("{base}.attention_mask.npy"); let label_name = format!("{base}.labels.npy"); write_npy_u16( &split_dir.join(&input_name), &input_ids, rows.len(), context.max_length, )?; write_npy_u8( &split_dir.join(&mask_name), &attention_mask, rows.len(), context.max_length, )?; write_npy_i16( &split_dir.join(&label_name), &labels, rows.len(), context.max_length, )?; Ok(ShardManifest { rows: rows.len(), input_ids: input_name, attention_mask: mask_name, labels: label_name, }) } fn encode_row(row: &SourceRow, context: &EncodeContext) -> Result<(Vec, Vec, Vec)> { let (tokens, labels) = labels_for_char_tokenizer(row); let mut input_ids = vec![context.vocab.pad_id; context.max_length]; let mut attention_mask = vec![0u8; context.max_length]; let mut label_ids = vec![-100i16; context.max_length]; input_ids[0] = context.vocab.cls_id; attention_mask[0] = 1; let available = context.max_length.saturating_sub(2); let token_count = tokens.len().min(labels.len()).min(available); for idx in 0..token_count { input_ids[idx + 1] = token_id(&context.vocab, &tokens[idx]); attention_mask[idx + 1] = 1; let label = canonical_bio_label(&labels[idx]); label_ids[idx + 1] = context .label_ids .get(&label) .copied() .with_context(|| format!("unknown label '{label}'"))?; } let sep_pos = token_count + 1; input_ids[sep_pos] = context.vocab.sep_id; attention_mask[sep_pos] = 1; Ok((input_ids, attention_mask, label_ids)) } fn labels_for_char_tokenizer(row: &SourceRow) -> (Vec, Vec) { let mut source_labels = row.labels.clone(); if let Some(filename) = row.filename.as_deref() { repair_known_label_issues(filename, &row.tokens, &mut source_labels); if row.tokenizer_variant.as_deref() == Some("char") { let filename_chars = chars_as_strings(filename); if row.tokens == filename_chars { return (row.tokens.clone(), source_labels); } } if let Some(projected) = project_labels_from_filename(filename, &row.tokens, &source_labels) { let (tokens, labels) = projected; return (tokens, labels); } } align_tokens_to_chars(&row.tokens, &source_labels) } fn project_labels_from_filename( filename: &str, source_tokens: &[String], source_labels: &[String], ) -> Option<(Vec, Vec)> { let offsets = token_offsets_in_text(filename, source_tokens)?; if offsets.len() != source_labels.len() { return None; } let char_len = filename.chars().count(); let mut char_entities: Vec> = vec![None; char_len]; for ((token, label), (mut start, mut end)) in source_tokens .iter() .zip(source_labels.iter()) .zip(offsets.into_iter()) { let Some(entity) = bio_entity(label) else { continue; }; if is_wrapped_token(token) && end > start + 1 { start += 1; end -= 1; } for pos in start..end.min(char_entities.len()) { char_entities[pos] = Some(entity.clone()); } } let tokens = chars_as_strings(filename); let mut labels = Vec::with_capacity(tokens.len()); let mut active_entity: Option = None; for entity in char_entities { match entity { Some(entity) => { let prefix = if active_entity.as_deref() == Some(entity.as_str()) { "I" } else { "B" }; labels.push(format!("{prefix}-{entity}")); active_entity = Some(entity); } None => { labels.push("O".to_string()); active_entity = None; } } } Some((tokens, labels)) } fn token_offsets_in_text(text: &str, tokens: &[String]) -> Option> { let mut offsets = Vec::with_capacity(tokens.len()); let mut cursor = 0usize; for token in tokens { if token.is_empty() { let char_cursor = char_index_at_byte(text, cursor); offsets.push((char_cursor, char_cursor)); continue; } let relative = text.get(cursor..)?.find(token)?; let start_byte = cursor + relative; let end_byte = start_byte + token.len(); offsets.push(( char_index_at_byte(text, start_byte), char_index_at_byte(text, end_byte), )); cursor = end_byte; } Some(offsets) } fn align_tokens_to_chars(tokens: &[String], labels: &[String]) -> (Vec, Vec) { let mut char_tokens = Vec::new(); let mut char_labels = Vec::new(); for (token, label) in tokens.iter().zip(labels.iter()) { let chars = chars_as_strings(token); if chars.is_empty() { continue; } let label = label.as_str(); if label.starts_with("B-") { let entity = label .split_once('-') .map(|(_, entity)| entity) .unwrap_or(""); char_labels.push(label.to_string()); char_labels.extend((1..chars.len()).map(|_| format!("I-{entity}"))); } else if label.starts_with("I-") { char_labels.extend((0..chars.len()).map(|_| label.to_string())); } else { char_labels.extend((0..chars.len()).map(|_| label.to_string())); } char_tokens.extend(chars); } (char_tokens, char_labels) } fn repair_structural_meta_labels( text: &str, _tokens: &[String], labels: &mut [String], offsets: &[(usize, usize)], ) { let episode_end = first_episode_span_end(labels, offsets, text); for (inner_start, inner_end) in bracket_inner_spans(text) { let bracket_start = inner_start.saturating_sub(1); if bracket_start < episode_end { continue; } let inner = chars_range_to_string(text, inner_start, inner_end); let (trim_start, trim_end) = trimmed_bounds(&inner); if trim_start >= trim_end { continue; } let clean = chars_slice_to_string(&inner, trim_start, trim_end); if special_tag_re().is_match(&clean) || special_code_re().is_match(&clean) { let indices = token_indices_for_span(offsets, inner_start, inner_end); label_span_if_safe(labels, &indices, "SPECIAL"); continue; } if source_tag_re().is_match(&clean) { let indices = token_indices_for_span(offsets, inner_start, inner_end); label_span_if_safe(labels, &indices, "SOURCE"); continue; } for mat in resolution_re() .find_iter(&clean) .filter_map(|item| item.ok()) { let start = inner_start + char_index_at_byte(&clean, mat.start()); let end = inner_start + char_index_at_byte(&clean, mat.end()); let indices = token_indices_for_span(offsets, start, end); label_span_if_safe(labels, &indices, "RESOLUTION"); } for mat in source_re().find_iter(&clean) { if !has_ascii_token_boundaries(&clean, mat.start(), mat.end()) { continue; } let start = inner_start + char_index_at_byte(&clean, mat.start()); let end = inner_start + char_index_at_byte(&clean, mat.end()); let indices = token_indices_for_span(offsets, start, end); label_span_if_safe(labels, &indices, "SOURCE"); } } for mat in resolution_re().find_iter(text).filter_map(|item| item.ok()) { let start = char_index_at_byte(text, mat.start()); if start < episode_end { continue; } let end = char_index_at_byte(text, mat.end()); let indices = token_indices_for_span(offsets, start, end); label_span_if_safe(labels, &indices, "RESOLUTION"); } for mat in source_re().find_iter(text) { if !has_ascii_token_boundaries(text, mat.start(), mat.end()) { continue; } let start = char_index_at_byte(text, mat.start()); if start < episode_end { continue; } let end = char_index_at_byte(text, mat.end()); let indices = token_indices_for_span(offsets, start, end); label_span_if_safe(labels, &indices, "SOURCE"); } } fn repair_known_label_issues(text: &str, tokens: &[String], labels: &mut [String]) { if tokens.len() != labels.len() { return; } let Some(offsets) = token_offsets_in_text(text, tokens) else { return; }; let quick_text = text.to_lowercase(); let has_sequel_marker_hint = [ " II", " III", " IV", " V", " VI", " VII", " VIII", " IX", "Ⅱ", "Ⅲ", "Ⅳ", "Ⅴ", "Ⅵ", "Ⅶ", "Ⅷ", "Ⅸ", "之章", "之期", "之季", "之部", "ノ章", "ノ期", "の章", "の期", "貳", "贰", "弐", "弍", "參", "叁", "参", "肆", "陸", "陆", "Ni ", " ni ", " no Sara", "Gakki", ] .iter() .any(|needle| text.contains(needle) || quick_text.contains(&needle.to_lowercase())); if has_sequel_marker_hint { for (start, end) in find_sequel_season_markers(text) { if labels_have_season_before(labels, &offsets, start) { continue; } let indices = token_indices_for_span(&offsets, start, end); if indices.is_empty() { continue; } if indices.iter().any(|idx| { matches!( label_entity(&labels[*idx]), Some( "GROUP" | "EPISODE" | "RESOLUTION" | "SOURCE" | "SPECIAL" | "TAG" | "PATH_SEASON" ) ) }) { continue; } if !indices.iter().any(|idx| is_title_like_label(&labels[*idx])) { continue; } label_span_indices(labels, &indices, "SEASON"); mark_adjacent_title_separators_o(tokens, labels, &indices); } } repair_structural_meta_labels(text, tokens, labels, &offsets); } fn find_sequel_season_markers(text: &str) -> Vec<(usize, usize)> { let mut repairs = Vec::new(); for mat in reading_marker_re() .find_iter(text) .filter_map(|item| item.ok()) { let marker = mat.as_str(); if season_marker_number(marker).is_none() || !has_episode_context(text, mat.end()) { continue; } repairs.push(( char_index_at_byte(text, mat.start()), char_index_at_byte(text, mat.end()), )); } for mat in roman_marker_re() .find_iter(text) .filter_map(|item| item.ok()) { let marker = mat.as_str(); if season_marker_number(marker).is_none() || !has_episode_context(text, mat.end()) { continue; } repairs.push(( char_index_at_byte(text, mat.start()), char_index_at_byte(text, mat.end()), )); } for mat in cjk_marker_re().find_iter(text) { let marker = mat.as_str(); if season_marker_number(marker).is_none() || !has_episode_context(text, mat.end()) { continue; } repairs.push(( char_index_at_byte(text, mat.start()), char_index_at_byte(text, mat.end()), )); } for (base, value) in standalone_ni_season_bases() { let mut search_start = 0usize; while let Some(relative) = text[search_start..].find(base) { let base_start = search_start + relative; let base_end = base_start + base.len(); let Some((ni_start, ni_end)) = standalone_ni_after_base(text, base_end) else { search_start = base_end; continue; }; if *value == 2 && has_episode_context(text, ni_end) && has_ascii_token_boundaries(text, ni_start, ni_end) { repairs.push(( char_index_at_byte(text, ni_start), char_index_at_byte(text, ni_end), )); } search_start = base_end; } } repairs.sort_by_key(|(start, end)| (*start, *end)); let mut deduped: Vec<(usize, usize)> = Vec::new(); for repair in repairs { if let Some(previous) = deduped.last_mut() { if repair.0 < previous.1 { if repair.1.saturating_sub(repair.0) > previous.1.saturating_sub(previous.0) { *previous = repair; } continue; } } deduped.push(repair); } deduped } fn season_marker_number(text: &str) -> Option { let clean = clean_marker_text(text); if clean.is_empty() { return None; } if let Some(value) = roman_numeral_value(&clean) { return Some(value); } let lowered = clean .split_whitespace() .collect::>() .join(" ") .to_lowercase(); if let Some(value) = reading_marker_value(&lowered) { return Some(value); } if lowered == "ni" { return Some(2); } if clean.starts_with('第') { if let Some(last) = clean.chars().last() { if matches!(last, '季' | '期' | '部' | '章') { let inner = clean .chars() .skip(1) .take(clean.chars().count().saturating_sub(2)) .collect::(); return cn_number_to_int(&inner); } } } let cjk_chars = clean.chars().collect::>(); if let Some(first) = cjk_chars.first() { if let Some(value) = cn_number_to_int(&first.to_string()) { let rest = cjk_chars.iter().skip(1).collect::(); if rest.trim().is_empty() || cjk_marker_suffix_remainder_ok(&rest) { return Some(value); } } } None } fn clean_marker_text(text: &str) -> String { text.trim() .trim_matches(|ch| { matches!( ch, '[' | ']' | '(' | ')' | '【' | '】' | '《' | '》' | '(' | ')' ) }) .trim() .to_string() } fn cn_number_to_int(text: &str) -> Option { let text = text.trim(); if text.is_empty() { return None; } if let Ok(value) = text.parse::() { return Some(value); } if let Some(value) = cn_digit_value(text) { return Some(value); } let chars = text.chars().collect::>(); if chars.len() == 2 && chars[0] == '十' { return Some(10 + cn_digit_value(&chars[1].to_string()).unwrap_or(0)); } if chars.len() == 2 && chars[1] == '十' { return Some(cn_digit_value(&chars[0].to_string()).unwrap_or(0) * 10); } if chars.len() == 3 && chars[1] == '十' { return Some( cn_digit_value(&chars[0].to_string()).unwrap_or(0) * 10 + cn_digit_value(&chars[2].to_string()).unwrap_or(0), ); } None } fn cn_digit_value(text: &str) -> Option { match text { "一" => Some(1), "二" | "兩" | "两" | "貳" | "贰" | "弐" | "弍" => Some(2), "三" | "參" | "叁" | "参" => Some(3), "四" | "肆" => Some(4), "五" | "伍" => Some(5), "六" | "陸" | "陆" => Some(6), "七" | "柒" => Some(7), "八" | "捌" => Some(8), "九" | "玖" => Some(9), "十" => Some(10), _ => None, } } fn roman_numeral_value(text: &str) -> Option { match text { "II" | "Ⅱ" => Some(2), "III" | "Ⅲ" => Some(3), "IV" | "Ⅳ" => Some(4), "V" | "Ⅴ" => Some(5), "VI" | "Ⅵ" => Some(6), "VII" | "Ⅶ" => Some(7), "VIII" | "Ⅷ" => Some(8), "IX" | "Ⅸ" => Some(9), _ => None, } } fn reading_marker_value(text: &str) -> Option { match text { "ni no sara" | "ni no shou" | "ni no sho" | "ni no syo" | "ni no shō" | "ni gakki" | "sono ni" => Some(2), "san no sara" | "san no shou" | "san no sho" | "san no syo" => Some(3), "yon no sara" | "shi no sara" | "shin no sara" => Some(4), "go no sara" | "gou no sara" => Some(5), _ => None, } } fn cjk_marker_suffix_remainder_ok(rest: &str) -> bool { let compact = rest.split_whitespace().collect::(); matches!( compact.as_str(), "ノ章" | "ノ期" | "ノ季" | "ノ部" | "の章" | "の期" | "の季" | "の部" | "之章" | "之期" | "之季" | "之部" ) } fn has_episode_context(text: &str, marker_end_byte: usize) -> bool { let tail = &text[marker_end_byte..]; if episode_context_re().is_match(tail) { return true; } let mut tail = tail.trim_start(); if let Some(ch) = tail.chars().next() { if matches!(ch, ']' | ')' | '】' | '》') { tail = &tail[ch.len_utf8()..]; tail = tail.trim_start(); } } if let Some(mat) = special_context_prefix_re().find(tail) { tail = &tail[mat.end()..]; } episode_context_re().is_match(tail) } fn first_episode_regex_end(text: &str) -> Option { episode_span_re() .find_iter(text) .filter_map(|item| item.ok()) .map(|mat| char_index_at_byte(text, mat.end())) .next() } fn labels_have_season_before( labels: &[String], offsets: &[(usize, usize)], marker_start: usize, ) -> bool { labels .iter() .zip(offsets.iter()) .any(|(label, (_start, end))| is_season_like_label(label) && *end <= marker_start) } fn token_indices_for_span(offsets: &[(usize, usize)], start: usize, end: usize) -> Vec { offsets .iter() .enumerate() .filter_map(|(idx, (token_start, token_end))| { if *token_start < end && *token_end > start { Some(idx) } else { None } }) .collect() } #[cfg(test)] fn label_span(labels: &mut [String], start: usize, end: usize, entity: &str) { let previous_same = start > 0 && label_entity(&labels[start - 1]) == Some(entity); let mut first = !previous_same; for label in labels.iter_mut().take(end).skip(start) { *label = if first { format!("B-{entity}") } else { format!("I-{entity}") }; first = false; } } fn label_span_indices(labels: &mut [String], indices: &[usize], entity: &str) { if indices.is_empty() { return; } let previous_same = indices[0] > 0 && label_entity(&labels[indices[0] - 1]) == Some(entity); let mut first = !previous_same; for idx in indices { labels[*idx] = if first { format!("B-{entity}") } else { format!("I-{entity}") }; first = false; } } fn mark_adjacent_title_separators_o( tokens: &[String], labels: &mut [String], marker_indices: &[usize], ) { if marker_indices.is_empty() { return; } let mut idx = marker_indices[0]; while idx > 0 { let prev = idx - 1; if !tokens[prev].trim().is_empty() || !is_title_like_label(&labels[prev]) { break; } labels[prev] = "O".to_string(); idx = prev; } let mut idx = marker_indices[marker_indices.len() - 1] + 1; while idx < tokens.len() && tokens[idx].chars().all(|ch| SEPARATOR_CHARS.contains(&ch)) && is_title_like_label(&labels[idx]) { labels[idx] = "O".to_string(); idx += 1; } } fn standalone_ni_season_bases() -> &'static [(&'static str, u8)] { &[("Kakuriyo no Yadomeshi", 2)] } fn standalone_ni_after_base(text: &str, base_end: usize) -> Option<(usize, usize)> { let mut cursor = base_end; while let Some(ch) = text[cursor..].chars().next() { if !ch.is_whitespace() { break; } cursor += ch.len_utf8(); } let ni_end = cursor.checked_add(2)?; if text.get(cursor..ni_end)? == "Ni" { Some((cursor, ni_end)) } else { None } } fn is_title_like_label(label: &str) -> bool { matches!( label_entity(label), Some( "TITLE" | "TITLE_CHS" | "TITLE_CHT" | "TITLE_JPN" | "TITLE_LATIN" | "TITLE_MIXED" | "PATH_TITLE_CHS" | "PATH_TITLE_CHT" | "PATH_TITLE_JPN" | "PATH_TITLE_LATIN" | "PATH_TITLE_MIXED" ) ) } fn is_season_like_label(label: &str) -> bool { matches!(label_entity(label), Some("SEASON" | "PATH_SEASON")) } fn first_episode_span_end(labels: &[String], offsets: &[(usize, usize)], text: &str) -> usize { let ends = labels .iter() .zip(offsets.iter()) .filter_map(|(label, (_start, end))| { if label_entity(label) == Some("EPISODE") { Some(*end) } else { None } }) .collect::>(); if let Some(end) = ends.into_iter().min() { return end; } first_episode_regex_end(text).unwrap_or(0) } fn bracket_inner_spans(text: &str) -> Vec<(usize, usize)> { let chars = text.chars().collect::>(); let mut spans = Vec::new(); let mut idx = 0usize; while idx < chars.len() { let close = match chars[idx] { '[' => ']', '(' => ')', '【' => '】', '《' => '》', _ => { idx += 1; continue; } }; if let Some(relative_end) = chars[idx + 1..].iter().position(|ch| *ch == close) { let end = idx + 1 + relative_end; spans.push((idx + 1, end)); idx = end + 1; } else { idx += 1; } } spans } fn trimmed_bounds(text: &str) -> (usize, usize) { let chars = text.chars().collect::>(); let mut start = 0usize; let mut end = chars.len(); while start < end && chars[start].is_whitespace() { start += 1; } while end > start && chars[end - 1].is_whitespace() { end -= 1; } (start, end) } fn chars_range_to_string(text: &str, start: usize, end: usize) -> String { text.chars() .skip(start) .take(end.saturating_sub(start)) .collect() } fn chars_slice_to_string(text: &str, start: usize, end: usize) -> String { text.chars() .skip(start) .take(end.saturating_sub(start)) .collect() } fn label_span_if_safe(labels: &mut [String], indices: &[usize], entity: &str) { if indices.is_empty() { return; } if indices.iter().any(|idx| { matches!( label_entity(&labels[*idx]), Some("GROUP" | "EPISODE" | "SEASON" | "PATH_SEASON") ) }) { return; } label_span_indices(labels, indices, entity); } fn has_ascii_token_boundaries(text: &str, start: usize, end: usize) -> bool { let previous_ok = text[..start] .chars() .next_back() .map(|ch| !ch.is_ascii_alphanumeric()) .unwrap_or(true); let next_ok = text[end..] .chars() .next() .map(|ch| !ch.is_ascii_alphanumeric()) .unwrap_or(true); previous_ok && next_ok } fn label_entity(label: &str) -> Option<&str> { let (prefix, entity) = label.split_once('-')?; if prefix == "B" || prefix == "I" { Some(entity) } else { None } } fn resolution_re() -> &'static FancyRegex { RESOLUTION_RE.get_or_init(|| { FancyRegex::new( r"(?i)(? &'static Regex { SOURCE_RE.get_or_init(|| Regex::new(&format!(r"(?i)(?:{SOURCE_TOKEN_PATTERN})")).unwrap()) } fn source_tag_re() -> &'static Regex { SOURCE_TAG_RE.get_or_init(|| { Regex::new(&format!( r"(?i)^(?:{SOURCE_TOKEN_PATTERN})(?:\s*(?:[&+/,_-]|,\s*)\s*(?:{SOURCE_TOKEN_PATTERN}))*$" )) .unwrap() }) } fn special_tag_re() -> &'static Regex { SPECIAL_TAG_RE.get_or_init(|| { Regex::new(r"(?i)^(?:檢索|检索|搜索|搜寻|搜尋|别名|別名|alias|search|keyword)\s*[::].+") .unwrap() }) } fn special_code_re() -> &'static Regex { SPECIAL_CODE_RE.get_or_init(|| { Regex::new(r"(?i)^(?:NCOP|NCED|OP|ED|PV|CM)\d*$|^IV\d+$|^(?:OVA|OAD|SP)\d*$").unwrap() }) } fn episode_context_re() -> &'static Regex { EPISODE_CONTEXT_RE.get_or_init(|| { Regex::new( r"(?i)^\s*(?:[-_]\s*(?:\d{1,4}|NCOP|NCED|OP|ED|OVA|OAD|SP|END)\b|#\s*\d{1,4}|[\[\(【《]\s*(?:EP?|#)?\d{1,4})", ) .unwrap() }) } fn episode_span_re() -> &'static FancyRegex { EPISODE_SPAN_RE.get_or_init(|| { FancyRegex::new( r"(?i)(?:[Ss]\d{1,2}[Ee]\d{1,4}(?:v\d+)?|(?:^|[\s._])[-_]\s*\d{1,4}(?:v\d+)?(?=$|[\s._\-\]\)】》\[])|[\[\(【《](?:EP?|#)?\d{1,4}(?:v\d+)?[\]\)】》]|(?:^|[\s._\-\[\(【《#])(?:EP?|第|#)\d{1,4}(?:v\d+)?(?:[话話集])?(?=$|[\s._\-\]\)】》]))", ) .unwrap() }) } fn reading_marker_re() -> &'static FancyRegex { READING_MARKER_RE.get_or_init(|| { FancyRegex::new( r"(?i)(?Ni\s+no\s+(?:Sara|Shou|Sho|Syo|Shō)|San\s+no\s+(?:Sara|Shou|Sho|Syo)|(?:Yon|Shi|Shin)\s+no\s+Sara|(?:Go|Gou)\s+no\s+Sara|Ni\s+Gakki|Sono\s+Ni)(?![A-Za-z0-9])", ) .unwrap() }) } fn roman_marker_re() -> &'static FancyRegex { ROMAN_MARKER_RE.get_or_init(|| { FancyRegex::new( r"(?II|III|IV|V|VI|VII|VIII|IX|[ⅡⅢⅣⅤⅥⅦⅧⅨ])(?![A-Za-z0-9])", ) .unwrap() }) } fn cjk_marker_re() -> &'static Regex { CJK_MARKER_RE.get_or_init(|| { Regex::new( r"(?:[一二三四五六七八九十兩两貳贰弐弍參叁参肆伍陸陆柒捌玖](?:\s*(?:ノ|の|之)\s*(?:章|期|季|部))?|第[一二三四五六七八九十兩两貳贰弐弍參叁参肆伍陸陆柒捌玖\d]+[季期部章])", ) .unwrap() }) } fn special_context_prefix_re() -> &'static Regex { SPECIAL_CONTEXT_PREFIX_RE.get_or_init(|| { Regex::new( r"(?i)^(?:[\[\(【《]\s*(?:menu|menus|bdmenu|ncop|nced|op|ed|ova|oad|sp)\s*[\]\)】》]\s*){0,2}", ) .unwrap() }) } fn chars_as_strings(text: &str) -> Vec { text.chars().map(|ch| ch.to_string()).collect() } fn char_index_at_byte(text: &str, byte_index: usize) -> usize { text[..byte_index].chars().count() } fn bio_entity(label: &str) -> Option { let (prefix, entity) = label.split_once('-')?; if prefix == "B" || prefix == "I" { Some(entity.to_string()) } else { None } } fn is_wrapped_token(token: &str) -> bool { let mut chars = token.chars(); let Some(first) = chars.next() else { return false; }; let Some(last) = token.chars().last() else { return false; }; matches!(first, '[' | '【' | '(' | '《') && matches!(last, ']' | '】' | ')' | '》') } fn canonical_bio_label(label: &str) -> String { let Some((prefix, entity)) = label.split_once('-') else { return if label == "O" { "O".to_string() } else { label.to_string() }; }; if prefix != "B" && prefix != "I" { return label.to_string(); } let canonical_entity = match entity { "TITLE" => "TITLE_MIXED", "PATH_TITLE" => "PATH_TITLE_MIXED", other => other, }; format!("{prefix}-{canonical_entity}") } fn token_id(vocab: &Vocab, token: &str) -> u16 { *vocab.ids.get(token).unwrap_or(&vocab.unk_id) } fn write_eval_records(rows: &[SourceRow], path: &Path) -> Result<()> { let mut writer = BufWriter::new( File::create(path).with_context(|| format!("failed to create {}", path.display()))?, ); for row in rows { writer.write_all(row.raw_line.as_bytes())?; writer.write_all(b"\n")?; } Ok(()) } fn write_npy_u16(path: &Path, data: &[u16], rows: usize, cols: usize) -> Result<()> { let mut writer = BufWriter::new( File::create(path).with_context(|| format!("failed to create {}", path.display()))?, ); write_npy_header(&mut writer, " Result<()> { let mut writer = BufWriter::new( File::create(path).with_context(|| format!("failed to create {}", path.display()))?, ); write_npy_header(&mut writer, "|u1", rows, cols)?; writer.write_all(data)?; Ok(()) } fn write_npy_i16(path: &Path, data: &[i16], rows: usize, cols: usize) -> Result<()> { let mut writer = BufWriter::new( File::create(path).with_context(|| format!("failed to create {}", path.display()))?, ); write_npy_header(&mut writer, "(writer: &mut W, descr: &str, rows: usize, cols: usize) -> Result<()> { let mut header = format!( "{{'descr': '{}', 'fortran_order': False, 'shape': ({}, {}), }}", descr, rows, cols ) .into_bytes(); let preamble_len = 10usize; let pad_len = (16 - ((preamble_len + header.len() + 1) % 16)) % 16; header.extend(std::iter::repeat(b' ').take(pad_len)); header.push(b'\n'); if header.len() > u16::MAX as usize { bail!("npy header too large"); } writer.write_all(b"\x93NUMPY")?; writer.write_all(&[1, 0])?; writer.write_all(&(header.len() as u16).to_le_bytes())?; writer.write_all(&header)?; Ok(()) } #[cfg(test)] mod tests { use super::*; fn char_row( text: &str, title_spans: &[(usize, usize)], episode_spans: &[(usize, usize)], ) -> SourceRow { let tokens = chars_as_strings(text); let mut labels = vec!["O".to_string(); tokens.len()]; for (start, end) in title_spans { label_span(&mut labels, *start, *end, "TITLE_LATIN"); } for (start, end) in episode_spans { label_span(&mut labels, *start, *end, "EPISODE"); } SourceRow { row_index: 0, raw_line: String::new(), filename: Some(text.to_string()), tokens, labels, tokenizer_variant: Some("char".to_string()), } } #[test] fn repairs_cjk_sequel_marker_in_char_fast_path() { let text = "妖怪旅館營業中 貳 - 11"; let title_end = char_index_at_byte(text, text.find(" - ").unwrap()); let episode_start = char_index_at_byte(text, text.find("11").unwrap()); let row = char_row( text, &[(0, title_end)], &[(episode_start, episode_start + 2)], ); let (_tokens, labels) = labels_for_char_tokenizer(&row); let marker = char_index_at_byte(text, text.find('貳').unwrap()); let before_marker = marker - 1; assert_eq!(labels[before_marker], "O"); assert_eq!(labels[marker], "B-SEASON"); assert_eq!(labels[episode_start], "B-EPISODE"); } #[test] fn repairs_reading_sequel_marker() { let text = "Shokugeki no Souma Ni no Sara - 13"; let title_end = text.find(" - ").unwrap(); let episode_start = text.find("13").unwrap(); let row = char_row( text, &[(0, title_end)], &[(episode_start, episode_start + 2)], ); let (_tokens, labels) = labels_for_char_tokenizer(&row); let marker_start = text.find("Ni").unwrap(); let marker_end = text.find(" - ").unwrap(); assert_eq!(labels[marker_start - 1], "O"); assert_eq!(labels[marker_start], "B-SEASON"); assert!(labels[marker_start + 1..marker_end] .iter() .all(|label| label == "I-SEASON")); } #[test] fn keeps_numeric_title_suffix_out_of_sequel_repair() { let text = "Kamisama Hajimemashita 2 - 01"; let title_end = text.find(" - ").unwrap(); let episode_start = text.find("01").unwrap(); let row = char_row( text, &[(0, title_end)], &[(episode_start, episode_start + 2)], ); let (_tokens, labels) = labels_for_char_tokenizer(&row); let suffix = text.find('2').unwrap(); assert_eq!(labels[suffix], "I-TITLE_LATIN"); assert!(!labels .iter() .any(|label| label_entity(label) == Some("SEASON"))); } #[test] fn skips_alias_marker_when_season_already_exists() { let text = "樱桃小丸子第二期(Chibi Maruko-chan II)[1439]"; let tokens = chars_as_strings(text); let mut labels = vec!["O".to_string(); tokens.len()]; let title_end = char_index_at_byte(text, text.find("第二期").unwrap()); label_span(&mut labels, 0, title_end, "TITLE_CHS"); let season_start = title_end; let season_end = season_start + "第二期".chars().count(); label_span(&mut labels, season_start, season_end, "SEASON"); let alias_start = char_index_at_byte(text, text.find("Chibi").unwrap()); let alias_end = char_index_at_byte(text, text.find(")").unwrap()); label_span(&mut labels, alias_start, alias_end, "TITLE_LATIN"); let episode_start = char_index_at_byte(text, text.find("1439").unwrap()); label_span(&mut labels, episode_start, episode_start + 4, "EPISODE"); let row = SourceRow { row_index: 0, raw_line: String::new(), filename: Some(text.to_string()), tokens, labels, tokenizer_variant: Some("char".to_string()), }; let (_tokens, labels) = labels_for_char_tokenizer(&row); let roman = char_index_at_byte(text, text.find("II").unwrap()); assert_eq!(labels[roman], "I-TITLE_LATIN"); } }