File size: 7,920 Bytes
3112824
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import json
import os
import argparse
import numpy as np
from sklearn.model_selection import train_test_split

parser = argparse.ArgumentParser()
parser.add_argument("--bump", type=int, default=0, help="Extra epochs to train (resumes from last checkpoint)")
args = parser.parse_args()

import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification,
)
from datasets import Dataset
import wandb

MODEL_NAME = "google/gemma-3-270m"
TRAIN_FILE = "train.json"
CACHE_FILE = "chunks.generative.cache.json"
MAX_LEN = 1024
STRIDE = 256
COMPLETION_RESERVE = 256


def parse_annotated(annotated):
    """Parse 'title[SEP]text with [SPAN]...[/SPAN]' into title, plain_text, and char offsets."""
    title, body = annotated.split("[SEP]", 1)

    spans = []
    plain = ""
    i = 0
    while i < len(body):
        if body[i:i+6] == "[SPAN]":
            start = len(plain)
            i += 6
            while i < len(body) and body[i:i+7] != "[/SPAN]":
                plain += body[i]
                i += 1
            end = len(plain)
            spans.append((start, end))
            if body[i:i+7] == "[/SPAN]":
                i += 7
        else:
            plain += body[i]
            i += 1

    return title.strip(), plain, spans


def chunk_generative(title, plain_text, span_offsets, tokenizer, max_len, stride):
    """Create overlapping chunks with generative span targets."""
    prompt_prefix = f"Title: {title}\n\nText: "
    prompt_suffix = "\n\nHooks:\n"

    prefix_ids = tokenizer(prompt_prefix, add_special_tokens=False)["input_ids"]
    suffix_ids = tokenizer(prompt_suffix, add_special_tokens=False)["input_ids"]

    text_enc = tokenizer(plain_text, add_special_tokens=False, return_offsets_mapping=True)
    text_ids = text_enc["input_ids"]
    text_offsets = text_enc["offset_mapping"]

    # Budget for text tokens per chunk
    fixed_overhead = 1 + len(prefix_ids) + len(suffix_ids) + 1  # bos + prefix + suffix + eos
    text_budget = max_len - fixed_overhead - COMPLETION_RESERVE

    if text_budget <= 0:
        return []

    chunks = []
    start = 0

    while start < len(text_ids):
        end = min(start + text_budget, len(text_ids))
        chunk_text_ids = text_ids[start:end]

        # Determine char range of this chunk
        chunk_char_start = None
        for idx in range(start, end):
            if text_offsets[idx][0] != 0 or text_offsets[idx][1] != 0:
                chunk_char_start = text_offsets[idx][0]
                break
        chunk_char_end = None
        for idx in range(end - 1, start - 1, -1):
            if text_offsets[idx][0] != 0 or text_offsets[idx][1] != 0:
                chunk_char_end = text_offsets[idx][1]
                break

        if chunk_char_start is None or chunk_char_end is None:
            if end >= len(text_ids):
                break
            start += stride
            continue

        # Find spans fully contained in this chunk
        chunk_spans = []
        for s_start, s_end in span_offsets:
            if s_start >= chunk_char_start and s_end <= chunk_char_end:
                chunk_spans.append(plain_text[s_start:s_end])

        # Build completion
        completion_text = "\n".join(chunk_spans) if chunk_spans else "[NONE]"
        completion_ids = tokenizer(completion_text, add_special_tokens=False)["input_ids"]

        # Full sequence: <bos> prefix text suffix completion <eos>
        prompt_ids = [tokenizer.bos_token_id] + prefix_ids + chunk_text_ids + suffix_ids
        input_ids = prompt_ids + completion_ids + [tokenizer.eos_token_id]

        # Labels: -100 for prompt, actual ids for completion + eos
        labels = [-100] * len(prompt_ids) + completion_ids + [tokenizer.eos_token_id]
        attention_mask = [1] * len(input_ids)

        # Truncate if needed
        if len(input_ids) > max_len:
            input_ids = input_ids[:max_len]
            labels = labels[:max_len]
            attention_mask = attention_mask[:max_len]

        chunks.append({
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        })

        if end >= len(text_ids):
            break
        start += stride

    return chunks


print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

if os.path.exists(CACHE_FILE):
    print(f"Loading cached chunks from {CACHE_FILE}...")
    with open(CACHE_FILE, "r", encoding="utf-8") as f:
        all_chunks = json.load(f)
    print(f"Loaded {len(all_chunks):,} chunks from cache")
else:
    print(f"Loading {TRAIN_FILE}...")
    with open(TRAIN_FILE, "r", encoding="utf-8") as f:
        raw_data = json.load(f)

    print(f"Parsing and tokenizing {len(raw_data):,} articles...")
    all_chunks = []

    for i, item in enumerate(raw_data):
        title, plain_text, span_offsets = parse_annotated(item["annotated"])
        chunks = chunk_generative(title, plain_text, span_offsets, tokenizer, MAX_LEN, STRIDE)
        all_chunks.extend(chunks)

        if (i + 1) % 2000 == 0:
            print(f"  [{i+1:,}/{len(raw_data):,}] chunks so far: {len(all_chunks):,}")

    print(f"Total chunks: {len(all_chunks):,}")
    print(f"Saving cache to {CACHE_FILE}...")
    with open(CACHE_FILE, "w", encoding="utf-8") as f:
        json.dump(all_chunks, f)
    print("Cache saved.")

# Completion stats
completion_lengths = []
none_count = 0
for c in all_chunks:
    comp_len = sum(1 for l in c["labels"] if l >= 0)
    completion_lengths.append(comp_len)
    comp_ids = [c["input_ids"][i] for i in range(len(c["labels"])) if c["labels"][i] >= 0]
    comp_text = tokenizer.decode(comp_ids, skip_special_tokens=True).strip()
    if comp_text == "[NONE]":
        none_count += 1

print(f"Completion stats:")
print(f"  Mean length: {np.mean(completion_lengths):.1f} tokens")
print(f"  Max length: {np.max(completion_lengths)} tokens")
print(f"  Chunks with no spans: {none_count:,} ({none_count/len(all_chunks)*100:.1f}%)")

# Split
print("Splitting 95/5 train/val...")
train_chunks, val_chunks = train_test_split(all_chunks, test_size=0.05, random_state=42)
print(f"Train: {len(train_chunks):,} | Val: {len(val_chunks):,}")

train_ds = Dataset.from_list(train_chunks)
val_ds = Dataset.from_list(val_chunks)

# Model
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.bfloat16)

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer, padding=True)

resume = args.bump > 0
total_epochs = 1 + args.bump

wandb.init(project="span-extractor", name=f"gemma-270m-generative{f'-bump{args.bump}' if resume else ''}")

training_args = TrainingArguments(
    output_dir="./span_model_generative",
    num_train_epochs=total_epochs,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,
    bf16=True,
    gradient_checkpointing=True,
    logging_steps=1,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=10,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    dataloader_num_workers=0,
    report_to="wandb",
    remove_unused_columns=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=data_collator,
)

print(f"Training... (epochs={total_epochs}, resume={resume})")
trainer.train(resume_from_checkpoint=resume)

print("Saving final model...")
trainer.save_model("./span_model_generative/final")
tokenizer.save_pretrained("./span_model_generative/final")

wandb.finish()
print("Done.")