ModerRAS's picture
Implement schema v2 anime filename labels
ed49faa
raw
history blame
73.9 kB
use anyhow::{bail, Context, Result};
use clap::{Parser, ValueEnum};
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use rand::Rng;
use rand::SeedableRng;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use serde_json::json;
use std::collections::{HashMap, HashSet};
use std::fs::{self, File};
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::{Path, PathBuf};
use std::sync::OnceLock;
use std::time::Instant;
const FILE_TITLE_ENTITIES: [Entity; 5] = [
Entity::TitleChs,
Entity::TitleCht,
Entity::TitleJpn,
Entity::TitleLatin,
Entity::TitleMixed,
];
const PATH_TITLE_ENTITIES: [Entity; 5] = [
Entity::PathTitleChs,
Entity::PathTitleCht,
Entity::PathTitleJpn,
Entity::PathTitleLatin,
Entity::PathTitleMixed,
];
const ENTITIES: [Entity; 18] = [
Entity::Group,
Entity::TitleChs,
Entity::TitleCht,
Entity::TitleJpn,
Entity::TitleLatin,
Entity::TitleMixed,
Entity::PathTitleChs,
Entity::PathTitleCht,
Entity::PathTitleJpn,
Entity::PathTitleLatin,
Entity::PathTitleMixed,
Entity::PathSeason,
Entity::Season,
Entity::Episode,
Entity::Special,
Entity::Resolution,
Entity::Source,
Entity::Tag,
];
const FALLBACK_LABELS: [&str; 37] = [
"O",
"B-TITLE_CHS",
"I-TITLE_CHS",
"B-TITLE_CHT",
"I-TITLE_CHT",
"B-TITLE_JPN",
"I-TITLE_JPN",
"B-TITLE_LATIN",
"I-TITLE_LATIN",
"B-TITLE_MIXED",
"I-TITLE_MIXED",
"B-PATH_TITLE_CHS",
"I-PATH_TITLE_CHS",
"B-PATH_TITLE_CHT",
"I-PATH_TITLE_CHT",
"B-PATH_TITLE_JPN",
"I-PATH_TITLE_JPN",
"B-PATH_TITLE_LATIN",
"I-PATH_TITLE_LATIN",
"B-PATH_TITLE_MIXED",
"I-PATH_TITLE_MIXED",
"B-PATH_SEASON",
"I-PATH_SEASON",
"B-SEASON",
"I-SEASON",
"B-EPISODE",
"I-EPISODE",
"B-SPECIAL",
"I-SPECIAL",
"B-GROUP",
"I-GROUP",
"B-RESOLUTION",
"I-RESOLUTION",
"B-SOURCE",
"I-SOURCE",
"B-TAG",
"I-TAG",
];
static LABEL_IDS: OnceLock<HashMap<String, i16>> = OnceLock::new();
#[derive(Debug, Deserialize)]
struct LabelSchema {
labels: Vec<String>,
}
#[derive(Parser, Debug)]
#[command(
about = "Generate pre-encoded AniFileBERT virtual BIO permutation shards",
version
)]
struct Args {
#[arg(long)]
input: PathBuf,
#[arg(long)]
vocab_file: PathBuf,
#[arg(long)]
output_dir: PathBuf,
#[arg(long, default_value_t = 128)]
max_length: usize,
#[arg(long, default_value_t = 25_000)]
shard_size: usize,
#[arg(long, default_value_t = 0)]
limit_rows: usize,
#[arg(long, default_value_t = 0)]
samples_per_source: usize,
#[arg(
long,
default_value_t = 0,
help = "Generate full-path context samples per source row; prefix directories are O labels"
)]
path_samples_per_source: usize,
#[arg(long, default_value_t = 42)]
seed: u64,
#[arg(long, default_value_t = 0)]
threads: usize,
#[arg(long, default_value = "global")]
separator_mode: SeparatorMode,
#[arg(long, default_value = "global")]
bracket_mode: BracketMode,
#[arg(long, value_delimiter = ',', default_value = " , - ,.,_,-,~,~")]
separators: Vec<String>,
#[arg(
long,
value_delimiter = ',',
default_value = "none,square,round,corner,angle"
)]
bracket_styles: Vec<String>,
#[arg(long, value_delimiter = ',', default_value = "windows,unix")]
path_styles: Vec<PathStyle>,
#[arg(long, default_value_t = true)]
include_original: bool,
#[arg(long, help = "Skip original source rows in generated shards")]
no_original: bool,
#[arg(long, help = "Skip ordinary BIO entity subset/permutation variants")]
no_bio_variants: bool,
#[arg(long, default_value_t = true)]
include_special_fixtures: bool,
#[arg(long, help = "Skip built-in standalone special fixtures")]
no_special_fixtures: bool,
#[arg(long, help = "Only count rows; do not write shard files")]
dry_run: bool,
}
#[derive(Clone, Copy, Debug, Serialize, ValueEnum)]
enum SeparatorMode {
Global,
PerGap,
}
#[derive(Clone, Copy, Debug, Serialize, ValueEnum)]
enum BracketMode {
Global,
PerPart,
}
#[derive(Clone, Copy, Debug, Serialize, ValueEnum)]
enum PathStyle {
Windows,
Unix,
}
impl PathStyle {
fn separator(self) -> &'static str {
match self {
PathStyle::Windows => "\\",
PathStyle::Unix => "/",
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd, Serialize)]
enum Entity {
Group,
TitleChs,
TitleCht,
TitleJpn,
TitleLatin,
TitleMixed,
PathTitleChs,
PathTitleCht,
PathTitleJpn,
PathTitleLatin,
PathTitleMixed,
PathSeason,
Season,
Episode,
Special,
Resolution,
Source,
Tag,
}
impl Entity {
fn index(self) -> usize {
ENTITIES
.iter()
.position(|entity| *entity == self)
.expect("entity missing from ENTITIES")
}
fn from_name(name: &str) -> Option<Self> {
match name {
"GROUP" => Some(Entity::Group),
"TITLE" | "TITLE_MIXED" => Some(Entity::TitleMixed),
"TITLE_CHS" => Some(Entity::TitleChs),
"TITLE_CHT" => Some(Entity::TitleCht),
"TITLE_JPN" => Some(Entity::TitleJpn),
"TITLE_LATIN" => Some(Entity::TitleLatin),
"PATH_TITLE" | "PATH_TITLE_MIXED" => Some(Entity::PathTitleMixed),
"PATH_TITLE_CHS" => Some(Entity::PathTitleChs),
"PATH_TITLE_CHT" => Some(Entity::PathTitleCht),
"PATH_TITLE_JPN" => Some(Entity::PathTitleJpn),
"PATH_TITLE_LATIN" => Some(Entity::PathTitleLatin),
"PATH_SEASON" => Some(Entity::PathSeason),
"SEASON" => Some(Entity::Season),
"EPISODE" => Some(Entity::Episode),
"SPECIAL" => Some(Entity::Special),
"RESOLUTION" => Some(Entity::Resolution),
"SOURCE" => Some(Entity::Source),
"TAG" => Some(Entity::Tag),
_ => None,
}
}
fn b_label(self) -> &'static str {
match self {
Entity::Group => "B-GROUP",
Entity::TitleChs => "B-TITLE_CHS",
Entity::TitleCht => "B-TITLE_CHT",
Entity::TitleJpn => "B-TITLE_JPN",
Entity::TitleLatin => "B-TITLE_LATIN",
Entity::TitleMixed => "B-TITLE_MIXED",
Entity::PathTitleChs => "B-PATH_TITLE_CHS",
Entity::PathTitleCht => "B-PATH_TITLE_CHT",
Entity::PathTitleJpn => "B-PATH_TITLE_JPN",
Entity::PathTitleLatin => "B-PATH_TITLE_LATIN",
Entity::PathTitleMixed => "B-PATH_TITLE_MIXED",
Entity::PathSeason => "B-PATH_SEASON",
Entity::Season => "B-SEASON",
Entity::Episode => "B-EPISODE",
Entity::Special => "B-SPECIAL",
Entity::Resolution => "B-RESOLUTION",
Entity::Source => "B-SOURCE",
Entity::Tag => "B-TAG",
}
}
fn i_label(self) -> &'static str {
match self {
Entity::Group => "I-GROUP",
Entity::TitleChs => "I-TITLE_CHS",
Entity::TitleCht => "I-TITLE_CHT",
Entity::TitleJpn => "I-TITLE_JPN",
Entity::TitleLatin => "I-TITLE_LATIN",
Entity::TitleMixed => "I-TITLE_MIXED",
Entity::PathTitleChs => "I-PATH_TITLE_CHS",
Entity::PathTitleCht => "I-PATH_TITLE_CHT",
Entity::PathTitleJpn => "I-PATH_TITLE_JPN",
Entity::PathTitleLatin => "I-PATH_TITLE_LATIN",
Entity::PathTitleMixed => "I-PATH_TITLE_MIXED",
Entity::PathSeason => "I-PATH_SEASON",
Entity::Season => "I-SEASON",
Entity::Episode => "I-EPISODE",
Entity::Special => "I-SPECIAL",
Entity::Resolution => "I-RESOLUTION",
Entity::Source => "I-SOURCE",
Entity::Tag => "I-TAG",
}
}
fn is_file_title(self) -> bool {
matches!(
self,
Entity::TitleChs
| Entity::TitleCht
| Entity::TitleJpn
| Entity::TitleLatin
| Entity::TitleMixed
)
}
fn is_path_title(self) -> bool {
matches!(
self,
Entity::PathTitleChs
| Entity::PathTitleCht
| Entity::PathTitleJpn
| Entity::PathTitleLatin
| Entity::PathTitleMixed
)
}
fn is_ordinary_variant_entity(self) -> bool {
!self.is_path_title() && self != Entity::PathSeason
}
fn as_path_title(self) -> Option<Self> {
match self {
Entity::TitleChs => Some(Entity::PathTitleChs),
Entity::TitleCht => Some(Entity::PathTitleCht),
Entity::TitleJpn => Some(Entity::PathTitleJpn),
Entity::TitleLatin => Some(Entity::PathTitleLatin),
Entity::TitleMixed => Some(Entity::PathTitleMixed),
Entity::PathTitleChs
| Entity::PathTitleCht
| Entity::PathTitleJpn
| Entity::PathTitleLatin
| Entity::PathTitleMixed => Some(self),
_ => None,
}
}
fn as_file_title(self) -> Option<Self> {
match self {
Entity::PathTitleChs => Some(Entity::TitleChs),
Entity::PathTitleCht => Some(Entity::TitleCht),
Entity::PathTitleJpn => Some(Entity::TitleJpn),
Entity::PathTitleLatin => Some(Entity::TitleLatin),
Entity::PathTitleMixed => Some(Entity::TitleMixed),
Entity::TitleChs
| Entity::TitleCht
| Entity::TitleJpn
| Entity::TitleLatin
| Entity::TitleMixed => Some(self),
_ => None,
}
}
}
#[derive(Clone, Debug)]
struct Bracket {
name: String,
open: String,
close: String,
}
impl Bracket {
fn from_name(name: &str) -> Result<Self> {
let trimmed = name.trim();
let pair = match trimmed {
"none" => ("", ""),
"square" => ("[", "]"),
"round" => ("(", ")"),
"corner" => ("【", "】"),
"angle" => ("《", "》"),
custom if custom.contains('|') => {
let mut parts = custom.splitn(2, '|');
let open = parts.next().unwrap_or_default();
let close = parts.next().unwrap_or_default();
return Ok(Self {
name: custom.to_string(),
open: open.to_string(),
close: close.to_string(),
});
}
other => bail!("unknown bracket style '{other}'"),
};
Ok(Self {
name: trimmed.to_string(),
open: pair.0.to_string(),
close: pair.1.to_string(),
})
}
}
#[derive(Deserialize)]
struct InputRow {
filename: Option<String>,
tokens: Vec<String>,
labels: Vec<String>,
tokenizer_variant: Option<String>,
}
#[derive(Clone)]
struct SourceSample {
row_index: usize,
filename: String,
tokens: Vec<String>,
labels: Vec<String>,
fields: Vec<Vec<String>>,
}
#[derive(Clone)]
struct GenConfig {
max_length: usize,
shard_size: usize,
separator_mode: SeparatorMode,
bracket_mode: BracketMode,
separators: Vec<String>,
brackets: Vec<Bracket>,
path_styles: Vec<PathStyle>,
include_original: bool,
include_bio_variants: bool,
samples_per_source: usize,
path_samples_per_source: usize,
seed: u64,
}
#[derive(Clone)]
struct Vocab {
ids: HashMap<String, u16>,
pad_id: u16,
unk_id: u16,
cls_id: u16,
sep_id: u16,
}
#[derive(Serialize)]
struct ShardManifest {
rows: usize,
input_ids: String,
attention_mask: String,
labels: String,
}
struct ShardWriter {
output_dir: PathBuf,
worker_id: usize,
shard_seq: usize,
shard_size: usize,
max_length: usize,
input_ids: Vec<u16>,
attention_mask: Vec<u8>,
labels: Vec<i16>,
rows: usize,
total_rows: u64,
shards: Vec<ShardManifest>,
}
impl ShardWriter {
fn new(output_dir: &Path, worker_id: usize, shard_size: usize, max_length: usize) -> Self {
let capacity = shard_size.saturating_mul(max_length);
Self {
output_dir: output_dir.to_path_buf(),
worker_id,
shard_seq: 0,
shard_size,
max_length,
input_ids: Vec::with_capacity(capacity),
attention_mask: Vec::with_capacity(capacity),
labels: Vec::with_capacity(capacity),
rows: 0,
total_rows: 0,
shards: Vec::new(),
}
}
fn add(&mut self, input_ids: &[u16], attention_mask: &[u8], labels: &[i16]) -> Result<()> {
if input_ids.len() != self.max_length
|| attention_mask.len() != self.max_length
|| labels.len() != self.max_length
{
bail!("encoded sample has wrong shape");
}
self.input_ids.extend_from_slice(input_ids);
self.attention_mask.extend_from_slice(attention_mask);
self.labels.extend_from_slice(labels);
self.rows += 1;
self.total_rows += 1;
if self.rows >= self.shard_size {
self.flush()?;
}
Ok(())
}
fn flush(&mut self) -> Result<()> {
if self.rows == 0 {
return Ok(());
}
let base = format!("part-w{:03}-s{:06}", self.worker_id, self.shard_seq);
let input_name = format!("{base}.input_ids.npy");
let mask_name = format!("{base}.attention_mask.npy");
let label_name = format!("{base}.labels.npy");
write_npy_u16(
&self.output_dir.join(&input_name),
&self.input_ids,
self.rows,
self.max_length,
)?;
write_npy_u8(
&self.output_dir.join(&mask_name),
&self.attention_mask,
self.rows,
self.max_length,
)?;
write_npy_i16(
&self.output_dir.join(&label_name),
&self.labels,
self.rows,
self.max_length,
)?;
self.shards.push(ShardManifest {
rows: self.rows,
input_ids: input_name,
attention_mask: mask_name,
labels: label_name,
});
self.input_ids.clear();
self.attention_mask.clear();
self.labels.clear();
self.rows = 0;
self.shard_seq += 1;
Ok(())
}
}
fn main() -> Result<()> {
let args = Args::parse();
let include_original = args.include_original && !args.no_original;
let include_bio_variants = !args.no_bio_variants;
let include_special_fixtures = args.include_special_fixtures && !args.no_special_fixtures;
if args.max_length < 4 {
bail!("--max-length must be at least 4");
}
if args.shard_size == 0 {
bail!("--shard-size must be positive");
}
if args.threads > 0 {
rayon::ThreadPoolBuilder::new()
.num_threads(args.threads)
.build_global()
.context("failed to configure rayon thread pool")?;
}
let started = Instant::now();
let vocab = load_vocab(&args.vocab_file)?;
let brackets = args
.bracket_styles
.iter()
.map(|style| Bracket::from_name(style))
.collect::<Result<Vec<_>>>()?;
let separators = args
.separators
.iter()
.map(|sep| normalize_separator_arg(sep))
.collect::<Vec<_>>();
let cfg = GenConfig {
max_length: args.max_length,
shard_size: args.shard_size,
separator_mode: args.separator_mode,
bracket_mode: args.bracket_mode,
separators,
brackets,
path_styles: args.path_styles.clone(),
include_original,
include_bio_variants,
samples_per_source: args.samples_per_source,
path_samples_per_source: args.path_samples_per_source,
seed: args.seed,
};
let mut samples = load_samples(&args.input, args.limit_rows)?;
let source_rows = samples.len();
let mut rng = StdRng::seed_from_u64(args.seed);
samples.shuffle(&mut rng);
let path_variant_rows: u128 = samples
.par_iter()
.map(|sample| count_path_variants(sample, &cfg) as u128)
.sum();
if args.dry_run {
let generated: u128 = samples
.par_iter()
.map(|sample| count_variants(sample, &cfg))
.sum();
let special_fixtures = if include_special_fixtures {
count_special_fixtures(&cfg) as u128
} else {
0
};
let manifest = json!({
"format": "anifilebert.virtual_dataset.preview.v1",
"input": args.input,
"vocab_file": args.vocab_file,
"source_rows": source_rows,
"estimated_rows": generated + special_fixtures,
"source_variant_rows": generated,
"path_variant_rows": path_variant_rows,
"special_fixture_rows": special_fixtures,
"max_length": cfg.max_length,
"separator_mode": cfg.separator_mode,
"bracket_mode": cfg.bracket_mode,
"separators": cfg.separators,
"brackets": cfg.brackets.iter().map(|b| &b.name).collect::<Vec<_>>(),
"include_original": cfg.include_original,
"include_bio_variants": cfg.include_bio_variants,
"samples_per_source": cfg.samples_per_source,
"path_samples_per_source": cfg.path_samples_per_source,
"path_styles": cfg.path_styles,
"include_special_fixtures": include_special_fixtures,
"seed": args.seed,
"elapsed_seconds": started.elapsed().as_secs_f64(),
});
println!("{}", serde_json::to_string_pretty(&manifest)?);
return Ok(());
}
fs::create_dir_all(&args.output_dir).with_context(|| {
format!(
"failed to create output directory {}",
args.output_dir.display()
)
})?;
let chunk_count = rayon::current_num_threads().max(1) * 4;
let chunk_size = samples.len().div_ceil(chunk_count).max(1);
let chunks = samples
.chunks(chunk_size)
.enumerate()
.collect::<Vec<(usize, &[SourceSample])>>();
let mut worker_results = chunks
.par_iter()
.map(|(chunk_idx, chunk)| {
let mut writer =
ShardWriter::new(&args.output_dir, *chunk_idx, cfg.shard_size, cfg.max_length);
for sample in *chunk {
generate_for_sample(sample, &cfg, &vocab, &mut writer)?;
}
writer.flush()?;
Ok::<_, anyhow::Error>((writer.total_rows, writer.shards))
})
.collect::<Result<Vec<_>>>()?;
let mut total_rows: u64 = 0;
let mut shards: Vec<ShardManifest> = Vec::new();
for (rows, mut worker_shards) in worker_results.drain(..) {
total_rows += rows;
shards.append(&mut worker_shards);
}
let special_rows = if include_special_fixtures {
let mut writer = ShardWriter::new(
&args.output_dir,
chunk_count + 1,
cfg.shard_size,
cfg.max_length,
);
for special in built_in_specials() {
let parts = vec![PartChoice {
entity: Entity::Special,
value: special,
}];
emit_syntax_variants(&parts, &cfg, &vocab, &mut writer)?;
}
writer.flush()?;
total_rows += writer.total_rows;
shards.append(&mut writer.shards);
writer.total_rows
} else {
0
};
shards.sort_by(|a, b| a.input_ids.cmp(&b.input_ids));
let manifest = json!({
"format": "anifilebert.virtual_dataset.shards.v1",
"input": args.input,
"vocab_file": args.vocab_file,
"source_rows": source_rows,
"total_rows": total_rows,
"path_variant_rows": path_variant_rows,
"special_fixture_rows": special_rows,
"max_length": cfg.max_length,
"shard_size": cfg.shard_size,
"tokenizer_variant": "char",
"encoding": {
"input_ids_dtype": "uint16",
"attention_mask_dtype": "uint8",
"labels_dtype": "int16",
"layout": "row_major_npy"
},
"special_tokens": {
"pad_id": vocab.pad_id,
"unk_id": vocab.unk_id,
"cls_id": vocab.cls_id,
"sep_id": vocab.sep_id
},
"generation": {
"separator_mode": cfg.separator_mode,
"bracket_mode": cfg.bracket_mode,
"separators": cfg.separators,
"brackets": cfg.brackets.iter().map(|b| &b.name).collect::<Vec<_>>(),
"include_original": cfg.include_original,
"include_bio_variants": cfg.include_bio_variants,
"samples_per_source": cfg.samples_per_source,
"path_samples_per_source": cfg.path_samples_per_source,
"path_styles": cfg.path_styles,
"include_special_fixtures": include_special_fixtures,
"seed": args.seed,
"threads": rayon::current_num_threads()
},
"shards": shards,
"elapsed_seconds": started.elapsed().as_secs_f64(),
});
let manifest_path = args.output_dir.join("manifest.json");
fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?)
.with_context(|| format!("failed to write {}", manifest_path.display()))?;
println!("{}", serde_json::to_string_pretty(&manifest)?);
Ok(())
}
fn normalize_separator_arg(value: &str) -> String {
match value {
"\\t" => "\t".to_string(),
"\\s" => " ".to_string(),
other => other.to_string(),
}
}
fn load_vocab(path: &Path) -> Result<Vocab> {
let text = fs::read_to_string(path)
.with_context(|| format!("failed to read vocab file {}", path.display()))?;
let raw: HashMap<String, u64> =
serde_json::from_str(&text).context("failed to parse vocab JSON")?;
let mut ids = HashMap::with_capacity(raw.len());
for (token, id) in raw {
if id > u16::MAX as u64 {
bail!("vocab id {id} for token '{token}' exceeds uint16 storage");
}
ids.insert(token, id as u16);
}
let pad_id = *ids.get("[PAD]").context("vocab is missing [PAD]")?;
let unk_id = *ids.get("[UNK]").context("vocab is missing [UNK]")?;
let cls_id = *ids.get("[CLS]").context("vocab is missing [CLS]")?;
let sep_id = *ids.get("[SEP]").context("vocab is missing [SEP]")?;
Ok(Vocab {
ids,
pad_id,
unk_id,
cls_id,
sep_id,
})
}
fn load_samples(path: &Path, limit_rows: usize) -> Result<Vec<SourceSample>> {
let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?;
let reader = BufReader::new(file);
let mut samples = Vec::new();
for (idx, line) in reader.lines().enumerate() {
if limit_rows > 0 && samples.len() >= limit_rows {
break;
}
let line = line.with_context(|| format!("failed reading line {}", idx + 1))?;
if line.trim().is_empty() {
continue;
}
let row: InputRow = serde_json::from_str(&line)
.with_context(|| format!("failed to parse JSONL line {}", idx + 1))?;
if let Some(variant) = row.tokenizer_variant.as_deref() {
if variant != "char" {
bail!(
"line {} has tokenizer_variant={variant}; virtual shard generation currently requires char data",
idx + 1
);
}
}
if row.tokens.len() != row.labels.len() {
bail!(
"line {} has mismatched token/label lengths: {} vs {}",
idx + 1,
row.tokens.len(),
row.labels.len()
);
}
let filename = row.filename.clone().unwrap_or_else(|| row.tokens.join(""));
let labels = row
.labels
.iter()
.map(|label| canonical_bio_label(label))
.collect::<Vec<_>>();
let fields = extract_fields(&row.tokens, &labels);
samples.push(SourceSample {
row_index: idx,
filename,
tokens: row.tokens,
labels,
fields,
});
}
Ok(samples)
}
fn canonical_bio_label(label: &str) -> String {
if label == "O" {
return "O".to_string();
}
let Some((prefix, entity_name)) = label.split_once('-') else {
return label.to_string();
};
let Some(entity) = Entity::from_name(entity_name) else {
return label.to_string();
};
match prefix {
"B" => entity.b_label().to_string(),
"I" => entity.i_label().to_string(),
_ => label.to_string(),
}
}
fn extract_fields(tokens: &[String], labels: &[String]) -> Vec<Vec<String>> {
let mut fields: Vec<Vec<String>> = (0..ENTITIES.len()).map(|_| Vec::new()).collect();
let mut seen: Vec<HashSet<String>> = (0..ENTITIES.len()).map(|_| HashSet::new()).collect();
let mut active_entity: Option<Entity> = None;
let mut active_text = String::new();
let flush = |entity: Option<Entity>,
text: &mut String,
fields: &mut Vec<Vec<String>>,
seen: &mut Vec<HashSet<String>>| {
if let Some(entity) = entity {
let value = text.trim().to_string();
push_extracted_field(fields, seen, entity, value);
}
text.clear();
};
for (token, label) in tokens.iter().zip(labels.iter()) {
if let Some(entity) = label.strip_prefix("B-").and_then(Entity::from_name) {
flush(active_entity, &mut active_text, &mut fields, &mut seen);
active_entity = Some(entity);
active_text.push_str(token);
} else if let Some(entity) = label.strip_prefix("I-").and_then(Entity::from_name) {
if active_entity == Some(entity) {
active_text.push_str(token);
} else {
flush(active_entity, &mut active_text, &mut fields, &mut seen);
active_entity = Some(entity);
active_text.push_str(token);
}
} else {
flush(active_entity, &mut active_text, &mut fields, &mut seen);
active_entity = None;
}
}
flush(active_entity, &mut active_text, &mut fields, &mut seen);
fields
}
fn push_extracted_field(
fields: &mut [Vec<String>],
seen: &mut [HashSet<String>],
entity: Entity,
value: String,
) {
fn add(fields: &mut [Vec<String>], seen: &mut [HashSet<String>], entity: Entity, value: &str) {
if !value.is_empty() && seen[entity.index()].insert(value.to_string()) {
fields[entity.index()].push(value.to_string());
}
}
let value = value.trim();
if value.is_empty() {
return;
}
add(fields, seen, entity, value);
if let Some(path_title) = entity.as_path_title() {
add(fields, seen, path_title, value);
}
if let Some(file_title) = entity.as_file_title() {
add(fields, seen, file_title, value);
}
match entity {
Entity::Season => add(fields, seen, Entity::PathSeason, value),
Entity::PathSeason => add(fields, seen, Entity::Season, value),
_ => {}
}
}
fn ordinary_available_entities(sample: &SourceSample) -> Vec<Entity> {
ENTITIES
.iter()
.copied()
.filter(|entity| {
entity.is_ordinary_variant_entity() && !sample.fields[entity.index()].is_empty()
})
.collect()
}
fn first_file_title_field(sample: &SourceSample) -> Option<(Entity, String)> {
FILE_TITLE_ENTITIES.iter().copied().find_map(|entity| {
sample.fields[entity.index()]
.iter()
.find(|value| !value.trim().is_empty())
.map(|value| (entity, value.trim().to_string()))
})
}
fn choose_path_title_field(sample: &SourceSample, rng: &mut StdRng) -> Option<(Entity, String)> {
let mut candidates = Vec::new();
for entity in PATH_TITLE_ENTITIES {
for value in &sample.fields[entity.index()] {
let value = value.trim();
if !value.is_empty() {
candidates.push((entity, value.to_string()));
}
}
}
candidates.choose(rng).cloned()
}
fn count_variants(sample: &SourceSample, cfg: &GenConfig) -> u128 {
let mut count = if cfg.include_original { 1 } else { 0 };
count += count_path_variants(sample, cfg) as u128;
let available = ordinary_available_entities(sample);
let n = available.len();
if n == 0 || !cfg.include_bio_variants {
return count;
}
if cfg.samples_per_source > 0 {
return count + cfg.samples_per_source as u128;
}
for mask in 1usize..(1usize << n) {
let selected = available
.iter()
.enumerate()
.filter_map(|(idx, entity)| ((mask & (1usize << idx)) != 0).then_some(*entity))
.collect::<Vec<_>>();
let m = selected.len();
let value_product: u128 = selected
.iter()
.map(|entity| sample.fields[entity.index()].len() as u128)
.product();
let perm_count = factorial(m as u32);
let sep_factor = if m <= 1 {
1
} else {
match cfg.separator_mode {
SeparatorMode::Global => cfg.separators.len() as u128,
SeparatorMode::PerGap => (cfg.separators.len() as u128).pow((m - 1) as u32),
}
};
let bracket_factor = match cfg.bracket_mode {
BracketMode::Global => cfg.brackets.len() as u128,
BracketMode::PerPart => (cfg.brackets.len() as u128).pow(m as u32),
};
count += value_product * perm_count * sep_factor * bracket_factor;
}
count
}
fn count_path_variants(sample: &SourceSample, cfg: &GenConfig) -> usize {
if cfg.path_samples_per_source == 0 || cfg.path_styles.is_empty() {
return 0;
}
if !PATH_TITLE_ENTITIES
.iter()
.any(|entity| !sample.fields[entity.index()].is_empty())
{
return 0;
}
if sample.fields[Entity::Episode.index()].is_empty()
&& sample.fields[Entity::Special.index()].is_empty()
{
return 0;
}
cfg.path_samples_per_source
}
fn count_special_fixtures(cfg: &GenConfig) -> usize {
let bracket_factor = match cfg.bracket_mode {
BracketMode::Global => cfg.brackets.len(),
BracketMode::PerPart => cfg.brackets.len(),
};
built_in_specials().len() * bracket_factor
}
fn factorial(n: u32) -> u128 {
(1..=n as u128).product::<u128>().max(1)
}
fn generate_for_sample(
sample: &SourceSample,
cfg: &GenConfig,
vocab: &Vocab,
writer: &mut ShardWriter,
) -> Result<()> {
if cfg.include_original {
let (input_ids, attention_mask, labels) =
encode_original_sample(sample, vocab, cfg.max_length)?;
writer.add(&input_ids, &attention_mask, &labels)?;
}
if cfg.path_samples_per_source > 0 {
generate_path_context_variants(sample, cfg, vocab, writer)?;
}
if cfg.include_bio_variants && cfg.samples_per_source > 0 {
generate_sampled_variants(sample, cfg, vocab, writer)?;
return Ok(());
}
if !cfg.include_bio_variants {
return Ok(());
}
let available = ordinary_available_entities(sample);
let n = available.len();
for mask in 1usize..(1usize << n) {
let mut selected = available
.iter()
.enumerate()
.filter_map(|(idx, entity)| ((mask & (1usize << idx)) != 0).then_some(*entity))
.collect::<Vec<_>>();
permute_entities(&mut selected, 0, &mut |order| {
let mut parts: Vec<PartChoice> = Vec::with_capacity(order.len());
for_each_value_combo(order, &sample.fields, 0, &mut parts, &mut |combo| {
emit_syntax_variants(combo, cfg, vocab, writer)
})
})?;
}
Ok(())
}
fn generate_sampled_variants(
sample: &SourceSample,
cfg: &GenConfig,
vocab: &Vocab,
writer: &mut ShardWriter,
) -> Result<()> {
let mut rng = StdRng::seed_from_u64(
cfg.seed ^ ((sample.row_index as u64).wrapping_mul(0x9E37_79B9_7F4A_7C15)),
);
let available = ordinary_available_entities(sample);
if available.is_empty() {
return Ok(());
}
let mut seen = HashSet::new();
let mut emitted = 0usize;
let budget = cfg.samples_per_source;
let max_unique_attempts = budget.saturating_mul(32).max(64);
let mut attempts = 0usize;
let mut templates: Vec<Vec<PartChoice>> = Vec::new();
if let Some((title_entity, title)) = first_file_title_field(sample) {
templates.push(vec![PartChoice {
entity: title_entity,
value: title.clone(),
}]);
if let Some(season) = sample.fields[Entity::Season.index()].first() {
templates.push(vec![
PartChoice {
entity: title_entity,
value: title.clone(),
},
PartChoice {
entity: Entity::Season,
value: season.clone(),
},
]);
}
}
if let Some(episode) = sample.fields[Entity::Episode.index()].first() {
templates.push(vec![PartChoice {
entity: Entity::Episode,
value: episode.clone(),
}]);
}
if let Some(special) = sample.fields[Entity::Special.index()].first() {
templates.push(vec![PartChoice {
entity: Entity::Special,
value: special.clone(),
}]);
}
if let (Some((title_entity, title)), Some(special)) = (
first_file_title_field(sample),
sample.fields[Entity::Special.index()].first(),
) {
templates.push(vec![
PartChoice {
entity: title_entity,
value: title.clone(),
},
PartChoice {
entity: Entity::Special,
value: special.clone(),
},
]);
}
for parts in templates {
if emitted >= budget {
break;
}
emit_sample_variant(
parts,
cfg,
vocab,
writer,
&mut seen,
&mut emitted,
&mut rng,
false,
)?;
}
while emitted < budget && attempts < max_unique_attempts {
attempts += 1;
let subset_size = match rng.gen_range(0..100) {
0..=29 => 1,
30..=54 => 2,
55..=74 => 3,
75..=89 => 4.min(available.len()),
_ => available.len().min(5),
}
.max(1)
.min(available.len());
let mut chosen = available
.choose_multiple(&mut rng, subset_size)
.copied()
.collect::<Vec<_>>();
chosen.shuffle(&mut rng);
if !chosen.iter().any(|entity| {
entity.is_file_title() || matches!(entity, Entity::Episode | Entity::Special)
}) {
if let Some(fallback) = available.iter().copied().find(|entity| {
entity.is_file_title() || matches!(entity, Entity::Episode | Entity::Special)
}) {
if !chosen.contains(&fallback) {
chosen.push(fallback);
}
}
}
let mut parts = Vec::with_capacity(chosen.len());
for entity in chosen {
let values = &sample.fields[entity.index()];
let value = values.choose(&mut rng).cloned().unwrap_or_default();
parts.push(PartChoice { entity, value });
}
parts.shuffle(&mut rng);
emit_sample_variant(
parts,
cfg,
vocab,
writer,
&mut seen,
&mut emitted,
&mut rng,
false,
)?;
}
while emitted < budget {
let subset_size = match rng.gen_range(0..100) {
0..=29 => 1,
30..=54 => 2,
55..=74 => 3,
75..=89 => 4.min(available.len()),
_ => available.len().min(5),
}
.max(1)
.min(available.len());
let mut chosen = available
.choose_multiple(&mut rng, subset_size)
.copied()
.collect::<Vec<_>>();
chosen.shuffle(&mut rng);
if !chosen.iter().any(|entity| {
entity.is_file_title() || matches!(entity, Entity::Episode | Entity::Special)
}) {
if let Some(fallback) = available.iter().copied().find(|entity| {
entity.is_file_title() || matches!(entity, Entity::Episode | Entity::Special)
}) {
if !chosen.contains(&fallback) {
chosen.push(fallback);
}
}
}
let mut parts = Vec::with_capacity(chosen.len());
for entity in chosen {
let values = &sample.fields[entity.index()];
let value = values.choose(&mut rng).cloned().unwrap_or_default();
parts.push(PartChoice { entity, value });
}
parts.shuffle(&mut rng);
emit_sample_variant(
parts,
cfg,
vocab,
writer,
&mut seen,
&mut emitted,
&mut rng,
true,
)?;
}
Ok(())
}
fn emit_sample_variant(
parts: Vec<PartChoice>,
cfg: &GenConfig,
vocab: &Vocab,
writer: &mut ShardWriter,
seen: &mut HashSet<String>,
emitted: &mut usize,
rng: &mut StdRng,
allow_duplicate: bool,
) -> Result<()> {
if *emitted >= cfg.samples_per_source {
return Ok(());
}
if parts.is_empty() {
return Ok(());
}
let separators = match cfg.separator_mode {
SeparatorMode::Global => {
let sep = cfg
.separators
.choose(rng)
.cloned()
.unwrap_or_else(|| " ".to_string());
if parts.len() > 1 {
vec![sep; parts.len() - 1]
} else {
Vec::new()
}
}
SeparatorMode::PerGap => {
let mut values = Vec::with_capacity(parts.len().saturating_sub(1));
for _ in 0..parts.len().saturating_sub(1) {
values.push(
cfg.separators
.choose(rng)
.cloned()
.unwrap_or_else(|| " ".to_string()),
);
}
values
}
};
let brackets = match cfg.bracket_mode {
BracketMode::Global => {
let bracket = cfg
.brackets
.choose(rng)
.cloned()
.unwrap_or_else(|| Bracket {
name: "none".to_string(),
open: String::new(),
close: String::new(),
});
vec![bracket; parts.len()]
}
BracketMode::PerPart => {
let mut values = Vec::with_capacity(parts.len());
for _ in 0..parts.len() {
values.push(
cfg.brackets
.choose(rng)
.cloned()
.unwrap_or_else(|| Bracket {
name: "none".to_string(),
open: String::new(),
close: String::new(),
}),
);
}
values
}
};
let text = render_variant_text(&parts, &separators, &brackets);
if !allow_duplicate && !seen.insert(text) {
return Ok(());
}
let (input_ids, attention_mask, labels) =
encode_generated_sample(&parts, &separators, &brackets, vocab, cfg.max_length)?;
writer.add(&input_ids, &attention_mask, &labels)?;
*emitted += 1;
Ok(())
}
fn generate_path_context_variants(
sample: &SourceSample,
cfg: &GenConfig,
vocab: &Vocab,
writer: &mut ShardWriter,
) -> Result<()> {
if count_path_variants(sample, cfg) == 0 {
return Ok(());
}
let mut rng = StdRng::seed_from_u64(
cfg.seed
^ 0xA076_1D64_78BD_642F
^ ((sample.row_index as u64).wrapping_mul(0xE703_7ED1_A0B4_28DB)),
);
let mut seen = HashSet::new();
let mut emitted = 0usize;
let budget = cfg.path_samples_per_source;
let max_unique_attempts = budget.saturating_mul(32).max(64);
let mut attempts = 0usize;
while emitted < budget && attempts < max_unique_attempts {
attempts += 1;
if let Some(pieces) = build_path_context_pieces(sample, cfg, &mut rng) {
let text = render_labeled_pieces(&pieces);
if seen.insert(text) {
let (input_ids, attention_mask, labels) =
encode_labeled_pieces(&pieces, vocab, cfg.max_length)?;
writer.add(&input_ids, &attention_mask, &labels)?;
emitted += 1;
}
} else {
return Ok(());
}
}
while emitted < budget {
if let Some(pieces) = build_path_context_pieces(sample, cfg, &mut rng) {
let (input_ids, attention_mask, labels) =
encode_labeled_pieces(&pieces, vocab, cfg.max_length)?;
writer.add(&input_ids, &attention_mask, &labels)?;
emitted += 1;
} else {
return Ok(());
}
}
Ok(())
}
fn build_path_context_pieces(
sample: &SourceSample,
cfg: &GenConfig,
rng: &mut StdRng,
) -> Option<Vec<LabeledPiece>> {
let (title_entity, title) = choose_path_title_field(sample, rng)?;
let style = *cfg.path_styles.choose(rng)?;
let sep = style.separator();
let mut components = path_prefix_components(style, rng);
components.push(vec![entity_piece(title.clone(), title_entity)]);
let season_component = choose_path_season_component(sample, rng);
if let Some(season) = season_component {
components.push(season);
}
let use_special = if sample.fields[Entity::Episode.index()].is_empty() {
true
} else if sample.fields[Entity::Special.index()].is_empty() {
false
} else {
rng.gen_bool(0.18)
};
let endpoint = if use_special {
let special = choose_field(sample, Entity::Special, rng)?;
entity_piece(random_special_path_text(&special, rng), Entity::Special)
} else {
let episode = choose_field(sample, Entity::Episode, rng)?;
entity_piece(random_episode_path_text(&episode, rng), Entity::Episode)
};
match rng.gen_range(0..6) {
0 => components.push(path_file_component(endpoint, sample, rng)),
1 => {
components.push(vec![endpoint]);
components.push(noise_file_component(rng));
}
2 => {
components.push(vec![endpoint]);
components.push(meta_file_component(sample, rng));
}
3 => components.push(compact_file_component(endpoint, sample, rng)),
4 => components.push(grouped_release_file_component(
&title, endpoint, sample, rng,
)),
_ => {
components.push(vec![endpoint]);
if rng.gen_bool(0.55) {
components.push(noise_file_component(rng));
}
}
}
Some(join_path_components(&components, sep))
}
fn choose_field(sample: &SourceSample, entity: Entity, rng: &mut StdRng) -> Option<String> {
sample.fields[entity.index()]
.choose(rng)
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
}
fn path_prefix_components(style: PathStyle, rng: &mut StdRng) -> Vec<Vec<LabeledPiece>> {
let templates: &[&[&str]] = match style {
PathStyle::Windows => &[
&["O:", "115open", "影音", "动漫"],
&["D:", "Media", "Anime"],
&["E:", "Downloads", "Bangumi"],
&["Z:", "Library", "Anime"],
&["C:", "Archive", "completed"],
],
PathStyle::Unix => &[
&["", "mnt", "media", "anime"],
&["", "volume1", "anime"],
&["home", "media", "Bangumi"],
&["library", "anime"],
&["srv", "downloads", "anime"],
],
};
let noise_dirs = [
"整理中",
"completed",
"old",
"temp",
"115",
"Bangumi",
"Library",
"_archive",
"2024",
"misc",
];
let selected = templates.choose(rng).copied().unwrap_or(&["Anime"]);
let mut components = selected
.iter()
.map(|component| vec![o_piece((*component).to_string())])
.collect::<Vec<_>>();
let extra_count = rng.gen_range(0..=2);
for _ in 0..extra_count {
let insert_at = components.len().saturating_sub(1);
let noise = noise_dirs
.choose(rng)
.copied()
.unwrap_or("Library")
.to_string();
components.insert(insert_at, vec![o_piece(noise)]);
}
components
}
fn choose_path_season_component(
sample: &SourceSample,
rng: &mut StdRng,
) -> Option<Vec<LabeledPiece>> {
let season = if let Some(source_season) = choose_field(sample, Entity::PathSeason, rng)
.or_else(|| choose_field(sample, Entity::Season, rng))
{
random_season_path_text(&source_season, rng)
} else {
let synthetic = ["01", "Season 1", "Season 01", "S01", "第1季"];
synthetic
.choose(rng)
.copied()
.unwrap_or("Season 1")
.to_string()
};
Some(vec![entity_piece(season, Entity::PathSeason)])
}
fn path_file_component(
endpoint: LabeledPiece,
sample: &SourceSample,
rng: &mut StdRng,
) -> Vec<LabeledPiece> {
let mut pieces = Vec::new();
if rng.gen_bool(0.25) {
pieces.push(o_piece("Episode ".to_string()));
}
pieces.push(endpoint);
append_path_meta(&mut pieces, sample, rng);
pieces.push(o_piece(random_extension(rng).to_string()));
pieces
}
fn compact_file_component(
endpoint: LabeledPiece,
sample: &SourceSample,
rng: &mut StdRng,
) -> Vec<LabeledPiece> {
let mut pieces = vec![endpoint];
if rng.gen_bool(0.75) {
append_path_meta(&mut pieces, sample, rng);
}
pieces.push(o_piece(random_extension(rng).to_string()));
pieces
}
fn grouped_release_file_component(
title: &str,
endpoint: LabeledPiece,
sample: &SourceSample,
rng: &mut StdRng,
) -> Vec<LabeledPiece> {
let mut pieces = Vec::new();
if let Some(group) = choose_field(sample, Entity::Group, rng) {
pieces.push(o_piece("[".to_string()));
pieces.push(entity_piece(group, Entity::Group));
pieces.push(o_piece("] ".to_string()));
}
pieces.push(o_piece(title.trim().to_string()));
let separator = [" - ", " ", " "].choose(rng).copied().unwrap_or(" - ");
pieces.push(o_piece(separator.to_string()));
pieces.push(endpoint);
append_path_meta(&mut pieces, sample, rng);
pieces.push(o_piece(random_extension(rng).to_string()));
pieces
}
fn meta_file_component(sample: &SourceSample, rng: &mut StdRng) -> Vec<LabeledPiece> {
let mut pieces = Vec::new();
if rng.gen_bool(0.5) {
pieces.push(o_piece("metadata".to_string()));
} else {
pieces.push(o_piece("video".to_string()));
}
append_path_meta(&mut pieces, sample, rng);
pieces.push(o_piece(random_extension(rng).to_string()));
pieces
}
fn noise_file_component(rng: &mut StdRng) -> Vec<LabeledPiece> {
let stems = ["video", "default", "main", "feature", "movie", "episode"];
let stem = stems.choose(rng).copied().unwrap_or("video");
vec![o_piece(format!("{stem}{}", random_extension(rng)))]
}
fn append_path_meta(pieces: &mut Vec<LabeledPiece>, sample: &SourceSample, rng: &mut StdRng) {
if let Some(resolution) = choose_field(sample, Entity::Resolution, rng) {
if rng.gen_bool(0.72) {
pieces.push(o_piece(" [".to_string()));
pieces.push(entity_piece(resolution, Entity::Resolution));
pieces.push(o_piece("]".to_string()));
}
}
let source_count = if rng.gen_bool(0.35) { 2 } else { 1 };
for _ in 0..source_count {
if let Some(source) = choose_field(sample, Entity::Source, rng) {
if rng.gen_bool(0.62) {
pieces.push(o_piece("[".to_string()));
pieces.push(entity_piece(source, Entity::Source));
pieces.push(o_piece("]".to_string()));
}
}
}
if let Some(tag) = choose_field(sample, Entity::Tag, rng) {
if rng.gen_bool(0.55) {
pieces.push(o_piece("[".to_string()));
pieces.push(entity_piece(tag, Entity::Tag));
pieces.push(o_piece("]".to_string()));
}
}
}
fn random_episode_path_text(value: &str, rng: &mut StdRng) -> String {
let mut variants = vec![value.trim().to_string()];
if let Some(number) = first_ascii_number(value) {
variants.push(format!("{number:02}"));
variants.push(format!("E{number:02}"));
variants.push(format!("EP{number:02}"));
}
variants
.choose(rng)
.cloned()
.unwrap_or_else(|| value.trim().to_string())
}
fn random_special_path_text(value: &str, rng: &mut StdRng) -> String {
let mut variants = vec![value.trim().to_string()];
if let Some(number) = first_ascii_number(value) {
variants.push(format!("SP{number:02}"));
variants.push(format!("Special {number:02}"));
}
variants
.choose(rng)
.cloned()
.unwrap_or_else(|| value.trim().to_string())
}
fn random_season_path_text(value: &str, rng: &mut StdRng) -> String {
let mut variants = vec![value.trim().to_string()];
if let Some(number) = first_ascii_number(value) {
variants.push(format!("{number:02}"));
variants.push(format!("Season {number}"));
variants.push(format!("Season {number:02}"));
variants.push(format!("S{number:02}"));
variants.push(format!("第{number}季"));
}
variants
.choose(rng)
.cloned()
.unwrap_or_else(|| value.trim().to_string())
}
fn first_ascii_number(value: &str) -> Option<u32> {
let mut current = String::new();
for ch in value.chars() {
if ch.is_ascii_digit() {
current.push(ch);
} else if !current.is_empty() {
break;
}
}
if current.is_empty() {
None
} else {
current.parse().ok()
}
}
fn random_extension(rng: &mut StdRng) -> &'static str {
[".mkv", ".mp4", ".avi"]
.choose(rng)
.copied()
.unwrap_or(".mkv")
}
fn join_path_components(components: &[Vec<LabeledPiece>], separator: &str) -> Vec<LabeledPiece> {
let mut pieces = Vec::new();
for (idx, component) in components.iter().enumerate() {
if idx > 0 {
pieces.push(o_piece(separator.to_string()));
}
pieces.extend(component.iter().cloned());
}
pieces
}
fn render_labeled_pieces(pieces: &[LabeledPiece]) -> String {
let mut text = String::new();
for piece in pieces {
text.push_str(&piece.text);
}
text
}
fn permute_entities<F>(values: &mut [Entity], start: usize, callback: &mut F) -> Result<()>
where
F: FnMut(&[Entity]) -> Result<()>,
{
if start >= values.len() {
return callback(values);
}
for idx in start..values.len() {
values.swap(start, idx);
permute_entities(values, start + 1, callback)?;
values.swap(start, idx);
}
Ok(())
}
#[derive(Clone)]
struct PartChoice {
entity: Entity,
value: String,
}
#[derive(Clone)]
struct LabeledPiece {
text: String,
entity: Option<Entity>,
}
fn o_piece(text: String) -> LabeledPiece {
LabeledPiece { text, entity: None }
}
fn entity_piece(text: String, entity: Entity) -> LabeledPiece {
LabeledPiece {
text,
entity: Some(entity),
}
}
fn for_each_value_combo<F>(
order: &[Entity],
fields: &[Vec<String>],
idx: usize,
current: &mut Vec<PartChoice>,
callback: &mut F,
) -> Result<()>
where
F: FnMut(&[PartChoice]) -> Result<()>,
{
if idx >= order.len() {
return callback(current);
}
let entity = order[idx];
for value in &fields[entity.index()] {
current.push(PartChoice {
entity,
value: value.clone(),
});
for_each_value_combo(order, fields, idx + 1, current, callback)?;
current.pop();
}
Ok(())
}
fn emit_syntax_variants(
parts: &[PartChoice],
cfg: &GenConfig,
vocab: &Vocab,
writer: &mut ShardWriter,
) -> Result<()> {
let gaps = parts.len().saturating_sub(1);
let mut separators = Vec::with_capacity(gaps);
for_each_separator_combo(gaps, cfg, 0, &mut separators, &mut |sep_combo| {
let mut brackets = Vec::with_capacity(parts.len());
for_each_bracket_combo(parts.len(), cfg, 0, &mut brackets, &mut |bracket_combo| {
let (input_ids, attention_mask, labels) =
encode_generated_sample(parts, sep_combo, bracket_combo, vocab, cfg.max_length)?;
writer.add(&input_ids, &attention_mask, &labels)
})
})
}
fn for_each_separator_combo<F>(
gaps: usize,
cfg: &GenConfig,
idx: usize,
current: &mut Vec<String>,
callback: &mut F,
) -> Result<()>
where
F: FnMut(&[String]) -> Result<()>,
{
if gaps == 0 {
return callback(current);
}
match cfg.separator_mode {
SeparatorMode::Global => {
if idx == 0 {
for sep in &cfg.separators {
current.clear();
current.resize(gaps, sep.clone());
callback(current)?;
}
}
Ok(())
}
SeparatorMode::PerGap => {
if idx >= gaps {
return callback(current);
}
for sep in &cfg.separators {
current.push(sep.clone());
for_each_separator_combo(gaps, cfg, idx + 1, current, callback)?;
current.pop();
}
Ok(())
}
}
}
fn for_each_bracket_combo<F>(
parts: usize,
cfg: &GenConfig,
idx: usize,
current: &mut Vec<Bracket>,
callback: &mut F,
) -> Result<()>
where
F: FnMut(&[Bracket]) -> Result<()>,
{
match cfg.bracket_mode {
BracketMode::Global => {
if idx == 0 {
for bracket in &cfg.brackets {
current.clear();
current.resize(parts, bracket.clone());
callback(current)?;
}
}
Ok(())
}
BracketMode::PerPart => {
if idx >= parts {
return callback(current);
}
for bracket in &cfg.brackets {
current.push(bracket.clone());
for_each_bracket_combo(parts, cfg, idx + 1, current, callback)?;
current.pop();
}
Ok(())
}
}
}
fn render_variant_text(
parts: &[PartChoice],
separators: &[String],
brackets: &[Bracket],
) -> String {
let mut text = String::new();
for (idx, part) in parts.iter().enumerate() {
text.push_str(&brackets[idx].open);
text.push_str(&part.value);
text.push_str(&brackets[idx].close);
if idx < separators.len() {
text.push_str(&separators[idx]);
}
}
text
}
fn encode_original_sample(
sample: &SourceSample,
vocab: &Vocab,
max_length: usize,
) -> Result<(Vec<u16>, Vec<u8>, Vec<i16>)> {
let mut input_ids = vec![vocab.pad_id; max_length];
let mut attention_mask = vec![0u8; max_length];
let mut labels = vec![-100i16; max_length];
input_ids[0] = vocab.cls_id;
attention_mask[0] = 1;
let available = max_length.saturating_sub(2);
let token_count = sample.tokens.len().min(available);
for idx in 0..token_count {
input_ids[idx + 1] = token_id(vocab, &sample.tokens[idx]);
attention_mask[idx + 1] = 1;
labels[idx + 1] = label_id(&sample.labels[idx]).with_context(|| {
format!(
"unknown label '{}' on source row {} ({})",
sample.labels[idx],
sample.row_index + 1,
sample.filename
)
})?;
}
let sep_pos = token_count + 1;
input_ids[sep_pos] = vocab.sep_id;
attention_mask[sep_pos] = 1;
Ok((input_ids, attention_mask, labels))
}
fn encode_generated_sample(
parts: &[PartChoice],
separators: &[String],
brackets: &[Bracket],
vocab: &Vocab,
max_length: usize,
) -> Result<(Vec<u16>, Vec<u8>, Vec<i16>)> {
let mut input_ids = vec![vocab.pad_id; max_length];
let mut attention_mask = vec![0u8; max_length];
let mut labels = vec![-100i16; max_length];
input_ids[0] = vocab.cls_id;
attention_mask[0] = 1;
let available = max_length.saturating_sub(2);
let mut pos = 1usize;
for (idx, part) in parts.iter().enumerate() {
let bracket = &brackets[idx];
append_o_text(
&bracket.open,
vocab,
available,
&mut pos,
&mut input_ids,
&mut attention_mask,
&mut labels,
);
append_entity_text(
&part.value,
part.entity,
vocab,
available,
&mut pos,
&mut input_ids,
&mut attention_mask,
&mut labels,
)?;
append_o_text(
&bracket.close,
vocab,
available,
&mut pos,
&mut input_ids,
&mut attention_mask,
&mut labels,
);
if idx < separators.len() {
append_o_text(
&separators[idx],
vocab,
available,
&mut pos,
&mut input_ids,
&mut attention_mask,
&mut labels,
);
}
}
let sep_pos = pos.min(max_length - 1);
input_ids[sep_pos] = vocab.sep_id;
attention_mask[sep_pos] = 1;
labels[sep_pos] = -100;
Ok((input_ids, attention_mask, labels))
}
fn encode_labeled_pieces(
pieces: &[LabeledPiece],
vocab: &Vocab,
max_length: usize,
) -> Result<(Vec<u16>, Vec<u8>, Vec<i16>)> {
let mut input_ids = vec![vocab.pad_id; max_length];
let mut attention_mask = vec![0u8; max_length];
let mut labels = vec![-100i16; max_length];
input_ids[0] = vocab.cls_id;
attention_mask[0] = 1;
let available = max_length.saturating_sub(2);
let mut pos = 1usize;
for piece in pieces {
if let Some(entity) = piece.entity {
append_entity_text(
&piece.text,
entity,
vocab,
available,
&mut pos,
&mut input_ids,
&mut attention_mask,
&mut labels,
)?;
} else {
append_o_text(
&piece.text,
vocab,
available,
&mut pos,
&mut input_ids,
&mut attention_mask,
&mut labels,
);
}
}
let sep_pos = pos.min(max_length - 1);
input_ids[sep_pos] = vocab.sep_id;
attention_mask[sep_pos] = 1;
labels[sep_pos] = -100;
Ok((input_ids, attention_mask, labels))
}
fn append_o_text(
text: &str,
vocab: &Vocab,
available: usize,
pos: &mut usize,
input_ids: &mut [u16],
attention_mask: &mut [u8],
labels: &mut [i16],
) {
for ch in text.chars() {
if *pos > available {
return;
}
let token = ch.to_string();
input_ids[*pos] = token_id(vocab, &token);
attention_mask[*pos] = 1;
labels[*pos] = 0;
*pos += 1;
}
}
fn append_entity_text(
text: &str,
entity: Entity,
vocab: &Vocab,
available: usize,
pos: &mut usize,
input_ids: &mut [u16],
attention_mask: &mut [u8],
labels: &mut [i16],
) -> Result<()> {
let b = label_id(entity.b_label()).context("missing B label")?;
let i = label_id(entity.i_label()).context("missing I label")?;
let mut first = true;
for ch in text.chars() {
if *pos > available {
return Ok(());
}
let token = ch.to_string();
input_ids[*pos] = token_id(vocab, &token);
attention_mask[*pos] = 1;
labels[*pos] = if first { b } else { i };
first = false;
*pos += 1;
}
Ok(())
}
fn token_id(vocab: &Vocab, token: &str) -> u16 {
*vocab.ids.get(token).unwrap_or(&vocab.unk_id)
}
fn label_id(label: &str) -> Option<i16> {
label_ids().get(label).copied()
}
fn label_ids() -> &'static HashMap<String, i16> {
LABEL_IDS.get_or_init(load_label_ids)
}
fn load_label_ids() -> HashMap<String, i16> {
let labels = read_schema_labels().unwrap_or_else(|| {
FALLBACK_LABELS
.iter()
.map(|label| (*label).to_string())
.collect()
});
labels
.into_iter()
.enumerate()
.map(|(idx, label)| (label, idx as i16))
.collect()
}
fn read_schema_labels() -> Option<Vec<String>> {
for path in label_schema_candidates() {
let Ok(text) = fs::read_to_string(path) else {
continue;
};
let Ok(schema) = serde_json::from_str::<LabelSchema>(&text) else {
continue;
};
if schema.labels.is_empty() || schema.labels.iter().any(|label| label.trim().is_empty()) {
continue;
}
return Some(schema.labels);
}
None
}
fn label_schema_candidates() -> Vec<PathBuf> {
let mut candidates = Vec::new();
if let Ok(current_dir) = std::env::current_dir() {
candidates.push(current_dir.join("label_schema.json"));
}
candidates.push(
Path::new(env!("CARGO_MANIFEST_DIR"))
.join("..")
.join("..")
.join("label_schema.json"),
);
candidates
}
fn built_in_specials() -> Vec<String> {
let mut values = Vec::new();
values.push("Menu".to_string());
for idx in 1..=24 {
values.push(format!("Menu{idx:02}"));
values.push(format!("Menu {idx:02}"));
values.push(format!("BDMenu{idx:02}"));
values.push(format!("BD Menu{idx:02}"));
values.push(format!("Menu{idx:02}-01"));
values.push(format!("ED E{idx:02}"));
}
for idx in 1..=6 {
values.push(format!("OP{idx:02}"));
values.push(format!("NCOP{idx:02}"));
values.push(format!("NCED{idx:02}"));
}
for idx in 1..=12 {
values.push(format!("CM{idx:02}"));
values.push(format!("PV{idx:02}"));
}
values
}
fn write_npy_u16(path: &Path, data: &[u16], rows: usize, cols: usize) -> Result<()> {
let mut writer = BufWriter::new(
File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
);
write_npy_header(&mut writer, "<u2", rows, cols)?;
for value in data {
writer.write_all(&value.to_le_bytes())?;
}
Ok(())
}
fn write_npy_u8(path: &Path, data: &[u8], rows: usize, cols: usize) -> Result<()> {
let mut writer = BufWriter::new(
File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
);
write_npy_header(&mut writer, "|u1", rows, cols)?;
writer.write_all(data)?;
Ok(())
}
fn write_npy_i16(path: &Path, data: &[i16], rows: usize, cols: usize) -> Result<()> {
let mut writer = BufWriter::new(
File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
);
write_npy_header(&mut writer, "<i2", rows, cols)?;
for value in data {
writer.write_all(&value.to_le_bytes())?;
}
Ok(())
}
fn write_npy_header<W: Write>(writer: &mut W, descr: &str, rows: usize, cols: usize) -> Result<()> {
let mut header = format!(
"{{'descr': '{}', 'fortran_order': False, 'shape': ({}, {}), }}",
descr, rows, cols
)
.into_bytes();
let preamble_len = 10usize;
let pad_len = (16 - ((preamble_len + header.len() + 1) % 16)) % 16;
header.extend(std::iter::repeat(b' ').take(pad_len));
header.push(b'\n');
if header.len() > u16::MAX as usize {
bail!("npy header too large");
}
writer.write_all(b"\x93NUMPY")?;
writer.write_all(&[1, 0])?;
writer.write_all(&(header.len() as u16).to_le_bytes())?;
writer.write_all(&header)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> GenConfig {
GenConfig {
max_length: 128,
shard_size: 16,
separator_mode: SeparatorMode::Global,
bracket_mode: BracketMode::Global,
separators: vec![" ".to_string()],
brackets: vec![Bracket {
name: "none".to_string(),
open: String::new(),
close: String::new(),
}],
path_styles: vec![PathStyle::Windows],
include_original: true,
include_bio_variants: true,
samples_per_source: 0,
path_samples_per_source: 1,
seed: 105,
}
}
fn sample_without_season() -> SourceSample {
let mut fields = vec![Vec::new(); ENTITIES.len()];
fields[Entity::TitleLatin.index()] = vec!["Example Show".to_string()];
fields[Entity::PathTitleLatin.index()] = vec!["Example Show".to_string()];
fields[Entity::Episode.index()] = vec!["1".to_string()];
fields[Entity::Resolution.index()] = vec!["1080P".to_string()];
fields[Entity::Source.index()] = vec!["WEB-DL".to_string()];
SourceSample {
row_index: 7,
filename: "Example Show 1 [1080P][WEB-DL].mkv".to_string(),
tokens: Vec::new(),
labels: Vec::new(),
fields,
}
}
fn sample_with_group() -> SourceSample {
let mut sample = sample_without_season();
sample.fields[Entity::Group.index()] = vec!["Erai-raws".to_string()];
sample
}
#[test]
fn path_prefix_has_at_least_two_noise_directories() {
for style in [PathStyle::Windows, PathStyle::Unix] {
for seed in 0..32 {
let mut rng = StdRng::seed_from_u64(seed);
let components = path_prefix_components(style, &mut rng);
let non_empty_components = components
.iter()
.filter(|component| !render_labeled_pieces(component).is_empty())
.count();
assert!(
non_empty_components >= 2,
"expected at least two noise directories for {style:?}: {}",
render_labeled_pieces(&join_path_components(&components, style.separator()))
);
assert!(components
.iter()
.flatten()
.all(|piece| piece.entity.is_none()));
}
}
}
#[test]
fn fixed_label_schema_ids_match_v2_order() {
assert_eq!(label_id("O"), Some(0));
assert_eq!(label_id("B-TITLE_CHS"), Some(1));
assert_eq!(label_id("I-TITLE_MIXED"), Some(10));
assert_eq!(label_id("B-PATH_TITLE_CHS"), Some(11));
assert_eq!(label_id("I-PATH_TITLE_MIXED"), Some(20));
assert_eq!(label_id("B-PATH_SEASON"), Some(21));
assert_eq!(label_id("B-SEASON"), Some(23));
assert_eq!(label_id("B-EPISODE"), Some(25));
assert_eq!(label_id("B-GROUP"), Some(29));
assert_eq!(label_id("B-SOURCE"), Some(33));
assert_eq!(label_id("B-TAG"), Some(35));
assert_eq!(label_id("I-TAG"), Some(36));
assert_eq!(label_id("B-TITLE"), None);
}
#[test]
fn legacy_source_title_labels_canonicalize_to_mixed_schema() {
assert_eq!(canonical_bio_label("B-TITLE"), "B-TITLE_MIXED");
assert_eq!(canonical_bio_label("I-TITLE"), "I-TITLE_MIXED");
assert_eq!(canonical_bio_label("B-PATH_TITLE"), "B-PATH_TITLE_MIXED");
assert_eq!(canonical_bio_label("B-SEASON"), "B-SEASON");
}
#[test]
fn generated_entities_do_not_emit_legacy_title_labels() {
for entity in ENTITIES {
assert_ne!(entity.b_label(), "B-TITLE");
assert_ne!(entity.i_label(), "I-TITLE");
}
}
#[test]
fn extraction_preserves_file_and_path_title_candidates() {
let tokens = ["A", "/", "僕", "ら"]
.iter()
.map(|value| value.to_string())
.collect::<Vec<_>>();
let labels = ["B-TITLE_LATIN", "O", "B-PATH_TITLE_JPN", "I-PATH_TITLE_JPN"]
.iter()
.map(|value| value.to_string())
.collect::<Vec<_>>();
let fields = extract_fields(&tokens, &labels);
assert_eq!(fields[Entity::TitleLatin.index()], vec!["A"]);
assert_eq!(fields[Entity::PathTitleLatin.index()], vec!["A"]);
assert_eq!(fields[Entity::PathTitleJpn.index()], vec!["僕ら"]);
assert_eq!(fields[Entity::TitleJpn.index()], vec!["僕ら"]);
}
#[test]
fn path_context_synthesizes_season_between_title_and_episode() {
let sample = sample_without_season();
let cfg = test_config();
let mut rng = StdRng::seed_from_u64(12);
let pieces = build_path_context_pieces(&sample, &cfg, &mut rng)
.expect("expected path context pieces");
let text = render_labeled_pieces(&pieces);
assert!(text.contains("Example Show"));
assert!(
text.contains("Season")
|| text.contains("S01")
|| text.contains("第1季")
|| text.contains("01"),
"missing synthetic season directory in {text}"
);
let mut seen_title = false;
let mut seen_season_after_title = false;
let mut seen_episode_after_season = false;
for piece in &pieces {
match piece.entity {
None if !seen_title => {}
Some(Entity::PathTitleLatin) => seen_title = true,
Some(Entity::PathSeason) if seen_title => seen_season_after_title = true,
Some(Entity::Episode) if seen_season_after_title => {
seen_episode_after_season = true
}
_ => {}
}
}
assert!(seen_title);
assert!(seen_season_after_title);
assert!(seen_episode_after_season);
}
#[test]
fn path_context_can_label_bare_numeric_path_season() {
let mut sample = sample_without_season();
sample.fields[Entity::Episode.index()] = vec!["3".to_string()];
let mut cfg = test_config();
cfg.path_styles = vec![PathStyle::Unix];
let mut found = None;
for seed in 0..2048 {
let mut rng = StdRng::seed_from_u64(seed);
let pieces = build_path_context_pieces(&sample, &cfg, &mut rng)
.expect("expected path context pieces");
let text = render_labeled_pieces(&pieces);
if text.contains("Example Show/01/03.mkv") {
found = Some(pieces);
break;
}
}
let pieces = found.expect("expected a Title/01/03.mkv-style path context");
assert!(pieces
.iter()
.any(|piece| piece.text == "01" && piece.entity == Some(Entity::PathSeason)));
assert!(pieces
.iter()
.any(|piece| piece.text == "03" && piece.entity == Some(Entity::Episode)));
}
#[test]
fn path_season_variants_include_common_directory_forms() {
let mut variants = HashSet::new();
for seed in 0..128 {
let mut rng = StdRng::seed_from_u64(seed);
variants.insert(random_season_path_text("S01", &mut rng));
}
assert!(variants.contains("S01"));
assert!(variants.contains("01"));
assert!(variants.contains("Season 1"));
assert!(variants.contains("Season 01"));
}
#[test]
fn grouped_path_file_labels_group_but_not_duplicate_title() {
let sample = sample_with_group();
let mut rng = StdRng::seed_from_u64(21);
let endpoint = entity_piece("01".to_string(), Entity::Episode);
let pieces = grouped_release_file_component("Example Show", endpoint, &sample, &mut rng);
let text = render_labeled_pieces(&pieces);
assert!(text.contains("[Erai-raws]"));
assert!(text.contains("Example Show"));
assert!(text.contains("01"));
assert!(pieces
.iter()
.any(|piece| piece.entity == Some(Entity::Group)));
assert!(pieces
.iter()
.any(|piece| piece.entity == Some(Entity::Episode)));
assert!(pieces
.iter()
.any(|piece| piece.text == "Example Show" && piece.entity.is_none()));
}
}