latent-entity / pretokenize.py
dejanseo's picture
Upload 3 files
41e2a73 verified
import json
import os
from transformers import AutoTokenizer
from collections import Counter
MODEL_NAME = "google/gemma-3-270m"
TRAIN_FILE = "train.json"
CACHE_FILE = "chunks.cache.json"
MAX_LEN = 4096
STRIDE = 1024
LABEL2ID = {"O": 0, "B-SPAN": 1, "I-SPAN": 2}
ID2LABEL = {v: k for k, v in LABEL2ID.items()}
def parse_annotated(annotated):
title, body = annotated.split("[SEP]", 1)
spans = []
plain = ""
i = 0
while i < len(body):
if body[i:i+6] == "[SPAN]":
start = len(plain)
i += 6
while i < len(body) and body[i:i+7] != "[/SPAN]":
plain += body[i]
i += 1
end = len(plain)
spans.append((start, end))
if body[i:i+7] == "[/SPAN]":
i += 7
else:
plain += body[i]
i += 1
return title.strip(), plain, spans
def chunk_with_title(title_ids, text_ids, text_labels, max_len, stride, tokenizer):
title_budget = len(title_ids) + 3
text_budget = max_len - title_budget
if text_budget <= 0:
return []
chunks = []
start = 0
while start < len(text_ids):
end = min(start + text_budget, len(text_ids))
chunk_text_ids = text_ids[start:end]
chunk_labels = list(text_labels[start:end])
for j, lbl in enumerate(chunk_labels):
if lbl == LABEL2ID["I-SPAN"]:
chunk_labels[j] = LABEL2ID["B-SPAN"]
break
elif lbl != -100:
break
input_ids = [tokenizer.bos_token_id] + title_ids + [tokenizer.eos_token_id] + chunk_text_ids + [tokenizer.eos_token_id]
labels = [-100] + [-100] * len(title_ids) + [-100] + chunk_labels + [-100]
attention_mask = [1] * len(input_ids)
chunks.append({
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
})
if end >= len(text_ids):
break
start += stride
return chunks
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print(f"Loading {TRAIN_FILE}...")
with open(TRAIN_FILE, "r", encoding="utf-8") as f:
raw_data = json.load(f)
print(f"Parsing and tokenizing {len(raw_data):,} articles...")
all_chunks = []
for i, item in enumerate(raw_data):
title, plain_text, span_offsets = parse_annotated(item["annotated"])
title_enc = tokenizer(title, add_special_tokens=False)
title_ids = title_enc["input_ids"]
text_enc = tokenizer(plain_text, add_special_tokens=False, return_offsets_mapping=True)
text_ids = text_enc["input_ids"]
text_offsets_map = text_enc["offset_mapping"]
text_labels = []
for tok_idx, (tok_start, tok_end) in enumerate(text_offsets_map):
if tok_start == 0 and tok_end == 0:
text_labels.append(-100)
continue
label = LABEL2ID["O"]
for span_start, span_end in span_offsets:
if tok_start >= span_start and tok_end <= span_end:
if tok_start == span_start:
label = LABEL2ID["B-SPAN"]
else:
label = LABEL2ID["I-SPAN"]
break
text_labels.append(label)
chunks = chunk_with_title(title_ids, text_ids, text_labels, MAX_LEN, STRIDE, tokenizer)
all_chunks.extend(chunks)
if (i + 1) % 2000 == 0:
print(f" [{i+1:,}/{len(raw_data):,}] chunks so far: {len(all_chunks):,}")
print(f"\nTotal chunks: {len(all_chunks):,}")
# Label distribution
all_labels_flat = [l for c in all_chunks for l in c["labels"] if l >= 0]
dist = Counter(all_labels_flat)
total_labeled = sum(dist.values())
print(f"\nLabel distribution:")
for label_id, count in sorted(dist.items()):
print(f" {ID2LABEL[label_id]}: {count:,} ({count/total_labeled*100:.2f}%)")
print(f"\nSaving to {CACHE_FILE}...")
with open(CACHE_FILE, "w", encoding="utf-8") as f:
json.dump(all_chunks, f)
print("Done.")