4nkh commited on
Commit
36b732d
·
verified ·
1 Parent(s): 7558b0f

Upload train_theme_model.py

Browse files
Files changed (1) hide show
  1. train_theme_model.py +143 -0
train_theme_model.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, os, math, random
2
+ from dataclasses import dataclass
3
+ from typing import Dict, List, Any
4
+
5
+ import numpy as np
6
+ from datasets import Dataset, DatasetDict
7
+ from transformers import (AutoTokenizer, AutoModelForSequenceClassification,
8
+ DataCollatorWithPadding, TrainingArguments, Trainer)
9
+ import evaluate
10
+ from sklearn.metrics import precision_recall_fscore_support
11
+
12
+ # ------------------
13
+ # CONFIG
14
+ # ------------------
15
+ MODEL_NAME = "bert-base-uncased" # swap to a lighter model (e.g., distilbert-base-uncased) if desired
16
+ LABELS = ["mentorship", "entrepreneurship", "startup success"]
17
+ TEXT_FIELDS = ["original_text", "summary"] # we'll concat these to give the model more signal
18
+ SEED = 42
19
+ HF_REPO_ID = "4hnk/theme-multilabel-model" # <--- change this to your namespace
20
+
21
+ random.seed(SEED)
22
+ np.random.seed(SEED)
23
+
24
+ # ------------------
25
+ # LOAD YOUR JSON
26
+ # ------------------
27
+ # Change this path if needed; it matches the file you mentioned.
28
+ DATA_PATH = "theme_response.json"
29
+
30
+ with open(DATA_PATH, "r", encoding="utf-8") as f:
31
+ data = json.load(f)["knowledge_theme_training_data"]
32
+
33
+ def to_example(row: Dict[str, Any]) -> Dict[str, Any]:
34
+ text = " ".join([row.get(k, "") for k in TEXT_FIELDS if row.get(k)])
35
+ y = [1 if lbl in row.get("themes", []) else 0 for lbl in LABELS]
36
+ return {"text": text.strip(), "labels": y}
37
+
38
+ examples = [to_example(r) for r in data if r.get("original_text")]
39
+ ds_full = Dataset.from_list(examples)
40
+
41
+ # ------------------
42
+ # TRAIN/VAL SPLIT (80/20)
43
+ # ------------------
44
+ ds_full = ds_full.shuffle(seed=SEED)
45
+ n = len(ds_full)
46
+ n_train = max(1, int(0.8 * n))
47
+ ds = DatasetDict({
48
+ "train": ds_full.select(range(n_train)),
49
+ "validation": ds_full.select(range(n_train, n))
50
+ })
51
+
52
+ # ------------------
53
+ # TOKENIZATION
54
+ # ------------------
55
+ tok = AutoTokenizer.from_pretrained(MODEL_NAME)
56
+
57
+ def tokenize(batch):
58
+ return tok(batch["text"], truncation=True)
59
+
60
+ ds = ds.map(tokenize, batched=True, remove_columns=["text"])
61
+ data_collator = DataCollatorWithPadding(tokenizer=tok)
62
+
63
+ # ------------------
64
+ # MODEL
65
+ # ------------------
66
+ model = AutoModelForSequenceClassification.from_pretrained(
67
+ MODEL_NAME,
68
+ num_labels=len(LABELS),
69
+ problem_type="multi_label_classification"
70
+ )
71
+ model.config.id2label = {i: l for i, l in enumerate(LABELS)}
72
+ model.config.label2id = {l: i for i, l in enumerate(LABELS)}
73
+
74
+ # ------------------
75
+ # METRICS (multi-label)
76
+ # ------------------
77
+ metric = evaluate.load("accuracy") # not super meaningful for multi-label, but we’ll compute real ones below
78
+
79
+ def sigmoid(x):
80
+ return 1 / (1 + np.exp(-x))
81
+
82
+ def compute_metrics(eval_pred, threshold=0.5):
83
+ logits, labels = eval_pred
84
+ probs = sigmoid(logits)
85
+ preds = (probs >= threshold).astype(int)
86
+
87
+ # micro/macro PRF
88
+ micro_p, micro_r, micro_f1, _ = precision_recall_fscore_support(
89
+ labels, preds, average="micro", zero_division=0
90
+ )
91
+ macro_p, macro_r, macro_f1, _ = precision_recall_fscore_support(
92
+ labels, preds, average="macro", zero_division=0
93
+ )
94
+ # per-label support could be useful too
95
+ out = {
96
+ "micro/precision": micro_p,
97
+ "micro/recall": micro_r,
98
+ "micro/f1": micro_f1,
99
+ "macro/precision": macro_p,
100
+ "macro/recall": macro_r,
101
+ "macro/f1": macro_f1,
102
+ }
103
+ return out
104
+
105
+ # ------------------
106
+ # TRAINING ARGS
107
+ # ------------------
108
+ args = TrainingArguments(
109
+ output_dir="./theme_model_outputs",
110
+ evaluation_strategy="epoch",
111
+ save_strategy="epoch",
112
+ learning_rate=2e-5,
113
+ per_device_train_batch_size=8,
114
+ per_device_eval_batch_size=16,
115
+ num_train_epochs=10, # small dataset -> more epochs
116
+ weight_decay=0.01,
117
+ load_best_model_at_end=True,
118
+ metric_for_best_model="micro/f1",
119
+ greater_is_better=True,
120
+ push_to_hub=True, # <--- enable Hub push
121
+ hub_model_id=HF_REPO_ID
122
+ )
123
+
124
+ # ------------------
125
+ # TRAIN
126
+ # ------------------
127
+ trainer = Trainer(
128
+ model=model,
129
+ args=args,
130
+ train_dataset=ds["train"],
131
+ eval_dataset=ds["validation"],
132
+ tokenizer=tok,
133
+ data_collator=data_collator,
134
+ compute_metrics=compute_metrics
135
+ )
136
+
137
+ trainer.train()
138
+ trainer.evaluate()
139
+
140
+ # ------------------
141
+ # SAVE + PUSH
142
+ # ------------------
143
+ trainer.push_to_hub()