DelaliScratchwerk commited on
Commit
d25d051
·
verified ·
1 Parent(s): 8ac19c7

Create train_hf_classifier.py

Browse files
Files changed (1) hide show
  1. train_hf_classifier.py +109 -0
train_hf_classifier.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_hf_classifier.py
2
+
3
+ import json
4
+ from datasets import load_dataset
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForSequenceClassification,
8
+ Trainer,
9
+ TrainingArguments,
10
+ )
11
+ from huggingface_hub import HfApi
12
+
13
+ MODEL_NAME = "distilbert-base-uncased" # backbone
14
+ REPO_ID = "DelaliScratchwerk/text-period-bert" # <- choose a new model repo name
15
+
16
+ LABELS = [
17
+ "pre-1900",
18
+ "1900–1945",
19
+ "1946–1990",
20
+ "1991–2008",
21
+ "2009–2015",
22
+ "2016–2018",
23
+ "2019–2022",
24
+ "2023–present",
25
+ ]
26
+ label2id = {l: i for i, l in enumerate(LABELS)}
27
+ id2label = {i: l for l, i in label2id.items()}
28
+
29
+ # 1) Load your jsonl data (same files you used for SetFit)
30
+ ds = load_dataset("json", data_files={"train": "train.jsonl", "val": "val.jsonl"})
31
+
32
+ # Check columns: assume {"text": "...", "label": "1946–1990"}
33
+ def encode_label(example):
34
+ example["labels"] = label2id[example["label"]]
35
+ return example
36
+
37
+ ds = ds.map(encode_label)
38
+
39
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
40
+
41
+ def tokenize(examples):
42
+ return tokenizer(
43
+ examples["text"],
44
+ padding="max_length",
45
+ truncation=True,
46
+ max_length=256,
47
+ )
48
+
49
+ tokenized = ds.map(tokenize, batched=True)
50
+
51
+ # HF Trainer expects these columns
52
+ tokenized = tokenized.remove_columns(["text", "label"])
53
+ tokenized.set_format("torch")
54
+
55
+ model = AutoModelForSequenceClassification.from_pretrained(
56
+ MODEL_NAME,
57
+ num_labels=len(LABELS),
58
+ id2label=id2label,
59
+ label2id=label2id,
60
+ )
61
+
62
+ args = TrainingArguments(
63
+ output_dir="./checkpoints-bert",
64
+ evaluation_strategy="epoch",
65
+ save_strategy="epoch",
66
+ learning_rate=2e-5,
67
+ per_device_train_batch_size=16,
68
+ per_device_eval_batch_size=16,
69
+ num_train_epochs=3,
70
+ weight_decay=0.01,
71
+ load_best_model_at_end=True,
72
+ metric_for_best_model="accuracy",
73
+ )
74
+
75
+ from datasets import load_metric
76
+ metric = load_metric("accuracy")
77
+
78
+ def compute_metrics(eval_pred):
79
+ logits, labels = eval_pred
80
+ preds = logits.argmax(axis=-1)
81
+ return metric.compute(predictions=preds, references=labels)
82
+
83
+ trainer = Trainer(
84
+ model=model,
85
+ args=args,
86
+ train_dataset=tokenized["train"],
87
+ eval_dataset=tokenized["val"],
88
+ compute_metrics=compute_metrics,
89
+ )
90
+
91
+ trainer.train()
92
+ print("Eval:", trainer.evaluate())
93
+
94
+ # 2) Push model to Hub
95
+ trainer.push_to_hub(REPO_ID)
96
+
97
+ # 3) Also upload labels list as labels.json (handy but optional)
98
+ with open("labels.json", "w") as f:
99
+ json.dump(LABELS, f, ensure_ascii=False, indent=2)
100
+
101
+ api = HfApi()
102
+ api.upload_file(
103
+ path_or_fileobj="labels.json",
104
+ path_in_repo="labels.json",
105
+ repo_id=REPO_ID,
106
+ repo_type="model",
107
+ )
108
+
109
+ print("Pushed model to:", REPO_ID)