import json import os import argparse import numpy as np from sklearn.model_selection import train_test_split parser = argparse.ArgumentParser() parser.add_argument("--bump", type=int, default=0, help="Extra epochs to train (resumes from last checkpoint)") args = parser.parse_args() import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForTokenClassification, ) from datasets import Dataset import wandb MODEL_NAME = "google/gemma-3-270m" TRAIN_FILE = "train.json" CACHE_FILE = "chunks.generative.cache.json" MAX_LEN = 1024 STRIDE = 256 COMPLETION_RESERVE = 256 def parse_annotated(annotated): """Parse 'title[SEP]text with [SPAN]...[/SPAN]' into title, plain_text, and char offsets.""" title, body = annotated.split("[SEP]", 1) spans = [] plain = "" i = 0 while i < len(body): if body[i:i+6] == "[SPAN]": start = len(plain) i += 6 while i < len(body) and body[i:i+7] != "[/SPAN]": plain += body[i] i += 1 end = len(plain) spans.append((start, end)) if body[i:i+7] == "[/SPAN]": i += 7 else: plain += body[i] i += 1 return title.strip(), plain, spans def chunk_generative(title, plain_text, span_offsets, tokenizer, max_len, stride): """Create overlapping chunks with generative span targets.""" prompt_prefix = f"Title: {title}\n\nText: " prompt_suffix = "\n\nHooks:\n" prefix_ids = tokenizer(prompt_prefix, add_special_tokens=False)["input_ids"] suffix_ids = tokenizer(prompt_suffix, add_special_tokens=False)["input_ids"] text_enc = tokenizer(plain_text, add_special_tokens=False, return_offsets_mapping=True) text_ids = text_enc["input_ids"] text_offsets = text_enc["offset_mapping"] # Budget for text tokens per chunk fixed_overhead = 1 + len(prefix_ids) + len(suffix_ids) + 1 # bos + prefix + suffix + eos text_budget = max_len - fixed_overhead - COMPLETION_RESERVE if text_budget <= 0: return [] chunks = [] start = 0 while start < len(text_ids): end = min(start + text_budget, len(text_ids)) chunk_text_ids = text_ids[start:end] # Determine char range of this chunk chunk_char_start = None for idx in range(start, end): if text_offsets[idx][0] != 0 or text_offsets[idx][1] != 0: chunk_char_start = text_offsets[idx][0] break chunk_char_end = None for idx in range(end - 1, start - 1, -1): if text_offsets[idx][0] != 0 or text_offsets[idx][1] != 0: chunk_char_end = text_offsets[idx][1] break if chunk_char_start is None or chunk_char_end is None: if end >= len(text_ids): break start += stride continue # Find spans fully contained in this chunk chunk_spans = [] for s_start, s_end in span_offsets: if s_start >= chunk_char_start and s_end <= chunk_char_end: chunk_spans.append(plain_text[s_start:s_end]) # Build completion completion_text = "\n".join(chunk_spans) if chunk_spans else "[NONE]" completion_ids = tokenizer(completion_text, add_special_tokens=False)["input_ids"] # Full sequence: prefix text suffix completion prompt_ids = [tokenizer.bos_token_id] + prefix_ids + chunk_text_ids + suffix_ids input_ids = prompt_ids + completion_ids + [tokenizer.eos_token_id] # Labels: -100 for prompt, actual ids for completion + eos labels = [-100] * len(prompt_ids) + completion_ids + [tokenizer.eos_token_id] attention_mask = [1] * len(input_ids) # Truncate if needed if len(input_ids) > max_len: input_ids = input_ids[:max_len] labels = labels[:max_len] attention_mask = attention_mask[:max_len] chunks.append({ "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, }) if end >= len(text_ids): break start += stride return chunks print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" if os.path.exists(CACHE_FILE): print(f"Loading cached chunks from {CACHE_FILE}...") with open(CACHE_FILE, "r", encoding="utf-8") as f: all_chunks = json.load(f) print(f"Loaded {len(all_chunks):,} chunks from cache") else: print(f"Loading {TRAIN_FILE}...") with open(TRAIN_FILE, "r", encoding="utf-8") as f: raw_data = json.load(f) print(f"Parsing and tokenizing {len(raw_data):,} articles...") all_chunks = [] for i, item in enumerate(raw_data): title, plain_text, span_offsets = parse_annotated(item["annotated"]) chunks = chunk_generative(title, plain_text, span_offsets, tokenizer, MAX_LEN, STRIDE) all_chunks.extend(chunks) if (i + 1) % 2000 == 0: print(f" [{i+1:,}/{len(raw_data):,}] chunks so far: {len(all_chunks):,}") print(f"Total chunks: {len(all_chunks):,}") print(f"Saving cache to {CACHE_FILE}...") with open(CACHE_FILE, "w", encoding="utf-8") as f: json.dump(all_chunks, f) print("Cache saved.") # Completion stats completion_lengths = [] none_count = 0 for c in all_chunks: comp_len = sum(1 for l in c["labels"] if l >= 0) completion_lengths.append(comp_len) comp_ids = [c["input_ids"][i] for i in range(len(c["labels"])) if c["labels"][i] >= 0] comp_text = tokenizer.decode(comp_ids, skip_special_tokens=True).strip() if comp_text == "[NONE]": none_count += 1 print(f"Completion stats:") print(f" Mean length: {np.mean(completion_lengths):.1f} tokens") print(f" Max length: {np.max(completion_lengths)} tokens") print(f" Chunks with no spans: {none_count:,} ({none_count/len(all_chunks)*100:.1f}%)") # Split print("Splitting 95/5 train/val...") train_chunks, val_chunks = train_test_split(all_chunks, test_size=0.05, random_state=42) print(f"Train: {len(train_chunks):,} | Val: {len(val_chunks):,}") train_ds = Dataset.from_list(train_chunks) val_ds = Dataset.from_list(val_chunks) # Model print("Loading model...") model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16) data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True) resume = args.bump > 0 total_epochs = 1 + args.bump wandb.init(project="span-extractor", name=f"gemma-270m-generative{f'-bump{args.bump}' if resume else ''}") training_args = TrainingArguments( output_dir="./span_model_generative", num_train_epochs=total_epochs, per_device_train_batch_size=1, per_device_eval_batch_size=1, gradient_accumulation_steps=16, learning_rate=2e-5, weight_decay=0.01, warmup_ratio=0.1, bf16=True, gradient_checkpointing=True, logging_steps=1, eval_strategy="steps", eval_steps=100, save_strategy="steps", save_steps=100, save_total_limit=10, load_best_model_at_end=True, metric_for_best_model="eval_loss", greater_is_better=False, dataloader_num_workers=0, report_to="wandb", remove_unused_columns=False, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=val_ds, data_collator=data_collator, ) print(f"Training... (epochs={total_epochs}, resume={resume})") trainer.train(resume_from_checkpoint=resume) print("Saving final model...") trainer.save_model("./span_model_generative/final") tokenizer.save_pretrained("./span_model_generative/final") wandb.finish() print("Done.")