|
|
import json, os, math, random |
|
|
from dataclasses import dataclass |
|
|
from typing import Dict, List, Any |
|
|
|
|
|
import numpy as np |
|
|
from datasets import Dataset, DatasetDict |
|
|
from transformers import (AutoTokenizer, AutoModelForSequenceClassification, |
|
|
DataCollatorWithPadding, TrainingArguments, Trainer) |
|
|
import evaluate |
|
|
from sklearn.metrics import precision_recall_fscore_support |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = "bert-base-uncased" |
|
|
LABELS = ["mentorship", "entrepreneurship", "startup success"] |
|
|
TEXT_FIELDS = ["original_text", "summary"] |
|
|
SEED = 42 |
|
|
HF_REPO_ID = "4hnk/theme-multilabel-model" |
|
|
|
|
|
random.seed(SEED) |
|
|
np.random.seed(SEED) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DATA_PATH = "theme_response.json" |
|
|
|
|
|
with open(DATA_PATH, "r", encoding="utf-8") as f: |
|
|
data = json.load(f)["knowledge_theme_training_data"] |
|
|
|
|
|
def to_example(row: Dict[str, Any]) -> Dict[str, Any]: |
|
|
text = " ".join([row.get(k, "") for k in TEXT_FIELDS if row.get(k)]) |
|
|
y = [1 if lbl in row.get("themes", []) else 0 for lbl in LABELS] |
|
|
return {"text": text.strip(), "labels": y} |
|
|
|
|
|
examples = [to_example(r) for r in data if r.get("original_text")] |
|
|
ds_full = Dataset.from_list(examples) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ds_full = ds_full.shuffle(seed=SEED) |
|
|
n = len(ds_full) |
|
|
n_train = max(1, int(0.8 * n)) |
|
|
ds = DatasetDict({ |
|
|
"train": ds_full.select(range(n_train)), |
|
|
"validation": ds_full.select(range(n_train, n)) |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tok = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
|
|
|
def tokenize(batch): |
|
|
return tok(batch["text"], truncation=True) |
|
|
|
|
|
ds = ds.map(tokenize, batched=True, remove_columns=["text"]) |
|
|
data_collator = DataCollatorWithPadding(tokenizer=tok) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
MODEL_NAME, |
|
|
num_labels=len(LABELS), |
|
|
problem_type="multi_label_classification" |
|
|
) |
|
|
model.config.id2label = {i: l for i, l in enumerate(LABELS)} |
|
|
model.config.label2id = {l: i for i, l in enumerate(LABELS)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
metric = evaluate.load("accuracy") |
|
|
|
|
|
def sigmoid(x): |
|
|
return 1 / (1 + np.exp(-x)) |
|
|
|
|
|
def compute_metrics(eval_pred, threshold=0.5): |
|
|
logits, labels = eval_pred |
|
|
probs = sigmoid(logits) |
|
|
preds = (probs >= threshold).astype(int) |
|
|
|
|
|
|
|
|
micro_p, micro_r, micro_f1, _ = precision_recall_fscore_support( |
|
|
labels, preds, average="micro", zero_division=0 |
|
|
) |
|
|
macro_p, macro_r, macro_f1, _ = precision_recall_fscore_support( |
|
|
labels, preds, average="macro", zero_division=0 |
|
|
) |
|
|
|
|
|
out = { |
|
|
"micro/precision": micro_p, |
|
|
"micro/recall": micro_r, |
|
|
"micro/f1": micro_f1, |
|
|
"macro/precision": macro_p, |
|
|
"macro/recall": macro_r, |
|
|
"macro/f1": macro_f1, |
|
|
} |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args = TrainingArguments( |
|
|
output_dir="./theme_model_outputs", |
|
|
evaluation_strategy="epoch", |
|
|
save_strategy="epoch", |
|
|
learning_rate=2e-5, |
|
|
per_device_train_batch_size=8, |
|
|
per_device_eval_batch_size=16, |
|
|
num_train_epochs=10, |
|
|
weight_decay=0.01, |
|
|
load_best_model_at_end=True, |
|
|
metric_for_best_model="micro/f1", |
|
|
greater_is_better=True, |
|
|
push_to_hub=True, |
|
|
hub_model_id=HF_REPO_ID |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=args, |
|
|
train_dataset=ds["train"], |
|
|
eval_dataset=ds["validation"], |
|
|
tokenizer=tok, |
|
|
data_collator=data_collator, |
|
|
compute_metrics=compute_metrics |
|
|
) |
|
|
|
|
|
trainer.train() |
|
|
trainer.evaluate() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trainer.push_to_hub() |
|
|
|