dejanseo's picture
Upload train.py
3112824 verified
import json
import os
import argparse
import numpy as np
from sklearn.model_selection import train_test_split
parser = argparse.ArgumentParser()
parser.add_argument("--bump", type=int, default=0, help="Extra epochs to train (resumes from last checkpoint)")
args = parser.parse_args()
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
Trainer,
DataCollatorForTokenClassification,
)
from datasets import Dataset
import wandb
MODEL_NAME = "google/gemma-3-270m"
TRAIN_FILE = "train.json"
CACHE_FILE = "chunks.generative.cache.json"
MAX_LEN = 1024
STRIDE = 256
COMPLETION_RESERVE = 256
def parse_annotated(annotated):
"""Parse 'title[SEP]text with [SPAN]...[/SPAN]' into title, plain_text, and char offsets."""
title, body = annotated.split("[SEP]", 1)
spans = []
plain = ""
i = 0
while i < len(body):
if body[i:i+6] == "[SPAN]":
start = len(plain)
i += 6
while i < len(body) and body[i:i+7] != "[/SPAN]":
plain += body[i]
i += 1
end = len(plain)
spans.append((start, end))
if body[i:i+7] == "[/SPAN]":
i += 7
else:
plain += body[i]
i += 1
return title.strip(), plain, spans
def chunk_generative(title, plain_text, span_offsets, tokenizer, max_len, stride):
"""Create overlapping chunks with generative span targets."""
prompt_prefix = f"Title: {title}\n\nText: "
prompt_suffix = "\n\nHooks:\n"
prefix_ids = tokenizer(prompt_prefix, add_special_tokens=False)["input_ids"]
suffix_ids = tokenizer(prompt_suffix, add_special_tokens=False)["input_ids"]
text_enc = tokenizer(plain_text, add_special_tokens=False, return_offsets_mapping=True)
text_ids = text_enc["input_ids"]
text_offsets = text_enc["offset_mapping"]
# Budget for text tokens per chunk
fixed_overhead = 1 + len(prefix_ids) + len(suffix_ids) + 1 # bos + prefix + suffix + eos
text_budget = max_len - fixed_overhead - COMPLETION_RESERVE
if text_budget <= 0:
return []
chunks = []
start = 0
while start < len(text_ids):
end = min(start + text_budget, len(text_ids))
chunk_text_ids = text_ids[start:end]
# Determine char range of this chunk
chunk_char_start = None
for idx in range(start, end):
if text_offsets[idx][0] != 0 or text_offsets[idx][1] != 0:
chunk_char_start = text_offsets[idx][0]
break
chunk_char_end = None
for idx in range(end - 1, start - 1, -1):
if text_offsets[idx][0] != 0 or text_offsets[idx][1] != 0:
chunk_char_end = text_offsets[idx][1]
break
if chunk_char_start is None or chunk_char_end is None:
if end >= len(text_ids):
break
start += stride
continue
# Find spans fully contained in this chunk
chunk_spans = []
for s_start, s_end in span_offsets:
if s_start >= chunk_char_start and s_end <= chunk_char_end:
chunk_spans.append(plain_text[s_start:s_end])
# Build completion
completion_text = "\n".join(chunk_spans) if chunk_spans else "[NONE]"
completion_ids = tokenizer(completion_text, add_special_tokens=False)["input_ids"]
# Full sequence: <bos> prefix text suffix completion <eos>
prompt_ids = [tokenizer.bos_token_id] + prefix_ids + chunk_text_ids + suffix_ids
input_ids = prompt_ids + completion_ids + [tokenizer.eos_token_id]
# Labels: -100 for prompt, actual ids for completion + eos
labels = [-100] * len(prompt_ids) + completion_ids + [tokenizer.eos_token_id]
attention_mask = [1] * len(input_ids)
# Truncate if needed
if len(input_ids) > max_len:
input_ids = input_ids[:max_len]
labels = labels[:max_len]
attention_mask = attention_mask[:max_len]
chunks.append({
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
})
if end >= len(text_ids):
break
start += stride
return chunks
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
if os.path.exists(CACHE_FILE):
print(f"Loading cached chunks from {CACHE_FILE}...")
with open(CACHE_FILE, "r", encoding="utf-8") as f:
all_chunks = json.load(f)
print(f"Loaded {len(all_chunks):,} chunks from cache")
else:
print(f"Loading {TRAIN_FILE}...")
with open(TRAIN_FILE, "r", encoding="utf-8") as f:
raw_data = json.load(f)
print(f"Parsing and tokenizing {len(raw_data):,} articles...")
all_chunks = []
for i, item in enumerate(raw_data):
title, plain_text, span_offsets = parse_annotated(item["annotated"])
chunks = chunk_generative(title, plain_text, span_offsets, tokenizer, MAX_LEN, STRIDE)
all_chunks.extend(chunks)
if (i + 1) % 2000 == 0:
print(f" [{i+1:,}/{len(raw_data):,}] chunks so far: {len(all_chunks):,}")
print(f"Total chunks: {len(all_chunks):,}")
print(f"Saving cache to {CACHE_FILE}...")
with open(CACHE_FILE, "w", encoding="utf-8") as f:
json.dump(all_chunks, f)
print("Cache saved.")
# Completion stats
completion_lengths = []
none_count = 0
for c in all_chunks:
comp_len = sum(1 for l in c["labels"] if l >= 0)
completion_lengths.append(comp_len)
comp_ids = [c["input_ids"][i] for i in range(len(c["labels"])) if c["labels"][i] >= 0]
comp_text = tokenizer.decode(comp_ids, skip_special_tokens=True).strip()
if comp_text == "[NONE]":
none_count += 1
print(f"Completion stats:")
print(f" Mean length: {np.mean(completion_lengths):.1f} tokens")
print(f" Max length: {np.max(completion_lengths)} tokens")
print(f" Chunks with no spans: {none_count:,} ({none_count/len(all_chunks)*100:.1f}%)")
# Split
print("Splitting 95/5 train/val...")
train_chunks, val_chunks = train_test_split(all_chunks, test_size=0.05, random_state=42)
print(f"Train: {len(train_chunks):,} | Val: {len(val_chunks):,}")
train_ds = Dataset.from_list(train_chunks)
val_ds = Dataset.from_list(val_chunks)
# Model
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True)
resume = args.bump > 0
total_epochs = 1 + args.bump
wandb.init(project="span-extractor", name=f"gemma-270m-generative{f'-bump{args.bump}' if resume else ''}")
training_args = TrainingArguments(
output_dir="./span_model_generative",
num_train_epochs=total_epochs,
per_device_train_batch_size=1,
per_device_eval_batch_size=1,
gradient_accumulation_steps=16,
learning_rate=2e-5,
weight_decay=0.01,
warmup_ratio=0.1,
bf16=True,
gradient_checkpointing=True,
logging_steps=1,
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
save_total_limit=10,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
dataloader_num_workers=0,
report_to="wandb",
remove_unused_columns=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=data_collator,
)
print(f"Training... (epochs={total_epochs}, resume={resume})")
trainer.train(resume_from_checkpoint=resume)
print("Saving final model...")
trainer.save_model("./span_model_generative/final")
tokenizer.save_pretrained("./span_model_generative/final")
wandb.finish()
print("Done.")