DelaliScratchwerk commited on
Commit
7936f59
·
verified ·
1 Parent(s): d25d051

Update train_hf_classifier.py

Browse files
Files changed (1) hide show
  1. train_hf_classifier.py +69 -49
train_hf_classifier.py CHANGED
@@ -1,18 +1,16 @@
1
- # train_hf_classifier.py
2
-
3
  import json
 
4
  from datasets import load_dataset
5
  from transformers import (
6
  AutoTokenizer,
7
  AutoModelForSequenceClassification,
8
- Trainer,
9
  TrainingArguments,
 
10
  )
11
- from huggingface_hub import HfApi
12
-
13
- MODEL_NAME = "distilbert-base-uncased" # backbone
14
- REPO_ID = "DelaliScratchwerk/text-period-bert" # <- choose a new model repo name
15
 
 
16
  LABELS = [
17
  "pre-1900",
18
  "1900–1945",
@@ -23,87 +21,109 @@ LABELS = [
23
  "2019–2022",
24
  "2023–present",
25
  ]
26
- label2id = {l: i for i, l in enumerate(LABELS)}
27
- id2label = {i: l for l, i in label2id.items()}
28
 
29
- # 1) Load your jsonl data (same files you used for SetFit)
30
- ds = load_dataset("json", data_files={"train": "train.jsonl", "val": "val.jsonl"})
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- # Check columns: assume {"text": "...", "label": "1946–1990"}
33
  def encode_label(example):
34
- example["labels"] = label2id[example["label"]]
35
- return example
36
 
37
  ds = ds.map(encode_label)
38
 
39
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
 
 
40
 
41
- def tokenize(examples):
42
  return tokenizer(
43
- examples["text"],
44
- padding="max_length",
45
  truncation=True,
 
46
  max_length=256,
47
  )
48
 
49
- tokenized = ds.map(tokenize, batched=True)
50
 
51
- # HF Trainer expects these columns
52
- tokenized = tokenized.remove_columns(["text", "label"])
 
 
 
 
53
  tokenized.set_format("torch")
54
 
 
55
  model = AutoModelForSequenceClassification.from_pretrained(
56
- MODEL_NAME,
57
  num_labels=len(LABELS),
58
  id2label=id2label,
59
- label2id=label2id,
60
  )
61
 
 
 
 
 
 
 
 
 
 
 
 
62
  args = TrainingArguments(
63
  output_dir="./checkpoints-bert",
64
- evaluation_strategy="epoch",
65
- save_strategy="epoch",
66
  learning_rate=2e-5,
67
- per_device_train_batch_size=16,
68
- per_device_eval_batch_size=16,
69
- num_train_epochs=3,
70
  weight_decay=0.01,
71
- load_best_model_at_end=True,
72
- metric_for_best_model="accuracy",
73
  )
74
 
75
- from datasets import load_metric
76
- metric = load_metric("accuracy")
77
-
78
- def compute_metrics(eval_pred):
79
- logits, labels = eval_pred
80
- preds = logits.argmax(axis=-1)
81
- return metric.compute(predictions=preds, references=labels)
82
-
83
  trainer = Trainer(
84
  model=model,
85
  args=args,
86
  train_dataset=tokenized["train"],
87
  eval_dataset=tokenized["val"],
 
88
  compute_metrics=compute_metrics,
89
  )
90
 
 
91
  trainer.train()
92
  print("Eval:", trainer.evaluate())
93
 
94
- # 2) Push model to Hub
95
- trainer.push_to_hub(REPO_ID)
 
 
 
96
 
97
- # 3) Also upload labels list as labels.json (handy but optional)
98
- with open("labels.json", "w") as f:
99
- json.dump(LABELS, f, ensure_ascii=False, indent=2)
100
 
101
- api = HfApi()
102
- api.upload_file(
103
- path_or_fileobj="labels.json",
104
  path_in_repo="labels.json",
105
- repo_id=REPO_ID,
106
  repo_type="model",
107
  )
108
-
109
- print("Pushed model to:", REPO_ID)
 
 
 
1
  import json
2
+ import numpy as np
3
  from datasets import load_dataset
4
  from transformers import (
5
  AutoTokenizer,
6
  AutoModelForSequenceClassification,
 
7
  TrainingArguments,
8
+ Trainer,
9
  )
10
+ import evaluate
11
+ from huggingface_hub import upload_file
 
 
12
 
13
+ # ---------- LABELS ----------
14
  LABELS = [
15
  "pre-1900",
16
  "1900–1945",
 
21
  "2019–2022",
22
  "2023–present",
23
  ]
 
 
24
 
25
+ name2id = {name: i for i, name in enumerate(LABELS)}
26
+ id2label = {i: name for i, name in enumerate(LABELS)}
27
+
28
+ # ---------- DATA ----------
29
+ # expects train.jsonl / val.jsonl with fields: "text", "label" (label is one of LABELS)
30
+ ds = load_dataset(
31
+ "json",
32
+ data_files={"train": "train.jsonl", "val": "val.jsonl"},
33
+ )
34
+
35
+ # make sure all label names are present in train
36
+ seen = set(row["label"] for row in ds["train"])
37
+ missing = set(LABELS) - seen
38
+ if missing:
39
+ raise ValueError(f"Train set missing labels: {missing}")
40
 
41
+ # map string labels -> ids
42
  def encode_label(example):
43
+ return {"label": name2id[example["label"]]}
 
44
 
45
  ds = ds.map(encode_label)
46
 
47
+ # ---------- TOKENIZATION ----------
48
+ model_ckpt = "distilbert-base-uncased"
49
+ tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
50
+
51
 
52
+ def tokenize_batch(batch):
53
  return tokenizer(
54
+ batch["text"],
 
55
  truncation=True,
56
+ padding="max_length",
57
  max_length=256,
58
  )
59
 
 
60
 
61
+ tokenized = ds.map(tokenize_batch, batched=True)
62
+
63
+ # set format for Trainer
64
+ tokenized = tokenized.remove_columns(
65
+ [c for c in tokenized["train"].column_names if c not in ["input_ids", "attention_mask", "label"]]
66
+ )
67
  tokenized.set_format("torch")
68
 
69
+ # ---------- MODEL ----------
70
  model = AutoModelForSequenceClassification.from_pretrained(
71
+ model_ckpt,
72
  num_labels=len(LABELS),
73
  id2label=id2label,
74
+ label2id=name2id,
75
  )
76
 
77
+ # ---------- METRICS ----------
78
+ accuracy_metric = evaluate.load("accuracy")
79
+
80
+
81
+ def compute_metrics(eval_pred):
82
+ logits, labels = eval_pred
83
+ preds = np.argmax(logits, axis=-1)
84
+ return accuracy_metric.compute(predictions=preds, references=labels)
85
+
86
+
87
+ # ---------- TRAINING ARGUMENTS (no evaluation_strategy etc.) ----------
88
  args = TrainingArguments(
89
  output_dir="./checkpoints-bert",
 
 
90
  learning_rate=2e-5,
91
+ per_device_train_batch_size=8,
92
+ per_device_eval_batch_size=8,
93
+ num_train_epochs=4,
94
  weight_decay=0.01,
95
+ logging_steps=10,
96
+ save_total_limit=2,
97
  )
98
 
99
+ # ---------- TRAINER ----------
 
 
 
 
 
 
 
100
  trainer = Trainer(
101
  model=model,
102
  args=args,
103
  train_dataset=tokenized["train"],
104
  eval_dataset=tokenized["val"],
105
+ tokenizer=tokenizer,
106
  compute_metrics=compute_metrics,
107
  )
108
 
109
+ # ---------- TRAIN + EVAL ----------
110
  trainer.train()
111
  print("Eval:", trainer.evaluate())
112
 
113
+ # ---------- PUSH TO HUB ----------
114
+ repo_id = "DelaliScratchwerk/text-period-bert" # pick the name you want
115
+
116
+ trainer.push_to_hub(repo_id)
117
+ print("Pushed model to:", repo_id)
118
 
119
+ # also push labels.json so your Space / client can load the label names
120
+ with open("labels_bert.json", "w") as f:
121
+ json.dump(LABELS, f, ensure_ascii=False)
122
 
123
+ upload_file(
124
+ path_or_fileobj="labels_bert.json",
 
125
  path_in_repo="labels.json",
126
+ repo_id=repo_id,
127
  repo_type="model",
128
  )
129
+ print("Uploaded labels.json")