ChiJuiChen commited on
Commit
f966412
·
verified ·
1 Parent(s): 70fe017

add how to use

Browse files
Files changed (1) hide show
  1. README.md +51 -0
README.md CHANGED
@@ -32,6 +32,57 @@ More information needed
32
 
33
  More information needed
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ## Training procedure
36
 
37
  ### Training hyperparameters
 
32
 
33
  More information needed
34
 
35
+ ## How to use
36
+
37
+ ```python
38
+ from datasets import load_dataset
39
+ from transformers import AutoTokenizer, DataCollatorWithPadding
40
+
41
+ raw_datasets = load_dataset("glue", "sst2")
42
+ checkpoint = "ChiJuiChen/Bert-Lab4"
43
+ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
44
+
45
+
46
+ def tokenize_function(example):
47
+ return tokenizer(example["sentence"], truncation=True)
48
+
49
+
50
+ tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
51
+
52
+ small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(100))
53
+ small_eval_dataset = tokenized_datasets["validation"].shuffle(seed=42).select(range(100))
54
+
55
+ data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
56
+
57
+ from transformers import TrainingArguments
58
+
59
+ training_args = TrainingArguments(output_dir="ChiJuiChen/Bert-Lab4",
60
+ evaluation_strategy="epoch",
61
+ hub_model_id="ChiJuiChen/Bert-Lab4")
62
+
63
+ from transformers import AutoModelForSequenceClassification
64
+ model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
65
+
66
+ from transformers import Trainer
67
+ trainer = Trainer(
68
+ model,
69
+ training_args,
70
+ train_dataset=small_train_dataset, # if using cpu
71
+ eval_dataset=small_eval_dataset, # if using cpu
72
+ data_collator=data_collator,
73
+ tokenizer=tokenizer,
74
+ compute_metrics=compute_metrics,
75
+ )
76
+
77
+ # Evaluation
78
+ predictions = trainer.predict(small_eval_dataset)
79
+ print(predictions.predictions.shape, predictions.label_ids.shape)
80
+ preds = np.argmax(predictions.predictions, axis=-1)
81
+
82
+ import evaluate
83
+ metric = evaluate.load("glue", "sst2")
84
+ metric.compute(predictions=preds, references=predictions.label_ids)
85
+ ```
86
  ## Training procedure
87
 
88
  ### Training hyperparameters