Excerp_v1 / train.py
LH-Tech-AI's picture
Update train.py
0b2f33e verified
# ============================================================
# Extractive Question Answering – From Scratch on SQuAD
# Kaggle T4 (16GB VRAM) | HF Transformers
# ============================================================
# ── Imports ─────────────────────────────────────────────────
import numpy as np
import collections
import evaluate
from datasets import load_dataset
from transformers import (
BertConfig,
BertForQuestionAnswering,
BertTokenizerFast,
DefaultDataCollator,
TrainingArguments,
Trainer,
)
# ── Config ───────────────────────────────────────────────────
MODEL_NAME = "bert-base-uncased" # tokenizer only!
MAX_LENGTH = 384
DOC_STRIDE = 128
BATCH_SIZE = 16
EPOCHS = 3
LR = 3e-4
OUTPUT_DIR = "Excerp"
# ── 1. Dataset ───────────────────────────────────────────────
raw = load_dataset("squad")
# ── 2. Tokenizer (pretrained vocab, NO pretrained weights) ─
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)
# ── 3. Preprocessing ─────────────────────────────────────────
def preprocess_train(examples):
tokenized = tokenizer(
examples["question"],
examples["context"],
max_length=MAX_LENGTH,
truncation="only_second",
stride=DOC_STRIDE,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
sample_map = tokenized.pop("overflow_to_sample_mapping")
offset_mapping = tokenized.pop("offset_mapping")
start_positions, end_positions = [], []
for i, offsets in enumerate(offset_mapping):
sample_idx = sample_map[i]
answers = examples["answers"][sample_idx]
cls_index = tokenized["input_ids"][i].index(tokenizer.cls_token_id)
sequence_ids = tokenized.sequence_ids(i)
if len(answers["answer_start"]) == 0:
start_positions.append(cls_index)
end_positions.append(cls_index)
continue
start_char = answers["answer_start"][0]
end_char = start_char + len(answers["text"][0])
token_start = next((j for j, s in enumerate(sequence_ids) if s == 1), None)
token_end = next((j for j in range(len(sequence_ids)-1, -1, -1) if sequence_ids[j] == 1), None)
if offsets[token_start][0] > end_char or offsets[token_end][1] < start_char:
start_positions.append(cls_index)
end_positions.append(cls_index)
continue
start_tok = token_start
while start_tok <= token_end and offsets[start_tok][0] <= start_char:
start_tok += 1
start_positions.append(start_tok - 1)
end_tok = token_end
while end_tok >= token_start and offsets[end_tok][1] >= end_char:
end_tok -= 1
end_positions.append(end_tok + 1)
tokenized["start_positions"] = start_positions
tokenized["end_positions"] = end_positions
return tokenized
def preprocess_validation(examples):
tokenized = tokenizer(
examples["question"],
examples["context"],
max_length=MAX_LENGTH,
truncation="only_second",
stride=DOC_STRIDE,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)
sample_map = tokenized.pop("overflow_to_sample_mapping")
tokenized["example_id"] = []
for i in range(len(tokenized["input_ids"])):
sample_idx = sample_map[i]
tokenized["example_id"].append(examples["id"][sample_idx])
sequence_ids = tokenized.sequence_ids(i)
tokenized["offset_mapping"][i] = [
o if sequence_ids[j] == 1 else None
for j, o in enumerate(tokenized["offset_mapping"][i])
]
return tokenized
train_dataset = raw["train"].map(
preprocess_train,
batched=True,
remove_columns=raw["train"].column_names,
)
val_dataset = raw["validation"].map(
preprocess_validation,
batched=True,
remove_columns=raw["validation"].column_names,
)
# ── 4. Modell FROM SCRATCH ────────────────────────────────────
config = BertConfig(
vocab_size=tokenizer.vocab_size, # 30522
hidden_size=384,
num_hidden_layers=6,
num_attention_heads=6,
intermediate_size=1536,
max_position_embeddings=512,
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
)
model = BertForQuestionAnswering(config)
print(f"Parameters: {model.num_parameters():,}") # ~22M
# ── 5. Evaluation (Exact Match + F1) ─────────────────────────
metric = evaluate.load("squad")
def compute_metrics(p):
# p = EvalPrediction with predictions=(start_logits, end_logits)
start_logits, end_logits = p.predictions
n_best = 20
max_answer_len = 30
example_ids = val_dataset["example_id"]
offset_mappings = val_dataset["offset_mapping"]
contexts = {ex["id"]: ex["context"] for ex in raw["validation"]}
references = {ex["id"]: ex["answers"] for ex in raw["validation"]}
feat_per_example = collections.defaultdict(list)
for feat_idx, ex_id in enumerate(example_ids):
feat_per_example[ex_id].append(feat_idx)
predicted_answers = []
for ex_id, feat_indices in feat_per_example.items():
context = contexts[ex_id]
candidates = []
for fi in feat_indices:
offsets = offset_mappings[fi]
s_logits = start_logits[fi]
e_logits = end_logits[fi]
s_indexes = np.argsort(s_logits)[-1:-n_best-1:-1].tolist()
e_indexes = np.argsort(e_logits)[-1:-n_best-1:-1].tolist()
for s in s_indexes:
for e in e_indexes:
if offsets[s] is None or offsets[e] is None:
continue
if e < s or e - s + 1 > max_answer_len:
continue
candidates.append({
"score": s_logits[s] + e_logits[e],
"text": context[offsets[s][0]: offsets[e][1]],
})
best = max(candidates, key=lambda x: x["score"]) if candidates else {"text": ""}
predicted_answers.append({"id": ex_id, "prediction_text": best["text"]})
formatted_refs = [{"id": k, "answers": v} for k, v in references.items()]
return metric.compute(predictions=predicted_answers, references=formatted_refs)
# ── 6. Training ───────────────────────────────────────────────
args = TrainingArguments(
output_dir=OUTPUT_DIR,
eval_strategy="steps",
eval_steps=500,
save_strategy="steps",
save_steps=500,
learning_rate=LR,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
num_train_epochs=EPOCHS,
weight_decay=0.01,
logging_steps=100,
fp16=True,
report_to="none",
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=val_for_trainer,
processing_class=tokenizer,
data_collator=DefaultDataCollator(),
compute_metrics=None,
)
trainer.train()
# ── 7. Final evaluation ────────────────────────────
print("--- Starting final evaluation ---")
predictions = trainer.predict(val_for_trainer)
final_metrics = compute_metrics(predictions)
print(f"Final results: {final_metrics}")
trainer.save_model(OUTPUT_DIR)
print("✅ DONE!")