ModerRAS commited on
Commit
f95ce71
·
1 Parent(s): 116c87c

Clean stale local scripts

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. check_f1.py +0 -33
  3. smoke_test.py +0 -50
  4. validate_fix.py +0 -80
  5. verify_data.py +0 -39
.gitignore CHANGED
@@ -1,6 +1,7 @@
1
  __pycache__/
2
  *.pyc
3
  .venv/
 
4
  .pytest_cache/
5
  .ruff_cache/
6
  logs/
 
1
  __pycache__/
2
  *.pyc
3
  .venv/
4
+ .venv-codex/
5
  .pytest_cache/
6
  .ruff_cache/
7
  logs/
check_f1.py DELETED
@@ -1,33 +0,0 @@
1
- """Check F1 score from training results."""
2
- import json
3
- import glob
4
- import os
5
-
6
- # Check full training checkpoints
7
- checkpoint_dirs = sorted(glob.glob('checkpoints/checkpoint-*'))
8
- if checkpoint_dirs:
9
- print('=== Full training checkpoints ===')
10
- for ckpt in checkpoint_dirs:
11
- state_file = os.path.join(ckpt, 'trainer_state.json')
12
- if os.path.exists(state_file):
13
- with open(state_file, 'r') as f:
14
- state = json.load(f)
15
- ckpt_metrics = [m for m in state.get('log_history', []) if 'eval_f1' in m]
16
- if ckpt_metrics:
17
- best = max(ckpt_metrics, key=lambda x: x['eval_f1'])
18
- print(f' {os.path.basename(ckpt)}: F1={best["eval_f1"]:.4f} (epoch={best.get("epoch","?"):.1f})')
19
-
20
- # Check latest checkpoint
21
- latest = checkpoint_dirs[-1] if checkpoint_dirs else None
22
- if latest:
23
- state_file = os.path.join(latest, 'trainer_state.json')
24
- with open(state_file, 'r') as f:
25
- state = json.load(f)
26
- all_metrics = [m for m in state.get('log_history', []) if 'eval_f1' in m]
27
- best = max(all_metrics, key=lambda x: x['eval_f1'])
28
- print(f'\nBest F1 overall: {best["eval_f1"]:.4f}')
29
- print(f'Meets >0.95 requirement: {best["eval_f1"] > 0.95}')
30
- else:
31
- print('No checkpoints found from full training.')
32
- print('Using mini-test results: F1=0.9979 (from test output logs)')
33
- print('This exceeds the >0.95 requirement.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
smoke_test.py DELETED
@@ -1,50 +0,0 @@
1
- """Smoke test for the full training pipeline."""
2
- import json
3
- import os
4
- import torch
5
- from config import Config
6
- from tokenizer import AnimeTokenizer
7
- from model import create_model, count_parameters
8
- from dataset import AnimeDataset
9
-
10
- cfg = Config()
11
-
12
- # Load tokenizer
13
- tok = AnimeTokenizer(vocab_file='data/vocab.json')
14
- cfg.vocab_size = tok.vocab_size
15
- print(f'Vocab: {tok.vocab_size}, Labels: {cfg.num_labels}')
16
-
17
- # Create model
18
- model = create_model(cfg)
19
- total_params = count_parameters(model)
20
- print(f'Model params: {total_params:,} / 5M limit')
21
- assert total_params < 5_000_000, f'Model too large: {total_params:,}'
22
-
23
- # Load a tiny dataset
24
- with open('data/synthetic.jsonl', 'r', encoding='utf-8') as f:
25
- samples = [json.loads(line) for line in f][:100]
26
-
27
- temp_file = 'data/test_smoke.jsonl'
28
- with open(temp_file, 'w', encoding='utf-8') as f:
29
- for s in samples:
30
- f.write(json.dumps(s, ensure_ascii=False) + '\n')
31
-
32
- ds = AnimeDataset(temp_file, tok, cfg.label2id, cfg.max_seq_length)
33
- print(f'Dataset: {len(ds)} samples')
34
- sample = ds[0]
35
- print(f'Input IDs shape: {sample["input_ids"].shape}')
36
- print(f'Labels shape: {sample["labels"].shape}')
37
- print(f'Attention mask shape: {sample["attention_mask"].shape}')
38
-
39
- # Forward pass
40
- with torch.no_grad():
41
- out = model(
42
- input_ids=sample['input_ids'].unsqueeze(0),
43
- attention_mask=sample['attention_mask'].unsqueeze(0),
44
- labels=sample['labels'].unsqueeze(0),
45
- )
46
- print(f'Loss: {out.loss.item():.4f}')
47
- print(f'Logits shape: {out.logits.shape}')
48
- print()
49
- print('Smoke test PASSED!')
50
- print(f'Model is ready for training: {total_params:,} params < 5M [OK]')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
validate_fix.py DELETED
@@ -1,80 +0,0 @@
1
- """Validate the fixed data generator produces correct labels."""
2
- import json
3
- import sys
4
- import os
5
- sys.path.insert(0, os.path.dirname(__file__))
6
-
7
- from tokenizer import AnimeTokenizer
8
- from data_generator import generate_sample, TEMPLATES
9
-
10
- tok = AnimeTokenizer()
11
- tok.build_vocab([["test"]])
12
-
13
- # Check specific problem patterns
14
- problem_cases = [
15
- # "E" starting words in titles/groups
16
- ("Eighty Six", "episode"), # was being mislabeled as episode
17
- ("Evangelion", "episode"), # was being mislabeled
18
- ("Erai", "episode"), # from Erai-raws, was mislabeled
19
-
20
- # Numbers in titles
21
- ("86", "episode"), # from "86 Eighty Six"
22
- ("100", "episode"), # from "100万の命の上に"
23
- ("07", "episode"), # possible episode or title number
24
- ]
25
-
26
- print("Testing specific problem patterns...")
27
- print("=" * 60)
28
-
29
- # Track label counts
30
- label_counts = {}
31
- for i in range(5000):
32
- sample = generate_sample(tok, TEMPLATES)
33
- for label in sample["labels"]:
34
- label_counts[label] = label_counts.get(label, 0) + 1
35
-
36
- # Check for E-starting mislabels
37
- for token, label in zip(sample["tokens"], sample["labels"]):
38
- # Check E-starting English words
39
- if len(token) > 2 and token[0].upper() == 'E' and token.isalpha() and label == 'B-EPISODE':
40
- print(f"POTENTIAL BUG: '{token}' labeled as EPISODE")
41
-
42
- # Check number tokens
43
- if token.isdigit() and len(token) <= 2 and label == 'B-EPISODE':
44
- # Should only appear in proper episode context
45
- pass
46
-
47
- print(f"\nLabel distribution from {5000} samples:")
48
- total = sum(label_counts.values())
49
- for label, count in sorted(label_counts.items(), key=lambda x: -x[1]):
50
- print(f" {label}: {count} ({count*100/total:.1f}%)")
51
-
52
- # Check for IOB2 validity
53
- print("\nIOB2 validity check...")
54
- errors = 0
55
- for i in range(1000):
56
- sample = generate_sample(tok, TEMPLATES)
57
- labels = sample["labels"]
58
- for j, label in enumerate(labels):
59
- if label.startswith("I-"):
60
- if j == 0:
61
- print(f" ERROR: I- at position 0 in sample {i}")
62
- errors += 1
63
- else:
64
- prev = labels[j-1]
65
- expected = label.replace("I-", "B-")
66
- if prev not in (label, expected):
67
- # Check if prev is O and there's a B- earlier (spanning O)
68
- pass # This is now valid for multi-word entities
69
-
70
- print(f"IOB2 errors found: {errors}")
71
-
72
- # Spot-check a few samples
73
- print("\nSample outputs:")
74
- for i in range(3):
75
- sample = generate_sample(tok, TEMPLATES)
76
- print(f"\nSample {i}:")
77
- for token, label in zip(sample["tokens"], sample["labels"]):
78
- print(f" {label}: {token}")
79
-
80
- print("\nValidation complete!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
verify_data.py DELETED
@@ -1,39 +0,0 @@
1
- """Verify generated dataset quality."""
2
- import json
3
- from collections import Counter
4
-
5
- with open('data/synthetic_small.jsonl', 'r', encoding='utf-8') as f:
6
- samples = [json.loads(line) for line in f]
7
-
8
- print(f'Total samples: {len(samples)}')
9
-
10
- # Check a few samples
11
- for i in range(min(5, len(samples))):
12
- s = samples[i]
13
- print(f'\nSample {i}:')
14
- print(f' Tokens: {s["tokens"]}')
15
- print(f' Labels: {s["labels"]}')
16
-
17
- assert len(s['tokens']) == len(s['labels']), f'Mismatch: {len(s["tokens"])} != {len(s["labels"])}'
18
-
19
- # Check BIO format validity
20
- for j, label in enumerate(s['labels']):
21
- if label.startswith('I-'):
22
- if j == 0:
23
- print(f' ERROR: First token is {label}')
24
- else:
25
- prev = s['labels'][j-1]
26
- expected_prefix = 'B-' + label[2:]
27
- if prev != label and prev != expected_prefix:
28
- print(f' WARN: I- without B- at pos {j}: {prev} -> {label}')
29
-
30
- # Label distribution
31
- print('\nLabel distribution:')
32
- all_labels = [l for s in samples for l in s['labels']]
33
- total = len(all_labels)
34
- for label, count in Counter(all_labels).most_common():
35
- print(f' {label}: {count} ({count*100/total:.1f}%)')
36
-
37
- # Sequence length stats
38
- lengths = [len(s['tokens']) for s in samples]
39
- print(f'\nSequence length: min={min(lengths)}, max={max(lengths)}, avg={sum(lengths)/len(lengths):.1f}')