use anyhow::{bail, Context, Result}; use clap::{Parser, ValueEnum}; use rand::rngs::StdRng; use rand::seq::SliceRandom; use rand::Rng; use rand::SeedableRng; use rayon::prelude::*; use serde::{Deserialize, Serialize}; use serde_json::json; use std::collections::{HashMap, HashSet}; use std::fs::{self, File}; use std::io::{BufRead, BufReader, BufWriter, Write}; use std::path::{Path, PathBuf}; use std::time::Instant; const ENTITIES: [Entity; 7] = [ Entity::Group, Entity::Title, Entity::Season, Entity::Episode, Entity::Special, Entity::Resolution, Entity::Source, ]; #[derive(Parser, Debug)] #[command( about = "Generate pre-encoded AniFileBERT virtual BIO permutation shards", version )] struct Args { #[arg(long)] input: PathBuf, #[arg(long)] vocab_file: PathBuf, #[arg(long)] output_dir: PathBuf, #[arg(long, default_value_t = 128)] max_length: usize, #[arg(long, default_value_t = 25_000)] shard_size: usize, #[arg(long, default_value_t = 0)] limit_rows: usize, #[arg(long, default_value_t = 0)] samples_per_source: usize, #[arg(long, default_value_t = 42)] seed: u64, #[arg(long, default_value_t = 0)] threads: usize, #[arg(long, default_value = "global")] separator_mode: SeparatorMode, #[arg(long, default_value = "global")] bracket_mode: BracketMode, #[arg(long, value_delimiter = ',', default_value = " , - ,.,_,-,~,~")] separators: Vec, #[arg( long, value_delimiter = ',', default_value = "none,square,round,corner,angle" )] bracket_styles: Vec, #[arg(long, default_value_t = true)] include_original: bool, #[arg(long, default_value_t = true)] include_special_fixtures: bool, #[arg(long, help = "Only count rows; do not write shard files")] dry_run: bool, } #[derive(Clone, Copy, Debug, Serialize, ValueEnum)] enum SeparatorMode { Global, PerGap, } #[derive(Clone, Copy, Debug, Serialize, ValueEnum)] enum BracketMode { Global, PerPart, } #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd, Serialize)] enum Entity { Group, Title, Season, Episode, Special, Resolution, Source, } impl Entity { fn index(self) -> usize { match self { Entity::Group => 0, Entity::Title => 1, Entity::Season => 2, Entity::Episode => 3, Entity::Special => 4, Entity::Resolution => 5, Entity::Source => 6, } } fn from_name(name: &str) -> Option { match name { "GROUP" => Some(Entity::Group), "TITLE" => Some(Entity::Title), "SEASON" => Some(Entity::Season), "EPISODE" => Some(Entity::Episode), "SPECIAL" => Some(Entity::Special), "RESOLUTION" => Some(Entity::Resolution), "SOURCE" => Some(Entity::Source), _ => None, } } fn b_label(self) -> &'static str { match self { Entity::Group => "B-GROUP", Entity::Title => "B-TITLE", Entity::Season => "B-SEASON", Entity::Episode => "B-EPISODE", Entity::Special => "B-SPECIAL", Entity::Resolution => "B-RESOLUTION", Entity::Source => "B-SOURCE", } } fn i_label(self) -> &'static str { match self { Entity::Group => "I-GROUP", Entity::Title => "I-TITLE", Entity::Season => "I-SEASON", Entity::Episode => "I-EPISODE", Entity::Special => "I-SPECIAL", Entity::Resolution => "I-RESOLUTION", Entity::Source => "I-SOURCE", } } } #[derive(Clone, Debug)] struct Bracket { name: String, open: String, close: String, } impl Bracket { fn from_name(name: &str) -> Result { let trimmed = name.trim(); let pair = match trimmed { "none" => ("", ""), "square" => ("[", "]"), "round" => ("(", ")"), "corner" => ("【", "】"), "angle" => ("《", "》"), custom if custom.contains('|') => { let mut parts = custom.splitn(2, '|'); let open = parts.next().unwrap_or_default(); let close = parts.next().unwrap_or_default(); return Ok(Self { name: custom.to_string(), open: open.to_string(), close: close.to_string(), }); } other => bail!("unknown bracket style '{other}'"), }; Ok(Self { name: trimmed.to_string(), open: pair.0.to_string(), close: pair.1.to_string(), }) } } #[derive(Deserialize)] struct InputRow { filename: Option, tokens: Vec, labels: Vec, tokenizer_variant: Option, } #[derive(Clone)] struct SourceSample { row_index: usize, filename: String, tokens: Vec, labels: Vec, fields: Vec>, } #[derive(Clone)] struct GenConfig { max_length: usize, shard_size: usize, separator_mode: SeparatorMode, bracket_mode: BracketMode, separators: Vec, brackets: Vec, include_original: bool, samples_per_source: usize, seed: u64, } #[derive(Clone)] struct Vocab { ids: HashMap, pad_id: u16, unk_id: u16, cls_id: u16, sep_id: u16, } #[derive(Serialize)] struct ShardManifest { rows: usize, input_ids: String, attention_mask: String, labels: String, } struct ShardWriter { output_dir: PathBuf, worker_id: usize, shard_seq: usize, shard_size: usize, max_length: usize, input_ids: Vec, attention_mask: Vec, labels: Vec, rows: usize, total_rows: u64, shards: Vec, } impl ShardWriter { fn new(output_dir: &Path, worker_id: usize, shard_size: usize, max_length: usize) -> Self { let capacity = shard_size.saturating_mul(max_length); Self { output_dir: output_dir.to_path_buf(), worker_id, shard_seq: 0, shard_size, max_length, input_ids: Vec::with_capacity(capacity), attention_mask: Vec::with_capacity(capacity), labels: Vec::with_capacity(capacity), rows: 0, total_rows: 0, shards: Vec::new(), } } fn add(&mut self, input_ids: &[u16], attention_mask: &[u8], labels: &[i16]) -> Result<()> { if input_ids.len() != self.max_length || attention_mask.len() != self.max_length || labels.len() != self.max_length { bail!("encoded sample has wrong shape"); } self.input_ids.extend_from_slice(input_ids); self.attention_mask.extend_from_slice(attention_mask); self.labels.extend_from_slice(labels); self.rows += 1; self.total_rows += 1; if self.rows >= self.shard_size { self.flush()?; } Ok(()) } fn flush(&mut self) -> Result<()> { if self.rows == 0 { return Ok(()); } let base = format!("part-w{:03}-s{:06}", self.worker_id, self.shard_seq); let input_name = format!("{base}.input_ids.npy"); let mask_name = format!("{base}.attention_mask.npy"); let label_name = format!("{base}.labels.npy"); write_npy_u16( &self.output_dir.join(&input_name), &self.input_ids, self.rows, self.max_length, )?; write_npy_u8( &self.output_dir.join(&mask_name), &self.attention_mask, self.rows, self.max_length, )?; write_npy_i16( &self.output_dir.join(&label_name), &self.labels, self.rows, self.max_length, )?; self.shards.push(ShardManifest { rows: self.rows, input_ids: input_name, attention_mask: mask_name, labels: label_name, }); self.input_ids.clear(); self.attention_mask.clear(); self.labels.clear(); self.rows = 0; self.shard_seq += 1; Ok(()) } } fn main() -> Result<()> { let args = Args::parse(); if args.max_length < 4 { bail!("--max-length must be at least 4"); } if args.shard_size == 0 { bail!("--shard-size must be positive"); } if args.threads > 0 { rayon::ThreadPoolBuilder::new() .num_threads(args.threads) .build_global() .context("failed to configure rayon thread pool")?; } let started = Instant::now(); let vocab = load_vocab(&args.vocab_file)?; let brackets = args .bracket_styles .iter() .map(|style| Bracket::from_name(style)) .collect::>>()?; let separators = args .separators .iter() .map(|sep| normalize_separator_arg(sep)) .collect::>(); let cfg = GenConfig { max_length: args.max_length, shard_size: args.shard_size, separator_mode: args.separator_mode, bracket_mode: args.bracket_mode, separators, brackets, include_original: args.include_original, samples_per_source: args.samples_per_source, seed: args.seed, }; let mut samples = load_samples(&args.input, args.limit_rows)?; let source_rows = samples.len(); let mut rng = StdRng::seed_from_u64(args.seed); samples.shuffle(&mut rng); if args.dry_run { let generated: u128 = samples .par_iter() .map(|sample| count_variants(sample, &cfg)) .sum(); let special_fixtures = if args.include_special_fixtures { count_special_fixtures(&cfg) as u128 } else { 0 }; let manifest = json!({ "format": "anifilebert.virtual_dataset.preview.v1", "input": args.input, "vocab_file": args.vocab_file, "source_rows": source_rows, "estimated_rows": generated + special_fixtures, "source_variant_rows": generated, "special_fixture_rows": special_fixtures, "max_length": cfg.max_length, "separator_mode": cfg.separator_mode, "bracket_mode": cfg.bracket_mode, "separators": cfg.separators, "brackets": cfg.brackets.iter().map(|b| &b.name).collect::>(), "include_original": cfg.include_original, "samples_per_source": cfg.samples_per_source, "include_special_fixtures": args.include_special_fixtures, "seed": args.seed, "elapsed_seconds": started.elapsed().as_secs_f64(), }); println!("{}", serde_json::to_string_pretty(&manifest)?); return Ok(()); } fs::create_dir_all(&args.output_dir).with_context(|| { format!( "failed to create output directory {}", args.output_dir.display() ) })?; let chunk_count = rayon::current_num_threads().max(1) * 4; let chunk_size = samples.len().div_ceil(chunk_count).max(1); let chunks = samples .chunks(chunk_size) .enumerate() .collect::>(); let mut worker_results = chunks .par_iter() .map(|(chunk_idx, chunk)| { let mut writer = ShardWriter::new(&args.output_dir, *chunk_idx, cfg.shard_size, cfg.max_length); for sample in *chunk { generate_for_sample(sample, &cfg, &vocab, &mut writer)?; } writer.flush()?; Ok::<_, anyhow::Error>((writer.total_rows, writer.shards)) }) .collect::>>()?; let mut total_rows: u64 = 0; let mut shards: Vec = Vec::new(); for (rows, mut worker_shards) in worker_results.drain(..) { total_rows += rows; shards.append(&mut worker_shards); } let special_rows = if args.include_special_fixtures { let mut writer = ShardWriter::new( &args.output_dir, chunk_count + 1, cfg.shard_size, cfg.max_length, ); for special in built_in_specials() { let parts = vec![PartChoice { entity: Entity::Special, value: special, }]; emit_syntax_variants(&parts, &cfg, &vocab, &mut writer)?; } writer.flush()?; total_rows += writer.total_rows; shards.append(&mut writer.shards); writer.total_rows } else { 0 }; shards.sort_by(|a, b| a.input_ids.cmp(&b.input_ids)); let manifest = json!({ "format": "anifilebert.virtual_dataset.shards.v1", "input": args.input, "vocab_file": args.vocab_file, "source_rows": source_rows, "total_rows": total_rows, "special_fixture_rows": special_rows, "max_length": cfg.max_length, "shard_size": cfg.shard_size, "tokenizer_variant": "char", "encoding": { "input_ids_dtype": "uint16", "attention_mask_dtype": "uint8", "labels_dtype": "int16", "layout": "row_major_npy" }, "special_tokens": { "pad_id": vocab.pad_id, "unk_id": vocab.unk_id, "cls_id": vocab.cls_id, "sep_id": vocab.sep_id }, "generation": { "separator_mode": cfg.separator_mode, "bracket_mode": cfg.bracket_mode, "separators": cfg.separators, "brackets": cfg.brackets.iter().map(|b| &b.name).collect::>(), "include_original": cfg.include_original, "samples_per_source": cfg.samples_per_source, "include_special_fixtures": args.include_special_fixtures, "seed": args.seed, "threads": rayon::current_num_threads() }, "shards": shards, "elapsed_seconds": started.elapsed().as_secs_f64(), }); let manifest_path = args.output_dir.join("manifest.json"); fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?) .with_context(|| format!("failed to write {}", manifest_path.display()))?; println!("{}", serde_json::to_string_pretty(&manifest)?); Ok(()) } fn normalize_separator_arg(value: &str) -> String { match value { "\\t" => "\t".to_string(), "\\s" => " ".to_string(), other => other.to_string(), } } fn load_vocab(path: &Path) -> Result { let text = fs::read_to_string(path) .with_context(|| format!("failed to read vocab file {}", path.display()))?; let raw: HashMap = serde_json::from_str(&text).context("failed to parse vocab JSON")?; let mut ids = HashMap::with_capacity(raw.len()); for (token, id) in raw { if id > u16::MAX as u64 { bail!("vocab id {id} for token '{token}' exceeds uint16 storage"); } ids.insert(token, id as u16); } let pad_id = *ids.get("[PAD]").context("vocab is missing [PAD]")?; let unk_id = *ids.get("[UNK]").context("vocab is missing [UNK]")?; let cls_id = *ids.get("[CLS]").context("vocab is missing [CLS]")?; let sep_id = *ids.get("[SEP]").context("vocab is missing [SEP]")?; Ok(Vocab { ids, pad_id, unk_id, cls_id, sep_id, }) } fn load_samples(path: &Path, limit_rows: usize) -> Result> { let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?; let reader = BufReader::new(file); let mut samples = Vec::new(); for (idx, line) in reader.lines().enumerate() { if limit_rows > 0 && samples.len() >= limit_rows { break; } let line = line.with_context(|| format!("failed reading line {}", idx + 1))?; if line.trim().is_empty() { continue; } let row: InputRow = serde_json::from_str(&line) .with_context(|| format!("failed to parse JSONL line {}", idx + 1))?; if let Some(variant) = row.tokenizer_variant.as_deref() { if variant != "char" { bail!( "line {} has tokenizer_variant={variant}; virtual shard generation currently requires char data", idx + 1 ); } } if row.tokens.len() != row.labels.len() { bail!( "line {} has mismatched token/label lengths: {} vs {}", idx + 1, row.tokens.len(), row.labels.len() ); } let filename = row.filename.clone().unwrap_or_else(|| row.tokens.join("")); let fields = extract_fields(&row.tokens, &row.labels); samples.push(SourceSample { row_index: idx, filename, tokens: row.tokens, labels: row.labels, fields, }); } Ok(samples) } fn extract_fields(tokens: &[String], labels: &[String]) -> Vec> { let mut fields: Vec> = (0..ENTITIES.len()).map(|_| Vec::new()).collect(); let mut seen: Vec> = (0..ENTITIES.len()).map(|_| HashSet::new()).collect(); let mut active_entity: Option = None; let mut active_text = String::new(); let flush = |entity: Option, text: &mut String, fields: &mut Vec>, seen: &mut Vec>| { if let Some(entity) = entity { let value = text.trim().to_string(); if !value.is_empty() && seen[entity.index()].insert(value.clone()) { fields[entity.index()].push(value); } } text.clear(); }; for (token, label) in tokens.iter().zip(labels.iter()) { if let Some(entity) = label.strip_prefix("B-").and_then(Entity::from_name) { flush(active_entity, &mut active_text, &mut fields, &mut seen); active_entity = Some(entity); active_text.push_str(token); } else if let Some(entity) = label.strip_prefix("I-").and_then(Entity::from_name) { if active_entity == Some(entity) { active_text.push_str(token); } else { flush(active_entity, &mut active_text, &mut fields, &mut seen); active_entity = Some(entity); active_text.push_str(token); } } else { flush(active_entity, &mut active_text, &mut fields, &mut seen); active_entity = None; } } flush(active_entity, &mut active_text, &mut fields, &mut seen); fields } fn count_variants(sample: &SourceSample, cfg: &GenConfig) -> u128 { let mut count = if cfg.include_original { 1 } else { 0 }; let available = ENTITIES .iter() .copied() .filter(|entity| !sample.fields[entity.index()].is_empty()) .collect::>(); let n = available.len(); if n == 0 { return count; } if cfg.samples_per_source > 0 { return count + cfg.samples_per_source as u128; } for mask in 1usize..(1usize << n) { let selected = available .iter() .enumerate() .filter_map(|(idx, entity)| ((mask & (1usize << idx)) != 0).then_some(*entity)) .collect::>(); let m = selected.len(); let value_product: u128 = selected .iter() .map(|entity| sample.fields[entity.index()].len() as u128) .product(); let perm_count = factorial(m as u32); let sep_factor = if m <= 1 { 1 } else { match cfg.separator_mode { SeparatorMode::Global => cfg.separators.len() as u128, SeparatorMode::PerGap => (cfg.separators.len() as u128).pow((m - 1) as u32), } }; let bracket_factor = match cfg.bracket_mode { BracketMode::Global => cfg.brackets.len() as u128, BracketMode::PerPart => (cfg.brackets.len() as u128).pow(m as u32), }; count += value_product * perm_count * sep_factor * bracket_factor; } count } fn count_special_fixtures(cfg: &GenConfig) -> usize { let bracket_factor = match cfg.bracket_mode { BracketMode::Global => cfg.brackets.len(), BracketMode::PerPart => cfg.brackets.len(), }; built_in_specials().len() * bracket_factor } fn factorial(n: u32) -> u128 { (1..=n as u128).product::().max(1) } fn generate_for_sample( sample: &SourceSample, cfg: &GenConfig, vocab: &Vocab, writer: &mut ShardWriter, ) -> Result<()> { if cfg.include_original { let (input_ids, attention_mask, labels) = encode_original_sample(sample, vocab, cfg.max_length)?; writer.add(&input_ids, &attention_mask, &labels)?; } if cfg.samples_per_source > 0 { generate_sampled_variants(sample, cfg, vocab, writer)?; return Ok(()); } let available = ENTITIES .iter() .copied() .filter(|entity| !sample.fields[entity.index()].is_empty()) .collect::>(); let n = available.len(); for mask in 1usize..(1usize << n) { let mut selected = available .iter() .enumerate() .filter_map(|(idx, entity)| ((mask & (1usize << idx)) != 0).then_some(*entity)) .collect::>(); permute_entities(&mut selected, 0, &mut |order| { let mut parts: Vec = Vec::with_capacity(order.len()); for_each_value_combo(order, &sample.fields, 0, &mut parts, &mut |combo| { emit_syntax_variants(combo, cfg, vocab, writer) }) })?; } Ok(()) } fn generate_sampled_variants( sample: &SourceSample, cfg: &GenConfig, vocab: &Vocab, writer: &mut ShardWriter, ) -> Result<()> { let mut rng = StdRng::seed_from_u64( cfg.seed ^ ((sample.row_index as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15)), ); let available = ENTITIES .iter() .copied() .filter(|entity| !sample.fields[entity.index()].is_empty()) .collect::>(); if available.is_empty() { return Ok(()); } let mut seen = HashSet::new(); let mut emitted = 0usize; let budget = cfg.samples_per_source; let max_unique_attempts = budget.saturating_mul(32).max(64); let mut attempts = 0usize; let mut templates: Vec> = Vec::new(); if let Some(title) = sample.fields[Entity::Title.index()].first() { templates.push(vec![PartChoice { entity: Entity::Title, value: title.clone(), }]); if let Some(season) = sample.fields[Entity::Season.index()].first() { templates.push(vec![ PartChoice { entity: Entity::Title, value: title.clone(), }, PartChoice { entity: Entity::Season, value: season.clone(), }, ]); } } if let Some(episode) = sample.fields[Entity::Episode.index()].first() { templates.push(vec![PartChoice { entity: Entity::Episode, value: episode.clone(), }]); } if let Some(special) = sample.fields[Entity::Special.index()].first() { templates.push(vec![PartChoice { entity: Entity::Special, value: special.clone(), }]); } if let (Some(title), Some(special)) = ( sample.fields[Entity::Title.index()].first(), sample.fields[Entity::Special.index()].first(), ) { templates.push(vec![ PartChoice { entity: Entity::Title, value: title.clone(), }, PartChoice { entity: Entity::Special, value: special.clone(), }, ]); } for parts in templates { if emitted >= budget { break; } emit_sample_variant( parts, cfg, vocab, writer, &mut seen, &mut emitted, &mut rng, false, )?; } while emitted < budget && attempts < max_unique_attempts { attempts += 1; let subset_size = match rng.gen_range(0..100) { 0..=29 => 1, 30..=54 => 2, 55..=74 => 3, 75..=89 => 4.min(available.len()), _ => available.len().min(5), } .max(1) .min(available.len()); let mut chosen = available .choose_multiple(&mut rng, subset_size) .copied() .collect::>(); chosen.shuffle(&mut rng); if !chosen .iter() .any(|entity| matches!(entity, Entity::Title | Entity::Episode | Entity::Special)) { if let Some(fallback) = available .iter() .copied() .find(|entity| matches!(entity, Entity::Title | Entity::Episode | Entity::Special)) { if !chosen.contains(&fallback) { chosen.push(fallback); } } } let mut parts = Vec::with_capacity(chosen.len()); for entity in chosen { let values = &sample.fields[entity.index()]; let value = values.choose(&mut rng).cloned().unwrap_or_default(); parts.push(PartChoice { entity, value }); } parts.shuffle(&mut rng); emit_sample_variant( parts, cfg, vocab, writer, &mut seen, &mut emitted, &mut rng, false, )?; } while emitted < budget { let subset_size = match rng.gen_range(0..100) { 0..=29 => 1, 30..=54 => 2, 55..=74 => 3, 75..=89 => 4.min(available.len()), _ => available.len().min(5), } .max(1) .min(available.len()); let mut chosen = available .choose_multiple(&mut rng, subset_size) .copied() .collect::>(); chosen.shuffle(&mut rng); if !chosen .iter() .any(|entity| matches!(entity, Entity::Title | Entity::Episode | Entity::Special)) { if let Some(fallback) = available .iter() .copied() .find(|entity| matches!(entity, Entity::Title | Entity::Episode | Entity::Special)) { if !chosen.contains(&fallback) { chosen.push(fallback); } } } let mut parts = Vec::with_capacity(chosen.len()); for entity in chosen { let values = &sample.fields[entity.index()]; let value = values.choose(&mut rng).cloned().unwrap_or_default(); parts.push(PartChoice { entity, value }); } parts.shuffle(&mut rng); emit_sample_variant( parts, cfg, vocab, writer, &mut seen, &mut emitted, &mut rng, true, )?; } Ok(()) } fn emit_sample_variant( parts: Vec, cfg: &GenConfig, vocab: &Vocab, writer: &mut ShardWriter, seen: &mut HashSet, emitted: &mut usize, rng: &mut StdRng, allow_duplicate: bool, ) -> Result<()> { if *emitted >= cfg.samples_per_source { return Ok(()); } if parts.is_empty() { return Ok(()); } let separators = match cfg.separator_mode { SeparatorMode::Global => { let sep = cfg .separators .choose(rng) .cloned() .unwrap_or_else(|| " ".to_string()); if parts.len() > 1 { vec![sep; parts.len() - 1] } else { Vec::new() } } SeparatorMode::PerGap => { let mut values = Vec::with_capacity(parts.len().saturating_sub(1)); for _ in 0..parts.len().saturating_sub(1) { values.push( cfg.separators .choose(rng) .cloned() .unwrap_or_else(|| " ".to_string()), ); } values } }; let brackets = match cfg.bracket_mode { BracketMode::Global => { let bracket = cfg .brackets .choose(rng) .cloned() .unwrap_or_else(|| Bracket { name: "none".to_string(), open: String::new(), close: String::new(), }); vec![bracket; parts.len()] } BracketMode::PerPart => { let mut values = Vec::with_capacity(parts.len()); for _ in 0..parts.len() { values.push( cfg.brackets .choose(rng) .cloned() .unwrap_or_else(|| Bracket { name: "none".to_string(), open: String::new(), close: String::new(), }), ); } values } }; let text = render_variant_text(&parts, &separators, &brackets); if !allow_duplicate && !seen.insert(text) { return Ok(()); } let (input_ids, attention_mask, labels) = encode_generated_sample(&parts, &separators, &brackets, vocab, cfg.max_length)?; writer.add(&input_ids, &attention_mask, &labels)?; *emitted += 1; Ok(()) } fn permute_entities(values: &mut [Entity], start: usize, callback: &mut F) -> Result<()> where F: FnMut(&[Entity]) -> Result<()>, { if start >= values.len() { return callback(values); } for idx in start..values.len() { values.swap(start, idx); permute_entities(values, start + 1, callback)?; values.swap(start, idx); } Ok(()) } #[derive(Clone)] struct PartChoice { entity: Entity, value: String, } fn for_each_value_combo( order: &[Entity], fields: &[Vec], idx: usize, current: &mut Vec, callback: &mut F, ) -> Result<()> where F: FnMut(&[PartChoice]) -> Result<()>, { if idx >= order.len() { return callback(current); } let entity = order[idx]; for value in &fields[entity.index()] { current.push(PartChoice { entity, value: value.clone(), }); for_each_value_combo(order, fields, idx + 1, current, callback)?; current.pop(); } Ok(()) } fn emit_syntax_variants( parts: &[PartChoice], cfg: &GenConfig, vocab: &Vocab, writer: &mut ShardWriter, ) -> Result<()> { let gaps = parts.len().saturating_sub(1); let mut separators = Vec::with_capacity(gaps); for_each_separator_combo(gaps, cfg, 0, &mut separators, &mut |sep_combo| { let mut brackets = Vec::with_capacity(parts.len()); for_each_bracket_combo(parts.len(), cfg, 0, &mut brackets, &mut |bracket_combo| { let (input_ids, attention_mask, labels) = encode_generated_sample(parts, sep_combo, bracket_combo, vocab, cfg.max_length)?; writer.add(&input_ids, &attention_mask, &labels) }) }) } fn for_each_separator_combo( gaps: usize, cfg: &GenConfig, idx: usize, current: &mut Vec, callback: &mut F, ) -> Result<()> where F: FnMut(&[String]) -> Result<()>, { if gaps == 0 { return callback(current); } match cfg.separator_mode { SeparatorMode::Global => { if idx == 0 { for sep in &cfg.separators { current.clear(); current.resize(gaps, sep.clone()); callback(current)?; } } Ok(()) } SeparatorMode::PerGap => { if idx >= gaps { return callback(current); } for sep in &cfg.separators { current.push(sep.clone()); for_each_separator_combo(gaps, cfg, idx + 1, current, callback)?; current.pop(); } Ok(()) } } } fn for_each_bracket_combo( parts: usize, cfg: &GenConfig, idx: usize, current: &mut Vec, callback: &mut F, ) -> Result<()> where F: FnMut(&[Bracket]) -> Result<()>, { match cfg.bracket_mode { BracketMode::Global => { if idx == 0 { for bracket in &cfg.brackets { current.clear(); current.resize(parts, bracket.clone()); callback(current)?; } } Ok(()) } BracketMode::PerPart => { if idx >= parts { return callback(current); } for bracket in &cfg.brackets { current.push(bracket.clone()); for_each_bracket_combo(parts, cfg, idx + 1, current, callback)?; current.pop(); } Ok(()) } } } fn render_variant_text( parts: &[PartChoice], separators: &[String], brackets: &[Bracket], ) -> String { let mut text = String::new(); for (idx, part) in parts.iter().enumerate() { text.push_str(&brackets[idx].open); text.push_str(&part.value); text.push_str(&brackets[idx].close); if idx < separators.len() { text.push_str(&separators[idx]); } } text } fn encode_original_sample( sample: &SourceSample, vocab: &Vocab, max_length: usize, ) -> Result<(Vec, Vec, Vec)> { let mut input_ids = vec![vocab.pad_id; max_length]; let mut attention_mask = vec![0u8; max_length]; let mut labels = vec![-100i16; max_length]; input_ids[0] = vocab.cls_id; attention_mask[0] = 1; let available = max_length.saturating_sub(2); let token_count = sample.tokens.len().min(available); for idx in 0..token_count { input_ids[idx + 1] = token_id(vocab, &sample.tokens[idx]); attention_mask[idx + 1] = 1; labels[idx + 1] = label_id(&sample.labels[idx]).with_context(|| { format!( "unknown label '{}' on source row {} ({})", sample.labels[idx], sample.row_index + 1, sample.filename ) })?; } let sep_pos = token_count + 1; input_ids[sep_pos] = vocab.sep_id; attention_mask[sep_pos] = 1; Ok((input_ids, attention_mask, labels)) } fn encode_generated_sample( parts: &[PartChoice], separators: &[String], brackets: &[Bracket], vocab: &Vocab, max_length: usize, ) -> Result<(Vec, Vec, Vec)> { let mut input_ids = vec![vocab.pad_id; max_length]; let mut attention_mask = vec![0u8; max_length]; let mut labels = vec![-100i16; max_length]; input_ids[0] = vocab.cls_id; attention_mask[0] = 1; let available = max_length.saturating_sub(2); let mut pos = 1usize; for (idx, part) in parts.iter().enumerate() { let bracket = &brackets[idx]; append_o_text( &bracket.open, vocab, available, &mut pos, &mut input_ids, &mut attention_mask, &mut labels, ); append_entity_text( &part.value, part.entity, vocab, available, &mut pos, &mut input_ids, &mut attention_mask, &mut labels, )?; append_o_text( &bracket.close, vocab, available, &mut pos, &mut input_ids, &mut attention_mask, &mut labels, ); if idx < separators.len() { append_o_text( &separators[idx], vocab, available, &mut pos, &mut input_ids, &mut attention_mask, &mut labels, ); } } let sep_pos = pos.min(max_length - 1); input_ids[sep_pos] = vocab.sep_id; attention_mask[sep_pos] = 1; labels[sep_pos] = -100; Ok((input_ids, attention_mask, labels)) } fn append_o_text( text: &str, vocab: &Vocab, available: usize, pos: &mut usize, input_ids: &mut [u16], attention_mask: &mut [u8], labels: &mut [i16], ) { for ch in text.chars() { if *pos > available { return; } let token = ch.to_string(); input_ids[*pos] = token_id(vocab, &token); attention_mask[*pos] = 1; labels[*pos] = 0; *pos += 1; } } fn append_entity_text( text: &str, entity: Entity, vocab: &Vocab, available: usize, pos: &mut usize, input_ids: &mut [u16], attention_mask: &mut [u8], labels: &mut [i16], ) -> Result<()> { let b = label_id(entity.b_label()).context("missing B label")?; let i = label_id(entity.i_label()).context("missing I label")?; let mut first = true; for ch in text.chars() { if *pos > available { return Ok(()); } let token = ch.to_string(); input_ids[*pos] = token_id(vocab, &token); attention_mask[*pos] = 1; labels[*pos] = if first { b } else { i }; first = false; *pos += 1; } Ok(()) } fn token_id(vocab: &Vocab, token: &str) -> u16 { *vocab.ids.get(token).unwrap_or(&vocab.unk_id) } fn label_id(label: &str) -> Option { Some(match label { "O" => 0, "B-TITLE" => 1, "I-TITLE" => 2, "B-SEASON" => 3, "I-SEASON" => 4, "B-EPISODE" => 5, "I-EPISODE" => 6, "B-SPECIAL" => 7, "I-SPECIAL" => 8, "B-GROUP" => 9, "I-GROUP" => 10, "B-RESOLUTION" => 11, "I-RESOLUTION" => 12, "B-SOURCE" => 13, "I-SOURCE" => 14, _ => return None, }) } fn built_in_specials() -> Vec { let mut values = Vec::new(); values.push("Menu".to_string()); for idx in 1..=24 { values.push(format!("Menu{idx:02}")); values.push(format!("Menu {idx:02}")); values.push(format!("BDMenu{idx:02}")); values.push(format!("BD Menu{idx:02}")); values.push(format!("Menu{idx:02}-01")); values.push(format!("ED E{idx:02}")); } for idx in 1..=6 { values.push(format!("OP{idx:02}")); values.push(format!("NCOP{idx:02}")); values.push(format!("NCED{idx:02}")); } for idx in 1..=12 { values.push(format!("CM{idx:02}")); values.push(format!("PV{idx:02}")); } values } fn write_npy_u16(path: &Path, data: &[u16], rows: usize, cols: usize) -> Result<()> { let mut writer = BufWriter::new( File::create(path).with_context(|| format!("failed to create {}", path.display()))?, ); write_npy_header(&mut writer, " Result<()> { let mut writer = BufWriter::new( File::create(path).with_context(|| format!("failed to create {}", path.display()))?, ); write_npy_header(&mut writer, "|u1", rows, cols)?; writer.write_all(data)?; Ok(()) } fn write_npy_i16(path: &Path, data: &[i16], rows: usize, cols: usize) -> Result<()> { let mut writer = BufWriter::new( File::create(path).with_context(|| format!("failed to create {}", path.display()))?, ); write_npy_header(&mut writer, "(writer: &mut W, descr: &str, rows: usize, cols: usize) -> Result<()> { let mut header = format!( "{{'descr': '{}', 'fortran_order': False, 'shape': ({}, {}), }}", descr, rows, cols ) .into_bytes(); let preamble_len = 10usize; let pad_len = (16 - ((preamble_len + header.len() + 1) % 16)) % 16; header.extend(std::iter::repeat(b' ').take(pad_len)); header.push(b'\n'); if header.len() > u16::MAX as usize { bail!("npy header too large"); } writer.write_all(b"\x93NUMPY")?; writer.write_all(&[1, 0])?; writer.write_all(&(header.len() as u16).to_le_bytes())?; writer.write_all(&header)?; Ok(()) }