import json import os from transformers import AutoTokenizer from collections import Counter MODEL_NAME = "google/gemma-3-270m" TRAIN_FILE = "train.json" CACHE_FILE = "chunks.cache.json" MAX_LEN = 4096 STRIDE = 1024 LABEL2ID = {"O": 0, "B-SPAN": 1, "I-SPAN": 2} ID2LABEL = {v: k for k, v in LABEL2ID.items()} def parse_annotated(annotated): title, body = annotated.split("[SEP]", 1) spans = [] plain = "" i = 0 while i < len(body): if body[i:i+6] == "[SPAN]": start = len(plain) i += 6 while i < len(body) and body[i:i+7] != "[/SPAN]": plain += body[i] i += 1 end = len(plain) spans.append((start, end)) if body[i:i+7] == "[/SPAN]": i += 7 else: plain += body[i] i += 1 return title.strip(), plain, spans def chunk_with_title(title_ids, text_ids, text_labels, max_len, stride, tokenizer): title_budget = len(title_ids) + 3 text_budget = max_len - title_budget if text_budget <= 0: return [] chunks = [] start = 0 while start < len(text_ids): end = min(start + text_budget, len(text_ids)) chunk_text_ids = text_ids[start:end] chunk_labels = list(text_labels[start:end]) for j, lbl in enumerate(chunk_labels): if lbl == LABEL2ID["I-SPAN"]: chunk_labels[j] = LABEL2ID["B-SPAN"] break elif lbl != -100: break input_ids = [tokenizer.bos_token_id] + title_ids + [tokenizer.eos_token_id] + chunk_text_ids + [tokenizer.eos_token_id] labels = [-100] + [-100] * len(title_ids) + [-100] + chunk_labels + [-100] attention_mask = [1] * len(input_ids) chunks.append({ "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, }) if end >= len(text_ids): break start += stride return chunks print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) print(f"Loading {TRAIN_FILE}...") with open(TRAIN_FILE, "r", encoding="utf-8") as f: raw_data = json.load(f) print(f"Parsing and tokenizing {len(raw_data):,} articles...") all_chunks = [] for i, item in enumerate(raw_data): title, plain_text, span_offsets = parse_annotated(item["annotated"]) title_enc = tokenizer(title, add_special_tokens=False) title_ids = title_enc["input_ids"] text_enc = tokenizer(plain_text, add_special_tokens=False, return_offsets_mapping=True) text_ids = text_enc["input_ids"] text_offsets_map = text_enc["offset_mapping"] text_labels = [] for tok_idx, (tok_start, tok_end) in enumerate(text_offsets_map): if tok_start == 0 and tok_end == 0: text_labels.append(-100) continue label = LABEL2ID["O"] for span_start, span_end in span_offsets: if tok_start >= span_start and tok_end <= span_end: if tok_start == span_start: label = LABEL2ID["B-SPAN"] else: label = LABEL2ID["I-SPAN"] break text_labels.append(label) chunks = chunk_with_title(title_ids, text_ids, text_labels, MAX_LEN, STRIDE, tokenizer) all_chunks.extend(chunks) if (i + 1) % 2000 == 0: print(f" [{i+1:,}/{len(raw_data):,}] chunks so far: {len(all_chunks):,}") print(f"\nTotal chunks: {len(all_chunks):,}") # Label distribution all_labels_flat = [l for c in all_chunks for l in c["labels"] if l >= 0] dist = Counter(all_labels_flat) total_labeled = sum(dist.values()) print(f"\nLabel distribution:") for label_id, count in sorted(dist.items()): print(f" {ID2LABEL[label_id]}: {count:,} ({count/total_labeled*100:.2f}%)") print(f"\nSaving to {CACHE_FILE}...") with open(CACHE_FILE, "w", encoding="utf-8") as f: json.dump(all_chunks, f) print("Done.")