chiapudding commited on
Commit
0c732eb
·
1 Parent(s): 8504c73

upload model

Browse files
Files changed (1) hide show
  1. qatransformer2.py +91 -0
qatransformer2.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering, TrainingArguments, Trainer, DefaultDataCollator
3
+
4
+ squad = load_dataset("squad", split="train[:5000]")
5
+ squad = squad.train_test_split(test_size=0.2)
6
+
7
+ # preprocess
8
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
9
+
10
+ def preprocess_function(examples):
11
+ questions = [q.strip() for q in examples["question"]]
12
+ inputs = tokenizer(
13
+ questions,
14
+ examples["context"],
15
+ max_length=384,
16
+ truncation="only_second",
17
+ return_offsets_mapping=True,
18
+ padding="max_length",
19
+ )
20
+
21
+ offset_mapping = inputs.pop("offset_mapping")
22
+ answers = examples["answers"]
23
+ start_positions = []
24
+ end_positions = []
25
+
26
+ for i, offset in enumerate(offset_mapping):
27
+ answer = answers[i]
28
+ start_char = answer["answer_start"][0]
29
+ end_char = answer["answer_start"][0] + len(answer["text"][0])
30
+ sequence_ids = inputs.sequence_ids(i)
31
+
32
+ # Find the start and end of the context
33
+ idx = 0
34
+ while sequence_ids[idx] != 1:
35
+ idx += 1
36
+ context_start = idx
37
+ while sequence_ids[idx] == 1:
38
+ idx += 1
39
+ context_end = idx - 1
40
+
41
+ # If the answer is not fully inside the context, label it (0, 0)
42
+ if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
43
+ start_positions.append(0)
44
+ end_positions.append(0)
45
+ else:
46
+ # Otherwise it's the start and end token positions
47
+ idx = context_start
48
+ while idx <= context_end and offset[idx][0] <= start_char:
49
+ idx += 1
50
+ start_positions.append(idx - 1)
51
+
52
+ idx = context_end
53
+ while idx >= context_start and offset[idx][1] >= end_char:
54
+ idx -= 1
55
+ end_positions.append(idx + 1)
56
+
57
+ inputs["start_positions"] = start_positions
58
+ inputs["end_positions"] = end_positions
59
+ return inputs
60
+
61
+
62
+
63
+ # train
64
+ train_dataset = squad["train"].map(preprocess_function, batched=True)
65
+ eval_dataset = squad["test"].map(preprocess_function, batched=True)
66
+
67
+ model = AutoModelForQuestionAnswering.from_pretrained("bert-base-uncased")
68
+ training_args = TrainingArguments(
69
+ output_dir="question-answering",
70
+ evaluation_strategy="epoch",
71
+ learning_rate=2e-5,
72
+ per_device_train_batch_size=16,
73
+ per_device_eval_batch_size=16,
74
+ num_train_epochs=3,
75
+ weight_decay=0.01,
76
+ push_to_hub=True,
77
+ )
78
+
79
+ trainer = Trainer(
80
+ model=model,
81
+ args=training_args,
82
+ train_dataset=train_dataset,
83
+ eval_dataset=eval_dataset,
84
+ tokenizer=tokenizer,
85
+ data_collator=data_collator,
86
+ )
87
+
88
+ trainer.train()
89
+
90
+
91
+ #evaluation - todo