Token Classification
Transformers
ONNX
Safetensors
English
Japanese
Chinese
bert
anime
filename-parsing
Eval Results (legacy)
Instructions to use ModerRAS/AniFileBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ModerRAS/AniFileBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="ModerRAS/AniFileBERT")# Load model directly from transformers import AutoTokenizer, AutoModelForTokenClassification tokenizer = AutoTokenizer.from_pretrained("ModerRAS/AniFileBERT") model = AutoModelForTokenClassification.from_pretrained("ModerRAS/AniFileBERT") - Notebooks
- Google Colab
- Kaggle
| use anyhow::{bail, Context, Result}; | |
| use clap::{Parser, ValueEnum}; | |
| use rand::rngs::StdRng; | |
| use rand::seq::SliceRandom; | |
| use rand::Rng; | |
| use rand::SeedableRng; | |
| use rayon::prelude::*; | |
| use serde::{Deserialize, Serialize}; | |
| use serde_json::json; | |
| use std::collections::{HashMap, HashSet}; | |
| use std::fs::{self, File}; | |
| use std::io::{BufRead, BufReader, BufWriter, Write}; | |
| use std::path::{Path, PathBuf}; | |
| use std::sync::OnceLock; | |
| use std::time::Instant; | |
| const FILE_TITLE_ENTITIES: [Entity; 5] = [ | |
| Entity::TitleChs, | |
| Entity::TitleCht, | |
| Entity::TitleJpn, | |
| Entity::TitleLatin, | |
| Entity::TitleMixed, | |
| ]; | |
| const PATH_TITLE_ENTITIES: [Entity; 5] = [ | |
| Entity::PathTitleChs, | |
| Entity::PathTitleCht, | |
| Entity::PathTitleJpn, | |
| Entity::PathTitleLatin, | |
| Entity::PathTitleMixed, | |
| ]; | |
| const ENTITIES: [Entity; 18] = [ | |
| Entity::Group, | |
| Entity::TitleChs, | |
| Entity::TitleCht, | |
| Entity::TitleJpn, | |
| Entity::TitleLatin, | |
| Entity::TitleMixed, | |
| Entity::PathTitleChs, | |
| Entity::PathTitleCht, | |
| Entity::PathTitleJpn, | |
| Entity::PathTitleLatin, | |
| Entity::PathTitleMixed, | |
| Entity::PathSeason, | |
| Entity::Season, | |
| Entity::Episode, | |
| Entity::Special, | |
| Entity::Resolution, | |
| Entity::Source, | |
| Entity::Tag, | |
| ]; | |
| const FALLBACK_LABELS: [&str; 37] = [ | |
| "O", | |
| "B-TITLE_CHS", | |
| "I-TITLE_CHS", | |
| "B-TITLE_CHT", | |
| "I-TITLE_CHT", | |
| "B-TITLE_JPN", | |
| "I-TITLE_JPN", | |
| "B-TITLE_LATIN", | |
| "I-TITLE_LATIN", | |
| "B-TITLE_MIXED", | |
| "I-TITLE_MIXED", | |
| "B-PATH_TITLE_CHS", | |
| "I-PATH_TITLE_CHS", | |
| "B-PATH_TITLE_CHT", | |
| "I-PATH_TITLE_CHT", | |
| "B-PATH_TITLE_JPN", | |
| "I-PATH_TITLE_JPN", | |
| "B-PATH_TITLE_LATIN", | |
| "I-PATH_TITLE_LATIN", | |
| "B-PATH_TITLE_MIXED", | |
| "I-PATH_TITLE_MIXED", | |
| "B-PATH_SEASON", | |
| "I-PATH_SEASON", | |
| "B-SEASON", | |
| "I-SEASON", | |
| "B-EPISODE", | |
| "I-EPISODE", | |
| "B-SPECIAL", | |
| "I-SPECIAL", | |
| "B-GROUP", | |
| "I-GROUP", | |
| "B-RESOLUTION", | |
| "I-RESOLUTION", | |
| "B-SOURCE", | |
| "I-SOURCE", | |
| "B-TAG", | |
| "I-TAG", | |
| ]; | |
| static LABEL_IDS: OnceLock<HashMap<String, i16>> = OnceLock::new(); | |
| struct LabelSchema { | |
| labels: Vec<String>, | |
| } | |
| struct Args { | |
| input: PathBuf, | |
| vocab_file: PathBuf, | |
| output_dir: PathBuf, | |
| max_length: usize, | |
| shard_size: usize, | |
| limit_rows: usize, | |
| samples_per_source: usize, | |
| path_samples_per_source: usize, | |
| seed: u64, | |
| threads: usize, | |
| separator_mode: SeparatorMode, | |
| bracket_mode: BracketMode, | |
| separators: Vec<String>, | |
| bracket_styles: Vec<String>, | |
| path_styles: Vec<PathStyle>, | |
| include_original: bool, | |
| no_original: bool, | |
| no_bio_variants: bool, | |
| include_special_fixtures: bool, | |
| no_special_fixtures: bool, | |
| dry_run: bool, | |
| } | |
| enum SeparatorMode { | |
| Global, | |
| PerGap, | |
| } | |
| enum BracketMode { | |
| Global, | |
| PerPart, | |
| } | |
| enum PathStyle { | |
| Windows, | |
| Unix, | |
| } | |
| impl PathStyle { | |
| fn separator(self) -> &'static str { | |
| match self { | |
| PathStyle::Windows => "\\", | |
| PathStyle::Unix => "/", | |
| } | |
| } | |
| } | |
| enum Entity { | |
| Group, | |
| TitleChs, | |
| TitleCht, | |
| TitleJpn, | |
| TitleLatin, | |
| TitleMixed, | |
| PathTitleChs, | |
| PathTitleCht, | |
| PathTitleJpn, | |
| PathTitleLatin, | |
| PathTitleMixed, | |
| PathSeason, | |
| Season, | |
| Episode, | |
| Special, | |
| Resolution, | |
| Source, | |
| Tag, | |
| } | |
| impl Entity { | |
| fn index(self) -> usize { | |
| ENTITIES | |
| .iter() | |
| .position(|entity| *entity == self) | |
| .expect("entity missing from ENTITIES") | |
| } | |
| fn from_name(name: &str) -> Option<Self> { | |
| match name { | |
| "GROUP" => Some(Entity::Group), | |
| "TITLE" | "TITLE_MIXED" => Some(Entity::TitleMixed), | |
| "TITLE_CHS" => Some(Entity::TitleChs), | |
| "TITLE_CHT" => Some(Entity::TitleCht), | |
| "TITLE_JPN" => Some(Entity::TitleJpn), | |
| "TITLE_LATIN" => Some(Entity::TitleLatin), | |
| "PATH_TITLE" | "PATH_TITLE_MIXED" => Some(Entity::PathTitleMixed), | |
| "PATH_TITLE_CHS" => Some(Entity::PathTitleChs), | |
| "PATH_TITLE_CHT" => Some(Entity::PathTitleCht), | |
| "PATH_TITLE_JPN" => Some(Entity::PathTitleJpn), | |
| "PATH_TITLE_LATIN" => Some(Entity::PathTitleLatin), | |
| "PATH_SEASON" => Some(Entity::PathSeason), | |
| "SEASON" => Some(Entity::Season), | |
| "EPISODE" => Some(Entity::Episode), | |
| "SPECIAL" => Some(Entity::Special), | |
| "RESOLUTION" => Some(Entity::Resolution), | |
| "SOURCE" => Some(Entity::Source), | |
| "TAG" => Some(Entity::Tag), | |
| _ => None, | |
| } | |
| } | |
| fn b_label(self) -> &'static str { | |
| match self { | |
| Entity::Group => "B-GROUP", | |
| Entity::TitleChs => "B-TITLE_CHS", | |
| Entity::TitleCht => "B-TITLE_CHT", | |
| Entity::TitleJpn => "B-TITLE_JPN", | |
| Entity::TitleLatin => "B-TITLE_LATIN", | |
| Entity::TitleMixed => "B-TITLE_MIXED", | |
| Entity::PathTitleChs => "B-PATH_TITLE_CHS", | |
| Entity::PathTitleCht => "B-PATH_TITLE_CHT", | |
| Entity::PathTitleJpn => "B-PATH_TITLE_JPN", | |
| Entity::PathTitleLatin => "B-PATH_TITLE_LATIN", | |
| Entity::PathTitleMixed => "B-PATH_TITLE_MIXED", | |
| Entity::PathSeason => "B-PATH_SEASON", | |
| Entity::Season => "B-SEASON", | |
| Entity::Episode => "B-EPISODE", | |
| Entity::Special => "B-SPECIAL", | |
| Entity::Resolution => "B-RESOLUTION", | |
| Entity::Source => "B-SOURCE", | |
| Entity::Tag => "B-TAG", | |
| } | |
| } | |
| fn i_label(self) -> &'static str { | |
| match self { | |
| Entity::Group => "I-GROUP", | |
| Entity::TitleChs => "I-TITLE_CHS", | |
| Entity::TitleCht => "I-TITLE_CHT", | |
| Entity::TitleJpn => "I-TITLE_JPN", | |
| Entity::TitleLatin => "I-TITLE_LATIN", | |
| Entity::TitleMixed => "I-TITLE_MIXED", | |
| Entity::PathTitleChs => "I-PATH_TITLE_CHS", | |
| Entity::PathTitleCht => "I-PATH_TITLE_CHT", | |
| Entity::PathTitleJpn => "I-PATH_TITLE_JPN", | |
| Entity::PathTitleLatin => "I-PATH_TITLE_LATIN", | |
| Entity::PathTitleMixed => "I-PATH_TITLE_MIXED", | |
| Entity::PathSeason => "I-PATH_SEASON", | |
| Entity::Season => "I-SEASON", | |
| Entity::Episode => "I-EPISODE", | |
| Entity::Special => "I-SPECIAL", | |
| Entity::Resolution => "I-RESOLUTION", | |
| Entity::Source => "I-SOURCE", | |
| Entity::Tag => "I-TAG", | |
| } | |
| } | |
| fn is_file_title(self) -> bool { | |
| matches!( | |
| self, | |
| Entity::TitleChs | |
| | Entity::TitleCht | |
| | Entity::TitleJpn | |
| | Entity::TitleLatin | |
| | Entity::TitleMixed | |
| ) | |
| } | |
| fn is_path_title(self) -> bool { | |
| matches!( | |
| self, | |
| Entity::PathTitleChs | |
| | Entity::PathTitleCht | |
| | Entity::PathTitleJpn | |
| | Entity::PathTitleLatin | |
| | Entity::PathTitleMixed | |
| ) | |
| } | |
| fn is_ordinary_variant_entity(self) -> bool { | |
| !self.is_path_title() && self != Entity::PathSeason | |
| } | |
| fn as_path_title(self) -> Option<Self> { | |
| match self { | |
| Entity::TitleChs => Some(Entity::PathTitleChs), | |
| Entity::TitleCht => Some(Entity::PathTitleCht), | |
| Entity::TitleJpn => Some(Entity::PathTitleJpn), | |
| Entity::TitleLatin => Some(Entity::PathTitleLatin), | |
| Entity::TitleMixed => Some(Entity::PathTitleMixed), | |
| Entity::PathTitleChs | |
| | Entity::PathTitleCht | |
| | Entity::PathTitleJpn | |
| | Entity::PathTitleLatin | |
| | Entity::PathTitleMixed => Some(self), | |
| _ => None, | |
| } | |
| } | |
| fn as_file_title(self) -> Option<Self> { | |
| match self { | |
| Entity::PathTitleChs => Some(Entity::TitleChs), | |
| Entity::PathTitleCht => Some(Entity::TitleCht), | |
| Entity::PathTitleJpn => Some(Entity::TitleJpn), | |
| Entity::PathTitleLatin => Some(Entity::TitleLatin), | |
| Entity::PathTitleMixed => Some(Entity::TitleMixed), | |
| Entity::TitleChs | |
| | Entity::TitleCht | |
| | Entity::TitleJpn | |
| | Entity::TitleLatin | |
| | Entity::TitleMixed => Some(self), | |
| _ => None, | |
| } | |
| } | |
| } | |
| struct Bracket { | |
| name: String, | |
| open: String, | |
| close: String, | |
| } | |
| impl Bracket { | |
| fn from_name(name: &str) -> Result<Self> { | |
| let trimmed = name.trim(); | |
| let pair = match trimmed { | |
| "none" => ("", ""), | |
| "square" => ("[", "]"), | |
| "round" => ("(", ")"), | |
| "corner" => ("【", "】"), | |
| "angle" => ("《", "》"), | |
| custom if custom.contains('|') => { | |
| let mut parts = custom.splitn(2, '|'); | |
| let open = parts.next().unwrap_or_default(); | |
| let close = parts.next().unwrap_or_default(); | |
| return Ok(Self { | |
| name: custom.to_string(), | |
| open: open.to_string(), | |
| close: close.to_string(), | |
| }); | |
| } | |
| other => bail!("unknown bracket style '{other}'"), | |
| }; | |
| Ok(Self { | |
| name: trimmed.to_string(), | |
| open: pair.0.to_string(), | |
| close: pair.1.to_string(), | |
| }) | |
| } | |
| } | |
| struct InputRow { | |
| filename: Option<String>, | |
| tokens: Vec<String>, | |
| labels: Vec<String>, | |
| tokenizer_variant: Option<String>, | |
| } | |
| struct SourceSample { | |
| row_index: usize, | |
| filename: String, | |
| tokens: Vec<String>, | |
| labels: Vec<String>, | |
| fields: Vec<Vec<String>>, | |
| } | |
| struct GenConfig { | |
| max_length: usize, | |
| shard_size: usize, | |
| separator_mode: SeparatorMode, | |
| bracket_mode: BracketMode, | |
| separators: Vec<String>, | |
| brackets: Vec<Bracket>, | |
| path_styles: Vec<PathStyle>, | |
| include_original: bool, | |
| include_bio_variants: bool, | |
| samples_per_source: usize, | |
| path_samples_per_source: usize, | |
| seed: u64, | |
| } | |
| struct Vocab { | |
| ids: HashMap<String, u16>, | |
| pad_id: u16, | |
| unk_id: u16, | |
| cls_id: u16, | |
| sep_id: u16, | |
| } | |
| struct ShardManifest { | |
| rows: usize, | |
| input_ids: String, | |
| attention_mask: String, | |
| labels: String, | |
| } | |
| struct ShardWriter { | |
| output_dir: PathBuf, | |
| worker_id: usize, | |
| shard_seq: usize, | |
| shard_size: usize, | |
| max_length: usize, | |
| input_ids: Vec<u16>, | |
| attention_mask: Vec<u8>, | |
| labels: Vec<i16>, | |
| rows: usize, | |
| total_rows: u64, | |
| shards: Vec<ShardManifest>, | |
| } | |
| impl ShardWriter { | |
| fn new(output_dir: &Path, worker_id: usize, shard_size: usize, max_length: usize) -> Self { | |
| let capacity = shard_size.saturating_mul(max_length); | |
| Self { | |
| output_dir: output_dir.to_path_buf(), | |
| worker_id, | |
| shard_seq: 0, | |
| shard_size, | |
| max_length, | |
| input_ids: Vec::with_capacity(capacity), | |
| attention_mask: Vec::with_capacity(capacity), | |
| labels: Vec::with_capacity(capacity), | |
| rows: 0, | |
| total_rows: 0, | |
| shards: Vec::new(), | |
| } | |
| } | |
| fn add(&mut self, input_ids: &[u16], attention_mask: &[u8], labels: &[i16]) -> Result<()> { | |
| if input_ids.len() != self.max_length | |
| || attention_mask.len() != self.max_length | |
| || labels.len() != self.max_length | |
| { | |
| bail!("encoded sample has wrong shape"); | |
| } | |
| self.input_ids.extend_from_slice(input_ids); | |
| self.attention_mask.extend_from_slice(attention_mask); | |
| self.labels.extend_from_slice(labels); | |
| self.rows += 1; | |
| self.total_rows += 1; | |
| if self.rows >= self.shard_size { | |
| self.flush()?; | |
| } | |
| Ok(()) | |
| } | |
| fn flush(&mut self) -> Result<()> { | |
| if self.rows == 0 { | |
| return Ok(()); | |
| } | |
| let base = format!("part-w{:03}-s{:06}", self.worker_id, self.shard_seq); | |
| let input_name = format!("{base}.input_ids.npy"); | |
| let mask_name = format!("{base}.attention_mask.npy"); | |
| let label_name = format!("{base}.labels.npy"); | |
| write_npy_u16( | |
| &self.output_dir.join(&input_name), | |
| &self.input_ids, | |
| self.rows, | |
| self.max_length, | |
| )?; | |
| write_npy_u8( | |
| &self.output_dir.join(&mask_name), | |
| &self.attention_mask, | |
| self.rows, | |
| self.max_length, | |
| )?; | |
| write_npy_i16( | |
| &self.output_dir.join(&label_name), | |
| &self.labels, | |
| self.rows, | |
| self.max_length, | |
| )?; | |
| self.shards.push(ShardManifest { | |
| rows: self.rows, | |
| input_ids: input_name, | |
| attention_mask: mask_name, | |
| labels: label_name, | |
| }); | |
| self.input_ids.clear(); | |
| self.attention_mask.clear(); | |
| self.labels.clear(); | |
| self.rows = 0; | |
| self.shard_seq += 1; | |
| Ok(()) | |
| } | |
| } | |
| fn main() -> Result<()> { | |
| let args = Args::parse(); | |
| let include_original = args.include_original && !args.no_original; | |
| let include_bio_variants = !args.no_bio_variants; | |
| let include_special_fixtures = args.include_special_fixtures && !args.no_special_fixtures; | |
| if args.max_length < 4 { | |
| bail!("--max-length must be at least 4"); | |
| } | |
| if args.shard_size == 0 { | |
| bail!("--shard-size must be positive"); | |
| } | |
| if args.threads > 0 { | |
| rayon::ThreadPoolBuilder::new() | |
| .num_threads(args.threads) | |
| .build_global() | |
| .context("failed to configure rayon thread pool")?; | |
| } | |
| let started = Instant::now(); | |
| let vocab = load_vocab(&args.vocab_file)?; | |
| let brackets = args | |
| .bracket_styles | |
| .iter() | |
| .map(|style| Bracket::from_name(style)) | |
| .collect::<Result<Vec<_>>>()?; | |
| let separators = args | |
| .separators | |
| .iter() | |
| .map(|sep| normalize_separator_arg(sep)) | |
| .collect::<Vec<_>>(); | |
| let cfg = GenConfig { | |
| max_length: args.max_length, | |
| shard_size: args.shard_size, | |
| separator_mode: args.separator_mode, | |
| bracket_mode: args.bracket_mode, | |
| separators, | |
| brackets, | |
| path_styles: args.path_styles.clone(), | |
| include_original, | |
| include_bio_variants, | |
| samples_per_source: args.samples_per_source, | |
| path_samples_per_source: args.path_samples_per_source, | |
| seed: args.seed, | |
| }; | |
| let mut samples = load_samples(&args.input, args.limit_rows)?; | |
| let source_rows = samples.len(); | |
| let mut rng = StdRng::seed_from_u64(args.seed); | |
| samples.shuffle(&mut rng); | |
| let path_variant_rows: u128 = samples | |
| .par_iter() | |
| .map(|sample| count_path_variants(sample, &cfg) as u128) | |
| .sum(); | |
| if args.dry_run { | |
| let generated: u128 = samples | |
| .par_iter() | |
| .map(|sample| count_variants(sample, &cfg)) | |
| .sum(); | |
| let special_fixtures = if include_special_fixtures { | |
| count_special_fixtures(&cfg) as u128 | |
| } else { | |
| 0 | |
| }; | |
| let manifest = json!({ | |
| "format": "anifilebert.virtual_dataset.preview.v1", | |
| "input": args.input, | |
| "vocab_file": args.vocab_file, | |
| "source_rows": source_rows, | |
| "estimated_rows": generated + special_fixtures, | |
| "source_variant_rows": generated, | |
| "path_variant_rows": path_variant_rows, | |
| "special_fixture_rows": special_fixtures, | |
| "max_length": cfg.max_length, | |
| "separator_mode": cfg.separator_mode, | |
| "bracket_mode": cfg.bracket_mode, | |
| "separators": cfg.separators, | |
| "brackets": cfg.brackets.iter().map(|b| &b.name).collect::<Vec<_>>(), | |
| "include_original": cfg.include_original, | |
| "include_bio_variants": cfg.include_bio_variants, | |
| "samples_per_source": cfg.samples_per_source, | |
| "path_samples_per_source": cfg.path_samples_per_source, | |
| "path_styles": cfg.path_styles, | |
| "include_special_fixtures": include_special_fixtures, | |
| "seed": args.seed, | |
| "elapsed_seconds": started.elapsed().as_secs_f64(), | |
| }); | |
| println!("{}", serde_json::to_string_pretty(&manifest)?); | |
| return Ok(()); | |
| } | |
| fs::create_dir_all(&args.output_dir).with_context(|| { | |
| format!( | |
| "failed to create output directory {}", | |
| args.output_dir.display() | |
| ) | |
| })?; | |
| let chunk_count = rayon::current_num_threads().max(1) * 4; | |
| let chunk_size = samples.len().div_ceil(chunk_count).max(1); | |
| let chunks = samples | |
| .chunks(chunk_size) | |
| .enumerate() | |
| .collect::<Vec<(usize, &[SourceSample])>>(); | |
| let mut worker_results = chunks | |
| .par_iter() | |
| .map(|(chunk_idx, chunk)| { | |
| let mut writer = | |
| ShardWriter::new(&args.output_dir, *chunk_idx, cfg.shard_size, cfg.max_length); | |
| for sample in *chunk { | |
| generate_for_sample(sample, &cfg, &vocab, &mut writer)?; | |
| } | |
| writer.flush()?; | |
| Ok::<_, anyhow::Error>((writer.total_rows, writer.shards)) | |
| }) | |
| .collect::<Result<Vec<_>>>()?; | |
| let mut total_rows: u64 = 0; | |
| let mut shards: Vec<ShardManifest> = Vec::new(); | |
| for (rows, mut worker_shards) in worker_results.drain(..) { | |
| total_rows += rows; | |
| shards.append(&mut worker_shards); | |
| } | |
| let special_rows = if include_special_fixtures { | |
| let mut writer = ShardWriter::new( | |
| &args.output_dir, | |
| chunk_count + 1, | |
| cfg.shard_size, | |
| cfg.max_length, | |
| ); | |
| for special in built_in_specials() { | |
| let parts = vec![PartChoice { | |
| entity: Entity::Special, | |
| value: special, | |
| }]; | |
| emit_syntax_variants(&parts, &cfg, &vocab, &mut writer)?; | |
| } | |
| writer.flush()?; | |
| total_rows += writer.total_rows; | |
| shards.append(&mut writer.shards); | |
| writer.total_rows | |
| } else { | |
| 0 | |
| }; | |
| shards.sort_by(|a, b| a.input_ids.cmp(&b.input_ids)); | |
| let manifest = json!({ | |
| "format": "anifilebert.virtual_dataset.shards.v1", | |
| "input": args.input, | |
| "vocab_file": args.vocab_file, | |
| "source_rows": source_rows, | |
| "total_rows": total_rows, | |
| "path_variant_rows": path_variant_rows, | |
| "special_fixture_rows": special_rows, | |
| "max_length": cfg.max_length, | |
| "shard_size": cfg.shard_size, | |
| "tokenizer_variant": "char", | |
| "encoding": { | |
| "input_ids_dtype": "uint16", | |
| "attention_mask_dtype": "uint8", | |
| "labels_dtype": "int16", | |
| "layout": "row_major_npy" | |
| }, | |
| "special_tokens": { | |
| "pad_id": vocab.pad_id, | |
| "unk_id": vocab.unk_id, | |
| "cls_id": vocab.cls_id, | |
| "sep_id": vocab.sep_id | |
| }, | |
| "generation": { | |
| "separator_mode": cfg.separator_mode, | |
| "bracket_mode": cfg.bracket_mode, | |
| "separators": cfg.separators, | |
| "brackets": cfg.brackets.iter().map(|b| &b.name).collect::<Vec<_>>(), | |
| "include_original": cfg.include_original, | |
| "include_bio_variants": cfg.include_bio_variants, | |
| "samples_per_source": cfg.samples_per_source, | |
| "path_samples_per_source": cfg.path_samples_per_source, | |
| "path_styles": cfg.path_styles, | |
| "include_special_fixtures": include_special_fixtures, | |
| "seed": args.seed, | |
| "threads": rayon::current_num_threads() | |
| }, | |
| "shards": shards, | |
| "elapsed_seconds": started.elapsed().as_secs_f64(), | |
| }); | |
| let manifest_path = args.output_dir.join("manifest.json"); | |
| fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?) | |
| .with_context(|| format!("failed to write {}", manifest_path.display()))?; | |
| println!("{}", serde_json::to_string_pretty(&manifest)?); | |
| Ok(()) | |
| } | |
| fn normalize_separator_arg(value: &str) -> String { | |
| match value { | |
| "\\t" => "\t".to_string(), | |
| "\\s" => " ".to_string(), | |
| other => other.to_string(), | |
| } | |
| } | |
| fn load_vocab(path: &Path) -> Result<Vocab> { | |
| let text = fs::read_to_string(path) | |
| .with_context(|| format!("failed to read vocab file {}", path.display()))?; | |
| let raw: HashMap<String, u64> = | |
| serde_json::from_str(&text).context("failed to parse vocab JSON")?; | |
| let mut ids = HashMap::with_capacity(raw.len()); | |
| for (token, id) in raw { | |
| if id > u16::MAX as u64 { | |
| bail!("vocab id {id} for token '{token}' exceeds uint16 storage"); | |
| } | |
| ids.insert(token, id as u16); | |
| } | |
| let pad_id = *ids.get("[PAD]").context("vocab is missing [PAD]")?; | |
| let unk_id = *ids.get("[UNK]").context("vocab is missing [UNK]")?; | |
| let cls_id = *ids.get("[CLS]").context("vocab is missing [CLS]")?; | |
| let sep_id = *ids.get("[SEP]").context("vocab is missing [SEP]")?; | |
| Ok(Vocab { | |
| ids, | |
| pad_id, | |
| unk_id, | |
| cls_id, | |
| sep_id, | |
| }) | |
| } | |
| fn load_samples(path: &Path, limit_rows: usize) -> Result<Vec<SourceSample>> { | |
| let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?; | |
| let reader = BufReader::new(file); | |
| let mut samples = Vec::new(); | |
| for (idx, line) in reader.lines().enumerate() { | |
| if limit_rows > 0 && samples.len() >= limit_rows { | |
| break; | |
| } | |
| let line = line.with_context(|| format!("failed reading line {}", idx + 1))?; | |
| if line.trim().is_empty() { | |
| continue; | |
| } | |
| let row: InputRow = serde_json::from_str(&line) | |
| .with_context(|| format!("failed to parse JSONL line {}", idx + 1))?; | |
| if let Some(variant) = row.tokenizer_variant.as_deref() { | |
| if variant != "char" { | |
| bail!( | |
| "line {} has tokenizer_variant={variant}; virtual shard generation currently requires char data", | |
| idx + 1 | |
| ); | |
| } | |
| } | |
| if row.tokens.len() != row.labels.len() { | |
| bail!( | |
| "line {} has mismatched token/label lengths: {} vs {}", | |
| idx + 1, | |
| row.tokens.len(), | |
| row.labels.len() | |
| ); | |
| } | |
| let filename = row.filename.clone().unwrap_or_else(|| row.tokens.join("")); | |
| let labels = row | |
| .labels | |
| .iter() | |
| .map(|label| canonical_bio_label(label)) | |
| .collect::<Vec<_>>(); | |
| let fields = extract_fields(&row.tokens, &labels); | |
| samples.push(SourceSample { | |
| row_index: idx, | |
| filename, | |
| tokens: row.tokens, | |
| labels, | |
| fields, | |
| }); | |
| } | |
| Ok(samples) | |
| } | |
| fn canonical_bio_label(label: &str) -> String { | |
| if label == "O" { | |
| return "O".to_string(); | |
| } | |
| let Some((prefix, entity_name)) = label.split_once('-') else { | |
| return label.to_string(); | |
| }; | |
| let Some(entity) = Entity::from_name(entity_name) else { | |
| return label.to_string(); | |
| }; | |
| match prefix { | |
| "B" => entity.b_label().to_string(), | |
| "I" => entity.i_label().to_string(), | |
| _ => label.to_string(), | |
| } | |
| } | |
| fn extract_fields(tokens: &[String], labels: &[String]) -> Vec<Vec<String>> { | |
| let mut fields: Vec<Vec<String>> = (0..ENTITIES.len()).map(|_| Vec::new()).collect(); | |
| let mut seen: Vec<HashSet<String>> = (0..ENTITIES.len()).map(|_| HashSet::new()).collect(); | |
| let mut active_entity: Option<Entity> = None; | |
| let mut active_text = String::new(); | |
| let flush = |entity: Option<Entity>, | |
| text: &mut String, | |
| fields: &mut Vec<Vec<String>>, | |
| seen: &mut Vec<HashSet<String>>| { | |
| if let Some(entity) = entity { | |
| let value = text.trim().to_string(); | |
| push_extracted_field(fields, seen, entity, value); | |
| } | |
| text.clear(); | |
| }; | |
| for (token, label) in tokens.iter().zip(labels.iter()) { | |
| if let Some(entity) = label.strip_prefix("B-").and_then(Entity::from_name) { | |
| flush(active_entity, &mut active_text, &mut fields, &mut seen); | |
| active_entity = Some(entity); | |
| active_text.push_str(token); | |
| } else if let Some(entity) = label.strip_prefix("I-").and_then(Entity::from_name) { | |
| if active_entity == Some(entity) { | |
| active_text.push_str(token); | |
| } else { | |
| flush(active_entity, &mut active_text, &mut fields, &mut seen); | |
| active_entity = Some(entity); | |
| active_text.push_str(token); | |
| } | |
| } else { | |
| flush(active_entity, &mut active_text, &mut fields, &mut seen); | |
| active_entity = None; | |
| } | |
| } | |
| flush(active_entity, &mut active_text, &mut fields, &mut seen); | |
| fields | |
| } | |
| fn push_extracted_field( | |
| fields: &mut [Vec<String>], | |
| seen: &mut [HashSet<String>], | |
| entity: Entity, | |
| value: String, | |
| ) { | |
| fn add(fields: &mut [Vec<String>], seen: &mut [HashSet<String>], entity: Entity, value: &str) { | |
| if !value.is_empty() && seen[entity.index()].insert(value.to_string()) { | |
| fields[entity.index()].push(value.to_string()); | |
| } | |
| } | |
| let value = value.trim(); | |
| if value.is_empty() { | |
| return; | |
| } | |
| add(fields, seen, entity, value); | |
| if let Some(path_title) = entity.as_path_title() { | |
| add(fields, seen, path_title, value); | |
| } | |
| if let Some(file_title) = entity.as_file_title() { | |
| add(fields, seen, file_title, value); | |
| } | |
| match entity { | |
| Entity::Season => add(fields, seen, Entity::PathSeason, value), | |
| Entity::PathSeason => add(fields, seen, Entity::Season, value), | |
| _ => {} | |
| } | |
| } | |
| fn ordinary_available_entities(sample: &SourceSample) -> Vec<Entity> { | |
| ENTITIES | |
| .iter() | |
| .copied() | |
| .filter(|entity| { | |
| entity.is_ordinary_variant_entity() && !sample.fields[entity.index()].is_empty() | |
| }) | |
| .collect() | |
| } | |
| fn first_file_title_field(sample: &SourceSample) -> Option<(Entity, String)> { | |
| FILE_TITLE_ENTITIES.iter().copied().find_map(|entity| { | |
| sample.fields[entity.index()] | |
| .iter() | |
| .find(|value| !value.trim().is_empty()) | |
| .map(|value| (entity, value.trim().to_string())) | |
| }) | |
| } | |
| fn choose_path_title_field(sample: &SourceSample, rng: &mut StdRng) -> Option<(Entity, String)> { | |
| let mut candidates = Vec::new(); | |
| for entity in PATH_TITLE_ENTITIES { | |
| for value in &sample.fields[entity.index()] { | |
| let value = value.trim(); | |
| if !value.is_empty() { | |
| candidates.push((entity, value.to_string())); | |
| } | |
| } | |
| } | |
| candidates.choose(rng).cloned() | |
| } | |
| fn count_variants(sample: &SourceSample, cfg: &GenConfig) -> u128 { | |
| let mut count = if cfg.include_original { 1 } else { 0 }; | |
| count += count_path_variants(sample, cfg) as u128; | |
| let available = ordinary_available_entities(sample); | |
| let n = available.len(); | |
| if n == 0 || !cfg.include_bio_variants { | |
| return count; | |
| } | |
| if cfg.samples_per_source > 0 { | |
| return count + cfg.samples_per_source as u128; | |
| } | |
| for mask in 1usize..(1usize << n) { | |
| let selected = available | |
| .iter() | |
| .enumerate() | |
| .filter_map(|(idx, entity)| ((mask & (1usize << idx)) != 0).then_some(*entity)) | |
| .collect::<Vec<_>>(); | |
| let m = selected.len(); | |
| let value_product: u128 = selected | |
| .iter() | |
| .map(|entity| sample.fields[entity.index()].len() as u128) | |
| .product(); | |
| let perm_count = factorial(m as u32); | |
| let sep_factor = if m <= 1 { | |
| 1 | |
| } else { | |
| match cfg.separator_mode { | |
| SeparatorMode::Global => cfg.separators.len() as u128, | |
| SeparatorMode::PerGap => (cfg.separators.len() as u128).pow((m - 1) as u32), | |
| } | |
| }; | |
| let bracket_factor = match cfg.bracket_mode { | |
| BracketMode::Global => cfg.brackets.len() as u128, | |
| BracketMode::PerPart => (cfg.brackets.len() as u128).pow(m as u32), | |
| }; | |
| count += value_product * perm_count * sep_factor * bracket_factor; | |
| } | |
| count | |
| } | |
| fn count_path_variants(sample: &SourceSample, cfg: &GenConfig) -> usize { | |
| if cfg.path_samples_per_source == 0 || cfg.path_styles.is_empty() { | |
| return 0; | |
| } | |
| if !PATH_TITLE_ENTITIES | |
| .iter() | |
| .any(|entity| !sample.fields[entity.index()].is_empty()) | |
| { | |
| return 0; | |
| } | |
| if sample.fields[Entity::Episode.index()].is_empty() | |
| && sample.fields[Entity::Special.index()].is_empty() | |
| { | |
| return 0; | |
| } | |
| cfg.path_samples_per_source | |
| } | |
| fn count_special_fixtures(cfg: &GenConfig) -> usize { | |
| let bracket_factor = match cfg.bracket_mode { | |
| BracketMode::Global => cfg.brackets.len(), | |
| BracketMode::PerPart => cfg.brackets.len(), | |
| }; | |
| built_in_specials().len() * bracket_factor | |
| } | |
| fn factorial(n: u32) -> u128 { | |
| (1..=n as u128).product::<u128>().max(1) | |
| } | |
| fn generate_for_sample( | |
| sample: &SourceSample, | |
| cfg: &GenConfig, | |
| vocab: &Vocab, | |
| writer: &mut ShardWriter, | |
| ) -> Result<()> { | |
| if cfg.include_original { | |
| let (input_ids, attention_mask, labels) = | |
| encode_original_sample(sample, vocab, cfg.max_length)?; | |
| writer.add(&input_ids, &attention_mask, &labels)?; | |
| } | |
| if cfg.path_samples_per_source > 0 { | |
| generate_path_context_variants(sample, cfg, vocab, writer)?; | |
| } | |
| if cfg.include_bio_variants && cfg.samples_per_source > 0 { | |
| generate_sampled_variants(sample, cfg, vocab, writer)?; | |
| return Ok(()); | |
| } | |
| if !cfg.include_bio_variants { | |
| return Ok(()); | |
| } | |
| let available = ordinary_available_entities(sample); | |
| let n = available.len(); | |
| for mask in 1usize..(1usize << n) { | |
| let mut selected = available | |
| .iter() | |
| .enumerate() | |
| .filter_map(|(idx, entity)| ((mask & (1usize << idx)) != 0).then_some(*entity)) | |
| .collect::<Vec<_>>(); | |
| permute_entities(&mut selected, 0, &mut |order| { | |
| let mut parts: Vec<PartChoice> = Vec::with_capacity(order.len()); | |
| for_each_value_combo(order, &sample.fields, 0, &mut parts, &mut |combo| { | |
| emit_syntax_variants(combo, cfg, vocab, writer) | |
| }) | |
| })?; | |
| } | |
| Ok(()) | |
| } | |
| fn generate_sampled_variants( | |
| sample: &SourceSample, | |
| cfg: &GenConfig, | |
| vocab: &Vocab, | |
| writer: &mut ShardWriter, | |
| ) -> Result<()> { | |
| let mut rng = StdRng::seed_from_u64( | |
| cfg.seed ^ ((sample.row_index as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15)), | |
| ); | |
| let available = ordinary_available_entities(sample); | |
| if available.is_empty() { | |
| return Ok(()); | |
| } | |
| let mut seen = HashSet::new(); | |
| let mut emitted = 0usize; | |
| let budget = cfg.samples_per_source; | |
| let max_unique_attempts = budget.saturating_mul(32).max(64); | |
| let mut attempts = 0usize; | |
| let mut templates: Vec<Vec<PartChoice>> = Vec::new(); | |
| if let Some((title_entity, title)) = first_file_title_field(sample) { | |
| templates.push(vec![PartChoice { | |
| entity: title_entity, | |
| value: title.clone(), | |
| }]); | |
| if let Some(season) = sample.fields[Entity::Season.index()].first() { | |
| templates.push(vec![ | |
| PartChoice { | |
| entity: title_entity, | |
| value: title.clone(), | |
| }, | |
| PartChoice { | |
| entity: Entity::Season, | |
| value: season.clone(), | |
| }, | |
| ]); | |
| } | |
| } | |
| if let Some(episode) = sample.fields[Entity::Episode.index()].first() { | |
| templates.push(vec![PartChoice { | |
| entity: Entity::Episode, | |
| value: episode.clone(), | |
| }]); | |
| } | |
| if let Some(special) = sample.fields[Entity::Special.index()].first() { | |
| templates.push(vec![PartChoice { | |
| entity: Entity::Special, | |
| value: special.clone(), | |
| }]); | |
| } | |
| if let (Some((title_entity, title)), Some(special)) = ( | |
| first_file_title_field(sample), | |
| sample.fields[Entity::Special.index()].first(), | |
| ) { | |
| templates.push(vec![ | |
| PartChoice { | |
| entity: title_entity, | |
| value: title.clone(), | |
| }, | |
| PartChoice { | |
| entity: Entity::Special, | |
| value: special.clone(), | |
| }, | |
| ]); | |
| } | |
| for parts in templates { | |
| if emitted >= budget { | |
| break; | |
| } | |
| emit_sample_variant( | |
| parts, | |
| cfg, | |
| vocab, | |
| writer, | |
| &mut seen, | |
| &mut emitted, | |
| &mut rng, | |
| false, | |
| )?; | |
| } | |
| while emitted < budget && attempts < max_unique_attempts { | |
| attempts += 1; | |
| let subset_size = match rng.gen_range(0..100) { | |
| 0..=29 => 1, | |
| 30..=54 => 2, | |
| 55..=74 => 3, | |
| 75..=89 => 4.min(available.len()), | |
| _ => available.len().min(5), | |
| } | |
| .max(1) | |
| .min(available.len()); | |
| let mut chosen = available | |
| .choose_multiple(&mut rng, subset_size) | |
| .copied() | |
| .collect::<Vec<_>>(); | |
| chosen.shuffle(&mut rng); | |
| if !chosen.iter().any(|entity| { | |
| entity.is_file_title() || matches!(entity, Entity::Episode | Entity::Special) | |
| }) { | |
| if let Some(fallback) = available.iter().copied().find(|entity| { | |
| entity.is_file_title() || matches!(entity, Entity::Episode | Entity::Special) | |
| }) { | |
| if !chosen.contains(&fallback) { | |
| chosen.push(fallback); | |
| } | |
| } | |
| } | |
| let mut parts = Vec::with_capacity(chosen.len()); | |
| for entity in chosen { | |
| let values = &sample.fields[entity.index()]; | |
| let value = values.choose(&mut rng).cloned().unwrap_or_default(); | |
| parts.push(PartChoice { entity, value }); | |
| } | |
| parts.shuffle(&mut rng); | |
| emit_sample_variant( | |
| parts, | |
| cfg, | |
| vocab, | |
| writer, | |
| &mut seen, | |
| &mut emitted, | |
| &mut rng, | |
| false, | |
| )?; | |
| } | |
| while emitted < budget { | |
| let subset_size = match rng.gen_range(0..100) { | |
| 0..=29 => 1, | |
| 30..=54 => 2, | |
| 55..=74 => 3, | |
| 75..=89 => 4.min(available.len()), | |
| _ => available.len().min(5), | |
| } | |
| .max(1) | |
| .min(available.len()); | |
| let mut chosen = available | |
| .choose_multiple(&mut rng, subset_size) | |
| .copied() | |
| .collect::<Vec<_>>(); | |
| chosen.shuffle(&mut rng); | |
| if !chosen.iter().any(|entity| { | |
| entity.is_file_title() || matches!(entity, Entity::Episode | Entity::Special) | |
| }) { | |
| if let Some(fallback) = available.iter().copied().find(|entity| { | |
| entity.is_file_title() || matches!(entity, Entity::Episode | Entity::Special) | |
| }) { | |
| if !chosen.contains(&fallback) { | |
| chosen.push(fallback); | |
| } | |
| } | |
| } | |
| let mut parts = Vec::with_capacity(chosen.len()); | |
| for entity in chosen { | |
| let values = &sample.fields[entity.index()]; | |
| let value = values.choose(&mut rng).cloned().unwrap_or_default(); | |
| parts.push(PartChoice { entity, value }); | |
| } | |
| parts.shuffle(&mut rng); | |
| emit_sample_variant( | |
| parts, | |
| cfg, | |
| vocab, | |
| writer, | |
| &mut seen, | |
| &mut emitted, | |
| &mut rng, | |
| true, | |
| )?; | |
| } | |
| Ok(()) | |
| } | |
| fn emit_sample_variant( | |
| parts: Vec<PartChoice>, | |
| cfg: &GenConfig, | |
| vocab: &Vocab, | |
| writer: &mut ShardWriter, | |
| seen: &mut HashSet<String>, | |
| emitted: &mut usize, | |
| rng: &mut StdRng, | |
| allow_duplicate: bool, | |
| ) -> Result<()> { | |
| if *emitted >= cfg.samples_per_source { | |
| return Ok(()); | |
| } | |
| if parts.is_empty() { | |
| return Ok(()); | |
| } | |
| let separators = match cfg.separator_mode { | |
| SeparatorMode::Global => { | |
| let sep = cfg | |
| .separators | |
| .choose(rng) | |
| .cloned() | |
| .unwrap_or_else(|| " ".to_string()); | |
| if parts.len() > 1 { | |
| vec![sep; parts.len() - 1] | |
| } else { | |
| Vec::new() | |
| } | |
| } | |
| SeparatorMode::PerGap => { | |
| let mut values = Vec::with_capacity(parts.len().saturating_sub(1)); | |
| for _ in 0..parts.len().saturating_sub(1) { | |
| values.push( | |
| cfg.separators | |
| .choose(rng) | |
| .cloned() | |
| .unwrap_or_else(|| " ".to_string()), | |
| ); | |
| } | |
| values | |
| } | |
| }; | |
| let brackets = match cfg.bracket_mode { | |
| BracketMode::Global => { | |
| let bracket = cfg | |
| .brackets | |
| .choose(rng) | |
| .cloned() | |
| .unwrap_or_else(|| Bracket { | |
| name: "none".to_string(), | |
| open: String::new(), | |
| close: String::new(), | |
| }); | |
| vec![bracket; parts.len()] | |
| } | |
| BracketMode::PerPart => { | |
| let mut values = Vec::with_capacity(parts.len()); | |
| for _ in 0..parts.len() { | |
| values.push( | |
| cfg.brackets | |
| .choose(rng) | |
| .cloned() | |
| .unwrap_or_else(|| Bracket { | |
| name: "none".to_string(), | |
| open: String::new(), | |
| close: String::new(), | |
| }), | |
| ); | |
| } | |
| values | |
| } | |
| }; | |
| let text = render_variant_text(&parts, &separators, &brackets); | |
| if !allow_duplicate && !seen.insert(text) { | |
| return Ok(()); | |
| } | |
| let (input_ids, attention_mask, labels) = | |
| encode_generated_sample(&parts, &separators, &brackets, vocab, cfg.max_length)?; | |
| writer.add(&input_ids, &attention_mask, &labels)?; | |
| *emitted += 1; | |
| Ok(()) | |
| } | |
| fn generate_path_context_variants( | |
| sample: &SourceSample, | |
| cfg: &GenConfig, | |
| vocab: &Vocab, | |
| writer: &mut ShardWriter, | |
| ) -> Result<()> { | |
| if count_path_variants(sample, cfg) == 0 { | |
| return Ok(()); | |
| } | |
| let mut rng = StdRng::seed_from_u64( | |
| cfg.seed | |
| ^ 0xA076_1D64_78BD_642F | |
| ^ ((sample.row_index as u64).wrapping_mul(0xE703_7ED1_A0B4_28DB)), | |
| ); | |
| let mut seen = HashSet::new(); | |
| let mut emitted = 0usize; | |
| let budget = cfg.path_samples_per_source; | |
| let max_unique_attempts = budget.saturating_mul(32).max(64); | |
| let mut attempts = 0usize; | |
| while emitted < budget && attempts < max_unique_attempts { | |
| attempts += 1; | |
| if let Some(pieces) = build_path_context_pieces(sample, cfg, &mut rng) { | |
| let text = render_labeled_pieces(&pieces); | |
| if seen.insert(text) { | |
| let (input_ids, attention_mask, labels) = | |
| encode_labeled_pieces(&pieces, vocab, cfg.max_length)?; | |
| writer.add(&input_ids, &attention_mask, &labels)?; | |
| emitted += 1; | |
| } | |
| } else { | |
| return Ok(()); | |
| } | |
| } | |
| while emitted < budget { | |
| if let Some(pieces) = build_path_context_pieces(sample, cfg, &mut rng) { | |
| let (input_ids, attention_mask, labels) = | |
| encode_labeled_pieces(&pieces, vocab, cfg.max_length)?; | |
| writer.add(&input_ids, &attention_mask, &labels)?; | |
| emitted += 1; | |
| } else { | |
| return Ok(()); | |
| } | |
| } | |
| Ok(()) | |
| } | |
| fn build_path_context_pieces( | |
| sample: &SourceSample, | |
| cfg: &GenConfig, | |
| rng: &mut StdRng, | |
| ) -> Option<Vec<LabeledPiece>> { | |
| let (title_entity, title) = choose_path_title_field(sample, rng)?; | |
| let style = *cfg.path_styles.choose(rng)?; | |
| let sep = style.separator(); | |
| let mut components = path_prefix_components(style, rng); | |
| components.push(vec![entity_piece(title.clone(), title_entity)]); | |
| let season_component = choose_path_season_component(sample, rng); | |
| if let Some(season) = season_component { | |
| components.push(season); | |
| } | |
| let use_special = if sample.fields[Entity::Episode.index()].is_empty() { | |
| true | |
| } else if sample.fields[Entity::Special.index()].is_empty() { | |
| false | |
| } else { | |
| rng.gen_bool(0.18) | |
| }; | |
| let endpoint = if use_special { | |
| let special = choose_field(sample, Entity::Special, rng)?; | |
| entity_piece(random_special_path_text(&special, rng), Entity::Special) | |
| } else { | |
| let episode = choose_field(sample, Entity::Episode, rng)?; | |
| entity_piece(random_episode_path_text(&episode, rng), Entity::Episode) | |
| }; | |
| match rng.gen_range(0..6) { | |
| 0 => components.push(path_file_component(endpoint, sample, rng)), | |
| 1 => { | |
| components.push(vec![endpoint]); | |
| components.push(noise_file_component(rng)); | |
| } | |
| 2 => { | |
| components.push(vec![endpoint]); | |
| components.push(meta_file_component(sample, rng)); | |
| } | |
| 3 => components.push(compact_file_component(endpoint, sample, rng)), | |
| 4 => components.push(grouped_release_file_component( | |
| &title, endpoint, sample, rng, | |
| )), | |
| _ => { | |
| components.push(vec![endpoint]); | |
| if rng.gen_bool(0.55) { | |
| components.push(noise_file_component(rng)); | |
| } | |
| } | |
| } | |
| Some(join_path_components(&components, sep)) | |
| } | |
| fn choose_field(sample: &SourceSample, entity: Entity, rng: &mut StdRng) -> Option<String> { | |
| sample.fields[entity.index()] | |
| .choose(rng) | |
| .map(|value| value.trim().to_string()) | |
| .filter(|value| !value.is_empty()) | |
| } | |
| fn path_prefix_components(style: PathStyle, rng: &mut StdRng) -> Vec<Vec<LabeledPiece>> { | |
| let templates: &[&[&str]] = match style { | |
| PathStyle::Windows => &[ | |
| &["O:", "115open", "影音", "动漫"], | |
| &["D:", "Media", "Anime"], | |
| &["E:", "Downloads", "Bangumi"], | |
| &["Z:", "Library", "Anime"], | |
| &["C:", "Archive", "completed"], | |
| ], | |
| PathStyle::Unix => &[ | |
| &["", "mnt", "media", "anime"], | |
| &["", "volume1", "anime"], | |
| &["home", "media", "Bangumi"], | |
| &["library", "anime"], | |
| &["srv", "downloads", "anime"], | |
| ], | |
| }; | |
| let noise_dirs = [ | |
| "整理中", | |
| "completed", | |
| "old", | |
| "temp", | |
| "115", | |
| "Bangumi", | |
| "Library", | |
| "_archive", | |
| "2024", | |
| "misc", | |
| ]; | |
| let selected = templates.choose(rng).copied().unwrap_or(&["Anime"]); | |
| let mut components = selected | |
| .iter() | |
| .map(|component| vec![o_piece((*component).to_string())]) | |
| .collect::<Vec<_>>(); | |
| let extra_count = rng.gen_range(0..=2); | |
| for _ in 0..extra_count { | |
| let insert_at = components.len().saturating_sub(1); | |
| let noise = noise_dirs | |
| .choose(rng) | |
| .copied() | |
| .unwrap_or("Library") | |
| .to_string(); | |
| components.insert(insert_at, vec![o_piece(noise)]); | |
| } | |
| components | |
| } | |
| fn choose_path_season_component( | |
| sample: &SourceSample, | |
| rng: &mut StdRng, | |
| ) -> Option<Vec<LabeledPiece>> { | |
| let season = if let Some(source_season) = choose_field(sample, Entity::PathSeason, rng) | |
| .or_else(|| choose_field(sample, Entity::Season, rng)) | |
| { | |
| random_season_path_text(&source_season, rng) | |
| } else { | |
| let synthetic = ["01", "Season 1", "Season 01", "S01", "第1季"]; | |
| synthetic | |
| .choose(rng) | |
| .copied() | |
| .unwrap_or("Season 1") | |
| .to_string() | |
| }; | |
| Some(vec![entity_piece(season, Entity::PathSeason)]) | |
| } | |
| fn path_file_component( | |
| endpoint: LabeledPiece, | |
| sample: &SourceSample, | |
| rng: &mut StdRng, | |
| ) -> Vec<LabeledPiece> { | |
| let mut pieces = Vec::new(); | |
| if rng.gen_bool(0.25) { | |
| pieces.push(o_piece("Episode ".to_string())); | |
| } | |
| pieces.push(endpoint); | |
| append_path_meta(&mut pieces, sample, rng); | |
| pieces.push(o_piece(random_extension(rng).to_string())); | |
| pieces | |
| } | |
| fn compact_file_component( | |
| endpoint: LabeledPiece, | |
| sample: &SourceSample, | |
| rng: &mut StdRng, | |
| ) -> Vec<LabeledPiece> { | |
| let mut pieces = vec![endpoint]; | |
| if rng.gen_bool(0.75) { | |
| append_path_meta(&mut pieces, sample, rng); | |
| } | |
| pieces.push(o_piece(random_extension(rng).to_string())); | |
| pieces | |
| } | |
| fn grouped_release_file_component( | |
| title: &str, | |
| endpoint: LabeledPiece, | |
| sample: &SourceSample, | |
| rng: &mut StdRng, | |
| ) -> Vec<LabeledPiece> { | |
| let mut pieces = Vec::new(); | |
| if let Some(group) = choose_field(sample, Entity::Group, rng) { | |
| pieces.push(o_piece("[".to_string())); | |
| pieces.push(entity_piece(group, Entity::Group)); | |
| pieces.push(o_piece("] ".to_string())); | |
| } | |
| pieces.push(o_piece(title.trim().to_string())); | |
| let separator = [" - ", " ", " "].choose(rng).copied().unwrap_or(" - "); | |
| pieces.push(o_piece(separator.to_string())); | |
| pieces.push(endpoint); | |
| append_path_meta(&mut pieces, sample, rng); | |
| pieces.push(o_piece(random_extension(rng).to_string())); | |
| pieces | |
| } | |
| fn meta_file_component(sample: &SourceSample, rng: &mut StdRng) -> Vec<LabeledPiece> { | |
| let mut pieces = Vec::new(); | |
| if rng.gen_bool(0.5) { | |
| pieces.push(o_piece("metadata".to_string())); | |
| } else { | |
| pieces.push(o_piece("video".to_string())); | |
| } | |
| append_path_meta(&mut pieces, sample, rng); | |
| pieces.push(o_piece(random_extension(rng).to_string())); | |
| pieces | |
| } | |
| fn noise_file_component(rng: &mut StdRng) -> Vec<LabeledPiece> { | |
| let stems = ["video", "default", "main", "feature", "movie", "episode"]; | |
| let stem = stems.choose(rng).copied().unwrap_or("video"); | |
| vec![o_piece(format!("{stem}{}", random_extension(rng)))] | |
| } | |
| fn append_path_meta(pieces: &mut Vec<LabeledPiece>, sample: &SourceSample, rng: &mut StdRng) { | |
| if let Some(resolution) = choose_field(sample, Entity::Resolution, rng) { | |
| if rng.gen_bool(0.72) { | |
| pieces.push(o_piece(" [".to_string())); | |
| pieces.push(entity_piece(resolution, Entity::Resolution)); | |
| pieces.push(o_piece("]".to_string())); | |
| } | |
| } | |
| let source_count = if rng.gen_bool(0.35) { 2 } else { 1 }; | |
| for _ in 0..source_count { | |
| if let Some(source) = choose_field(sample, Entity::Source, rng) { | |
| if rng.gen_bool(0.62) { | |
| pieces.push(o_piece("[".to_string())); | |
| pieces.push(entity_piece(source, Entity::Source)); | |
| pieces.push(o_piece("]".to_string())); | |
| } | |
| } | |
| } | |
| if let Some(tag) = choose_field(sample, Entity::Tag, rng) { | |
| if rng.gen_bool(0.55) { | |
| pieces.push(o_piece("[".to_string())); | |
| pieces.push(entity_piece(tag, Entity::Tag)); | |
| pieces.push(o_piece("]".to_string())); | |
| } | |
| } | |
| } | |
| fn random_episode_path_text(value: &str, rng: &mut StdRng) -> String { | |
| let mut variants = vec![value.trim().to_string()]; | |
| if let Some(number) = first_ascii_number(value) { | |
| variants.push(format!("{number:02}")); | |
| variants.push(format!("E{number:02}")); | |
| variants.push(format!("EP{number:02}")); | |
| } | |
| variants | |
| .choose(rng) | |
| .cloned() | |
| .unwrap_or_else(|| value.trim().to_string()) | |
| } | |
| fn random_special_path_text(value: &str, rng: &mut StdRng) -> String { | |
| let mut variants = vec![value.trim().to_string()]; | |
| if let Some(number) = first_ascii_number(value) { | |
| variants.push(format!("SP{number:02}")); | |
| variants.push(format!("Special {number:02}")); | |
| } | |
| variants | |
| .choose(rng) | |
| .cloned() | |
| .unwrap_or_else(|| value.trim().to_string()) | |
| } | |
| fn random_season_path_text(value: &str, rng: &mut StdRng) -> String { | |
| let mut variants = vec![value.trim().to_string()]; | |
| if let Some(number) = first_ascii_number(value) { | |
| variants.push(format!("{number:02}")); | |
| variants.push(format!("Season {number}")); | |
| variants.push(format!("Season {number:02}")); | |
| variants.push(format!("S{number:02}")); | |
| variants.push(format!("第{number}季")); | |
| } | |
| variants | |
| .choose(rng) | |
| .cloned() | |
| .unwrap_or_else(|| value.trim().to_string()) | |
| } | |
| fn first_ascii_number(value: &str) -> Option<u32> { | |
| let mut current = String::new(); | |
| for ch in value.chars() { | |
| if ch.is_ascii_digit() { | |
| current.push(ch); | |
| } else if !current.is_empty() { | |
| break; | |
| } | |
| } | |
| if current.is_empty() { | |
| None | |
| } else { | |
| current.parse().ok() | |
| } | |
| } | |
| fn random_extension(rng: &mut StdRng) -> &'static str { | |
| [".mkv", ".mp4", ".avi"] | |
| .choose(rng) | |
| .copied() | |
| .unwrap_or(".mkv") | |
| } | |
| fn join_path_components(components: &[Vec<LabeledPiece>], separator: &str) -> Vec<LabeledPiece> { | |
| let mut pieces = Vec::new(); | |
| for (idx, component) in components.iter().enumerate() { | |
| if idx > 0 { | |
| pieces.push(o_piece(separator.to_string())); | |
| } | |
| pieces.extend(component.iter().cloned()); | |
| } | |
| pieces | |
| } | |
| fn render_labeled_pieces(pieces: &[LabeledPiece]) -> String { | |
| let mut text = String::new(); | |
| for piece in pieces { | |
| text.push_str(&piece.text); | |
| } | |
| text | |
| } | |
| fn permute_entities<F>(values: &mut [Entity], start: usize, callback: &mut F) -> Result<()> | |
| where | |
| F: FnMut(&[Entity]) -> Result<()>, | |
| { | |
| if start >= values.len() { | |
| return callback(values); | |
| } | |
| for idx in start..values.len() { | |
| values.swap(start, idx); | |
| permute_entities(values, start + 1, callback)?; | |
| values.swap(start, idx); | |
| } | |
| Ok(()) | |
| } | |
| struct PartChoice { | |
| entity: Entity, | |
| value: String, | |
| } | |
| struct LabeledPiece { | |
| text: String, | |
| entity: Option<Entity>, | |
| } | |
| fn o_piece(text: String) -> LabeledPiece { | |
| LabeledPiece { text, entity: None } | |
| } | |
| fn entity_piece(text: String, entity: Entity) -> LabeledPiece { | |
| LabeledPiece { | |
| text, | |
| entity: Some(entity), | |
| } | |
| } | |
| fn for_each_value_combo<F>( | |
| order: &[Entity], | |
| fields: &[Vec<String>], | |
| idx: usize, | |
| current: &mut Vec<PartChoice>, | |
| callback: &mut F, | |
| ) -> Result<()> | |
| where | |
| F: FnMut(&[PartChoice]) -> Result<()>, | |
| { | |
| if idx >= order.len() { | |
| return callback(current); | |
| } | |
| let entity = order[idx]; | |
| for value in &fields[entity.index()] { | |
| current.push(PartChoice { | |
| entity, | |
| value: value.clone(), | |
| }); | |
| for_each_value_combo(order, fields, idx + 1, current, callback)?; | |
| current.pop(); | |
| } | |
| Ok(()) | |
| } | |
| fn emit_syntax_variants( | |
| parts: &[PartChoice], | |
| cfg: &GenConfig, | |
| vocab: &Vocab, | |
| writer: &mut ShardWriter, | |
| ) -> Result<()> { | |
| let gaps = parts.len().saturating_sub(1); | |
| let mut separators = Vec::with_capacity(gaps); | |
| for_each_separator_combo(gaps, cfg, 0, &mut separators, &mut |sep_combo| { | |
| let mut brackets = Vec::with_capacity(parts.len()); | |
| for_each_bracket_combo(parts.len(), cfg, 0, &mut brackets, &mut |bracket_combo| { | |
| let (input_ids, attention_mask, labels) = | |
| encode_generated_sample(parts, sep_combo, bracket_combo, vocab, cfg.max_length)?; | |
| writer.add(&input_ids, &attention_mask, &labels) | |
| }) | |
| }) | |
| } | |
| fn for_each_separator_combo<F>( | |
| gaps: usize, | |
| cfg: &GenConfig, | |
| idx: usize, | |
| current: &mut Vec<String>, | |
| callback: &mut F, | |
| ) -> Result<()> | |
| where | |
| F: FnMut(&[String]) -> Result<()>, | |
| { | |
| if gaps == 0 { | |
| return callback(current); | |
| } | |
| match cfg.separator_mode { | |
| SeparatorMode::Global => { | |
| if idx == 0 { | |
| for sep in &cfg.separators { | |
| current.clear(); | |
| current.resize(gaps, sep.clone()); | |
| callback(current)?; | |
| } | |
| } | |
| Ok(()) | |
| } | |
| SeparatorMode::PerGap => { | |
| if idx >= gaps { | |
| return callback(current); | |
| } | |
| for sep in &cfg.separators { | |
| current.push(sep.clone()); | |
| for_each_separator_combo(gaps, cfg, idx + 1, current, callback)?; | |
| current.pop(); | |
| } | |
| Ok(()) | |
| } | |
| } | |
| } | |
| fn for_each_bracket_combo<F>( | |
| parts: usize, | |
| cfg: &GenConfig, | |
| idx: usize, | |
| current: &mut Vec<Bracket>, | |
| callback: &mut F, | |
| ) -> Result<()> | |
| where | |
| F: FnMut(&[Bracket]) -> Result<()>, | |
| { | |
| match cfg.bracket_mode { | |
| BracketMode::Global => { | |
| if idx == 0 { | |
| for bracket in &cfg.brackets { | |
| current.clear(); | |
| current.resize(parts, bracket.clone()); | |
| callback(current)?; | |
| } | |
| } | |
| Ok(()) | |
| } | |
| BracketMode::PerPart => { | |
| if idx >= parts { | |
| return callback(current); | |
| } | |
| for bracket in &cfg.brackets { | |
| current.push(bracket.clone()); | |
| for_each_bracket_combo(parts, cfg, idx + 1, current, callback)?; | |
| current.pop(); | |
| } | |
| Ok(()) | |
| } | |
| } | |
| } | |
| fn render_variant_text( | |
| parts: &[PartChoice], | |
| separators: &[String], | |
| brackets: &[Bracket], | |
| ) -> String { | |
| let mut text = String::new(); | |
| for (idx, part) in parts.iter().enumerate() { | |
| text.push_str(&brackets[idx].open); | |
| text.push_str(&part.value); | |
| text.push_str(&brackets[idx].close); | |
| if idx < separators.len() { | |
| text.push_str(&separators[idx]); | |
| } | |
| } | |
| text | |
| } | |
| fn encode_original_sample( | |
| sample: &SourceSample, | |
| vocab: &Vocab, | |
| max_length: usize, | |
| ) -> Result<(Vec<u16>, Vec<u8>, Vec<i16>)> { | |
| let mut input_ids = vec![vocab.pad_id; max_length]; | |
| let mut attention_mask = vec![0u8; max_length]; | |
| let mut labels = vec![-100i16; max_length]; | |
| input_ids[0] = vocab.cls_id; | |
| attention_mask[0] = 1; | |
| let available = max_length.saturating_sub(2); | |
| let token_count = sample.tokens.len().min(available); | |
| for idx in 0..token_count { | |
| input_ids[idx + 1] = token_id(vocab, &sample.tokens[idx]); | |
| attention_mask[idx + 1] = 1; | |
| labels[idx + 1] = label_id(&sample.labels[idx]).with_context(|| { | |
| format!( | |
| "unknown label '{}' on source row {} ({})", | |
| sample.labels[idx], | |
| sample.row_index + 1, | |
| sample.filename | |
| ) | |
| })?; | |
| } | |
| let sep_pos = token_count + 1; | |
| input_ids[sep_pos] = vocab.sep_id; | |
| attention_mask[sep_pos] = 1; | |
| Ok((input_ids, attention_mask, labels)) | |
| } | |
| fn encode_generated_sample( | |
| parts: &[PartChoice], | |
| separators: &[String], | |
| brackets: &[Bracket], | |
| vocab: &Vocab, | |
| max_length: usize, | |
| ) -> Result<(Vec<u16>, Vec<u8>, Vec<i16>)> { | |
| let mut input_ids = vec![vocab.pad_id; max_length]; | |
| let mut attention_mask = vec![0u8; max_length]; | |
| let mut labels = vec![-100i16; max_length]; | |
| input_ids[0] = vocab.cls_id; | |
| attention_mask[0] = 1; | |
| let available = max_length.saturating_sub(2); | |
| let mut pos = 1usize; | |
| for (idx, part) in parts.iter().enumerate() { | |
| let bracket = &brackets[idx]; | |
| append_o_text( | |
| &bracket.open, | |
| vocab, | |
| available, | |
| &mut pos, | |
| &mut input_ids, | |
| &mut attention_mask, | |
| &mut labels, | |
| ); | |
| append_entity_text( | |
| &part.value, | |
| part.entity, | |
| vocab, | |
| available, | |
| &mut pos, | |
| &mut input_ids, | |
| &mut attention_mask, | |
| &mut labels, | |
| )?; | |
| append_o_text( | |
| &bracket.close, | |
| vocab, | |
| available, | |
| &mut pos, | |
| &mut input_ids, | |
| &mut attention_mask, | |
| &mut labels, | |
| ); | |
| if idx < separators.len() { | |
| append_o_text( | |
| &separators[idx], | |
| vocab, | |
| available, | |
| &mut pos, | |
| &mut input_ids, | |
| &mut attention_mask, | |
| &mut labels, | |
| ); | |
| } | |
| } | |
| let sep_pos = pos.min(max_length - 1); | |
| input_ids[sep_pos] = vocab.sep_id; | |
| attention_mask[sep_pos] = 1; | |
| labels[sep_pos] = -100; | |
| Ok((input_ids, attention_mask, labels)) | |
| } | |
| fn encode_labeled_pieces( | |
| pieces: &[LabeledPiece], | |
| vocab: &Vocab, | |
| max_length: usize, | |
| ) -> Result<(Vec<u16>, Vec<u8>, Vec<i16>)> { | |
| let mut input_ids = vec![vocab.pad_id; max_length]; | |
| let mut attention_mask = vec![0u8; max_length]; | |
| let mut labels = vec![-100i16; max_length]; | |
| input_ids[0] = vocab.cls_id; | |
| attention_mask[0] = 1; | |
| let available = max_length.saturating_sub(2); | |
| let mut pos = 1usize; | |
| for piece in pieces { | |
| if let Some(entity) = piece.entity { | |
| append_entity_text( | |
| &piece.text, | |
| entity, | |
| vocab, | |
| available, | |
| &mut pos, | |
| &mut input_ids, | |
| &mut attention_mask, | |
| &mut labels, | |
| )?; | |
| } else { | |
| append_o_text( | |
| &piece.text, | |
| vocab, | |
| available, | |
| &mut pos, | |
| &mut input_ids, | |
| &mut attention_mask, | |
| &mut labels, | |
| ); | |
| } | |
| } | |
| let sep_pos = pos.min(max_length - 1); | |
| input_ids[sep_pos] = vocab.sep_id; | |
| attention_mask[sep_pos] = 1; | |
| labels[sep_pos] = -100; | |
| Ok((input_ids, attention_mask, labels)) | |
| } | |
| fn append_o_text( | |
| text: &str, | |
| vocab: &Vocab, | |
| available: usize, | |
| pos: &mut usize, | |
| input_ids: &mut [u16], | |
| attention_mask: &mut [u8], | |
| labels: &mut [i16], | |
| ) { | |
| for ch in text.chars() { | |
| if *pos > available { | |
| return; | |
| } | |
| let token = ch.to_string(); | |
| input_ids[*pos] = token_id(vocab, &token); | |
| attention_mask[*pos] = 1; | |
| labels[*pos] = 0; | |
| *pos += 1; | |
| } | |
| } | |
| fn append_entity_text( | |
| text: &str, | |
| entity: Entity, | |
| vocab: &Vocab, | |
| available: usize, | |
| pos: &mut usize, | |
| input_ids: &mut [u16], | |
| attention_mask: &mut [u8], | |
| labels: &mut [i16], | |
| ) -> Result<()> { | |
| let b = label_id(entity.b_label()).context("missing B label")?; | |
| let i = label_id(entity.i_label()).context("missing I label")?; | |
| let mut first = true; | |
| for ch in text.chars() { | |
| if *pos > available { | |
| return Ok(()); | |
| } | |
| let token = ch.to_string(); | |
| input_ids[*pos] = token_id(vocab, &token); | |
| attention_mask[*pos] = 1; | |
| labels[*pos] = if first { b } else { i }; | |
| first = false; | |
| *pos += 1; | |
| } | |
| Ok(()) | |
| } | |
| fn token_id(vocab: &Vocab, token: &str) -> u16 { | |
| *vocab.ids.get(token).unwrap_or(&vocab.unk_id) | |
| } | |
| fn label_id(label: &str) -> Option<i16> { | |
| label_ids().get(label).copied() | |
| } | |
| fn label_ids() -> &'static HashMap<String, i16> { | |
| LABEL_IDS.get_or_init(load_label_ids) | |
| } | |
| fn load_label_ids() -> HashMap<String, i16> { | |
| let labels = read_schema_labels().unwrap_or_else(|| { | |
| FALLBACK_LABELS | |
| .iter() | |
| .map(|label| (*label).to_string()) | |
| .collect() | |
| }); | |
| labels | |
| .into_iter() | |
| .enumerate() | |
| .map(|(idx, label)| (label, idx as i16)) | |
| .collect() | |
| } | |
| fn read_schema_labels() -> Option<Vec<String>> { | |
| for path in label_schema_candidates() { | |
| let Ok(text) = fs::read_to_string(path) else { | |
| continue; | |
| }; | |
| let Ok(schema) = serde_json::from_str::<LabelSchema>(&text) else { | |
| continue; | |
| }; | |
| if schema.labels.is_empty() || schema.labels.iter().any(|label| label.trim().is_empty()) { | |
| continue; | |
| } | |
| return Some(schema.labels); | |
| } | |
| None | |
| } | |
| fn label_schema_candidates() -> Vec<PathBuf> { | |
| let mut candidates = Vec::new(); | |
| if let Ok(current_dir) = std::env::current_dir() { | |
| candidates.push(current_dir.join("label_schema.json")); | |
| } | |
| candidates.push( | |
| Path::new(env!("CARGO_MANIFEST_DIR")) | |
| .join("..") | |
| .join("..") | |
| .join("label_schema.json"), | |
| ); | |
| candidates | |
| } | |
| fn built_in_specials() -> Vec<String> { | |
| let mut values = Vec::new(); | |
| values.push("Menu".to_string()); | |
| for idx in 1..=24 { | |
| values.push(format!("Menu{idx:02}")); | |
| values.push(format!("Menu {idx:02}")); | |
| values.push(format!("BDMenu{idx:02}")); | |
| values.push(format!("BD Menu{idx:02}")); | |
| values.push(format!("Menu{idx:02}-01")); | |
| values.push(format!("ED E{idx:02}")); | |
| } | |
| for idx in 1..=6 { | |
| values.push(format!("OP{idx:02}")); | |
| values.push(format!("NCOP{idx:02}")); | |
| values.push(format!("NCED{idx:02}")); | |
| } | |
| for idx in 1..=12 { | |
| values.push(format!("CM{idx:02}")); | |
| values.push(format!("PV{idx:02}")); | |
| } | |
| values | |
| } | |
| fn write_npy_u16(path: &Path, data: &[u16], rows: usize, cols: usize) -> Result<()> { | |
| let mut writer = BufWriter::new( | |
| File::create(path).with_context(|| format!("failed to create {}", path.display()))?, | |
| ); | |
| write_npy_header(&mut writer, "<u2", rows, cols)?; | |
| for value in data { | |
| writer.write_all(&value.to_le_bytes())?; | |
| } | |
| Ok(()) | |
| } | |
| fn write_npy_u8(path: &Path, data: &[u8], rows: usize, cols: usize) -> Result<()> { | |
| let mut writer = BufWriter::new( | |
| File::create(path).with_context(|| format!("failed to create {}", path.display()))?, | |
| ); | |
| write_npy_header(&mut writer, "|u1", rows, cols)?; | |
| writer.write_all(data)?; | |
| Ok(()) | |
| } | |
| fn write_npy_i16(path: &Path, data: &[i16], rows: usize, cols: usize) -> Result<()> { | |
| let mut writer = BufWriter::new( | |
| File::create(path).with_context(|| format!("failed to create {}", path.display()))?, | |
| ); | |
| write_npy_header(&mut writer, "<i2", rows, cols)?; | |
| for value in data { | |
| writer.write_all(&value.to_le_bytes())?; | |
| } | |
| Ok(()) | |
| } | |
| fn write_npy_header<W: Write>(writer: &mut W, descr: &str, rows: usize, cols: usize) -> Result<()> { | |
| let mut header = format!( | |
| "{{'descr': '{}', 'fortran_order': False, 'shape': ({}, {}), }}", | |
| descr, rows, cols | |
| ) | |
| .into_bytes(); | |
| let preamble_len = 10usize; | |
| let pad_len = (16 - ((preamble_len + header.len() + 1) % 16)) % 16; | |
| header.extend(std::iter::repeat(b' ').take(pad_len)); | |
| header.push(b'\n'); | |
| if header.len() > u16::MAX as usize { | |
| bail!("npy header too large"); | |
| } | |
| writer.write_all(b"\x93NUMPY")?; | |
| writer.write_all(&[1, 0])?; | |
| writer.write_all(&(header.len() as u16).to_le_bytes())?; | |
| writer.write_all(&header)?; | |
| Ok(()) | |
| } | |
| mod tests { | |
| use super::*; | |
| fn test_config() -> GenConfig { | |
| GenConfig { | |
| max_length: 128, | |
| shard_size: 16, | |
| separator_mode: SeparatorMode::Global, | |
| bracket_mode: BracketMode::Global, | |
| separators: vec![" ".to_string()], | |
| brackets: vec![Bracket { | |
| name: "none".to_string(), | |
| open: String::new(), | |
| close: String::new(), | |
| }], | |
| path_styles: vec![PathStyle::Windows], | |
| include_original: true, | |
| include_bio_variants: true, | |
| samples_per_source: 0, | |
| path_samples_per_source: 1, | |
| seed: 105, | |
| } | |
| } | |
| fn sample_without_season() -> SourceSample { | |
| let mut fields = vec![Vec::new(); ENTITIES.len()]; | |
| fields[Entity::TitleLatin.index()] = vec!["Example Show".to_string()]; | |
| fields[Entity::PathTitleLatin.index()] = vec!["Example Show".to_string()]; | |
| fields[Entity::Episode.index()] = vec!["1".to_string()]; | |
| fields[Entity::Resolution.index()] = vec!["1080P".to_string()]; | |
| fields[Entity::Source.index()] = vec!["WEB-DL".to_string()]; | |
| SourceSample { | |
| row_index: 7, | |
| filename: "Example Show 1 [1080P][WEB-DL].mkv".to_string(), | |
| tokens: Vec::new(), | |
| labels: Vec::new(), | |
| fields, | |
| } | |
| } | |
| fn sample_with_group() -> SourceSample { | |
| let mut sample = sample_without_season(); | |
| sample.fields[Entity::Group.index()] = vec!["Erai-raws".to_string()]; | |
| sample | |
| } | |
| fn path_prefix_has_at_least_two_noise_directories() { | |
| for style in [PathStyle::Windows, PathStyle::Unix] { | |
| for seed in 0..32 { | |
| let mut rng = StdRng::seed_from_u64(seed); | |
| let components = path_prefix_components(style, &mut rng); | |
| let non_empty_components = components | |
| .iter() | |
| .filter(|component| !render_labeled_pieces(component).is_empty()) | |
| .count(); | |
| assert!( | |
| non_empty_components >= 2, | |
| "expected at least two noise directories for {style:?}: {}", | |
| render_labeled_pieces(&join_path_components(&components, style.separator())) | |
| ); | |
| assert!(components | |
| .iter() | |
| .flatten() | |
| .all(|piece| piece.entity.is_none())); | |
| } | |
| } | |
| } | |
| fn fixed_label_schema_ids_match_v2_order() { | |
| assert_eq!(label_id("O"), Some(0)); | |
| assert_eq!(label_id("B-TITLE_CHS"), Some(1)); | |
| assert_eq!(label_id("I-TITLE_MIXED"), Some(10)); | |
| assert_eq!(label_id("B-PATH_TITLE_CHS"), Some(11)); | |
| assert_eq!(label_id("I-PATH_TITLE_MIXED"), Some(20)); | |
| assert_eq!(label_id("B-PATH_SEASON"), Some(21)); | |
| assert_eq!(label_id("B-SEASON"), Some(23)); | |
| assert_eq!(label_id("B-EPISODE"), Some(25)); | |
| assert_eq!(label_id("B-GROUP"), Some(29)); | |
| assert_eq!(label_id("B-SOURCE"), Some(33)); | |
| assert_eq!(label_id("B-TAG"), Some(35)); | |
| assert_eq!(label_id("I-TAG"), Some(36)); | |
| assert_eq!(label_id("B-TITLE"), None); | |
| } | |
| fn legacy_source_title_labels_canonicalize_to_mixed_schema() { | |
| assert_eq!(canonical_bio_label("B-TITLE"), "B-TITLE_MIXED"); | |
| assert_eq!(canonical_bio_label("I-TITLE"), "I-TITLE_MIXED"); | |
| assert_eq!(canonical_bio_label("B-PATH_TITLE"), "B-PATH_TITLE_MIXED"); | |
| assert_eq!(canonical_bio_label("B-SEASON"), "B-SEASON"); | |
| } | |
| fn generated_entities_do_not_emit_legacy_title_labels() { | |
| for entity in ENTITIES { | |
| assert_ne!(entity.b_label(), "B-TITLE"); | |
| assert_ne!(entity.i_label(), "I-TITLE"); | |
| } | |
| } | |
| fn extraction_preserves_file_and_path_title_candidates() { | |
| let tokens = ["A", "/", "僕", "ら"] | |
| .iter() | |
| .map(|value| value.to_string()) | |
| .collect::<Vec<_>>(); | |
| let labels = ["B-TITLE_LATIN", "O", "B-PATH_TITLE_JPN", "I-PATH_TITLE_JPN"] | |
| .iter() | |
| .map(|value| value.to_string()) | |
| .collect::<Vec<_>>(); | |
| let fields = extract_fields(&tokens, &labels); | |
| assert_eq!(fields[Entity::TitleLatin.index()], vec!["A"]); | |
| assert_eq!(fields[Entity::PathTitleLatin.index()], vec!["A"]); | |
| assert_eq!(fields[Entity::PathTitleJpn.index()], vec!["僕ら"]); | |
| assert_eq!(fields[Entity::TitleJpn.index()], vec!["僕ら"]); | |
| } | |
| fn path_context_synthesizes_season_between_title_and_episode() { | |
| let sample = sample_without_season(); | |
| let cfg = test_config(); | |
| let mut rng = StdRng::seed_from_u64(12); | |
| let pieces = build_path_context_pieces(&sample, &cfg, &mut rng) | |
| .expect("expected path context pieces"); | |
| let text = render_labeled_pieces(&pieces); | |
| assert!(text.contains("Example Show")); | |
| assert!( | |
| text.contains("Season") | |
| || text.contains("S01") | |
| || text.contains("第1季") | |
| || text.contains("01"), | |
| "missing synthetic season directory in {text}" | |
| ); | |
| let mut seen_title = false; | |
| let mut seen_season_after_title = false; | |
| let mut seen_episode_after_season = false; | |
| for piece in &pieces { | |
| match piece.entity { | |
| None if !seen_title => {} | |
| Some(Entity::PathTitleLatin) => seen_title = true, | |
| Some(Entity::PathSeason) if seen_title => seen_season_after_title = true, | |
| Some(Entity::Episode) if seen_season_after_title => { | |
| seen_episode_after_season = true | |
| } | |
| _ => {} | |
| } | |
| } | |
| assert!(seen_title); | |
| assert!(seen_season_after_title); | |
| assert!(seen_episode_after_season); | |
| } | |
| fn path_context_can_label_bare_numeric_path_season() { | |
| let mut sample = sample_without_season(); | |
| sample.fields[Entity::Episode.index()] = vec!["3".to_string()]; | |
| let mut cfg = test_config(); | |
| cfg.path_styles = vec![PathStyle::Unix]; | |
| let mut found = None; | |
| for seed in 0..2048 { | |
| let mut rng = StdRng::seed_from_u64(seed); | |
| let pieces = build_path_context_pieces(&sample, &cfg, &mut rng) | |
| .expect("expected path context pieces"); | |
| let text = render_labeled_pieces(&pieces); | |
| if text.contains("Example Show/01/03.mkv") { | |
| found = Some(pieces); | |
| break; | |
| } | |
| } | |
| let pieces = found.expect("expected a Title/01/03.mkv-style path context"); | |
| assert!(pieces | |
| .iter() | |
| .any(|piece| piece.text == "01" && piece.entity == Some(Entity::PathSeason))); | |
| assert!(pieces | |
| .iter() | |
| .any(|piece| piece.text == "03" && piece.entity == Some(Entity::Episode))); | |
| } | |
| fn path_season_variants_include_common_directory_forms() { | |
| let mut variants = HashSet::new(); | |
| for seed in 0..128 { | |
| let mut rng = StdRng::seed_from_u64(seed); | |
| variants.insert(random_season_path_text("S01", &mut rng)); | |
| } | |
| assert!(variants.contains("S01")); | |
| assert!(variants.contains("01")); | |
| assert!(variants.contains("Season 1")); | |
| assert!(variants.contains("Season 01")); | |
| } | |
| fn grouped_path_file_labels_group_but_not_duplicate_title() { | |
| let sample = sample_with_group(); | |
| let mut rng = StdRng::seed_from_u64(21); | |
| let endpoint = entity_piece("01".to_string(), Entity::Episode); | |
| let pieces = grouped_release_file_component("Example Show", endpoint, &sample, &mut rng); | |
| let text = render_labeled_pieces(&pieces); | |
| assert!(text.contains("[Erai-raws]")); | |
| assert!(text.contains("Example Show")); | |
| assert!(text.contains("01")); | |
| assert!(pieces | |
| .iter() | |
| .any(|piece| piece.entity == Some(Entity::Group))); | |
| assert!(pieces | |
| .iter() | |
| .any(|piece| piece.entity == Some(Entity::Episode))); | |
| assert!(pieces | |
| .iter() | |
| .any(|piece| piece.text == "Example Show" && piece.entity.is_none())); | |
| } | |
| } | |