ModerRAS's picture
Add Rust encoded dataset cache
c705a32
raw
history blame
28 kB
use anyhow::{bail, Context, Result};
use clap::Parser;
use rand::rngs::StdRng;
use rand::seq::SliceRandom;
use rand::SeedableRng;
use rayon::prelude::*;
use regex::Regex;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::fs::{self, File};
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::{Path, PathBuf};
use std::sync::OnceLock;
use std::time::Instant;
const FALLBACK_LABELS: [&str; 37] = [
"O",
"B-TITLE_CHS",
"I-TITLE_CHS",
"B-TITLE_CHT",
"I-TITLE_CHT",
"B-TITLE_JPN",
"I-TITLE_JPN",
"B-TITLE_LATIN",
"I-TITLE_LATIN",
"B-TITLE_MIXED",
"I-TITLE_MIXED",
"B-PATH_TITLE_CHS",
"I-PATH_TITLE_CHS",
"B-PATH_TITLE_CHT",
"I-PATH_TITLE_CHT",
"B-PATH_TITLE_JPN",
"I-PATH_TITLE_JPN",
"B-PATH_TITLE_LATIN",
"I-PATH_TITLE_LATIN",
"B-PATH_TITLE_MIXED",
"I-PATH_TITLE_MIXED",
"B-PATH_SEASON",
"I-PATH_SEASON",
"B-SEASON",
"I-SEASON",
"B-EPISODE",
"I-EPISODE",
"B-SPECIAL",
"I-SPECIAL",
"B-GROUP",
"I-GROUP",
"B-RESOLUTION",
"I-RESOLUTION",
"B-SOURCE",
"I-SOURCE",
"B-TAG",
"I-TAG",
];
const SOURCE_TOKEN_PATTERN: &str = r"WEB[-_ ]?DL|WEB[-_ ]?Rip|BDRip|BluRay|BDMV|BD|DVDRip|DVD|TVRip|HDTV|Netflix|NF|AMZN|Baha|CR|ABEMA|DSNP|U[-_ ]?NEXT|Hulu|AT[-_ ]?X|x26[45]|h\.?26[45]|HEVC|AVC|AV1|AAC\d*(?:\.\d+)?|AAC|FLAC|MP3|DTS|Opus|CHS|CHT|GB|BIG5|JPN?|JPSC|JPTC|繁中|简中";
static RESOLUTION_RE: OnceLock<Regex> = OnceLock::new();
static SOURCE_RE: OnceLock<Regex> = OnceLock::new();
static SOURCE_TAG_RE: OnceLock<Regex> = OnceLock::new();
static SPECIAL_TAG_RE: OnceLock<Regex> = OnceLock::new();
static SPECIAL_CODE_RE: OnceLock<Regex> = OnceLock::new();
#[derive(Parser, Debug)]
#[command(
about = "Build split train/eval encoded AniFileBERT shard caches",
version
)]
struct Args {
#[arg(long)]
input: PathBuf,
#[arg(long)]
vocab_file: PathBuf,
#[arg(long)]
output_dir: PathBuf,
#[arg(long, default_value = "label_schema.json")]
label_schema_file: PathBuf,
#[arg(long, default_value_t = 128)]
max_length: usize,
#[arg(long, default_value_t = 25_000)]
shard_size: usize,
#[arg(long, default_value_t = 0)]
limit_rows: usize,
#[arg(long, default_value_t = 0.98)]
train_split: f64,
#[arg(long, default_value_t = 42)]
seed: u64,
#[arg(long)]
no_shuffle: bool,
#[arg(long, default_value_t = 0)]
threads: usize,
}
#[derive(Debug, Deserialize)]
struct LabelSchema {
labels: Vec<String>,
}
#[derive(Clone)]
struct SourceRow {
row_index: usize,
raw_line: String,
filename: Option<String>,
tokens: Vec<String>,
labels: Vec<String>,
tokenizer_variant: Option<String>,
}
#[derive(Clone)]
struct Vocab {
ids: HashMap<String, u16>,
pad_id: u16,
unk_id: u16,
cls_id: u16,
sep_id: u16,
}
#[derive(Clone)]
struct EncodeContext {
vocab: Vocab,
label_ids: HashMap<String, i16>,
max_length: usize,
}
#[derive(Serialize)]
struct ShardManifest {
rows: usize,
input_ids: String,
attention_mask: String,
labels: String,
}
#[derive(Serialize)]
struct SplitSummary {
split: String,
rows: usize,
shards: usize,
directory: String,
}
fn main() -> Result<()> {
let args = Args::parse();
if args.max_length < 4 {
bail!("--max-length must be at least 4");
}
if args.shard_size == 0 {
bail!("--shard-size must be positive");
}
if !(0.0..1.0).contains(&args.train_split) {
bail!("--train-split must be > 0 and < 1");
}
if args.threads > 0 {
rayon::ThreadPoolBuilder::new()
.num_threads(args.threads)
.build_global()
.context("failed to configure rayon thread pool")?;
}
let started = Instant::now();
let vocab = load_vocab(&args.vocab_file)?;
let label_ids = load_label_ids(&args.label_schema_file)?;
let mut rows = load_rows(&args.input, args.limit_rows)?;
if rows.len() < 2 {
bail!("need at least two rows to build train/eval cache");
}
if !args.no_shuffle {
let mut rng = StdRng::seed_from_u64(args.seed);
rows.shuffle(&mut rng);
}
let split_idx = ((rows.len() as f64) * args.train_split) as usize;
let split_idx = split_idx.max(1).min(rows.len() - 1);
let (train_rows, eval_rows) = rows.split_at(split_idx);
fs::create_dir_all(&args.output_dir).with_context(|| {
format!(
"failed to create output directory {}",
args.output_dir.display()
)
})?;
let context = EncodeContext {
vocab,
label_ids,
max_length: args.max_length,
};
let train_summary = write_split(
"train",
train_rows,
&args.output_dir,
&context,
args.shard_size,
)?;
let eval_summary = write_split(
"eval",
eval_rows,
&args.output_dir,
&context,
args.shard_size,
)?;
write_eval_records(eval_rows, &args.output_dir.join("eval_records.jsonl"))?;
let manifest = json!({
"format": "anifilebert.encoded_dataset_cache.v1",
"input": args.input,
"vocab_file": args.vocab_file,
"label_schema_file": args.label_schema_file,
"output_dir": args.output_dir,
"max_length": args.max_length,
"shard_size": args.shard_size,
"limit_rows": args.limit_rows,
"source_rows": train_rows.len() + eval_rows.len(),
"train_split": args.train_split,
"seed": args.seed,
"shuffle": !args.no_shuffle,
"train": train_summary,
"eval": eval_summary,
"eval_records": "eval_records.jsonl",
"elapsed_seconds": started.elapsed().as_secs_f64(),
});
let manifest_path = args.output_dir.join("manifest.json");
fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?)
.with_context(|| format!("failed to write {}", manifest_path.display()))?;
println!("{}", serde_json::to_string_pretty(&manifest)?);
Ok(())
}
fn load_vocab(path: &Path) -> Result<Vocab> {
let text = fs::read_to_string(path)
.with_context(|| format!("failed to read vocab {}", path.display()))?;
let raw: HashMap<String, u64> =
serde_json::from_str(&text).with_context(|| format!("invalid vocab {}", path.display()))?;
let mut ids = HashMap::with_capacity(raw.len());
for (token, id) in raw {
if id > u16::MAX as u64 {
bail!("vocab id for token '{token}' exceeds u16: {id}");
}
ids.insert(token, id as u16);
}
let special = |token: &str| -> Result<u16> {
ids.get(token)
.copied()
.with_context(|| format!("vocab is missing special token {token}"))
};
Ok(Vocab {
pad_id: special("[PAD]")?,
unk_id: special("[UNK]")?,
cls_id: special("[CLS]")?,
sep_id: special("[SEP]")?,
ids,
})
}
fn load_label_ids(path: &Path) -> Result<HashMap<String, i16>> {
let labels = match fs::read_to_string(path) {
Ok(text) => {
serde_json::from_str::<LabelSchema>(&text)
.with_context(|| format!("invalid label schema {}", path.display()))?
.labels
}
Err(_) => FALLBACK_LABELS
.iter()
.map(|label| (*label).to_string())
.collect(),
};
if labels.is_empty() {
bail!("label schema has no labels");
}
Ok(labels
.into_iter()
.enumerate()
.map(|(idx, label)| (label, idx as i16))
.collect())
}
fn load_rows(path: &Path, limit_rows: usize) -> Result<Vec<SourceRow>> {
let file = File::open(path).with_context(|| format!("failed to open {}", path.display()))?;
let reader = BufReader::new(file);
let mut rows = Vec::new();
for (idx, line) in reader.lines().enumerate() {
if limit_rows > 0 && rows.len() >= limit_rows {
break;
}
let raw_line = line.with_context(|| format!("failed reading line {}", idx + 1))?;
if raw_line.trim().is_empty() {
continue;
}
let value: Value = serde_json::from_str(&raw_line)
.with_context(|| format!("failed to parse JSONL line {}", idx + 1))?;
let tokens = string_array_field(&value, "tokens", idx + 1)?;
let labels = string_array_field(&value, "labels", idx + 1)?;
if tokens.len() != labels.len() {
bail!(
"line {} has mismatched token/label lengths: {} vs {}",
idx + 1,
tokens.len(),
labels.len()
);
}
rows.push(SourceRow {
row_index: idx,
raw_line,
filename: value
.get("filename")
.and_then(Value::as_str)
.map(ToOwned::to_owned),
tokens,
labels,
tokenizer_variant: value
.get("tokenizer_variant")
.and_then(Value::as_str)
.map(ToOwned::to_owned),
});
}
Ok(rows)
}
fn string_array_field(value: &Value, field: &str, line_no: usize) -> Result<Vec<String>> {
let array = value
.get(field)
.and_then(Value::as_array)
.with_context(|| format!("line {line_no} missing array field '{field}'"))?;
array
.iter()
.map(|item| match item {
Value::String(text) => Ok(text.clone()),
other => Ok(match other {
Value::Null => String::new(),
_ => other.to_string(),
}),
})
.collect()
}
fn write_split(
split: &str,
rows: &[SourceRow],
output_dir: &Path,
context: &EncodeContext,
shard_size: usize,
) -> Result<SplitSummary> {
let split_dir = output_dir.join(split);
fs::create_dir_all(&split_dir)
.with_context(|| format!("failed to create {}", split_dir.display()))?;
let chunks = rows
.chunks(shard_size)
.enumerate()
.collect::<Vec<(usize, &[SourceRow])>>();
let shards = chunks
.par_iter()
.map(|(shard_idx, chunk)| write_shard(split, *shard_idx, chunk, &split_dir, context))
.collect::<Result<Vec<_>>>()?;
let manifest = json!({
"format": "anifilebert.virtual_dataset.shards.v1",
"generated_by": "tools/encoded_dataset_cache",
"split": split,
"max_length": context.max_length,
"total_rows": rows.len(),
"shards": shards,
});
let manifest_path = split_dir.join("manifest.json");
fs::write(&manifest_path, serde_json::to_string_pretty(&manifest)?)
.with_context(|| format!("failed to write {}", manifest_path.display()))?;
Ok(SplitSummary {
split: split.to_string(),
rows: rows.len(),
shards: chunks.len(),
directory: split.to_string(),
})
}
fn write_shard(
split: &str,
shard_idx: usize,
rows: &[SourceRow],
split_dir: &Path,
context: &EncodeContext,
) -> Result<ShardManifest> {
let capacity = rows.len().saturating_mul(context.max_length);
let mut input_ids = Vec::with_capacity(capacity);
let mut attention_mask = Vec::with_capacity(capacity);
let mut labels = Vec::with_capacity(capacity);
for row in rows {
let encoded = encode_row(row, context)
.with_context(|| format!("failed to encode source line {}", row.row_index + 1))?;
input_ids.extend_from_slice(&encoded.0);
attention_mask.extend_from_slice(&encoded.1);
labels.extend_from_slice(&encoded.2);
}
let base = format!("part-{split}-s{shard_idx:06}");
let input_name = format!("{base}.input_ids.npy");
let mask_name = format!("{base}.attention_mask.npy");
let label_name = format!("{base}.labels.npy");
write_npy_u16(
&split_dir.join(&input_name),
&input_ids,
rows.len(),
context.max_length,
)?;
write_npy_u8(
&split_dir.join(&mask_name),
&attention_mask,
rows.len(),
context.max_length,
)?;
write_npy_i16(
&split_dir.join(&label_name),
&labels,
rows.len(),
context.max_length,
)?;
Ok(ShardManifest {
rows: rows.len(),
input_ids: input_name,
attention_mask: mask_name,
labels: label_name,
})
}
fn encode_row(row: &SourceRow, context: &EncodeContext) -> Result<(Vec<u16>, Vec<u8>, Vec<i16>)> {
let (tokens, labels) = labels_for_char_tokenizer(row);
let mut input_ids = vec![context.vocab.pad_id; context.max_length];
let mut attention_mask = vec![0u8; context.max_length];
let mut label_ids = vec![-100i16; context.max_length];
input_ids[0] = context.vocab.cls_id;
attention_mask[0] = 1;
let available = context.max_length.saturating_sub(2);
let token_count = tokens.len().min(labels.len()).min(available);
for idx in 0..token_count {
input_ids[idx + 1] = token_id(&context.vocab, &tokens[idx]);
attention_mask[idx + 1] = 1;
let label = canonical_bio_label(&labels[idx]);
label_ids[idx + 1] = context
.label_ids
.get(&label)
.copied()
.with_context(|| format!("unknown label '{label}'"))?;
}
let sep_pos = token_count + 1;
input_ids[sep_pos] = context.vocab.sep_id;
attention_mask[sep_pos] = 1;
Ok((input_ids, attention_mask, label_ids))
}
fn labels_for_char_tokenizer(row: &SourceRow) -> (Vec<String>, Vec<String>) {
if row.tokenizer_variant.as_deref() == Some("char") {
if let Some(filename) = row.filename.as_deref() {
let filename_chars = chars_as_strings(filename);
if row.tokens == filename_chars {
return (row.tokens.clone(), row.labels.clone());
}
}
}
if let Some(filename) = row.filename.as_deref() {
if let Some(projected) = project_labels_from_filename(filename, &row.tokens, &row.labels) {
let (tokens, mut labels) = projected;
repair_structural_meta_labels(filename, &mut labels);
return (tokens, labels);
}
}
let (tokens, mut labels) = align_tokens_to_chars(&row.tokens, &row.labels);
if let Some(filename) = row.filename.as_deref() {
repair_structural_meta_labels(filename, &mut labels);
}
(tokens, labels)
}
fn project_labels_from_filename(
filename: &str,
source_tokens: &[String],
source_labels: &[String],
) -> Option<(Vec<String>, Vec<String>)> {
let offsets = token_offsets_in_text(filename, source_tokens)?;
if offsets.len() != source_labels.len() {
return None;
}
let char_len = filename.chars().count();
let mut char_entities: Vec<Option<String>> = vec![None; char_len];
for ((token, label), (mut start, mut end)) in source_tokens
.iter()
.zip(source_labels.iter())
.zip(offsets.into_iter())
{
let Some(entity) = bio_entity(label) else {
continue;
};
if is_wrapped_token(token) && end > start + 1 {
start += 1;
end -= 1;
}
for pos in start..end.min(char_entities.len()) {
char_entities[pos] = Some(entity.clone());
}
}
let tokens = chars_as_strings(filename);
let mut labels = Vec::with_capacity(tokens.len());
let mut active_entity: Option<String> = None;
for entity in char_entities {
match entity {
Some(entity) => {
let prefix = if active_entity.as_deref() == Some(entity.as_str()) {
"I"
} else {
"B"
};
labels.push(format!("{prefix}-{entity}"));
active_entity = Some(entity);
}
None => {
labels.push("O".to_string());
active_entity = None;
}
}
}
Some((tokens, labels))
}
fn token_offsets_in_text(text: &str, tokens: &[String]) -> Option<Vec<(usize, usize)>> {
let mut offsets = Vec::with_capacity(tokens.len());
let mut cursor = 0usize;
for token in tokens {
if token.is_empty() {
let char_cursor = char_index_at_byte(text, cursor);
offsets.push((char_cursor, char_cursor));
continue;
}
let relative = text.get(cursor..)?.find(token)?;
let start_byte = cursor + relative;
let end_byte = start_byte + token.len();
offsets.push((
char_index_at_byte(text, start_byte),
char_index_at_byte(text, end_byte),
));
cursor = end_byte;
}
Some(offsets)
}
fn align_tokens_to_chars(tokens: &[String], labels: &[String]) -> (Vec<String>, Vec<String>) {
let mut char_tokens = Vec::new();
let mut char_labels = Vec::new();
for (token, label) in tokens.iter().zip(labels.iter()) {
let chars = chars_as_strings(token);
if chars.is_empty() {
continue;
}
let label = label.as_str();
if label.starts_with("B-") {
let entity = label
.split_once('-')
.map(|(_, entity)| entity)
.unwrap_or("");
char_labels.push(label.to_string());
char_labels.extend((1..chars.len()).map(|_| format!("I-{entity}")));
} else if label.starts_with("I-") {
char_labels.extend((0..chars.len()).map(|_| label.to_string()));
} else {
char_labels.extend((0..chars.len()).map(|_| label.to_string()));
}
char_tokens.extend(chars);
}
(char_tokens, char_labels)
}
fn repair_structural_meta_labels(text: &str, labels: &mut [String]) {
if labels.len() != text.chars().count() {
return;
}
let episode_end = first_episode_span_end(labels);
for (inner_start, inner_end) in bracket_inner_spans(text) {
let bracket_start = inner_start.saturating_sub(1);
if bracket_start < episode_end {
continue;
}
let inner = chars_range_to_string(text, inner_start, inner_end);
let (trim_start, trim_end) = trimmed_bounds(&inner);
if trim_start >= trim_end {
continue;
}
let clean = chars_slice_to_string(&inner, trim_start, trim_end);
let clean_start = inner_start + trim_start;
let clean_end = inner_start + trim_end;
if special_tag_re().is_match(&clean) || special_code_re().is_match(&clean) {
label_span_if_safe(labels, clean_start, clean_end, "SPECIAL");
continue;
}
if source_tag_re().is_match(&clean) {
label_span_if_safe(labels, clean_start, clean_end, "SOURCE");
continue;
}
for mat in resolution_re().find_iter(&inner) {
if !has_ascii_token_boundaries(&inner, mat.start(), mat.end()) {
continue;
}
let start = inner_start + char_index_at_byte(&inner, mat.start());
let end = inner_start + char_index_at_byte(&inner, mat.end());
label_span_if_safe(labels, start, end, "RESOLUTION");
}
for mat in source_re().find_iter(&inner) {
if !has_ascii_token_boundaries(&inner, mat.start(), mat.end()) {
continue;
}
let start = inner_start + char_index_at_byte(&inner, mat.start());
let end = inner_start + char_index_at_byte(&inner, mat.end());
label_span_if_safe(labels, start, end, "SOURCE");
}
}
for mat in resolution_re().find_iter(text) {
if !has_ascii_token_boundaries(text, mat.start(), mat.end()) {
continue;
}
let start = char_index_at_byte(text, mat.start());
if start < episode_end {
continue;
}
let end = char_index_at_byte(text, mat.end());
label_span_if_safe(labels, start, end, "RESOLUTION");
}
for mat in source_re().find_iter(text) {
if !has_ascii_token_boundaries(text, mat.start(), mat.end()) {
continue;
}
let start = char_index_at_byte(text, mat.start());
if start < episode_end {
continue;
}
let end = char_index_at_byte(text, mat.end());
label_span_if_safe(labels, start, end, "SOURCE");
}
}
fn first_episode_span_end(labels: &[String]) -> usize {
let mut idx = 0usize;
while idx < labels.len() {
if label_entity(&labels[idx]) == Some("EPISODE") {
let mut end = idx + 1;
while end < labels.len() && label_entity(&labels[end]) == Some("EPISODE") {
end += 1;
}
return end;
}
idx += 1;
}
0
}
fn bracket_inner_spans(text: &str) -> Vec<(usize, usize)> {
let chars = text.chars().collect::<Vec<_>>();
let mut spans = Vec::new();
let mut idx = 0usize;
while idx < chars.len() {
let close = match chars[idx] {
'[' => ']',
'(' => ')',
'【' => '】',
'《' => '》',
_ => {
idx += 1;
continue;
}
};
if let Some(relative_end) = chars[idx + 1..].iter().position(|ch| *ch == close) {
let end = idx + 1 + relative_end;
spans.push((idx + 1, end));
idx = end + 1;
} else {
idx += 1;
}
}
spans
}
fn trimmed_bounds(text: &str) -> (usize, usize) {
let chars = text.chars().collect::<Vec<_>>();
let mut start = 0usize;
let mut end = chars.len();
while start < end && chars[start].is_whitespace() {
start += 1;
}
while end > start && chars[end - 1].is_whitespace() {
end -= 1;
}
(start, end)
}
fn chars_range_to_string(text: &str, start: usize, end: usize) -> String {
text.chars()
.skip(start)
.take(end.saturating_sub(start))
.collect()
}
fn chars_slice_to_string(text: &str, start: usize, end: usize) -> String {
text.chars()
.skip(start)
.take(end.saturating_sub(start))
.collect()
}
fn label_span_if_safe(labels: &mut [String], start: usize, end: usize, entity: &str) {
if start >= end || end > labels.len() {
return;
}
if labels[start..end].iter().any(|label| {
matches!(
label_entity(label),
Some("GROUP" | "EPISODE" | "SEASON" | "PATH_SEASON")
)
}) {
return;
}
let previous_same = start > 0 && label_entity(&labels[start - 1]) == Some(entity);
let mut first = !previous_same;
for label in labels.iter_mut().take(end).skip(start) {
*label = if first {
format!("B-{entity}")
} else {
format!("I-{entity}")
};
first = false;
}
}
fn has_ascii_token_boundaries(text: &str, start: usize, end: usize) -> bool {
let previous_ok = text[..start]
.chars()
.next_back()
.map(|ch| !ch.is_ascii_alphanumeric())
.unwrap_or(true);
let next_ok = text[end..]
.chars()
.next()
.map(|ch| !ch.is_ascii_alphanumeric())
.unwrap_or(true);
previous_ok && next_ok
}
fn label_entity(label: &str) -> Option<&str> {
let (prefix, entity) = label.split_once('-')?;
if prefix == "B" || prefix == "I" {
Some(entity)
} else {
None
}
}
fn resolution_re() -> &'static Regex {
RESOLUTION_RE
.get_or_init(|| Regex::new(r"(?i)(?:\d{3,4}p|\d[kK]|\d{3,4}[xX×]\d{3,4})").unwrap())
}
fn source_re() -> &'static Regex {
SOURCE_RE.get_or_init(|| Regex::new(&format!(r"(?i)(?:{SOURCE_TOKEN_PATTERN})")).unwrap())
}
fn source_tag_re() -> &'static Regex {
SOURCE_TAG_RE.get_or_init(|| {
Regex::new(&format!(
r"(?i)^(?:{SOURCE_TOKEN_PATTERN})(?:\s*(?:[&+/,_-]|,\s*)\s*(?:{SOURCE_TOKEN_PATTERN}))*$"
))
.unwrap()
})
}
fn special_tag_re() -> &'static Regex {
SPECIAL_TAG_RE.get_or_init(|| {
Regex::new(r"(?i)^(?:檢索|检索|搜索|搜寻|搜尋|别名|別名|alias|search|keyword)\s*[::].+")
.unwrap()
})
}
fn special_code_re() -> &'static Regex {
SPECIAL_CODE_RE.get_or_init(|| {
Regex::new(r"(?i)^(?:NCOP|NCED|OP|ED|PV|CM)\d*$|^IV\d+$|^(?:OVA|OAD|SP)\d*$").unwrap()
})
}
fn chars_as_strings(text: &str) -> Vec<String> {
text.chars().map(|ch| ch.to_string()).collect()
}
fn char_index_at_byte(text: &str, byte_index: usize) -> usize {
text[..byte_index].chars().count()
}
fn bio_entity(label: &str) -> Option<String> {
let (prefix, entity) = label.split_once('-')?;
if prefix == "B" || prefix == "I" {
Some(entity.to_string())
} else {
None
}
}
fn is_wrapped_token(token: &str) -> bool {
let mut chars = token.chars();
let Some(first) = chars.next() else {
return false;
};
let Some(last) = token.chars().last() else {
return false;
};
matches!(first, '[' | '【' | '(' | '《') && matches!(last, ']' | '】' | ')' | '》')
}
fn canonical_bio_label(label: &str) -> String {
let Some((prefix, entity)) = label.split_once('-') else {
return if label == "O" {
"O".to_string()
} else {
label.to_string()
};
};
if prefix != "B" && prefix != "I" {
return label.to_string();
}
let canonical_entity = match entity {
"TITLE" => "TITLE_MIXED",
"PATH_TITLE" => "PATH_TITLE_MIXED",
other => other,
};
format!("{prefix}-{canonical_entity}")
}
fn token_id(vocab: &Vocab, token: &str) -> u16 {
*vocab.ids.get(token).unwrap_or(&vocab.unk_id)
}
fn write_eval_records(rows: &[SourceRow], path: &Path) -> Result<()> {
let mut writer = BufWriter::new(
File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
);
for row in rows {
writer.write_all(row.raw_line.as_bytes())?;
writer.write_all(b"\n")?;
}
Ok(())
}
fn write_npy_u16(path: &Path, data: &[u16], rows: usize, cols: usize) -> Result<()> {
let mut writer = BufWriter::new(
File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
);
write_npy_header(&mut writer, "<u2", rows, cols)?;
for value in data {
writer.write_all(&value.to_le_bytes())?;
}
Ok(())
}
fn write_npy_u8(path: &Path, data: &[u8], rows: usize, cols: usize) -> Result<()> {
let mut writer = BufWriter::new(
File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
);
write_npy_header(&mut writer, "|u1", rows, cols)?;
writer.write_all(data)?;
Ok(())
}
fn write_npy_i16(path: &Path, data: &[i16], rows: usize, cols: usize) -> Result<()> {
let mut writer = BufWriter::new(
File::create(path).with_context(|| format!("failed to create {}", path.display()))?,
);
write_npy_header(&mut writer, "<i2", rows, cols)?;
for value in data {
writer.write_all(&value.to_le_bytes())?;
}
Ok(())
}
fn write_npy_header<W: Write>(writer: &mut W, descr: &str, rows: usize, cols: usize) -> Result<()> {
let mut header = format!(
"{{'descr': '{}', 'fortran_order': False, 'shape': ({}, {}), }}",
descr, rows, cols
)
.into_bytes();
let preamble_len = 10usize;
let pad_len = (16 - ((preamble_len + header.len() + 1) % 16)) % 16;
header.extend(std::iter::repeat(b' ').take(pad_len));
header.push(b'\n');
if header.len() > u16::MAX as usize {
bail!("npy header too large");
}
writer.write_all(b"\x93NUMPY")?;
writer.write_all(&[1, 0])?;
writer.write_all(&(header.len() as u16).to_le_bytes())?;
writer.write_all(&header)?;
Ok(())
}