DelaliScratchwerk commited on
Commit
467cb44
·
verified ·
1 Parent(s): 7936f59

Delete train_hf_classifier.py

Browse files
Files changed (1) hide show
  1. train_hf_classifier.py +0 -129
train_hf_classifier.py DELETED
@@ -1,129 +0,0 @@
1
- import json
2
- import numpy as np
3
- from datasets import load_dataset
4
- from transformers import (
5
- AutoTokenizer,
6
- AutoModelForSequenceClassification,
7
- TrainingArguments,
8
- Trainer,
9
- )
10
- import evaluate
11
- from huggingface_hub import upload_file
12
-
13
- # ---------- LABELS ----------
14
- LABELS = [
15
- "pre-1900",
16
- "1900–1945",
17
- "1946–1990",
18
- "1991–2008",
19
- "2009–2015",
20
- "2016–2018",
21
- "2019–2022",
22
- "2023–present",
23
- ]
24
-
25
- name2id = {name: i for i, name in enumerate(LABELS)}
26
- id2label = {i: name for i, name in enumerate(LABELS)}
27
-
28
- # ---------- DATA ----------
29
- # expects train.jsonl / val.jsonl with fields: "text", "label" (label is one of LABELS)
30
- ds = load_dataset(
31
- "json",
32
- data_files={"train": "train.jsonl", "val": "val.jsonl"},
33
- )
34
-
35
- # make sure all label names are present in train
36
- seen = set(row["label"] for row in ds["train"])
37
- missing = set(LABELS) - seen
38
- if missing:
39
- raise ValueError(f"Train set missing labels: {missing}")
40
-
41
- # map string labels -> ids
42
- def encode_label(example):
43
- return {"label": name2id[example["label"]]}
44
-
45
- ds = ds.map(encode_label)
46
-
47
- # ---------- TOKENIZATION ----------
48
- model_ckpt = "distilbert-base-uncased"
49
- tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
50
-
51
-
52
- def tokenize_batch(batch):
53
- return tokenizer(
54
- batch["text"],
55
- truncation=True,
56
- padding="max_length",
57
- max_length=256,
58
- )
59
-
60
-
61
- tokenized = ds.map(tokenize_batch, batched=True)
62
-
63
- # set format for Trainer
64
- tokenized = tokenized.remove_columns(
65
- [c for c in tokenized["train"].column_names if c not in ["input_ids", "attention_mask", "label"]]
66
- )
67
- tokenized.set_format("torch")
68
-
69
- # ---------- MODEL ----------
70
- model = AutoModelForSequenceClassification.from_pretrained(
71
- model_ckpt,
72
- num_labels=len(LABELS),
73
- id2label=id2label,
74
- label2id=name2id,
75
- )
76
-
77
- # ---------- METRICS ----------
78
- accuracy_metric = evaluate.load("accuracy")
79
-
80
-
81
- def compute_metrics(eval_pred):
82
- logits, labels = eval_pred
83
- preds = np.argmax(logits, axis=-1)
84
- return accuracy_metric.compute(predictions=preds, references=labels)
85
-
86
-
87
- # ---------- TRAINING ARGUMENTS (no evaluation_strategy etc.) ----------
88
- args = TrainingArguments(
89
- output_dir="./checkpoints-bert",
90
- learning_rate=2e-5,
91
- per_device_train_batch_size=8,
92
- per_device_eval_batch_size=8,
93
- num_train_epochs=4,
94
- weight_decay=0.01,
95
- logging_steps=10,
96
- save_total_limit=2,
97
- )
98
-
99
- # ---------- TRAINER ----------
100
- trainer = Trainer(
101
- model=model,
102
- args=args,
103
- train_dataset=tokenized["train"],
104
- eval_dataset=tokenized["val"],
105
- tokenizer=tokenizer,
106
- compute_metrics=compute_metrics,
107
- )
108
-
109
- # ---------- TRAIN + EVAL ----------
110
- trainer.train()
111
- print("Eval:", trainer.evaluate())
112
-
113
- # ---------- PUSH TO HUB ----------
114
- repo_id = "DelaliScratchwerk/text-period-bert" # pick the name you want
115
-
116
- trainer.push_to_hub(repo_id)
117
- print("Pushed model to:", repo_id)
118
-
119
- # also push labels.json so your Space / client can load the label names
120
- with open("labels_bert.json", "w") as f:
121
- json.dump(LABELS, f, ensure_ascii=False)
122
-
123
- upload_file(
124
- path_or_fileobj="labels_bert.json",
125
- path_in_repo="labels.json",
126
- repo_id=repo_id,
127
- repo_type="model",
128
- )
129
- print("Uploaded labels.json")