Token Classification
Transformers
ONNX
Safetensors
English
Japanese
Chinese
bert
anime
filename-parsing
Eval Results (legacy)
Instructions to use ModerRAS/AniFileBERT with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ModerRAS/AniFileBERT with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("token-classification", model="ModerRAS/AniFileBERT")# Load model directly from transformers import AutoTokenizer, AutoModelForTokenClassification tokenizer = AutoTokenizer.from_pretrained("ModerRAS/AniFileBERT") model = AutoModelForTokenClassification.from_pretrained("ModerRAS/AniFileBERT") - Notebooks
- Google Colab
- Kaggle
Clean stale local scripts
Browse files- .gitignore +1 -0
- check_f1.py +0 -33
- smoke_test.py +0 -50
- validate_fix.py +0 -80
- verify_data.py +0 -39
.gitignore
CHANGED
|
@@ -1,6 +1,7 @@
|
|
| 1 |
__pycache__/
|
| 2 |
*.pyc
|
| 3 |
.venv/
|
|
|
|
| 4 |
.pytest_cache/
|
| 5 |
.ruff_cache/
|
| 6 |
logs/
|
|
|
|
| 1 |
__pycache__/
|
| 2 |
*.pyc
|
| 3 |
.venv/
|
| 4 |
+
.venv-codex/
|
| 5 |
.pytest_cache/
|
| 6 |
.ruff_cache/
|
| 7 |
logs/
|
check_f1.py
DELETED
|
@@ -1,33 +0,0 @@
|
|
| 1 |
-
"""Check F1 score from training results."""
|
| 2 |
-
import json
|
| 3 |
-
import glob
|
| 4 |
-
import os
|
| 5 |
-
|
| 6 |
-
# Check full training checkpoints
|
| 7 |
-
checkpoint_dirs = sorted(glob.glob('checkpoints/checkpoint-*'))
|
| 8 |
-
if checkpoint_dirs:
|
| 9 |
-
print('=== Full training checkpoints ===')
|
| 10 |
-
for ckpt in checkpoint_dirs:
|
| 11 |
-
state_file = os.path.join(ckpt, 'trainer_state.json')
|
| 12 |
-
if os.path.exists(state_file):
|
| 13 |
-
with open(state_file, 'r') as f:
|
| 14 |
-
state = json.load(f)
|
| 15 |
-
ckpt_metrics = [m for m in state.get('log_history', []) if 'eval_f1' in m]
|
| 16 |
-
if ckpt_metrics:
|
| 17 |
-
best = max(ckpt_metrics, key=lambda x: x['eval_f1'])
|
| 18 |
-
print(f' {os.path.basename(ckpt)}: F1={best["eval_f1"]:.4f} (epoch={best.get("epoch","?"):.1f})')
|
| 19 |
-
|
| 20 |
-
# Check latest checkpoint
|
| 21 |
-
latest = checkpoint_dirs[-1] if checkpoint_dirs else None
|
| 22 |
-
if latest:
|
| 23 |
-
state_file = os.path.join(latest, 'trainer_state.json')
|
| 24 |
-
with open(state_file, 'r') as f:
|
| 25 |
-
state = json.load(f)
|
| 26 |
-
all_metrics = [m for m in state.get('log_history', []) if 'eval_f1' in m]
|
| 27 |
-
best = max(all_metrics, key=lambda x: x['eval_f1'])
|
| 28 |
-
print(f'\nBest F1 overall: {best["eval_f1"]:.4f}')
|
| 29 |
-
print(f'Meets >0.95 requirement: {best["eval_f1"] > 0.95}')
|
| 30 |
-
else:
|
| 31 |
-
print('No checkpoints found from full training.')
|
| 32 |
-
print('Using mini-test results: F1=0.9979 (from test output logs)')
|
| 33 |
-
print('This exceeds the >0.95 requirement.')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
smoke_test.py
DELETED
|
@@ -1,50 +0,0 @@
|
|
| 1 |
-
"""Smoke test for the full training pipeline."""
|
| 2 |
-
import json
|
| 3 |
-
import os
|
| 4 |
-
import torch
|
| 5 |
-
from config import Config
|
| 6 |
-
from tokenizer import AnimeTokenizer
|
| 7 |
-
from model import create_model, count_parameters
|
| 8 |
-
from dataset import AnimeDataset
|
| 9 |
-
|
| 10 |
-
cfg = Config()
|
| 11 |
-
|
| 12 |
-
# Load tokenizer
|
| 13 |
-
tok = AnimeTokenizer(vocab_file='data/vocab.json')
|
| 14 |
-
cfg.vocab_size = tok.vocab_size
|
| 15 |
-
print(f'Vocab: {tok.vocab_size}, Labels: {cfg.num_labels}')
|
| 16 |
-
|
| 17 |
-
# Create model
|
| 18 |
-
model = create_model(cfg)
|
| 19 |
-
total_params = count_parameters(model)
|
| 20 |
-
print(f'Model params: {total_params:,} / 5M limit')
|
| 21 |
-
assert total_params < 5_000_000, f'Model too large: {total_params:,}'
|
| 22 |
-
|
| 23 |
-
# Load a tiny dataset
|
| 24 |
-
with open('data/synthetic.jsonl', 'r', encoding='utf-8') as f:
|
| 25 |
-
samples = [json.loads(line) for line in f][:100]
|
| 26 |
-
|
| 27 |
-
temp_file = 'data/test_smoke.jsonl'
|
| 28 |
-
with open(temp_file, 'w', encoding='utf-8') as f:
|
| 29 |
-
for s in samples:
|
| 30 |
-
f.write(json.dumps(s, ensure_ascii=False) + '\n')
|
| 31 |
-
|
| 32 |
-
ds = AnimeDataset(temp_file, tok, cfg.label2id, cfg.max_seq_length)
|
| 33 |
-
print(f'Dataset: {len(ds)} samples')
|
| 34 |
-
sample = ds[0]
|
| 35 |
-
print(f'Input IDs shape: {sample["input_ids"].shape}')
|
| 36 |
-
print(f'Labels shape: {sample["labels"].shape}')
|
| 37 |
-
print(f'Attention mask shape: {sample["attention_mask"].shape}')
|
| 38 |
-
|
| 39 |
-
# Forward pass
|
| 40 |
-
with torch.no_grad():
|
| 41 |
-
out = model(
|
| 42 |
-
input_ids=sample['input_ids'].unsqueeze(0),
|
| 43 |
-
attention_mask=sample['attention_mask'].unsqueeze(0),
|
| 44 |
-
labels=sample['labels'].unsqueeze(0),
|
| 45 |
-
)
|
| 46 |
-
print(f'Loss: {out.loss.item():.4f}')
|
| 47 |
-
print(f'Logits shape: {out.logits.shape}')
|
| 48 |
-
print()
|
| 49 |
-
print('Smoke test PASSED!')
|
| 50 |
-
print(f'Model is ready for training: {total_params:,} params < 5M [OK]')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
validate_fix.py
DELETED
|
@@ -1,80 +0,0 @@
|
|
| 1 |
-
"""Validate the fixed data generator produces correct labels."""
|
| 2 |
-
import json
|
| 3 |
-
import sys
|
| 4 |
-
import os
|
| 5 |
-
sys.path.insert(0, os.path.dirname(__file__))
|
| 6 |
-
|
| 7 |
-
from tokenizer import AnimeTokenizer
|
| 8 |
-
from data_generator import generate_sample, TEMPLATES
|
| 9 |
-
|
| 10 |
-
tok = AnimeTokenizer()
|
| 11 |
-
tok.build_vocab([["test"]])
|
| 12 |
-
|
| 13 |
-
# Check specific problem patterns
|
| 14 |
-
problem_cases = [
|
| 15 |
-
# "E" starting words in titles/groups
|
| 16 |
-
("Eighty Six", "episode"), # was being mislabeled as episode
|
| 17 |
-
("Evangelion", "episode"), # was being mislabeled
|
| 18 |
-
("Erai", "episode"), # from Erai-raws, was mislabeled
|
| 19 |
-
|
| 20 |
-
# Numbers in titles
|
| 21 |
-
("86", "episode"), # from "86 Eighty Six"
|
| 22 |
-
("100", "episode"), # from "100万の命の上に"
|
| 23 |
-
("07", "episode"), # possible episode or title number
|
| 24 |
-
]
|
| 25 |
-
|
| 26 |
-
print("Testing specific problem patterns...")
|
| 27 |
-
print("=" * 60)
|
| 28 |
-
|
| 29 |
-
# Track label counts
|
| 30 |
-
label_counts = {}
|
| 31 |
-
for i in range(5000):
|
| 32 |
-
sample = generate_sample(tok, TEMPLATES)
|
| 33 |
-
for label in sample["labels"]:
|
| 34 |
-
label_counts[label] = label_counts.get(label, 0) + 1
|
| 35 |
-
|
| 36 |
-
# Check for E-starting mislabels
|
| 37 |
-
for token, label in zip(sample["tokens"], sample["labels"]):
|
| 38 |
-
# Check E-starting English words
|
| 39 |
-
if len(token) > 2 and token[0].upper() == 'E' and token.isalpha() and label == 'B-EPISODE':
|
| 40 |
-
print(f"POTENTIAL BUG: '{token}' labeled as EPISODE")
|
| 41 |
-
|
| 42 |
-
# Check number tokens
|
| 43 |
-
if token.isdigit() and len(token) <= 2 and label == 'B-EPISODE':
|
| 44 |
-
# Should only appear in proper episode context
|
| 45 |
-
pass
|
| 46 |
-
|
| 47 |
-
print(f"\nLabel distribution from {5000} samples:")
|
| 48 |
-
total = sum(label_counts.values())
|
| 49 |
-
for label, count in sorted(label_counts.items(), key=lambda x: -x[1]):
|
| 50 |
-
print(f" {label}: {count} ({count*100/total:.1f}%)")
|
| 51 |
-
|
| 52 |
-
# Check for IOB2 validity
|
| 53 |
-
print("\nIOB2 validity check...")
|
| 54 |
-
errors = 0
|
| 55 |
-
for i in range(1000):
|
| 56 |
-
sample = generate_sample(tok, TEMPLATES)
|
| 57 |
-
labels = sample["labels"]
|
| 58 |
-
for j, label in enumerate(labels):
|
| 59 |
-
if label.startswith("I-"):
|
| 60 |
-
if j == 0:
|
| 61 |
-
print(f" ERROR: I- at position 0 in sample {i}")
|
| 62 |
-
errors += 1
|
| 63 |
-
else:
|
| 64 |
-
prev = labels[j-1]
|
| 65 |
-
expected = label.replace("I-", "B-")
|
| 66 |
-
if prev not in (label, expected):
|
| 67 |
-
# Check if prev is O and there's a B- earlier (spanning O)
|
| 68 |
-
pass # This is now valid for multi-word entities
|
| 69 |
-
|
| 70 |
-
print(f"IOB2 errors found: {errors}")
|
| 71 |
-
|
| 72 |
-
# Spot-check a few samples
|
| 73 |
-
print("\nSample outputs:")
|
| 74 |
-
for i in range(3):
|
| 75 |
-
sample = generate_sample(tok, TEMPLATES)
|
| 76 |
-
print(f"\nSample {i}:")
|
| 77 |
-
for token, label in zip(sample["tokens"], sample["labels"]):
|
| 78 |
-
print(f" {label}: {token}")
|
| 79 |
-
|
| 80 |
-
print("\nValidation complete!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
verify_data.py
DELETED
|
@@ -1,39 +0,0 @@
|
|
| 1 |
-
"""Verify generated dataset quality."""
|
| 2 |
-
import json
|
| 3 |
-
from collections import Counter
|
| 4 |
-
|
| 5 |
-
with open('data/synthetic_small.jsonl', 'r', encoding='utf-8') as f:
|
| 6 |
-
samples = [json.loads(line) for line in f]
|
| 7 |
-
|
| 8 |
-
print(f'Total samples: {len(samples)}')
|
| 9 |
-
|
| 10 |
-
# Check a few samples
|
| 11 |
-
for i in range(min(5, len(samples))):
|
| 12 |
-
s = samples[i]
|
| 13 |
-
print(f'\nSample {i}:')
|
| 14 |
-
print(f' Tokens: {s["tokens"]}')
|
| 15 |
-
print(f' Labels: {s["labels"]}')
|
| 16 |
-
|
| 17 |
-
assert len(s['tokens']) == len(s['labels']), f'Mismatch: {len(s["tokens"])} != {len(s["labels"])}'
|
| 18 |
-
|
| 19 |
-
# Check BIO format validity
|
| 20 |
-
for j, label in enumerate(s['labels']):
|
| 21 |
-
if label.startswith('I-'):
|
| 22 |
-
if j == 0:
|
| 23 |
-
print(f' ERROR: First token is {label}')
|
| 24 |
-
else:
|
| 25 |
-
prev = s['labels'][j-1]
|
| 26 |
-
expected_prefix = 'B-' + label[2:]
|
| 27 |
-
if prev != label and prev != expected_prefix:
|
| 28 |
-
print(f' WARN: I- without B- at pos {j}: {prev} -> {label}')
|
| 29 |
-
|
| 30 |
-
# Label distribution
|
| 31 |
-
print('\nLabel distribution:')
|
| 32 |
-
all_labels = [l for s in samples for l in s['labels']]
|
| 33 |
-
total = len(all_labels)
|
| 34 |
-
for label, count in Counter(all_labels).most_common():
|
| 35 |
-
print(f' {label}: {count} ({count*100/total:.1f}%)')
|
| 36 |
-
|
| 37 |
-
# Sequence length stats
|
| 38 |
-
lengths = [len(s['tokens']) for s in samples]
|
| 39 |
-
print(f'\nSequence length: min={min(lengths)}, max={max(lengths)}, avg={sum(lengths)/len(lengths):.1f}')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|