| import json |
| import os |
| from transformers import AutoTokenizer |
| from collections import Counter |
|
|
| MODEL_NAME = "google/gemma-3-270m" |
| TRAIN_FILE = "train.json" |
| CACHE_FILE = "chunks.cache.json" |
| MAX_LEN = 4096 |
| STRIDE = 1024 |
| LABEL2ID = {"O": 0, "B-SPAN": 1, "I-SPAN": 2} |
| ID2LABEL = {v: k for k, v in LABEL2ID.items()} |
|
|
|
|
| def parse_annotated(annotated): |
| title, body = annotated.split("[SEP]", 1) |
|
|
| spans = [] |
| plain = "" |
| i = 0 |
| while i < len(body): |
| if body[i:i+6] == "[SPAN]": |
| start = len(plain) |
| i += 6 |
| while i < len(body) and body[i:i+7] != "[/SPAN]": |
| plain += body[i] |
| i += 1 |
| end = len(plain) |
| spans.append((start, end)) |
| if body[i:i+7] == "[/SPAN]": |
| i += 7 |
| else: |
| plain += body[i] |
| i += 1 |
|
|
| return title.strip(), plain, spans |
|
|
|
|
| def chunk_with_title(title_ids, text_ids, text_labels, max_len, stride, tokenizer): |
| title_budget = len(title_ids) + 3 |
| text_budget = max_len - title_budget |
|
|
| if text_budget <= 0: |
| return [] |
|
|
| chunks = [] |
| start = 0 |
|
|
| while start < len(text_ids): |
| end = min(start + text_budget, len(text_ids)) |
| chunk_text_ids = text_ids[start:end] |
| chunk_labels = list(text_labels[start:end]) |
|
|
| for j, lbl in enumerate(chunk_labels): |
| if lbl == LABEL2ID["I-SPAN"]: |
| chunk_labels[j] = LABEL2ID["B-SPAN"] |
| break |
| elif lbl != -100: |
| break |
|
|
| input_ids = [tokenizer.bos_token_id] + title_ids + [tokenizer.eos_token_id] + chunk_text_ids + [tokenizer.eos_token_id] |
| labels = [-100] + [-100] * len(title_ids) + [-100] + chunk_labels + [-100] |
| attention_mask = [1] * len(input_ids) |
|
|
| chunks.append({ |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": labels, |
| }) |
|
|
| if end >= len(text_ids): |
| break |
| start += stride |
|
|
| return chunks |
|
|
|
|
| print("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
| print(f"Loading {TRAIN_FILE}...") |
| with open(TRAIN_FILE, "r", encoding="utf-8") as f: |
| raw_data = json.load(f) |
|
|
| print(f"Parsing and tokenizing {len(raw_data):,} articles...") |
| all_chunks = [] |
|
|
| for i, item in enumerate(raw_data): |
| title, plain_text, span_offsets = parse_annotated(item["annotated"]) |
|
|
| title_enc = tokenizer(title, add_special_tokens=False) |
| title_ids = title_enc["input_ids"] |
|
|
| text_enc = tokenizer(plain_text, add_special_tokens=False, return_offsets_mapping=True) |
| text_ids = text_enc["input_ids"] |
| text_offsets_map = text_enc["offset_mapping"] |
|
|
| text_labels = [] |
| for tok_idx, (tok_start, tok_end) in enumerate(text_offsets_map): |
| if tok_start == 0 and tok_end == 0: |
| text_labels.append(-100) |
| continue |
|
|
| label = LABEL2ID["O"] |
| for span_start, span_end in span_offsets: |
| if tok_start >= span_start and tok_end <= span_end: |
| if tok_start == span_start: |
| label = LABEL2ID["B-SPAN"] |
| else: |
| label = LABEL2ID["I-SPAN"] |
| break |
| text_labels.append(label) |
|
|
| chunks = chunk_with_title(title_ids, text_ids, text_labels, MAX_LEN, STRIDE, tokenizer) |
| all_chunks.extend(chunks) |
|
|
| if (i + 1) % 2000 == 0: |
| print(f" [{i+1:,}/{len(raw_data):,}] chunks so far: {len(all_chunks):,}") |
|
|
| print(f"\nTotal chunks: {len(all_chunks):,}") |
|
|
| |
| all_labels_flat = [l for c in all_chunks for l in c["labels"] if l >= 0] |
| dist = Counter(all_labels_flat) |
| total_labeled = sum(dist.values()) |
| print(f"\nLabel distribution:") |
| for label_id, count in sorted(dist.items()): |
| print(f" {ID2LABEL[label_id]}: {count:,} ({count/total_labeled*100:.2f}%)") |
|
|
| print(f"\nSaving to {CACHE_FILE}...") |
| with open(CACHE_FILE, "w", encoding="utf-8") as f: |
| json.dump(all_chunks, f) |
| print("Done.") |
|
|