ModerRAS's picture
Train virtual-shard anime parser
359ff82
use anyhow::{bail, Context, Result};
use clap::{Parser, ValueEnum};
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use rand::Rng;
use rand::SeedableRng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::{HashMap, HashSet};
use std::fs::{self, File};
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::{Path, PathBuf};
use std::time::Instant;
const ENTITIES: [Entity; 7] = [
Entity::Group,
Entity::Title,
Entity::Season,
Entity::Episode,
Entity::Special,
Entity::Resolution,
Entity::Source,
];
#[derive(Parser, Debug)]
#[command(
about = "Generate pre-encoded AniFileBERT virtual BIO permutation shards",
version
)]
struct Args {
#[arg(long)]
input: PathBuf,
#[arg(long)]
vocab_file: PathBuf,
#[arg(long)]
output_dir: PathBuf,
#[arg(long, default_value_t = 128)]
max_length: usize,
#[arg(long, default_value_t = 25_000)]
shard_size: usize,
#[arg(long, default_value_t = 0)]
limit_rows: usize,
#[arg(long, default_value_t = 0)]
samples_per_source: usize,
#[arg(long, default_value_t = 42)]
seed: u64,
#[arg(long, default_value_t = 0)]
threads: usize,
#[arg(long, default_value = "global")]
separator_mode: SeparatorMode,
#[arg(long, default_value = "global")]
bracket_mode: BracketMode,
#[arg(long, value_delimiter = ',', default_value = " , - ,.,_,-,~,~")]
separators: Vec<String>,
#[arg(
long,
value_delimiter = ',',
default_value = "none,square,round,corner,angle"
)]
bracket_styles: Vec<String>,
#[arg(long, default_value_t = true)]
include_original: bool,
#[arg(long, default_value_t = true)]
include_special_fixtures: bool,
#[arg(long, help = "Only count rows; do not write shard files")]
dry_run: bool,
}
#[derive(Clone, Copy, Debug, Serialize, ValueEnum)]
enum SeparatorMode {
Global,
PerGap,
}
#[derive(Clone, Copy, Debug, Serialize, ValueEnum)]
enum BracketMode {
Global,
PerPart,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd, Serialize)]
enum Entity {
Group,
Title,
Season,
Episode,
Special,
Resolution,
Source,
}
impl Entity {
fn index(self) -> usize {
match self {
Entity::Group => 0,
Entity::Title => 1,
Entity::Season => 2,
Entity::Episode => 3,
Entity::Special => 4,
Entity::Resolution => 5,
Entity::Source => 6,
}
}
fn from_name(name: &str) -> Option<Self> {
match name {
"GROUP" => Some(Entity::Group),
"TITLE" => Some(Entity::Title),
"SEASON" => Some(Entity::Season),
"EPISODE" => Some(Entity::Episode),
"SPECIAL" => Some(Entity::Special),
"RESOLUTION" => Some(Entity::Resolution),
"SOURCE" => Some(Entity::Source),
_ => None,
}
}
fn b_label(self) -> &'static str {
match self {
Entity::Group => "B-GROUP",
Entity::Title => "B-TITLE",
Entity::Season => "B-SEASON",
Entity::Episode => "B-EPISODE",
Entity::Special => "B-SPECIAL",
Entity::Resolution => "B-RESOLUTION",
Entity::Source => "B-SOURCE",
}
}
fn i_label(self) -> &'static str {
match self {
Entity::Group => "I-GROUP",
Entity::Title => "I-TITLE",
Entity::Season => "I-SEASON",
Entity::Episode => "I-EPISODE",
Entity::Special => "I-SPECIAL",
Entity::Resolution => "I-RESOLUTION",
Entity::Source => "I-SOURCE",
}
}
}
#[derive(Clone, Debug)]
struct Bracket {
name: String,
open: String,
close: String,
}
impl Bracket {
fn from_name(name: &str) -> Result<Self> {
let trimmed = name.trim();
let pair = match trimmed {
"none" => ("", ""),
"square" => ("[", "]"),
"round" => ("(", ")"),
"corner" => ("【", "】"),
"angle" => ("《", "》"),
custom if custom.contains('|') => {
let mut parts = custom.splitn(2, '|');
let open = parts.next().unwrap_or_default();
let close = parts.next().unwrap_or_default();
return Ok(Self {
name: custom.to_string(),
open: open.to_string(),
close: close.to_string(),
});
}
other => bail!("unknown bracket style '{other}'"),
};
Ok(Self {
name: trimmed.to_string(),
open: pair.0.to_string(),
close: pair.1.to_string(),
})
}
}
#[derive(Deserialize)]
struct InputRow {
filename: Option<String>,
tokens: Vec<String>,
labels: Vec<String>,
tokenizer_variant: Option<String>,
}
#[derive(Clone)]
struct SourceSample {
row_index: usize,
filename: String,
tokens: Vec<String>,
labels: Vec<String>,
fields: Vec<Vec<String>>,
}
#[derive(Clone)]
struct GenConfig {
max_length: usize,
shard_size: usize,
separator_mode: SeparatorMode,
bracket_mode: BracketMode,
separators: Vec<String>,
brackets: Vec<Bracket>,
include_original: bool,
samples_per_source: usize,
seed: u64,
}
#[derive(Clone)]
struct Vocab {
ids: HashMap<String, u16>,
pad_id: u16,
unk_id: u16,
cls_id: u16,
sep_id: u16,
}
#[derive(Serialize)]
struct ShardManifest {
rows: usize,
input_ids: String,
attention_mask: String,
labels: String,
}
struct ShardWriter {
output_dir: PathBuf,
worker_id: usize,
shard_seq: usize,
shard_size: usize,
max_length: usize,
input_ids: Vec<u16>,
attention_mask: Vec<u8>,
labels: Vec<i16>,
rows: usize,
total_rows: u64,
shards: Vec<ShardManifest>,
}
impl ShardWriter {
fn new(output_dir: &Path, worker_id: usize, shard_size: usize, max_length: usize) -> Self {
let capacity = shard_size.saturating_mul(max_length);
Self {
output_dir: output_dir.to_path_buf(),
worker_id,
shard_seq: 0,
shard_size,
max_length,
input_ids: Vec::with_capacity(capacity),
attention_mask: Vec::with_capacity(capacity),
labels: Vec::with_capacity(capacity),
rows: 0,
total_rows: 0,
shards: Vec::new(),
}
}
fn add(&mut self, input_ids: &[u16], attention_mask: &[u8], labels: &[i16]) -> Result<()> {
if input_ids.len() != self.max_length
|| attention_mask.len() != self.max_length
|| labels.len() != self.max_length
{
bail!("encoded sample has wrong shape");
}
self.input_ids.extend_from_slice(input_ids);
self.attention_mask.extend_from_slice(attention_mask);
self.labels.extend_from_slice(labels);
self.rows += 1;
self.total_rows += 1;
if self.rows >= self.shard_size {
self.flush()?;
}
Ok(())
}
fn flush(&mut self) -> Result<()> {
if self.rows == 0 {
return Ok(());
}
let base = format!("part-w{:03}-s{:06}", self.worker_id, self.shard_seq);
let input_name = format!("{base}.input_ids.npy");
let mask_name = format!("{base}.attention_mask.npy");
let label_name = format!("{base}.labels.npy");
write_npy_u16(
&self.output_dir.join(&input_name),
&self.input_ids,
self.rows,
self.max_length,
)?;
write_npy_u8(
&self.output_dir.join(&mask_name),
&self.attention_mask,
self.rows,
self.max_length,
)?;
write_npy_i16(
&self.output_dir.join(&label_name),
&self.labels,
self.rows,
self.max_length,
)?;
self.shards.push(ShardManifest {
rows: self.rows,
input_ids: input_name,
attention_mask: mask_name,
labels: label_name,
});
self.input_ids.clear();
self.attention_mask.clear();
self.labels.clear();
self.rows = 0;
self.shard_seq += 1;
Ok(())
}
}
fn main() -> Result<()> {
let args = Args::parse();
if args.max_length < 4 {
bail!("--max-length must be at least 4");
}
if args.shard_size == 0 {
bail!("--shard-size must be positive");
}
if args.threads > 0 {
rayon::ThreadPoolBuilder::new()
.num_threads(args.threads)
.build_global()
.context("failed to configure rayon thread pool")?;
}
let started = Instant::now();
let vocab = load_vocab(&args.vocab_file)?;
let brackets = args
.bracket_styles
.iter()
.map(|style| Bracket::from_name(style))
.collect::<Result<Vec<_>>>()?;
let separators = args
.separators
.iter()
.map(|sep| normalize_separator_arg(sep))
.collect::<Vec<_>>();
let cfg = GenConfig {
max_length: args.max_length,
shard_size: args.shard_size,
separator_mode: args.separator_mode,
bracket_mode: args.bracket_mode,
separators,
brackets,
include_original: args.include_original,
samples_per_source: args.samples_per_source,
seed: args.seed,
};
let mut samples = load_samples(&args.input, args.limit_rows)?;
let source_rows = samples.len();
let mut rng = StdRng::seed_from_u64(args.seed);
samples.shuffle(&mut rng);
if args.dry_run {
let generated: u128 = samples
.par_iter()
.map(|sample| count_variants(sample, &cfg))
.sum();
let special_fixtures = if args.include_special_fixtures {
count_special_fixtures(&cfg) as u128
} else {
0
};
let manifest = json!({
"format": "anifilebert.virtual_dataset.preview.v1",
"input": args.input,
"vocab_file": args.vocab_file,
"source_rows": source_rows,
"estimated_rows": generated + special_fixtures,
"source_variant_rows": generated,
"special_fixture_rows": special_fixtures,
"max_length": cfg.max_length,
"separator_mode": cfg.separator_mode,
"bracket_mode": cfg.bracket_mode,
"separators": cfg.separators,
"brackets": cfg.brackets.iter().map(|b| &b.name).collect::<Vec<_>>(),
"include_original": cfg.include_original,
"samples_per_source": cfg.samples_per_source,
"include_special_fixtures": args.include_special_fixtures,
"seed": args.seed,
"elapsed_seconds": started.elapsed().as_secs_f64(),
});
println!("{}", serde_json::to_string_pretty(&manifest)?);
return Ok(());
}
fs::create_dir_all(&args.output_dir).with_context(|| {
format!(
"failed to create output directory {}",
args.output_dir.display()
)
})?;
let chunk_count = rayon::current_num_threads().max(1) * 4;
let chunk_size = samples.len().div_ceil(chunk_count).max(1);
let chunks = samples
.chunks(chunk_size)
.enumerate()
.collect::<Vec<(usize, &[SourceSample])>>();
let mut worker_results = chunks
.par_iter()
.map(|(chunk_idx, chunk)| {
let mut writer =
ShardWriter::new(&args.output_dir, *chunk_idx, cfg.shard_size, cfg.max_length);
for sample in *chunk {
generate_for_sample(sample, &cfg, &vocab, &mut writer)?;
}
writer.flush()?;
Ok::<_, anyhow::Error>((writer.total_rows, writer.shards))
})
.collect::<Result<Vec<_>>>()?;
let mut total_rows: u64 = 0;
let mut shards: Vec<ShardManifest> = Vec::new();
for (rows, mut worker_shards) in worker_results.drain(..) {
total_rows += rows;
shards.append(&mut worker_shards);
}
let special_rows = if args.include_special_fixtures {
let mut writer = ShardWriter::new(
&args.output_dir,
chunk_count + 1,
cfg.shard_size,
cfg.max_length,
);
for special in built_in_specials() {
let parts = vec![PartChoice {
entity: Entity::Special,
value: special,
}];
emit_syntax_variants(&parts, &cfg, &vocab, &mut writer)?;
}
writer.flush()?;
total_rows += writer.total_rows;
shards.append(&mut writer.shards);
writer.total_rows
} else {
0
};
shards.sort_by(|a, b| a.input_ids.cmp(&b.input_ids));
let manifest = json!({
"format": "anifilebert.virtual_dataset.shards.v1",
"input": args.input,
"vocab_file": args.vocab_file,
"source_rows": source_rows,
"total_rows": total_rows,
"special_fixture_rows": special_rows,
"max_length": cfg.max_length,
"shard_size": cfg.shard_size,
"tokenizer_variant": "char",
"encoding": {
"input_ids_dtype": "uint16",
"attention_mask_dtype": "uint8",
"labels_dtype": "int16",
"layout": "row_major_npy"
},
"special_tokens": {
"pad_id": vocab.pad_id,
"unk_id": vocab.unk_id,
"cls_id": vocab.cls_id,
"sep_id": vocab.sep_id
},
"generation": {
"separator_mode": cfg.separator_mode,
"bracket_mode": cfg.bracket_mode,
"separators": cfg.separators,
"brackets": cfg.brackets.iter().map(|b| &b.name).collect::<Vec<_>>(),
"include_original": cfg.include_original,
"samples_per_source": cfg.samples_per_source,
"include_special_fixtures": args.include_special_fixtures,
"seed": args.seed,
"threads": rayon::current_num_threads()
},
"shards": shards,
"elapsed_seconds": started.elapsed().as_secs_f64(),
});
let manifest_path = args.output_dir.join("manifest.json");
fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?)
.with_context(|| format!("failed to write {}", manifest_path.display()))?;
println!("{}", serde_json::to_string_pretty(&manifest)?);
Ok(())
}
fn normalize_separator_arg(value: &str) -> String {
match value {
"\\t" => "\t".to_string(),
"\\s" => " ".to_string(),
other => other.to_string(),
}
}
fn load_vocab(path: &Path) -> Result<Vocab> {
let text = fs::read_to_string(path)
.with_context(|| format!("failed to read vocab file {}", path.display()))?;
let raw: HashMap<String, u64> =
serde_json::from_str(&text).context("failed to parse vocab JSON")?;
let mut ids = HashMap::with_capacity(raw.len());
for (token, id) in raw {
if id > u16::MAX as u64 {
bail!("vocab id {id} for token '{token}' exceeds uint16 storage");
}
ids.insert(token, id as u16);
}
let pad_id = *ids.get("[PAD]").context("vocab is missing [PAD]")?;
let unk_id = *ids.get("[UNK]").context("vocab is missing [UNK]")?;
let cls_id = *ids.get("[CLS]").context("vocab is missing [CLS]")?;
let sep_id = *ids.get("[SEP]").context("vocab is missing [SEP]")?;
Ok(Vocab {
ids,
pad_id,
unk_id,
cls_id,
sep_id,
})
}
fn load_samples(path: &Path, limit_rows: usize) -> Result<Vec<SourceSample>> {
let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?;
let reader = BufReader::new(file);
let mut samples = Vec::new();
for (idx, line) in reader.lines().enumerate() {
if limit_rows > 0 && samples.len() >= limit_rows {
break;
}
let line = line.with_context(|| format!("failed reading line {}", idx + 1))?;
if line.trim().is_empty() {
continue;
}
let row: InputRow = serde_json::from_str(&line)
.with_context(|| format!("failed to parse JSONL line {}", idx + 1))?;
if let Some(variant) = row.tokenizer_variant.as_deref() {
if variant != "char" {
bail!(
"line {} has tokenizer_variant={variant}; virtual shard generation currently requires char data",
idx + 1
);
}
}
if row.tokens.len() != row.labels.len() {
bail!(
"line {} has mismatched token/label lengths: {} vs {}",
idx + 1,
row.tokens.len(),
row.labels.len()
);
}
let filename = row.filename.clone().unwrap_or_else(|| row.tokens.join(""));
let fields = extract_fields(&row.tokens, &row.labels);
samples.push(SourceSample {
row_index: idx,
filename,
tokens: row.tokens,
labels: row.labels,
fields,
});
}
Ok(samples)
}
fn extract_fields(tokens: &[String], labels: &[String]) -> Vec<Vec<String>> {
let mut fields: Vec<Vec<String>> = (0..ENTITIES.len()).map(|_| Vec::new()).collect();
let mut seen: Vec<HashSet<String>> = (0..ENTITIES.len()).map(|_| HashSet::new()).collect();
let mut active_entity: Option<Entity> = None;
let mut active_text = String::new();
let flush = |entity: Option<Entity>,
text: &mut String,
fields: &mut Vec<Vec<String>>,
seen: &mut Vec<HashSet<String>>| {
if let Some(entity) = entity {
let value = text.trim().to_string();
if !value.is_empty() && seen[entity.index()].insert(value.clone()) {
fields[entity.index()].push(value);
}
}
text.clear();
};
for (token, label) in tokens.iter().zip(labels.iter()) {
if let Some(entity) = label.strip_prefix("B-").and_then(Entity::from_name) {
flush(active_entity, &mut active_text, &mut fields, &mut seen);
active_entity = Some(entity);
active_text.push_str(token);
} else if let Some(entity) = label.strip_prefix("I-").and_then(Entity::from_name) {
if active_entity == Some(entity) {
active_text.push_str(token);
} else {
flush(active_entity, &mut active_text, &mut fields, &mut seen);
active_entity = Some(entity);
active_text.push_str(token);
}
} else {
flush(active_entity, &mut active_text, &mut fields, &mut seen);
active_entity = None;
}
}
flush(active_entity, &mut active_text, &mut fields, &mut seen);
fields
}
fn count_variants(sample: &SourceSample, cfg: &GenConfig) -> u128 {
let mut count = if cfg.include_original { 1 } else { 0 };
let available = ENTITIES
.iter()
.copied()
.filter(|entity| !sample.fields[entity.index()].is_empty())
.collect::<Vec<_>>();
let n = available.len();
if n == 0 {
return count;
}
if cfg.samples_per_source > 0 {
return count + cfg.samples_per_source as u128;
}
for mask in 1usize..(1usize << n) {
let selected = available
.iter()
.enumerate()
.filter_map(|(idx, entity)| ((mask & (1usize << idx)) != 0).then_some(*entity))
.collect::<Vec<_>>();
let m = selected.len();
let value_product: u128 = selected
.iter()
.map(|entity| sample.fields[entity.index()].len() as u128)
.product();
let perm_count = factorial(m as u32);
let sep_factor = if m <= 1 {
1
} else {
match cfg.separator_mode {
SeparatorMode::Global => cfg.separators.len() as u128,
SeparatorMode::PerGap => (cfg.separators.len() as u128).pow((m - 1) as u32),
}
};
let bracket_factor = match cfg.bracket_mode {
BracketMode::Global => cfg.brackets.len() as u128,
BracketMode::PerPart => (cfg.brackets.len() as u128).pow(m as u32),
};
count += value_product * perm_count * sep_factor * bracket_factor;
}
count
}
fn count_special_fixtures(cfg: &GenConfig) -> usize {
let bracket_factor = match cfg.bracket_mode {
BracketMode::Global => cfg.brackets.len(),
BracketMode::PerPart => cfg.brackets.len(),
};
built_in_specials().len() * bracket_factor
}
fn factorial(n: u32) -> u128 {
(1..=n as u128).product::<u128>().max(1)
}
fn generate_for_sample(
sample: &SourceSample,
cfg: &GenConfig,
vocab: &Vocab,
writer: &mut ShardWriter,
) -> Result<()> {
if cfg.include_original {
let (input_ids, attention_mask, labels) =
encode_original_sample(sample, vocab, cfg.max_length)?;
writer.add(&input_ids, &attention_mask, &labels)?;
}
if cfg.samples_per_source > 0 {
generate_sampled_variants(sample, cfg, vocab, writer)?;
return Ok(());
}
let available = ENTITIES
.iter()
.copied()
.filter(|entity| !sample.fields[entity.index()].is_empty())
.collect::<Vec<_>>();
let n = available.len();
for mask in 1usize..(1usize << n) {
let mut selected = available
.iter()
.enumerate()
.filter_map(|(idx, entity)| ((mask & (1usize << idx)) != 0).then_some(*entity))
.collect::<Vec<_>>();
permute_entities(&mut selected, 0, &mut |order| {
let mut parts: Vec<PartChoice> = Vec::with_capacity(order.len());
for_each_value_combo(order, &sample.fields, 0, &mut parts, &mut |combo| {
emit_syntax_variants(combo, cfg, vocab, writer)
})
})?;
}
Ok(())
}
fn generate_sampled_variants(
sample: &SourceSample,
cfg: &GenConfig,
vocab: &Vocab,
writer: &mut ShardWriter,
) -> Result<()> {
let mut rng = StdRng::seed_from_u64(
cfg.seed ^ ((sample.row_index as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15)),
);
let available = ENTITIES
.iter()
.copied()
.filter(|entity| !sample.fields[entity.index()].is_empty())
.collect::<Vec<_>>();
if available.is_empty() {
return Ok(());
}
let mut seen = HashSet::new();
let mut emitted = 0usize;
let budget = cfg.samples_per_source;
let max_unique_attempts = budget.saturating_mul(32).max(64);
let mut attempts = 0usize;
let mut templates: Vec<Vec<PartChoice>> = Vec::new();
if let Some(title) = sample.fields[Entity::Title.index()].first() {
templates.push(vec![PartChoice {
entity: Entity::Title,
value: title.clone(),
}]);
if let Some(season) = sample.fields[Entity::Season.index()].first() {
templates.push(vec![
PartChoice {
entity: Entity::Title,
value: title.clone(),
},
PartChoice {
entity: Entity::Season,
value: season.clone(),
},
]);
}
}
if let Some(episode) = sample.fields[Entity::Episode.index()].first() {
templates.push(vec![PartChoice {
entity: Entity::Episode,
value: episode.clone(),
}]);
}
if let Some(special) = sample.fields[Entity::Special.index()].first() {
templates.push(vec![PartChoice {
entity: Entity::Special,
value: special.clone(),
}]);
}
if let (Some(title), Some(special)) = (
sample.fields[Entity::Title.index()].first(),
sample.fields[Entity::Special.index()].first(),
) {
templates.push(vec![
PartChoice {
entity: Entity::Title,
value: title.clone(),
},
PartChoice {
entity: Entity::Special,
value: special.clone(),
},
]);
}
for parts in templates {
if emitted >= budget {
break;
}
emit_sample_variant(
parts,
cfg,
vocab,
writer,
&mut seen,
&mut emitted,
&mut rng,
false,
)?;
}
while emitted < budget && attempts < max_unique_attempts {
attempts += 1;
let subset_size = match rng.gen_range(0..100) {
0..=29 => 1,
30..=54 => 2,
55..=74 => 3,
75..=89 => 4.min(available.len()),
_ => available.len().min(5),
}
.max(1)
.min(available.len());
let mut chosen = available
.choose_multiple(&mut rng, subset_size)
.copied()
.collect::<Vec<_>>();
chosen.shuffle(&mut rng);
if !chosen
.iter()
.any(|entity| matches!(entity, Entity::Title | Entity::Episode | Entity::Special))
{
if let Some(fallback) = available
.iter()
.copied()
.find(|entity| matches!(entity, Entity::Title | Entity::Episode | Entity::Special))
{
if !chosen.contains(&fallback) {
chosen.push(fallback);
}
}
}
let mut parts = Vec::with_capacity(chosen.len());
for entity in chosen {
let values = &sample.fields[entity.index()];
let value = values.choose(&mut rng).cloned().unwrap_or_default();
parts.push(PartChoice { entity, value });
}
parts.shuffle(&mut rng);
emit_sample_variant(
parts,
cfg,
vocab,
writer,
&mut seen,
&mut emitted,
&mut rng,
false,
)?;
}
while emitted < budget {
let subset_size = match rng.gen_range(0..100) {
0..=29 => 1,
30..=54 => 2,
55..=74 => 3,
75..=89 => 4.min(available.len()),
_ => available.len().min(5),
}
.max(1)
.min(available.len());
let mut chosen = available
.choose_multiple(&mut rng, subset_size)
.copied()
.collect::<Vec<_>>();
chosen.shuffle(&mut rng);
if !chosen
.iter()
.any(|entity| matches!(entity, Entity::Title | Entity::Episode | Entity::Special))
{
if let Some(fallback) = available
.iter()
.copied()
.find(|entity| matches!(entity, Entity::Title | Entity::Episode | Entity::Special))
{
if !chosen.contains(&fallback) {
chosen.push(fallback);
}
}
}
let mut parts = Vec::with_capacity(chosen.len());
for entity in chosen {
let values = &sample.fields[entity.index()];
let value = values.choose(&mut rng).cloned().unwrap_or_default();
parts.push(PartChoice { entity, value });
}
parts.shuffle(&mut rng);
emit_sample_variant(
parts,
cfg,
vocab,
writer,
&mut seen,
&mut emitted,
&mut rng,
true,
)?;
}
Ok(())
}
fn emit_sample_variant(
parts: Vec<PartChoice>,
cfg: &GenConfig,
vocab: &Vocab,
writer: &mut ShardWriter,
seen: &mut HashSet<String>,
emitted: &mut usize,
rng: &mut StdRng,
allow_duplicate: bool,
) -> Result<()> {
if *emitted >= cfg.samples_per_source {
return Ok(());
}
if parts.is_empty() {
return Ok(());
}
let separators = match cfg.separator_mode {
SeparatorMode::Global => {
let sep = cfg
.separators
.choose(rng)
.cloned()
.unwrap_or_else(|| " ".to_string());
if parts.len() > 1 {
vec![sep; parts.len() - 1]
} else {
Vec::new()
}
}
SeparatorMode::PerGap => {
let mut values = Vec::with_capacity(parts.len().saturating_sub(1));
for _ in 0..parts.len().saturating_sub(1) {
values.push(
cfg.separators
.choose(rng)
.cloned()
.unwrap_or_else(|| " ".to_string()),
);
}
values
}
};
let brackets = match cfg.bracket_mode {
BracketMode::Global => {
let bracket = cfg
.brackets
.choose(rng)
.cloned()
.unwrap_or_else(|| Bracket {
name: "none".to_string(),
open: String::new(),
close: String::new(),
});
vec![bracket; parts.len()]
}
BracketMode::PerPart => {
let mut values = Vec::with_capacity(parts.len());
for _ in 0..parts.len() {
values.push(
cfg.brackets
.choose(rng)
.cloned()
.unwrap_or_else(|| Bracket {
name: "none".to_string(),
open: String::new(),
close: String::new(),
}),
);
}
values
}
};
let text = render_variant_text(&parts, &separators, &brackets);
if !allow_duplicate && !seen.insert(text) {
return Ok(());
}
let (input_ids, attention_mask, labels) =
encode_generated_sample(&parts, &separators, &brackets, vocab, cfg.max_length)?;
writer.add(&input_ids, &attention_mask, &labels)?;
*emitted += 1;
Ok(())
}
fn permute_entities<F>(values: &mut [Entity], start: usize, callback: &mut F) -> Result<()>
where
F: FnMut(&[Entity]) -> Result<()>,
{
if start >= values.len() {
return callback(values);
}
for idx in start..values.len() {
values.swap(start, idx);
permute_entities(values, start + 1, callback)?;
values.swap(start, idx);
}
Ok(())
}
#[derive(Clone)]
struct PartChoice {
entity: Entity,
value: String,
}
fn for_each_value_combo<F>(
order: &[Entity],
fields: &[Vec<String>],
idx: usize,
current: &mut Vec<PartChoice>,
callback: &mut F,
) -> Result<()>
where
F: FnMut(&[PartChoice]) -> Result<()>,
{
if idx >= order.len() {
return callback(current);
}
let entity = order[idx];
for value in &fields[entity.index()] {
current.push(PartChoice {
entity,
value: value.clone(),
});
for_each_value_combo(order, fields, idx + 1, current, callback)?;
current.pop();
}
Ok(())
}
fn emit_syntax_variants(
parts: &[PartChoice],
cfg: &GenConfig,
vocab: &Vocab,
writer: &mut ShardWriter,
) -> Result<()> {
let gaps = parts.len().saturating_sub(1);
let mut separators = Vec::with_capacity(gaps);
for_each_separator_combo(gaps, cfg, 0, &mut separators, &mut |sep_combo| {
let mut brackets = Vec::with_capacity(parts.len());
for_each_bracket_combo(parts.len(), cfg, 0, &mut brackets, &mut |bracket_combo| {
let (input_ids, attention_mask, labels) =
encode_generated_sample(parts, sep_combo, bracket_combo, vocab, cfg.max_length)?;
writer.add(&input_ids, &attention_mask, &labels)
})
})
}
fn for_each_separator_combo<F>(
gaps: usize,
cfg: &GenConfig,
idx: usize,
current: &mut Vec<String>,
callback: &mut F,
) -> Result<()>
where
F: FnMut(&[String]) -> Result<()>,
{
if gaps == 0 {
return callback(current);
}
match cfg.separator_mode {
SeparatorMode::Global => {
if idx == 0 {
for sep in &cfg.separators {
current.clear();
current.resize(gaps, sep.clone());
callback(current)?;
}
}
Ok(())
}
SeparatorMode::PerGap => {
if idx >= gaps {
return callback(current);
}
for sep in &cfg.separators {
current.push(sep.clone());
for_each_separator_combo(gaps, cfg, idx + 1, current, callback)?;
current.pop();
}
Ok(())
}
}
}
fn for_each_bracket_combo<F>(
parts: usize,
cfg: &GenConfig,
idx: usize,
current: &mut Vec<Bracket>,
callback: &mut F,
) -> Result<()>
where
F: FnMut(&[Bracket]) -> Result<()>,
{
match cfg.bracket_mode {
BracketMode::Global => {
if idx == 0 {
for bracket in &cfg.brackets {
current.clear();
current.resize(parts, bracket.clone());
callback(current)?;
}
}
Ok(())
}
BracketMode::PerPart => {
if idx >= parts {
return callback(current);
}
for bracket in &cfg.brackets {
current.push(bracket.clone());
for_each_bracket_combo(parts, cfg, idx + 1, current, callback)?;
current.pop();
}
Ok(())
}
}
}
fn render_variant_text(
parts: &[PartChoice],
separators: &[String],
brackets: &[Bracket],
) -> String {
let mut text = String::new();
for (idx, part) in parts.iter().enumerate() {
text.push_str(&brackets[idx].open);
text.push_str(&part.value);
text.push_str(&brackets[idx].close);
if idx < separators.len() {
text.push_str(&separators[idx]);
}
}
text
}
fn encode_original_sample(
sample: &SourceSample,
vocab: &Vocab,
max_length: usize,
) -> Result<(Vec<u16>, Vec<u8>, Vec<i16>)> {
let mut input_ids = vec![vocab.pad_id; max_length];
let mut attention_mask = vec![0u8; max_length];
let mut labels = vec![-100i16; max_length];
input_ids[0] = vocab.cls_id;
attention_mask[0] = 1;
let available = max_length.saturating_sub(2);
let token_count = sample.tokens.len().min(available);
for idx in 0..token_count {
input_ids[idx + 1] = token_id(vocab, &sample.tokens[idx]);
attention_mask[idx + 1] = 1;
labels[idx + 1] = label_id(&sample.labels[idx]).with_context(|| {
format!(
"unknown label '{}' on source row {} ({})",
sample.labels[idx],
sample.row_index + 1,
sample.filename
)
})?;
}
let sep_pos = token_count + 1;
input_ids[sep_pos] = vocab.sep_id;
attention_mask[sep_pos] = 1;
Ok((input_ids, attention_mask, labels))
}
fn encode_generated_sample(
parts: &[PartChoice],
separators: &[String],
brackets: &[Bracket],
vocab: &Vocab,
max_length: usize,
) -> Result<(Vec<u16>, Vec<u8>, Vec<i16>)> {
let mut input_ids = vec![vocab.pad_id; max_length];
let mut attention_mask = vec![0u8; max_length];
let mut labels = vec![-100i16; max_length];
input_ids[0] = vocab.cls_id;
attention_mask[0] = 1;
let available = max_length.saturating_sub(2);
let mut pos = 1usize;
for (idx, part) in parts.iter().enumerate() {
let bracket = &brackets[idx];
append_o_text(
&bracket.open,
vocab,
available,
&mut pos,
&mut input_ids,
&mut attention_mask,
&mut labels,
);
append_entity_text(
&part.value,
part.entity,
vocab,
available,
&mut pos,
&mut input_ids,
&mut attention_mask,
&mut labels,
)?;
append_o_text(
&bracket.close,
vocab,
available,
&mut pos,
&mut input_ids,
&mut attention_mask,
&mut labels,
);
if idx < separators.len() {
append_o_text(
&separators[idx],
vocab,
available,
&mut pos,
&mut input_ids,
&mut attention_mask,
&mut labels,
);
}
}
let sep_pos = pos.min(max_length - 1);
input_ids[sep_pos] = vocab.sep_id;
attention_mask[sep_pos] = 1;
labels[sep_pos] = -100;
Ok((input_ids, attention_mask, labels))
}
fn append_o_text(
text: &str,
vocab: &Vocab,
available: usize,
pos: &mut usize,
input_ids: &mut [u16],
attention_mask: &mut [u8],
labels: &mut [i16],
) {
for ch in text.chars() {
if *pos > available {
return;
}
let token = ch.to_string();
input_ids[*pos] = token_id(vocab, &token);
attention_mask[*pos] = 1;
labels[*pos] = 0;
*pos += 1;
}
}
fn append_entity_text(
text: &str,
entity: Entity,
vocab: &Vocab,
available: usize,
pos: &mut usize,
input_ids: &mut [u16],
attention_mask: &mut [u8],
labels: &mut [i16],
) -> Result<()> {
let b = label_id(entity.b_label()).context("missing B label")?;
let i = label_id(entity.i_label()).context("missing I label")?;
let mut first = true;
for ch in text.chars() {
if *pos > available {
return Ok(());
}
let token = ch.to_string();
input_ids[*pos] = token_id(vocab, &token);
attention_mask[*pos] = 1;
labels[*pos] = if first { b } else { i };
first = false;
*pos += 1;
}
Ok(())
}
fn token_id(vocab: &Vocab, token: &str) -> u16 {
*vocab.ids.get(token).unwrap_or(&vocab.unk_id)
}
fn label_id(label: &str) -> Option<i16> {
Some(match label {
"O" => 0,
"B-TITLE" => 1,
"I-TITLE" => 2,
"B-SEASON" => 3,
"I-SEASON" => 4,
"B-EPISODE" => 5,
"I-EPISODE" => 6,
"B-SPECIAL" => 7,
"I-SPECIAL" => 8,
"B-GROUP" => 9,
"I-GROUP" => 10,
"B-RESOLUTION" => 11,
"I-RESOLUTION" => 12,
"B-SOURCE" => 13,
"I-SOURCE" => 14,
_ => return None,
})
}
fn built_in_specials() -> Vec<String> {
let mut values = Vec::new();
values.push("Menu".to_string());
for idx in 1..=24 {
values.push(format!("Menu{idx:02}"));
values.push(format!("Menu {idx:02}"));
values.push(format!("BDMenu{idx:02}"));
values.push(format!("BD Menu{idx:02}"));
values.push(format!("Menu{idx:02}-01"));
values.push(format!("ED E{idx:02}"));
}
for idx in 1..=6 {
values.push(format!("OP{idx:02}"));
values.push(format!("NCOP{idx:02}"));
values.push(format!("NCED{idx:02}"));
}
for idx in 1..=12 {
values.push(format!("CM{idx:02}"));
values.push(format!("PV{idx:02}"));
}
values
}
fn write_npy_u16(path: &Path, data: &[u16], rows: usize, cols: usize) -> Result<()> {
let mut writer = BufWriter::new(
File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
);
write_npy_header(&mut writer, "<u2", rows, cols)?;
for value in data {
writer.write_all(&value.to_le_bytes())?;
}
Ok(())
}
fn write_npy_u8(path: &Path, data: &[u8], rows: usize, cols: usize) -> Result<()> {
let mut writer = BufWriter::new(
File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
);
write_npy_header(&mut writer, "|u1", rows, cols)?;
writer.write_all(data)?;
Ok(())
}
fn write_npy_i16(path: &Path, data: &[i16], rows: usize, cols: usize) -> Result<()> {
let mut writer = BufWriter::new(
File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
);
write_npy_header(&mut writer, "<i2", rows, cols)?;
for value in data {
writer.write_all(&value.to_le_bytes())?;
}
Ok(())
}
fn write_npy_header<W: Write>(writer: &mut W, descr: &str, rows: usize, cols: usize) -> Result<()> {
let mut header = format!(
"{{'descr': '{}', 'fortran_order': False, 'shape': ({}, {}), }}",
descr, rows, cols
)
.into_bytes();
let preamble_len = 10usize;
let pad_len = (16 - ((preamble_len + header.len() + 1) % 16)) % 16;
header.extend(std::iter::repeat(b' ').take(pad_len));
header.push(b'\n');
if header.len() > u16::MAX as usize {
bail!("npy header too large");
}
writer.write_all(b"\x93NUMPY")?;
writer.write_all(&[1, 0])?;
writer.write_all(&(header.len() as u16).to_le_bytes())?;
writer.write_all(&header)?;
Ok(())
}