Token Classification
Transformers
ONNX
Safetensors
English
Japanese
Chinese
bert
anime
filename-parsing
Eval Results (legacy)
Instructions to use ModerRAS/AniFileBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ModerRAS/AniFileBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="ModerRAS/AniFileBERT")# Load model directly from transformers import AutoTokenizer, AutoModelForTokenClassification tokenizer = AutoTokenizer.from_pretrained("ModerRAS/AniFileBERT") model = AutoModelForTokenClassification.from_pretrained("ModerRAS/AniFileBERT") - Notebooks
- Google Colab
- Kaggle
| use anyhow::{bail, Context, Result}; | |
| use clap::{Parser, ValueEnum}; | |
| use rand::rngs::StdRng; | |
| use rand::seq::SliceRandom; | |
| use rand::Rng; | |
| use rand::SeedableRng; | |
| use rayon::prelude::*; | |
| use serde::{Deserialize, Serialize}; | |
| use serde_json::json; | |
| use std::collections::{HashMap, HashSet}; | |
| use std::fs::{self, File}; | |
| use std::io::{BufRead, BufReader, BufWriter, Write}; | |
| use std::path::{Path, PathBuf}; | |
| use std::time::Instant; | |
| const ENTITIES: [Entity; 7] = [ | |
| Entity::Group, | |
| Entity::Title, | |
| Entity::Season, | |
| Entity::Episode, | |
| Entity::Special, | |
| Entity::Resolution, | |
| Entity::Source, | |
| ]; | |
| struct Args { | |
| input: PathBuf, | |
| vocab_file: PathBuf, | |
| output_dir: PathBuf, | |
| max_length: usize, | |
| shard_size: usize, | |
| limit_rows: usize, | |
| samples_per_source: usize, | |
| seed: u64, | |
| threads: usize, | |
| separator_mode: SeparatorMode, | |
| bracket_mode: BracketMode, | |
| separators: Vec<String>, | |
| bracket_styles: Vec<String>, | |
| include_original: bool, | |
| include_special_fixtures: bool, | |
| dry_run: bool, | |
| } | |
| enum SeparatorMode { | |
| Global, | |
| PerGap, | |
| } | |
| enum BracketMode { | |
| Global, | |
| PerPart, | |
| } | |
| enum Entity { | |
| Group, | |
| Title, | |
| Season, | |
| Episode, | |
| Special, | |
| Resolution, | |
| Source, | |
| } | |
| impl Entity { | |
| fn index(self) -> usize { | |
| match self { | |
| Entity::Group => 0, | |
| Entity::Title => 1, | |
| Entity::Season => 2, | |
| Entity::Episode => 3, | |
| Entity::Special => 4, | |
| Entity::Resolution => 5, | |
| Entity::Source => 6, | |
| } | |
| } | |
| fn from_name(name: &str) -> Option<Self> { | |
| match name { | |
| "GROUP" => Some(Entity::Group), | |
| "TITLE" => Some(Entity::Title), | |
| "SEASON" => Some(Entity::Season), | |
| "EPISODE" => Some(Entity::Episode), | |
| "SPECIAL" => Some(Entity::Special), | |
| "RESOLUTION" => Some(Entity::Resolution), | |
| "SOURCE" => Some(Entity::Source), | |
| _ => None, | |
| } | |
| } | |
| fn b_label(self) -> &'static str { | |
| match self { | |
| Entity::Group => "B-GROUP", | |
| Entity::Title => "B-TITLE", | |
| Entity::Season => "B-SEASON", | |
| Entity::Episode => "B-EPISODE", | |
| Entity::Special => "B-SPECIAL", | |
| Entity::Resolution => "B-RESOLUTION", | |
| Entity::Source => "B-SOURCE", | |
| } | |
| } | |
| fn i_label(self) -> &'static str { | |
| match self { | |
| Entity::Group => "I-GROUP", | |
| Entity::Title => "I-TITLE", | |
| Entity::Season => "I-SEASON", | |
| Entity::Episode => "I-EPISODE", | |
| Entity::Special => "I-SPECIAL", | |
| Entity::Resolution => "I-RESOLUTION", | |
| Entity::Source => "I-SOURCE", | |
| } | |
| } | |
| } | |
| struct Bracket { | |
| name: String, | |
| open: String, | |
| close: String, | |
| } | |
| impl Bracket { | |
| fn from_name(name: &str) -> Result<Self> { | |
| let trimmed = name.trim(); | |
| let pair = match trimmed { | |
| "none" => ("", ""), | |
| "square" => ("[", "]"), | |
| "round" => ("(", ")"), | |
| "corner" => ("【", "】"), | |
| "angle" => ("《", "》"), | |
| custom if custom.contains('|') => { | |
| let mut parts = custom.splitn(2, '|'); | |
| let open = parts.next().unwrap_or_default(); | |
| let close = parts.next().unwrap_or_default(); | |
| return Ok(Self { | |
| name: custom.to_string(), | |
| open: open.to_string(), | |
| close: close.to_string(), | |
| }); | |
| } | |
| other => bail!("unknown bracket style '{other}'"), | |
| }; | |
| Ok(Self { | |
| name: trimmed.to_string(), | |
| open: pair.0.to_string(), | |
| close: pair.1.to_string(), | |
| }) | |
| } | |
| } | |
| struct InputRow { | |
| filename: Option<String>, | |
| tokens: Vec<String>, | |
| labels: Vec<String>, | |
| tokenizer_variant: Option<String>, | |
| } | |
| struct SourceSample { | |
| row_index: usize, | |
| filename: String, | |
| tokens: Vec<String>, | |
| labels: Vec<String>, | |
| fields: Vec<Vec<String>>, | |
| } | |
| struct GenConfig { | |
| max_length: usize, | |
| shard_size: usize, | |
| separator_mode: SeparatorMode, | |
| bracket_mode: BracketMode, | |
| separators: Vec<String>, | |
| brackets: Vec<Bracket>, | |
| include_original: bool, | |
| samples_per_source: usize, | |
| seed: u64, | |
| } | |
| struct Vocab { | |
| ids: HashMap<String, u16>, | |
| pad_id: u16, | |
| unk_id: u16, | |
| cls_id: u16, | |
| sep_id: u16, | |
| } | |
| struct ShardManifest { | |
| rows: usize, | |
| input_ids: String, | |
| attention_mask: String, | |
| labels: String, | |
| } | |
| struct ShardWriter { | |
| output_dir: PathBuf, | |
| worker_id: usize, | |
| shard_seq: usize, | |
| shard_size: usize, | |
| max_length: usize, | |
| input_ids: Vec<u16>, | |
| attention_mask: Vec<u8>, | |
| labels: Vec<i16>, | |
| rows: usize, | |
| total_rows: u64, | |
| shards: Vec<ShardManifest>, | |
| } | |
| impl ShardWriter { | |
| fn new(output_dir: &Path, worker_id: usize, shard_size: usize, max_length: usize) -> Self { | |
| let capacity = shard_size.saturating_mul(max_length); | |
| Self { | |
| output_dir: output_dir.to_path_buf(), | |
| worker_id, | |
| shard_seq: 0, | |
| shard_size, | |
| max_length, | |
| input_ids: Vec::with_capacity(capacity), | |
| attention_mask: Vec::with_capacity(capacity), | |
| labels: Vec::with_capacity(capacity), | |
| rows: 0, | |
| total_rows: 0, | |
| shards: Vec::new(), | |
| } | |
| } | |
| fn add(&mut self, input_ids: &[u16], attention_mask: &[u8], labels: &[i16]) -> Result<()> { | |
| if input_ids.len() != self.max_length | |
| || attention_mask.len() != self.max_length | |
| || labels.len() != self.max_length | |
| { | |
| bail!("encoded sample has wrong shape"); | |
| } | |
| self.input_ids.extend_from_slice(input_ids); | |
| self.attention_mask.extend_from_slice(attention_mask); | |
| self.labels.extend_from_slice(labels); | |
| self.rows += 1; | |
| self.total_rows += 1; | |
| if self.rows >= self.shard_size { | |
| self.flush()?; | |
| } | |
| Ok(()) | |
| } | |
| fn flush(&mut self) -> Result<()> { | |
| if self.rows == 0 { | |
| return Ok(()); | |
| } | |
| let base = format!("part-w{:03}-s{:06}", self.worker_id, self.shard_seq); | |
| let input_name = format!("{base}.input_ids.npy"); | |
| let mask_name = format!("{base}.attention_mask.npy"); | |
| let label_name = format!("{base}.labels.npy"); | |
| write_npy_u16( | |
| &self.output_dir.join(&input_name), | |
| &self.input_ids, | |
| self.rows, | |
| self.max_length, | |
| )?; | |
| write_npy_u8( | |
| &self.output_dir.join(&mask_name), | |
| &self.attention_mask, | |
| self.rows, | |
| self.max_length, | |
| )?; | |
| write_npy_i16( | |
| &self.output_dir.join(&label_name), | |
| &self.labels, | |
| self.rows, | |
| self.max_length, | |
| )?; | |
| self.shards.push(ShardManifest { | |
| rows: self.rows, | |
| input_ids: input_name, | |
| attention_mask: mask_name, | |
| labels: label_name, | |
| }); | |
| self.input_ids.clear(); | |
| self.attention_mask.clear(); | |
| self.labels.clear(); | |
| self.rows = 0; | |
| self.shard_seq += 1; | |
| Ok(()) | |
| } | |
| } | |
| fn main() -> Result<()> { | |
| let args = Args::parse(); | |
| if args.max_length < 4 { | |
| bail!("--max-length must be at least 4"); | |
| } | |
| if args.shard_size == 0 { | |
| bail!("--shard-size must be positive"); | |
| } | |
| if args.threads > 0 { | |
| rayon::ThreadPoolBuilder::new() | |
| .num_threads(args.threads) | |
| .build_global() | |
| .context("failed to configure rayon thread pool")?; | |
| } | |
| let started = Instant::now(); | |
| let vocab = load_vocab(&args.vocab_file)?; | |
| let brackets = args | |
| .bracket_styles | |
| .iter() | |
| .map(|style| Bracket::from_name(style)) | |
| .collect::<Result<Vec<_>>>()?; | |
| let separators = args | |
| .separators | |
| .iter() | |
| .map(|sep| normalize_separator_arg(sep)) | |
| .collect::<Vec<_>>(); | |
| let cfg = GenConfig { | |
| max_length: args.max_length, | |
| shard_size: args.shard_size, | |
| separator_mode: args.separator_mode, | |
| bracket_mode: args.bracket_mode, | |
| separators, | |
| brackets, | |
| include_original: args.include_original, | |
| samples_per_source: args.samples_per_source, | |
| seed: args.seed, | |
| }; | |
| let mut samples = load_samples(&args.input, args.limit_rows)?; | |
| let source_rows = samples.len(); | |
| let mut rng = StdRng::seed_from_u64(args.seed); | |
| samples.shuffle(&mut rng); | |
| if args.dry_run { | |
| let generated: u128 = samples | |
| .par_iter() | |
| .map(|sample| count_variants(sample, &cfg)) | |
| .sum(); | |
| let special_fixtures = if args.include_special_fixtures { | |
| count_special_fixtures(&cfg) as u128 | |
| } else { | |
| 0 | |
| }; | |
| let manifest = json!({ | |
| "format": "anifilebert.virtual_dataset.preview.v1", | |
| "input": args.input, | |
| "vocab_file": args.vocab_file, | |
| "source_rows": source_rows, | |
| "estimated_rows": generated + special_fixtures, | |
| "source_variant_rows": generated, | |
| "special_fixture_rows": special_fixtures, | |
| "max_length": cfg.max_length, | |
| "separator_mode": cfg.separator_mode, | |
| "bracket_mode": cfg.bracket_mode, | |
| "separators": cfg.separators, | |
| "brackets": cfg.brackets.iter().map(|b| &b.name).collect::<Vec<_>>(), | |
| "include_original": cfg.include_original, | |
| "samples_per_source": cfg.samples_per_source, | |
| "include_special_fixtures": args.include_special_fixtures, | |
| "seed": args.seed, | |
| "elapsed_seconds": started.elapsed().as_secs_f64(), | |
| }); | |
| println!("{}", serde_json::to_string_pretty(&manifest)?); | |
| return Ok(()); | |
| } | |
| fs::create_dir_all(&args.output_dir).with_context(|| { | |
| format!( | |
| "failed to create output directory {}", | |
| args.output_dir.display() | |
| ) | |
| })?; | |
| let chunk_count = rayon::current_num_threads().max(1) * 4; | |
| let chunk_size = samples.len().div_ceil(chunk_count).max(1); | |
| let chunks = samples | |
| .chunks(chunk_size) | |
| .enumerate() | |
| .collect::<Vec<(usize, &[SourceSample])>>(); | |
| let mut worker_results = chunks | |
| .par_iter() | |
| .map(|(chunk_idx, chunk)| { | |
| let mut writer = | |
| ShardWriter::new(&args.output_dir, *chunk_idx, cfg.shard_size, cfg.max_length); | |
| for sample in *chunk { | |
| generate_for_sample(sample, &cfg, &vocab, &mut writer)?; | |
| } | |
| writer.flush()?; | |
| Ok::<_, anyhow::Error>((writer.total_rows, writer.shards)) | |
| }) | |
| .collect::<Result<Vec<_>>>()?; | |
| let mut total_rows: u64 = 0; | |
| let mut shards: Vec<ShardManifest> = Vec::new(); | |
| for (rows, mut worker_shards) in worker_results.drain(..) { | |
| total_rows += rows; | |
| shards.append(&mut worker_shards); | |
| } | |
| let special_rows = if args.include_special_fixtures { | |
| let mut writer = ShardWriter::new( | |
| &args.output_dir, | |
| chunk_count + 1, | |
| cfg.shard_size, | |
| cfg.max_length, | |
| ); | |
| for special in built_in_specials() { | |
| let parts = vec![PartChoice { | |
| entity: Entity::Special, | |
| value: special, | |
| }]; | |
| emit_syntax_variants(&parts, &cfg, &vocab, &mut writer)?; | |
| } | |
| writer.flush()?; | |
| total_rows += writer.total_rows; | |
| shards.append(&mut writer.shards); | |
| writer.total_rows | |
| } else { | |
| 0 | |
| }; | |
| shards.sort_by(|a, b| a.input_ids.cmp(&b.input_ids)); | |
| let manifest = json!({ | |
| "format": "anifilebert.virtual_dataset.shards.v1", | |
| "input": args.input, | |
| "vocab_file": args.vocab_file, | |
| "source_rows": source_rows, | |
| "total_rows": total_rows, | |
| "special_fixture_rows": special_rows, | |
| "max_length": cfg.max_length, | |
| "shard_size": cfg.shard_size, | |
| "tokenizer_variant": "char", | |
| "encoding": { | |
| "input_ids_dtype": "uint16", | |
| "attention_mask_dtype": "uint8", | |
| "labels_dtype": "int16", | |
| "layout": "row_major_npy" | |
| }, | |
| "special_tokens": { | |
| "pad_id": vocab.pad_id, | |
| "unk_id": vocab.unk_id, | |
| "cls_id": vocab.cls_id, | |
| "sep_id": vocab.sep_id | |
| }, | |
| "generation": { | |
| "separator_mode": cfg.separator_mode, | |
| "bracket_mode": cfg.bracket_mode, | |
| "separators": cfg.separators, | |
| "brackets": cfg.brackets.iter().map(|b| &b.name).collect::<Vec<_>>(), | |
| "include_original": cfg.include_original, | |
| "samples_per_source": cfg.samples_per_source, | |
| "include_special_fixtures": args.include_special_fixtures, | |
| "seed": args.seed, | |
| "threads": rayon::current_num_threads() | |
| }, | |
| "shards": shards, | |
| "elapsed_seconds": started.elapsed().as_secs_f64(), | |
| }); | |
| let manifest_path = args.output_dir.join("manifest.json"); | |
| fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?) | |
| .with_context(|| format!("failed to write {}", manifest_path.display()))?; | |
| println!("{}", serde_json::to_string_pretty(&manifest)?); | |
| Ok(()) | |
| } | |
| fn normalize_separator_arg(value: &str) -> String { | |
| match value { | |
| "\\t" => "\t".to_string(), | |
| "\\s" => " ".to_string(), | |
| other => other.to_string(), | |
| } | |
| } | |
| fn load_vocab(path: &Path) -> Result<Vocab> { | |
| let text = fs::read_to_string(path) | |
| .with_context(|| format!("failed to read vocab file {}", path.display()))?; | |
| let raw: HashMap<String, u64> = | |
| serde_json::from_str(&text).context("failed to parse vocab JSON")?; | |
| let mut ids = HashMap::with_capacity(raw.len()); | |
| for (token, id) in raw { | |
| if id > u16::MAX as u64 { | |
| bail!("vocab id {id} for token '{token}' exceeds uint16 storage"); | |
| } | |
| ids.insert(token, id as u16); | |
| } | |
| let pad_id = *ids.get("[PAD]").context("vocab is missing [PAD]")?; | |
| let unk_id = *ids.get("[UNK]").context("vocab is missing [UNK]")?; | |
| let cls_id = *ids.get("[CLS]").context("vocab is missing [CLS]")?; | |
| let sep_id = *ids.get("[SEP]").context("vocab is missing [SEP]")?; | |
| Ok(Vocab { | |
| ids, | |
| pad_id, | |
| unk_id, | |
| cls_id, | |
| sep_id, | |
| }) | |
| } | |
| fn load_samples(path: &Path, limit_rows: usize) -> Result<Vec<SourceSample>> { | |
| let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?; | |
| let reader = BufReader::new(file); | |
| let mut samples = Vec::new(); | |
| for (idx, line) in reader.lines().enumerate() { | |
| if limit_rows > 0 && samples.len() >= limit_rows { | |
| break; | |
| } | |
| let line = line.with_context(|| format!("failed reading line {}", idx + 1))?; | |
| if line.trim().is_empty() { | |
| continue; | |
| } | |
| let row: InputRow = serde_json::from_str(&line) | |
| .with_context(|| format!("failed to parse JSONL line {}", idx + 1))?; | |
| if let Some(variant) = row.tokenizer_variant.as_deref() { | |
| if variant != "char" { | |
| bail!( | |
| "line {} has tokenizer_variant={variant}; virtual shard generation currently requires char data", | |
| idx + 1 | |
| ); | |
| } | |
| } | |
| if row.tokens.len() != row.labels.len() { | |
| bail!( | |
| "line {} has mismatched token/label lengths: {} vs {}", | |
| idx + 1, | |
| row.tokens.len(), | |
| row.labels.len() | |
| ); | |
| } | |
| let filename = row.filename.clone().unwrap_or_else(|| row.tokens.join("")); | |
| let fields = extract_fields(&row.tokens, &row.labels); | |
| samples.push(SourceSample { | |
| row_index: idx, | |
| filename, | |
| tokens: row.tokens, | |
| labels: row.labels, | |
| fields, | |
| }); | |
| } | |
| Ok(samples) | |
| } | |
| fn extract_fields(tokens: &[String], labels: &[String]) -> Vec<Vec<String>> { | |
| let mut fields: Vec<Vec<String>> = (0..ENTITIES.len()).map(|_| Vec::new()).collect(); | |
| let mut seen: Vec<HashSet<String>> = (0..ENTITIES.len()).map(|_| HashSet::new()).collect(); | |
| let mut active_entity: Option<Entity> = None; | |
| let mut active_text = String::new(); | |
| let flush = |entity: Option<Entity>, | |
| text: &mut String, | |
| fields: &mut Vec<Vec<String>>, | |
| seen: &mut Vec<HashSet<String>>| { | |
| if let Some(entity) = entity { | |
| let value = text.trim().to_string(); | |
| if !value.is_empty() && seen[entity.index()].insert(value.clone()) { | |
| fields[entity.index()].push(value); | |
| } | |
| } | |
| text.clear(); | |
| }; | |
| for (token, label) in tokens.iter().zip(labels.iter()) { | |
| if let Some(entity) = label.strip_prefix("B-").and_then(Entity::from_name) { | |
| flush(active_entity, &mut active_text, &mut fields, &mut seen); | |
| active_entity = Some(entity); | |
| active_text.push_str(token); | |
| } else if let Some(entity) = label.strip_prefix("I-").and_then(Entity::from_name) { | |
| if active_entity == Some(entity) { | |
| active_text.push_str(token); | |
| } else { | |
| flush(active_entity, &mut active_text, &mut fields, &mut seen); | |
| active_entity = Some(entity); | |
| active_text.push_str(token); | |
| } | |
| } else { | |
| flush(active_entity, &mut active_text, &mut fields, &mut seen); | |
| active_entity = None; | |
| } | |
| } | |
| flush(active_entity, &mut active_text, &mut fields, &mut seen); | |
| fields | |
| } | |
| fn count_variants(sample: &SourceSample, cfg: &GenConfig) -> u128 { | |
| let mut count = if cfg.include_original { 1 } else { 0 }; | |
| let available = ENTITIES | |
| .iter() | |
| .copied() | |
| .filter(|entity| !sample.fields[entity.index()].is_empty()) | |
| .collect::<Vec<_>>(); | |
| let n = available.len(); | |
| if n == 0 { | |
| return count; | |
| } | |
| if cfg.samples_per_source > 0 { | |
| return count + cfg.samples_per_source as u128; | |
| } | |
| for mask in 1usize..(1usize << n) { | |
| let selected = available | |
| .iter() | |
| .enumerate() | |
| .filter_map(|(idx, entity)| ((mask & (1usize << idx)) != 0).then_some(*entity)) | |
| .collect::<Vec<_>>(); | |
| let m = selected.len(); | |
| let value_product: u128 = selected | |
| .iter() | |
| .map(|entity| sample.fields[entity.index()].len() as u128) | |
| .product(); | |
| let perm_count = factorial(m as u32); | |
| let sep_factor = if m <= 1 { | |
| 1 | |
| } else { | |
| match cfg.separator_mode { | |
| SeparatorMode::Global => cfg.separators.len() as u128, | |
| SeparatorMode::PerGap => (cfg.separators.len() as u128).pow((m - 1) as u32), | |
| } | |
| }; | |
| let bracket_factor = match cfg.bracket_mode { | |
| BracketMode::Global => cfg.brackets.len() as u128, | |
| BracketMode::PerPart => (cfg.brackets.len() as u128).pow(m as u32), | |
| }; | |
| count += value_product * perm_count * sep_factor * bracket_factor; | |
| } | |
| count | |
| } | |
| fn count_special_fixtures(cfg: &GenConfig) -> usize { | |
| let bracket_factor = match cfg.bracket_mode { | |
| BracketMode::Global => cfg.brackets.len(), | |
| BracketMode::PerPart => cfg.brackets.len(), | |
| }; | |
| built_in_specials().len() * bracket_factor | |
| } | |
| fn factorial(n: u32) -> u128 { | |
| (1..=n as u128).product::<u128>().max(1) | |
| } | |
| fn generate_for_sample( | |
| sample: &SourceSample, | |
| cfg: &GenConfig, | |
| vocab: &Vocab, | |
| writer: &mut ShardWriter, | |
| ) -> Result<()> { | |
| if cfg.include_original { | |
| let (input_ids, attention_mask, labels) = | |
| encode_original_sample(sample, vocab, cfg.max_length)?; | |
| writer.add(&input_ids, &attention_mask, &labels)?; | |
| } | |
| if cfg.samples_per_source > 0 { | |
| generate_sampled_variants(sample, cfg, vocab, writer)?; | |
| return Ok(()); | |
| } | |
| let available = ENTITIES | |
| .iter() | |
| .copied() | |
| .filter(|entity| !sample.fields[entity.index()].is_empty()) | |
| .collect::<Vec<_>>(); | |
| let n = available.len(); | |
| for mask in 1usize..(1usize << n) { | |
| let mut selected = available | |
| .iter() | |
| .enumerate() | |
| .filter_map(|(idx, entity)| ((mask & (1usize << idx)) != 0).then_some(*entity)) | |
| .collect::<Vec<_>>(); | |
| permute_entities(&mut selected, 0, &mut |order| { | |
| let mut parts: Vec<PartChoice> = Vec::with_capacity(order.len()); | |
| for_each_value_combo(order, &sample.fields, 0, &mut parts, &mut |combo| { | |
| emit_syntax_variants(combo, cfg, vocab, writer) | |
| }) | |
| })?; | |
| } | |
| Ok(()) | |
| } | |
| fn generate_sampled_variants( | |
| sample: &SourceSample, | |
| cfg: &GenConfig, | |
| vocab: &Vocab, | |
| writer: &mut ShardWriter, | |
| ) -> Result<()> { | |
| let mut rng = StdRng::seed_from_u64( | |
| cfg.seed ^ ((sample.row_index as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15)), | |
| ); | |
| let available = ENTITIES | |
| .iter() | |
| .copied() | |
| .filter(|entity| !sample.fields[entity.index()].is_empty()) | |
| .collect::<Vec<_>>(); | |
| if available.is_empty() { | |
| return Ok(()); | |
| } | |
| let mut seen = HashSet::new(); | |
| let mut emitted = 0usize; | |
| let budget = cfg.samples_per_source; | |
| let max_unique_attempts = budget.saturating_mul(32).max(64); | |
| let mut attempts = 0usize; | |
| let mut templates: Vec<Vec<PartChoice>> = Vec::new(); | |
| if let Some(title) = sample.fields[Entity::Title.index()].first() { | |
| templates.push(vec![PartChoice { | |
| entity: Entity::Title, | |
| value: title.clone(), | |
| }]); | |
| if let Some(season) = sample.fields[Entity::Season.index()].first() { | |
| templates.push(vec![ | |
| PartChoice { | |
| entity: Entity::Title, | |
| value: title.clone(), | |
| }, | |
| PartChoice { | |
| entity: Entity::Season, | |
| value: season.clone(), | |
| }, | |
| ]); | |
| } | |
| } | |
| if let Some(episode) = sample.fields[Entity::Episode.index()].first() { | |
| templates.push(vec![PartChoice { | |
| entity: Entity::Episode, | |
| value: episode.clone(), | |
| }]); | |
| } | |
| if let Some(special) = sample.fields[Entity::Special.index()].first() { | |
| templates.push(vec![PartChoice { | |
| entity: Entity::Special, | |
| value: special.clone(), | |
| }]); | |
| } | |
| if let (Some(title), Some(special)) = ( | |
| sample.fields[Entity::Title.index()].first(), | |
| sample.fields[Entity::Special.index()].first(), | |
| ) { | |
| templates.push(vec![ | |
| PartChoice { | |
| entity: Entity::Title, | |
| value: title.clone(), | |
| }, | |
| PartChoice { | |
| entity: Entity::Special, | |
| value: special.clone(), | |
| }, | |
| ]); | |
| } | |
| for parts in templates { | |
| if emitted >= budget { | |
| break; | |
| } | |
| emit_sample_variant( | |
| parts, | |
| cfg, | |
| vocab, | |
| writer, | |
| &mut seen, | |
| &mut emitted, | |
| &mut rng, | |
| false, | |
| )?; | |
| } | |
| while emitted < budget && attempts < max_unique_attempts { | |
| attempts += 1; | |
| let subset_size = match rng.gen_range(0..100) { | |
| 0..=29 => 1, | |
| 30..=54 => 2, | |
| 55..=74 => 3, | |
| 75..=89 => 4.min(available.len()), | |
| _ => available.len().min(5), | |
| } | |
| .max(1) | |
| .min(available.len()); | |
| let mut chosen = available | |
| .choose_multiple(&mut rng, subset_size) | |
| .copied() | |
| .collect::<Vec<_>>(); | |
| chosen.shuffle(&mut rng); | |
| if !chosen | |
| .iter() | |
| .any(|entity| matches!(entity, Entity::Title | Entity::Episode | Entity::Special)) | |
| { | |
| if let Some(fallback) = available | |
| .iter() | |
| .copied() | |
| .find(|entity| matches!(entity, Entity::Title | Entity::Episode | Entity::Special)) | |
| { | |
| if !chosen.contains(&fallback) { | |
| chosen.push(fallback); | |
| } | |
| } | |
| } | |
| let mut parts = Vec::with_capacity(chosen.len()); | |
| for entity in chosen { | |
| let values = &sample.fields[entity.index()]; | |
| let value = values.choose(&mut rng).cloned().unwrap_or_default(); | |
| parts.push(PartChoice { entity, value }); | |
| } | |
| parts.shuffle(&mut rng); | |
| emit_sample_variant( | |
| parts, | |
| cfg, | |
| vocab, | |
| writer, | |
| &mut seen, | |
| &mut emitted, | |
| &mut rng, | |
| false, | |
| )?; | |
| } | |
| while emitted < budget { | |
| let subset_size = match rng.gen_range(0..100) { | |
| 0..=29 => 1, | |
| 30..=54 => 2, | |
| 55..=74 => 3, | |
| 75..=89 => 4.min(available.len()), | |
| _ => available.len().min(5), | |
| } | |
| .max(1) | |
| .min(available.len()); | |
| let mut chosen = available | |
| .choose_multiple(&mut rng, subset_size) | |
| .copied() | |
| .collect::<Vec<_>>(); | |
| chosen.shuffle(&mut rng); | |
| if !chosen | |
| .iter() | |
| .any(|entity| matches!(entity, Entity::Title | Entity::Episode | Entity::Special)) | |
| { | |
| if let Some(fallback) = available | |
| .iter() | |
| .copied() | |
| .find(|entity| matches!(entity, Entity::Title | Entity::Episode | Entity::Special)) | |
| { | |
| if !chosen.contains(&fallback) { | |
| chosen.push(fallback); | |
| } | |
| } | |
| } | |
| let mut parts = Vec::with_capacity(chosen.len()); | |
| for entity in chosen { | |
| let values = &sample.fields[entity.index()]; | |
| let value = values.choose(&mut rng).cloned().unwrap_or_default(); | |
| parts.push(PartChoice { entity, value }); | |
| } | |
| parts.shuffle(&mut rng); | |
| emit_sample_variant( | |
| parts, | |
| cfg, | |
| vocab, | |
| writer, | |
| &mut seen, | |
| &mut emitted, | |
| &mut rng, | |
| true, | |
| )?; | |
| } | |
| Ok(()) | |
| } | |
| fn emit_sample_variant( | |
| parts: Vec<PartChoice>, | |
| cfg: &GenConfig, | |
| vocab: &Vocab, | |
| writer: &mut ShardWriter, | |
| seen: &mut HashSet<String>, | |
| emitted: &mut usize, | |
| rng: &mut StdRng, | |
| allow_duplicate: bool, | |
| ) -> Result<()> { | |
| if *emitted >= cfg.samples_per_source { | |
| return Ok(()); | |
| } | |
| if parts.is_empty() { | |
| return Ok(()); | |
| } | |
| let separators = match cfg.separator_mode { | |
| SeparatorMode::Global => { | |
| let sep = cfg | |
| .separators | |
| .choose(rng) | |
| .cloned() | |
| .unwrap_or_else(|| " ".to_string()); | |
| if parts.len() > 1 { | |
| vec![sep; parts.len() - 1] | |
| } else { | |
| Vec::new() | |
| } | |
| } | |
| SeparatorMode::PerGap => { | |
| let mut values = Vec::with_capacity(parts.len().saturating_sub(1)); | |
| for _ in 0..parts.len().saturating_sub(1) { | |
| values.push( | |
| cfg.separators | |
| .choose(rng) | |
| .cloned() | |
| .unwrap_or_else(|| " ".to_string()), | |
| ); | |
| } | |
| values | |
| } | |
| }; | |
| let brackets = match cfg.bracket_mode { | |
| BracketMode::Global => { | |
| let bracket = cfg | |
| .brackets | |
| .choose(rng) | |
| .cloned() | |
| .unwrap_or_else(|| Bracket { | |
| name: "none".to_string(), | |
| open: String::new(), | |
| close: String::new(), | |
| }); | |
| vec![bracket; parts.len()] | |
| } | |
| BracketMode::PerPart => { | |
| let mut values = Vec::with_capacity(parts.len()); | |
| for _ in 0..parts.len() { | |
| values.push( | |
| cfg.brackets | |
| .choose(rng) | |
| .cloned() | |
| .unwrap_or_else(|| Bracket { | |
| name: "none".to_string(), | |
| open: String::new(), | |
| close: String::new(), | |
| }), | |
| ); | |
| } | |
| values | |
| } | |
| }; | |
| let text = render_variant_text(&parts, &separators, &brackets); | |
| if !allow_duplicate && !seen.insert(text) { | |
| return Ok(()); | |
| } | |
| let (input_ids, attention_mask, labels) = | |
| encode_generated_sample(&parts, &separators, &brackets, vocab, cfg.max_length)?; | |
| writer.add(&input_ids, &attention_mask, &labels)?; | |
| *emitted += 1; | |
| Ok(()) | |
| } | |
| fn permute_entities<F>(values: &mut [Entity], start: usize, callback: &mut F) -> Result<()> | |
| where | |
| F: FnMut(&[Entity]) -> Result<()>, | |
| { | |
| if start >= values.len() { | |
| return callback(values); | |
| } | |
| for idx in start..values.len() { | |
| values.swap(start, idx); | |
| permute_entities(values, start + 1, callback)?; | |
| values.swap(start, idx); | |
| } | |
| Ok(()) | |
| } | |
| struct PartChoice { | |
| entity: Entity, | |
| value: String, | |
| } | |
| fn for_each_value_combo<F>( | |
| order: &[Entity], | |
| fields: &[Vec<String>], | |
| idx: usize, | |
| current: &mut Vec<PartChoice>, | |
| callback: &mut F, | |
| ) -> Result<()> | |
| where | |
| F: FnMut(&[PartChoice]) -> Result<()>, | |
| { | |
| if idx >= order.len() { | |
| return callback(current); | |
| } | |
| let entity = order[idx]; | |
| for value in &fields[entity.index()] { | |
| current.push(PartChoice { | |
| entity, | |
| value: value.clone(), | |
| }); | |
| for_each_value_combo(order, fields, idx + 1, current, callback)?; | |
| current.pop(); | |
| } | |
| Ok(()) | |
| } | |
| fn emit_syntax_variants( | |
| parts: &[PartChoice], | |
| cfg: &GenConfig, | |
| vocab: &Vocab, | |
| writer: &mut ShardWriter, | |
| ) -> Result<()> { | |
| let gaps = parts.len().saturating_sub(1); | |
| let mut separators = Vec::with_capacity(gaps); | |
| for_each_separator_combo(gaps, cfg, 0, &mut separators, &mut |sep_combo| { | |
| let mut brackets = Vec::with_capacity(parts.len()); | |
| for_each_bracket_combo(parts.len(), cfg, 0, &mut brackets, &mut |bracket_combo| { | |
| let (input_ids, attention_mask, labels) = | |
| encode_generated_sample(parts, sep_combo, bracket_combo, vocab, cfg.max_length)?; | |
| writer.add(&input_ids, &attention_mask, &labels) | |
| }) | |
| }) | |
| } | |
| fn for_each_separator_combo<F>( | |
| gaps: usize, | |
| cfg: &GenConfig, | |
| idx: usize, | |
| current: &mut Vec<String>, | |
| callback: &mut F, | |
| ) -> Result<()> | |
| where | |
| F: FnMut(&[String]) -> Result<()>, | |
| { | |
| if gaps == 0 { | |
| return callback(current); | |
| } | |
| match cfg.separator_mode { | |
| SeparatorMode::Global => { | |
| if idx == 0 { | |
| for sep in &cfg.separators { | |
| current.clear(); | |
| current.resize(gaps, sep.clone()); | |
| callback(current)?; | |
| } | |
| } | |
| Ok(()) | |
| } | |
| SeparatorMode::PerGap => { | |
| if idx >= gaps { | |
| return callback(current); | |
| } | |
| for sep in &cfg.separators { | |
| current.push(sep.clone()); | |
| for_each_separator_combo(gaps, cfg, idx + 1, current, callback)?; | |
| current.pop(); | |
| } | |
| Ok(()) | |
| } | |
| } | |
| } | |
| fn for_each_bracket_combo<F>( | |
| parts: usize, | |
| cfg: &GenConfig, | |
| idx: usize, | |
| current: &mut Vec<Bracket>, | |
| callback: &mut F, | |
| ) -> Result<()> | |
| where | |
| F: FnMut(&[Bracket]) -> Result<()>, | |
| { | |
| match cfg.bracket_mode { | |
| BracketMode::Global => { | |
| if idx == 0 { | |
| for bracket in &cfg.brackets { | |
| current.clear(); | |
| current.resize(parts, bracket.clone()); | |
| callback(current)?; | |
| } | |
| } | |
| Ok(()) | |
| } | |
| BracketMode::PerPart => { | |
| if idx >= parts { | |
| return callback(current); | |
| } | |
| for bracket in &cfg.brackets { | |
| current.push(bracket.clone()); | |
| for_each_bracket_combo(parts, cfg, idx + 1, current, callback)?; | |
| current.pop(); | |
| } | |
| Ok(()) | |
| } | |
| } | |
| } | |
| fn render_variant_text( | |
| parts: &[PartChoice], | |
| separators: &[String], | |
| brackets: &[Bracket], | |
| ) -> String { | |
| let mut text = String::new(); | |
| for (idx, part) in parts.iter().enumerate() { | |
| text.push_str(&brackets[idx].open); | |
| text.push_str(&part.value); | |
| text.push_str(&brackets[idx].close); | |
| if idx < separators.len() { | |
| text.push_str(&separators[idx]); | |
| } | |
| } | |
| text | |
| } | |
| fn encode_original_sample( | |
| sample: &SourceSample, | |
| vocab: &Vocab, | |
| max_length: usize, | |
| ) -> Result<(Vec<u16>, Vec<u8>, Vec<i16>)> { | |
| let mut input_ids = vec![vocab.pad_id; max_length]; | |
| let mut attention_mask = vec![0u8; max_length]; | |
| let mut labels = vec![-100i16; max_length]; | |
| input_ids[0] = vocab.cls_id; | |
| attention_mask[0] = 1; | |
| let available = max_length.saturating_sub(2); | |
| let token_count = sample.tokens.len().min(available); | |
| for idx in 0..token_count { | |
| input_ids[idx + 1] = token_id(vocab, &sample.tokens[idx]); | |
| attention_mask[idx + 1] = 1; | |
| labels[idx + 1] = label_id(&sample.labels[idx]).with_context(|| { | |
| format!( | |
| "unknown label '{}' on source row {} ({})", | |
| sample.labels[idx], | |
| sample.row_index + 1, | |
| sample.filename | |
| ) | |
| })?; | |
| } | |
| let sep_pos = token_count + 1; | |
| input_ids[sep_pos] = vocab.sep_id; | |
| attention_mask[sep_pos] = 1; | |
| Ok((input_ids, attention_mask, labels)) | |
| } | |
| fn encode_generated_sample( | |
| parts: &[PartChoice], | |
| separators: &[String], | |
| brackets: &[Bracket], | |
| vocab: &Vocab, | |
| max_length: usize, | |
| ) -> Result<(Vec<u16>, Vec<u8>, Vec<i16>)> { | |
| let mut input_ids = vec![vocab.pad_id; max_length]; | |
| let mut attention_mask = vec![0u8; max_length]; | |
| let mut labels = vec![-100i16; max_length]; | |
| input_ids[0] = vocab.cls_id; | |
| attention_mask[0] = 1; | |
| let available = max_length.saturating_sub(2); | |
| let mut pos = 1usize; | |
| for (idx, part) in parts.iter().enumerate() { | |
| let bracket = &brackets[idx]; | |
| append_o_text( | |
| &bracket.open, | |
| vocab, | |
| available, | |
| &mut pos, | |
| &mut input_ids, | |
| &mut attention_mask, | |
| &mut labels, | |
| ); | |
| append_entity_text( | |
| &part.value, | |
| part.entity, | |
| vocab, | |
| available, | |
| &mut pos, | |
| &mut input_ids, | |
| &mut attention_mask, | |
| &mut labels, | |
| )?; | |
| append_o_text( | |
| &bracket.close, | |
| vocab, | |
| available, | |
| &mut pos, | |
| &mut input_ids, | |
| &mut attention_mask, | |
| &mut labels, | |
| ); | |
| if idx < separators.len() { | |
| append_o_text( | |
| &separators[idx], | |
| vocab, | |
| available, | |
| &mut pos, | |
| &mut input_ids, | |
| &mut attention_mask, | |
| &mut labels, | |
| ); | |
| } | |
| } | |
| let sep_pos = pos.min(max_length - 1); | |
| input_ids[sep_pos] = vocab.sep_id; | |
| attention_mask[sep_pos] = 1; | |
| labels[sep_pos] = -100; | |
| Ok((input_ids, attention_mask, labels)) | |
| } | |
| fn append_o_text( | |
| text: &str, | |
| vocab: &Vocab, | |
| available: usize, | |
| pos: &mut usize, | |
| input_ids: &mut [u16], | |
| attention_mask: &mut [u8], | |
| labels: &mut [i16], | |
| ) { | |
| for ch in text.chars() { | |
| if *pos > available { | |
| return; | |
| } | |
| let token = ch.to_string(); | |
| input_ids[*pos] = token_id(vocab, &token); | |
| attention_mask[*pos] = 1; | |
| labels[*pos] = 0; | |
| *pos += 1; | |
| } | |
| } | |
| fn append_entity_text( | |
| text: &str, | |
| entity: Entity, | |
| vocab: &Vocab, | |
| available: usize, | |
| pos: &mut usize, | |
| input_ids: &mut [u16], | |
| attention_mask: &mut [u8], | |
| labels: &mut [i16], | |
| ) -> Result<()> { | |
| let b = label_id(entity.b_label()).context("missing B label")?; | |
| let i = label_id(entity.i_label()).context("missing I label")?; | |
| let mut first = true; | |
| for ch in text.chars() { | |
| if *pos > available { | |
| return Ok(()); | |
| } | |
| let token = ch.to_string(); | |
| input_ids[*pos] = token_id(vocab, &token); | |
| attention_mask[*pos] = 1; | |
| labels[*pos] = if first { b } else { i }; | |
| first = false; | |
| *pos += 1; | |
| } | |
| Ok(()) | |
| } | |
| fn token_id(vocab: &Vocab, token: &str) -> u16 { | |
| *vocab.ids.get(token).unwrap_or(&vocab.unk_id) | |
| } | |
| fn label_id(label: &str) -> Option<i16> { | |
| Some(match label { | |
| "O" => 0, | |
| "B-TITLE" => 1, | |
| "I-TITLE" => 2, | |
| "B-SEASON" => 3, | |
| "I-SEASON" => 4, | |
| "B-EPISODE" => 5, | |
| "I-EPISODE" => 6, | |
| "B-SPECIAL" => 7, | |
| "I-SPECIAL" => 8, | |
| "B-GROUP" => 9, | |
| "I-GROUP" => 10, | |
| "B-RESOLUTION" => 11, | |
| "I-RESOLUTION" => 12, | |
| "B-SOURCE" => 13, | |
| "I-SOURCE" => 14, | |
| _ => return None, | |
| }) | |
| } | |
| fn built_in_specials() -> Vec<String> { | |
| let mut values = Vec::new(); | |
| values.push("Menu".to_string()); | |
| for idx in 1..=24 { | |
| values.push(format!("Menu{idx:02}")); | |
| values.push(format!("Menu {idx:02}")); | |
| values.push(format!("BDMenu{idx:02}")); | |
| values.push(format!("BD Menu{idx:02}")); | |
| values.push(format!("Menu{idx:02}-01")); | |
| values.push(format!("ED E{idx:02}")); | |
| } | |
| for idx in 1..=6 { | |
| values.push(format!("OP{idx:02}")); | |
| values.push(format!("NCOP{idx:02}")); | |
| values.push(format!("NCED{idx:02}")); | |
| } | |
| for idx in 1..=12 { | |
| values.push(format!("CM{idx:02}")); | |
| values.push(format!("PV{idx:02}")); | |
| } | |
| values | |
| } | |
| fn write_npy_u16(path: &Path, data: &[u16], rows: usize, cols: usize) -> Result<()> { | |
| let mut writer = BufWriter::new( | |
| File::create(path).with_context(|| format!("failed to create {}", path.display()))?, | |
| ); | |
| write_npy_header(&mut writer, "<u2", rows, cols)?; | |
| for value in data { | |
| writer.write_all(&value.to_le_bytes())?; | |
| } | |
| Ok(()) | |
| } | |
| fn write_npy_u8(path: &Path, data: &[u8], rows: usize, cols: usize) -> Result<()> { | |
| let mut writer = BufWriter::new( | |
| File::create(path).with_context(|| format!("failed to create {}", path.display()))?, | |
| ); | |
| write_npy_header(&mut writer, "|u1", rows, cols)?; | |
| writer.write_all(data)?; | |
| Ok(()) | |
| } | |
| fn write_npy_i16(path: &Path, data: &[i16], rows: usize, cols: usize) -> Result<()> { | |
| let mut writer = BufWriter::new( | |
| File::create(path).with_context(|| format!("failed to create {}", path.display()))?, | |
| ); | |
| write_npy_header(&mut writer, "<i2", rows, cols)?; | |
| for value in data { | |
| writer.write_all(&value.to_le_bytes())?; | |
| } | |
| Ok(()) | |
| } | |
| fn write_npy_header<W: Write>(writer: &mut W, descr: &str, rows: usize, cols: usize) -> Result<()> { | |
| let mut header = format!( | |
| "{{'descr': '{}', 'fortran_order': False, 'shape': ({}, {}), }}", | |
| descr, rows, cols | |
| ) | |
| .into_bytes(); | |
| let preamble_len = 10usize; | |
| let pad_len = (16 - ((preamble_len + header.len() + 1) % 16)) % 16; | |
| header.extend(std::iter::repeat(b' ').take(pad_len)); | |
| header.push(b'\n'); | |
| if header.len() > u16::MAX as usize { | |
| bail!("npy header too large"); | |
| } | |
| writer.write_all(b"\x93NUMPY")?; | |
| writer.write_all(&[1, 0])?; | |
| writer.write_all(&(header.len() as u16).to_le_bytes())?; | |
| writer.write_all(&header)?; | |
| Ok(()) | |
| } | |