| import json |
| import os |
| import argparse |
| import numpy as np |
| from sklearn.model_selection import train_test_split |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--bump", type=int, default=0, help="Extra epochs to train (resumes from last checkpoint)") |
| args = parser.parse_args() |
|
|
| import torch |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| TrainingArguments, |
| Trainer, |
| DataCollatorForTokenClassification, |
| ) |
| from datasets import Dataset |
| import wandb |
|
|
| MODEL_NAME = "google/gemma-3-270m" |
| TRAIN_FILE = "train.json" |
| CACHE_FILE = "chunks.generative.cache.json" |
| MAX_LEN = 1024 |
| STRIDE = 256 |
| COMPLETION_RESERVE = 256 |
|
|
|
|
| def parse_annotated(annotated): |
| """Parse 'title[SEP]text with [SPAN]...[/SPAN]' into title, plain_text, and char offsets.""" |
| title, body = annotated.split("[SEP]", 1) |
|
|
| spans = [] |
| plain = "" |
| i = 0 |
| while i < len(body): |
| if body[i:i+6] == "[SPAN]": |
| start = len(plain) |
| i += 6 |
| while i < len(body) and body[i:i+7] != "[/SPAN]": |
| plain += body[i] |
| i += 1 |
| end = len(plain) |
| spans.append((start, end)) |
| if body[i:i+7] == "[/SPAN]": |
| i += 7 |
| else: |
| plain += body[i] |
| i += 1 |
|
|
| return title.strip(), plain, spans |
|
|
|
|
| def chunk_generative(title, plain_text, span_offsets, tokenizer, max_len, stride): |
| """Create overlapping chunks with generative span targets.""" |
| prompt_prefix = f"Title: {title}\n\nText: " |
| prompt_suffix = "\n\nHooks:\n" |
|
|
| prefix_ids = tokenizer(prompt_prefix, add_special_tokens=False)["input_ids"] |
| suffix_ids = tokenizer(prompt_suffix, add_special_tokens=False)["input_ids"] |
|
|
| text_enc = tokenizer(plain_text, add_special_tokens=False, return_offsets_mapping=True) |
| text_ids = text_enc["input_ids"] |
| text_offsets = text_enc["offset_mapping"] |
|
|
| |
| fixed_overhead = 1 + len(prefix_ids) + len(suffix_ids) + 1 |
| text_budget = max_len - fixed_overhead - COMPLETION_RESERVE |
|
|
| if text_budget <= 0: |
| return [] |
|
|
| chunks = [] |
| start = 0 |
|
|
| while start < len(text_ids): |
| end = min(start + text_budget, len(text_ids)) |
| chunk_text_ids = text_ids[start:end] |
|
|
| |
| chunk_char_start = None |
| for idx in range(start, end): |
| if text_offsets[idx][0] != 0 or text_offsets[idx][1] != 0: |
| chunk_char_start = text_offsets[idx][0] |
| break |
| chunk_char_end = None |
| for idx in range(end - 1, start - 1, -1): |
| if text_offsets[idx][0] != 0 or text_offsets[idx][1] != 0: |
| chunk_char_end = text_offsets[idx][1] |
| break |
|
|
| if chunk_char_start is None or chunk_char_end is None: |
| if end >= len(text_ids): |
| break |
| start += stride |
| continue |
|
|
| |
| chunk_spans = [] |
| for s_start, s_end in span_offsets: |
| if s_start >= chunk_char_start and s_end <= chunk_char_end: |
| chunk_spans.append(plain_text[s_start:s_end]) |
|
|
| |
| completion_text = "\n".join(chunk_spans) if chunk_spans else "[NONE]" |
| completion_ids = tokenizer(completion_text, add_special_tokens=False)["input_ids"] |
|
|
| |
| prompt_ids = [tokenizer.bos_token_id] + prefix_ids + chunk_text_ids + suffix_ids |
| input_ids = prompt_ids + completion_ids + [tokenizer.eos_token_id] |
|
|
| |
| labels = [-100] * len(prompt_ids) + completion_ids + [tokenizer.eos_token_id] |
| attention_mask = [1] * len(input_ids) |
|
|
| |
| if len(input_ids) > max_len: |
| input_ids = input_ids[:max_len] |
| labels = labels[:max_len] |
| attention_mask = attention_mask[:max_len] |
|
|
| chunks.append({ |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": labels, |
| }) |
|
|
| if end >= len(text_ids): |
| break |
| start += stride |
|
|
| return chunks |
|
|
|
|
| print("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.padding_side = "right" |
|
|
| if os.path.exists(CACHE_FILE): |
| print(f"Loading cached chunks from {CACHE_FILE}...") |
| with open(CACHE_FILE, "r", encoding="utf-8") as f: |
| all_chunks = json.load(f) |
| print(f"Loaded {len(all_chunks):,} chunks from cache") |
| else: |
| print(f"Loading {TRAIN_FILE}...") |
| with open(TRAIN_FILE, "r", encoding="utf-8") as f: |
| raw_data = json.load(f) |
|
|
| print(f"Parsing and tokenizing {len(raw_data):,} articles...") |
| all_chunks = [] |
|
|
| for i, item in enumerate(raw_data): |
| title, plain_text, span_offsets = parse_annotated(item["annotated"]) |
| chunks = chunk_generative(title, plain_text, span_offsets, tokenizer, MAX_LEN, STRIDE) |
| all_chunks.extend(chunks) |
|
|
| if (i + 1) % 2000 == 0: |
| print(f" [{i+1:,}/{len(raw_data):,}] chunks so far: {len(all_chunks):,}") |
|
|
| print(f"Total chunks: {len(all_chunks):,}") |
| print(f"Saving cache to {CACHE_FILE}...") |
| with open(CACHE_FILE, "w", encoding="utf-8") as f: |
| json.dump(all_chunks, f) |
| print("Cache saved.") |
|
|
| |
| completion_lengths = [] |
| none_count = 0 |
| for c in all_chunks: |
| comp_len = sum(1 for l in c["labels"] if l >= 0) |
| completion_lengths.append(comp_len) |
| comp_ids = [c["input_ids"][i] for i in range(len(c["labels"])) if c["labels"][i] >= 0] |
| comp_text = tokenizer.decode(comp_ids, skip_special_tokens=True).strip() |
| if comp_text == "[NONE]": |
| none_count += 1 |
|
|
| print(f"Completion stats:") |
| print(f" Mean length: {np.mean(completion_lengths):.1f} tokens") |
| print(f" Max length: {np.max(completion_lengths)} tokens") |
| print(f" Chunks with no spans: {none_count:,} ({none_count/len(all_chunks)*100:.1f}%)") |
|
|
| |
| print("Splitting 95/5 train/val...") |
| train_chunks, val_chunks = train_test_split(all_chunks, test_size=0.05, random_state=42) |
| print(f"Train: {len(train_chunks):,} | Val: {len(val_chunks):,}") |
|
|
| train_ds = Dataset.from_list(train_chunks) |
| val_ds = Dataset.from_list(val_chunks) |
|
|
| |
| print("Loading model...") |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16) |
|
|
| data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True) |
|
|
| resume = args.bump > 0 |
| total_epochs = 1 + args.bump |
|
|
| wandb.init(project="span-extractor", name=f"gemma-270m-generative{f'-bump{args.bump}' if resume else ''}") |
|
|
| training_args = TrainingArguments( |
| output_dir="./span_model_generative", |
| num_train_epochs=total_epochs, |
| per_device_train_batch_size=1, |
| per_device_eval_batch_size=1, |
| gradient_accumulation_steps=16, |
| learning_rate=2e-5, |
| weight_decay=0.01, |
| warmup_ratio=0.1, |
| bf16=True, |
| gradient_checkpointing=True, |
| logging_steps=1, |
| eval_strategy="steps", |
| eval_steps=100, |
| save_strategy="steps", |
| save_steps=100, |
| save_total_limit=10, |
| load_best_model_at_end=True, |
| metric_for_best_model="eval_loss", |
| greater_is_better=False, |
| dataloader_num_workers=0, |
| report_to="wandb", |
| remove_unused_columns=False, |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_ds, |
| eval_dataset=val_ds, |
| data_collator=data_collator, |
| ) |
|
|
| print(f"Training... (epochs={total_epochs}, resume={resume})") |
| trainer.train(resume_from_checkpoint=resume) |
|
|
| print("Saving final model...") |
| trainer.save_model("./span_model_generative/final") |
| tokenizer.save_pretrained("./span_model_generative/final") |
|
|
| wandb.finish() |
| print("Done.") |
|
|