| |
| |
| |
| |
|
|
| |
| import numpy as np |
| import collections |
| import evaluate |
| from datasets import load_dataset |
| from transformers import ( |
| BertConfig, |
| BertForQuestionAnswering, |
| BertTokenizerFast, |
| DefaultDataCollator, |
| TrainingArguments, |
| Trainer, |
| ) |
|
|
| |
| MODEL_NAME = "bert-base-uncased" |
| MAX_LENGTH = 384 |
| DOC_STRIDE = 128 |
| BATCH_SIZE = 16 |
| EPOCHS = 3 |
| LR = 3e-4 |
| OUTPUT_DIR = "Excerp" |
|
|
| |
| raw = load_dataset("squad") |
|
|
| |
| tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME) |
|
|
| |
| def preprocess_train(examples): |
| tokenized = tokenizer( |
| examples["question"], |
| examples["context"], |
| max_length=MAX_LENGTH, |
| truncation="only_second", |
| stride=DOC_STRIDE, |
| return_overflowing_tokens=True, |
| return_offsets_mapping=True, |
| padding="max_length", |
| ) |
| sample_map = tokenized.pop("overflow_to_sample_mapping") |
| offset_mapping = tokenized.pop("offset_mapping") |
|
|
| start_positions, end_positions = [], [] |
|
|
| for i, offsets in enumerate(offset_mapping): |
| sample_idx = sample_map[i] |
| answers = examples["answers"][sample_idx] |
| cls_index = tokenized["input_ids"][i].index(tokenizer.cls_token_id) |
|
|
| sequence_ids = tokenized.sequence_ids(i) |
|
|
| if len(answers["answer_start"]) == 0: |
| start_positions.append(cls_index) |
| end_positions.append(cls_index) |
| continue |
|
|
| start_char = answers["answer_start"][0] |
| end_char = start_char + len(answers["text"][0]) |
|
|
| token_start = next((j for j, s in enumerate(sequence_ids) if s == 1), None) |
| token_end = next((j for j in range(len(sequence_ids)-1, -1, -1) if sequence_ids[j] == 1), None) |
|
|
| if offsets[token_start][0] > end_char or offsets[token_end][1] < start_char: |
| start_positions.append(cls_index) |
| end_positions.append(cls_index) |
| continue |
|
|
| start_tok = token_start |
| while start_tok <= token_end and offsets[start_tok][0] <= start_char: |
| start_tok += 1 |
| start_positions.append(start_tok - 1) |
|
|
| end_tok = token_end |
| while end_tok >= token_start and offsets[end_tok][1] >= end_char: |
| end_tok -= 1 |
| end_positions.append(end_tok + 1) |
|
|
| tokenized["start_positions"] = start_positions |
| tokenized["end_positions"] = end_positions |
| return tokenized |
|
|
|
|
| def preprocess_validation(examples): |
| tokenized = tokenizer( |
| examples["question"], |
| examples["context"], |
| max_length=MAX_LENGTH, |
| truncation="only_second", |
| stride=DOC_STRIDE, |
| return_overflowing_tokens=True, |
| return_offsets_mapping=True, |
| padding="max_length", |
| ) |
| sample_map = tokenized.pop("overflow_to_sample_mapping") |
| tokenized["example_id"] = [] |
|
|
| for i in range(len(tokenized["input_ids"])): |
| sample_idx = sample_map[i] |
| tokenized["example_id"].append(examples["id"][sample_idx]) |
| sequence_ids = tokenized.sequence_ids(i) |
| tokenized["offset_mapping"][i] = [ |
| o if sequence_ids[j] == 1 else None |
| for j, o in enumerate(tokenized["offset_mapping"][i]) |
| ] |
| return tokenized |
|
|
|
|
| train_dataset = raw["train"].map( |
| preprocess_train, |
| batched=True, |
| remove_columns=raw["train"].column_names, |
| ) |
| val_dataset = raw["validation"].map( |
| preprocess_validation, |
| batched=True, |
| remove_columns=raw["validation"].column_names, |
| ) |
|
|
| |
| config = BertConfig( |
| vocab_size=tokenizer.vocab_size, |
| hidden_size=384, |
| num_hidden_layers=6, |
| num_attention_heads=6, |
| intermediate_size=1536, |
| max_position_embeddings=512, |
| hidden_dropout_prob=0.1, |
| attention_probs_dropout_prob=0.1, |
| ) |
| model = BertForQuestionAnswering(config) |
| print(f"Parameters: {model.num_parameters():,}") |
|
|
| |
| metric = evaluate.load("squad") |
|
|
| def compute_metrics(p): |
| |
| start_logits, end_logits = p.predictions |
|
|
| n_best = 20 |
| max_answer_len = 30 |
| example_ids = val_dataset["example_id"] |
| offset_mappings = val_dataset["offset_mapping"] |
| contexts = {ex["id"]: ex["context"] for ex in raw["validation"]} |
| references = {ex["id"]: ex["answers"] for ex in raw["validation"]} |
|
|
| feat_per_example = collections.defaultdict(list) |
| for feat_idx, ex_id in enumerate(example_ids): |
| feat_per_example[ex_id].append(feat_idx) |
|
|
| predicted_answers = [] |
| for ex_id, feat_indices in feat_per_example.items(): |
| context = contexts[ex_id] |
| candidates = [] |
|
|
| for fi in feat_indices: |
| offsets = offset_mappings[fi] |
| s_logits = start_logits[fi] |
| e_logits = end_logits[fi] |
| s_indexes = np.argsort(s_logits)[-1:-n_best-1:-1].tolist() |
| e_indexes = np.argsort(e_logits)[-1:-n_best-1:-1].tolist() |
|
|
| for s in s_indexes: |
| for e in e_indexes: |
| if offsets[s] is None or offsets[e] is None: |
| continue |
| if e < s or e - s + 1 > max_answer_len: |
| continue |
| candidates.append({ |
| "score": s_logits[s] + e_logits[e], |
| "text": context[offsets[s][0]: offsets[e][1]], |
| }) |
|
|
| best = max(candidates, key=lambda x: x["score"]) if candidates else {"text": ""} |
| predicted_answers.append({"id": ex_id, "prediction_text": best["text"]}) |
|
|
| formatted_refs = [{"id": k, "answers": v} for k, v in references.items()] |
| return metric.compute(predictions=predicted_answers, references=formatted_refs) |
|
|
|
|
| |
| args = TrainingArguments( |
| output_dir=OUTPUT_DIR, |
| eval_strategy="steps", |
| eval_steps=500, |
| save_strategy="steps", |
| save_steps=500, |
| learning_rate=LR, |
| per_device_train_batch_size=BATCH_SIZE, |
| per_device_eval_batch_size=BATCH_SIZE, |
| num_train_epochs=EPOCHS, |
| weight_decay=0.01, |
| logging_steps=100, |
| fp16=True, |
| report_to="none", |
| ) |
|
|
| trainer = Trainer( |
| model=model, |
| args=args, |
| train_dataset=train_dataset, |
| eval_dataset=val_for_trainer, |
| processing_class=tokenizer, |
| data_collator=DefaultDataCollator(), |
| compute_metrics=None, |
| ) |
|
|
| trainer.train() |
|
|
| |
| print("--- Starting final evaluation ---") |
| predictions = trainer.predict(val_for_trainer) |
| final_metrics = compute_metrics(predictions) |
| print(f"Final results: {final_metrics}") |
|
|
| trainer.save_model(OUTPUT_DIR) |
| print("✅ DONE!") |