ModerRAS commited on
Commit
93f322e
·
1 Parent(s): 494b24c

Verify low-frequency DMHY generated output

Browse files
tools/rust_dmhy_template_apply/README.md CHANGED
@@ -43,6 +43,16 @@ cargo run --release --manifest-path tools\rust_dmhy_template_apply\Cargo.toml --
43
  --threads 24
44
  ```
45
 
 
 
 
 
 
 
 
 
 
 
46
  Optional controls:
47
 
48
  ```powershell
 
43
  --threads 24
44
  ```
45
 
46
+ Verify the generated training output has no low-frequency blocking warnings:
47
+
48
+ ```powershell
49
+ cargo run --release --manifest-path tools\rust_dmhy_template_apply\Cargo.toml -- `
50
+ --verify-generated-output `
51
+ --input reports\dmhy_weak.template_generated.top5000.rust.jsonl `
52
+ --recipes reports\dmhy_template_recipes.full_top5000.seed.jsonl `
53
+ --audit-max-count 50
54
+ ```
55
+
56
  Optional controls:
57
 
58
  ```powershell
tools/rust_dmhy_template_apply/src/main.rs CHANGED
@@ -19,6 +19,8 @@ struct Args {
19
  cluster: bool,
20
  #[arg(long)]
21
  audit_low_frequency: bool,
 
 
22
  #[arg(long, default_value = "datasets/AnimeName/dmhy_list.jsonl")]
23
  input: PathBuf,
24
  #[arg(long, default_value = "reports/dmhy_template_recipes.seed.jsonl")]
@@ -92,7 +94,7 @@ struct Recipe {
92
  count: Option<u64>,
93
  }
94
 
95
- #[derive(Debug, Clone, Serialize)]
96
  struct Record {
97
  filename: String,
98
  tokens: Vec<String>,
@@ -237,6 +239,9 @@ fn main() -> Result<()> {
237
  if args.audit_low_frequency {
238
  return run_low_frequency_audit(&args);
239
  }
 
 
 
240
  if args.expand != "all" && args.expand != "sample" {
241
  bail!("--expand must be all or sample");
242
  }
@@ -704,6 +709,77 @@ fn run_low_frequency_audit(args: &Args) -> Result<()> {
704
  Ok(())
705
  }
706
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
  fn entity_spans(tokens: &[String], labels: &[String]) -> Vec<Value> {
708
  let mut spans = Vec::new();
709
  let mut current_label: Option<String> = None;
 
19
  cluster: bool,
20
  #[arg(long)]
21
  audit_low_frequency: bool,
22
+ #[arg(long)]
23
+ verify_generated_output: bool,
24
  #[arg(long, default_value = "datasets/AnimeName/dmhy_list.jsonl")]
25
  input: PathBuf,
26
  #[arg(long, default_value = "reports/dmhy_template_recipes.seed.jsonl")]
 
94
  count: Option<u64>,
95
  }
96
 
97
+ #[derive(Debug, Clone, Serialize, Deserialize)]
98
  struct Record {
99
  filename: String,
100
  tokens: Vec<String>,
 
239
  if args.audit_low_frequency {
240
  return run_low_frequency_audit(&args);
241
  }
242
+ if args.verify_generated_output {
243
+ return run_verify_generated_output(&args);
244
+ }
245
  if args.expand != "all" && args.expand != "sample" {
246
  bail!("--expand must be all or sample");
247
  }
 
709
  Ok(())
710
  }
711
 
712
+ fn run_verify_generated_output(args: &Args) -> Result<()> {
713
+ let file = File::open(&args.input)
714
+ .with_context(|| format!("generated JSONL not found: {}", args.input.display()))?;
715
+ let recipes_by_id: HashMap<String, u64> = load_recipes(args)?
716
+ .into_values()
717
+ .map(|recipe| (recipe.template_id, recipe.count.unwrap_or(0)))
718
+ .collect();
719
+ let mut rows = 0usize;
720
+ let mut low_frequency_rows = 0usize;
721
+ let mut warning_counts: HashMap<String, usize> = HashMap::new();
722
+ let mut examples: HashMap<String, Vec<Value>> = HashMap::new();
723
+
724
+ for (line_number, line) in BufReader::new(file).lines().enumerate() {
725
+ let line = line?;
726
+ if line.trim().is_empty() {
727
+ continue;
728
+ }
729
+ let record: Record = serde_json::from_str(&line).with_context(|| {
730
+ format!(
731
+ "invalid generated record at {}:{}",
732
+ args.input.display(),
733
+ line_number + 1
734
+ )
735
+ })?;
736
+ rows += 1;
737
+ let count = recipes_by_id
738
+ .get(&record.template_id)
739
+ .copied()
740
+ .unwrap_or(u64::MAX);
741
+ if count > args.audit_max_count {
742
+ continue;
743
+ }
744
+ low_frequency_rows += 1;
745
+ for warning in audit_warnings(&record) {
746
+ if !matches!(
747
+ warning.as_str(),
748
+ "hash_labeled" | "multiple_title_spans" | "no_title" | "path_retained"
749
+ ) {
750
+ continue;
751
+ }
752
+ *warning_counts.entry(warning.clone()).or_default() += 1;
753
+ let bucket = examples.entry(warning).or_default();
754
+ if bucket.len() < 5 {
755
+ bucket.push(json!({
756
+ "template_id": record.template_id,
757
+ "template_count": count,
758
+ "filename": record.filename,
759
+ "spans": entity_spans(&record.tokens, &record.labels),
760
+ }));
761
+ }
762
+ }
763
+ }
764
+
765
+ let manifest = json!({
766
+ "generated_at": Utc::now().to_rfc3339(),
767
+ "input": args.input.to_string_lossy(),
768
+ "recipes": args.recipes.to_string_lossy(),
769
+ "audit_max_count": args.audit_max_count,
770
+ "rows": rows,
771
+ "low_frequency_rows": low_frequency_rows,
772
+ "blocking_warning_counts": warning_counts,
773
+ "examples": examples,
774
+ "implementation": "rust_dmhy_generated_output_verify"
775
+ });
776
+ println!("{}", serde_json::to_string_pretty(&manifest)?);
777
+ if !warning_counts.is_empty() {
778
+ bail!("generated output still has low-frequency blocking warnings");
779
+ }
780
+ Ok(())
781
+ }
782
+
783
  fn entity_spans(tokens: &[String], labels: &[String]) -> Vec<Value> {
784
  let mut spans = Vec::new();
785
  let mut current_label: Option<String> = None;