amcoff commited on
Commit
852edbc
·
unverified ·
1 Parent(s): 3d5601e

add train.py

Browse files
Files changed (2) hide show
  1. README.md +16 -3
  2. train.py +67 -0
README.md CHANGED
@@ -1,11 +1,24 @@
1
  ---
2
  license: mit
3
  datasets:
4
- - amcoff/skolmat
5
  language:
6
- - sv
7
  library_name: transformers
8
  pipeline_tag: text-classification
9
  widget:
10
- - text: "Kycklingwok med äggnudlar och sojasås"
 
11
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
  datasets:
4
+ - amcoff/skolmat
5
  language:
6
+ - sv
7
  library_name: transformers
8
  pipeline_tag: text-classification
9
  widget:
10
+ - text: "Kycklingwok med äggnudlar och sojasås"
11
+ - text: "Kökets val"
12
  ---
13
+
14
+ ```python
15
+ from transformers import pipeline
16
+
17
+ nlp = pipeline(
18
+ "text-classification",
19
+ model="amcoff/classify_skolmat",
20
+ tokenizer="KBLab/bert-base-swedish-cased",
21
+ )
22
+
23
+ nlp("Kökets val")
24
+ ```
train.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import (
2
+ AutoModelForSequenceClassification,
3
+ AutoTokenizer,
4
+ TrainingArguments,
5
+ Trainer,
6
+ )
7
+ from datasets import load_dataset
8
+ import numpy as np
9
+ import evaluate
10
+
11
+ dataset = load_dataset("amcoff/skolmat")["train"].train_test_split(test_size=0.1)
12
+
13
+ id2label = {k: v for k, v in enumerate(dataset["train"].features["label"].names)}
14
+ label2id = {v: k for k, v in id2label.items()}
15
+
16
+ tokenizer = AutoTokenizer.from_pretrained("KBLab/bert-base-swedish-cased")
17
+
18
+ max_length = 128
19
+
20
+
21
+ def tokenize_function(examples):
22
+ return tokenizer(
23
+ examples["meal"], padding="max_length", truncation=True, max_length=max_length
24
+ )
25
+
26
+
27
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
28
+
29
+ small_train_dataset = tokenized_datasets["train"]
30
+ small_eval_dataset = tokenized_datasets["train"]
31
+
32
+ model = AutoModelForSequenceClassification.from_pretrained(
33
+ "KBLab/bert-base-swedish-cased",
34
+ num_labels=len(id2label),
35
+ id2label=id2label,
36
+ label2id=label2id,
37
+ )
38
+
39
+ training_args = TrainingArguments(
40
+ output_dir="trainer",
41
+ evaluation_strategy="epoch",
42
+ per_device_train_batch_size=4,
43
+ )
44
+
45
+ metric = evaluate.load("accuracy")
46
+
47
+
48
+ def compute_metrics(eval_pred):
49
+ logits, labels = eval_pred
50
+
51
+ predictions = np.argmax(logits, axis=-1)
52
+
53
+ return metric.compute(predictions=predictions, references=labels)
54
+
55
+
56
+ trainer = Trainer(
57
+ model=model,
58
+ args=training_args,
59
+ train_dataset=small_train_dataset,
60
+ eval_dataset=small_eval_dataset,
61
+ compute_metrics=compute_metrics,
62
+ )
63
+
64
+ trainer.train()
65
+
66
+
67
+ trainer.save_model("model")