|
|
import json |
|
|
from datasets import load_dataset |
|
|
from setfit import SetFitModel, SetFitTrainer |
|
|
from huggingface_hub import upload_file |
|
|
|
|
|
LABELS = ["pre-1900","1900–1945","1946–1990","1991–2008","2009–2015","2016–2018","2019–2022","2023–present"] |
|
|
name2id = {n:i for i,n in enumerate(LABELS)} |
|
|
|
|
|
ds = load_dataset("json", data_files={"train":"train.jsonl","val":"val.jsonl"}) |
|
|
|
|
|
seen = set([row["label"] for row in ds["train"]]) |
|
|
assert seen.issuperset(LABELS), f"Train set missing labels: {set(LABELS)-seen}" |
|
|
|
|
|
ds = ds.map(lambda x: {"label": name2id[x["label"]]}) |
|
|
|
|
|
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2", num_labels=len(LABELS)) |
|
|
trainer = SetFitTrainer( |
|
|
model=model, |
|
|
train_dataset=ds["train"], |
|
|
eval_dataset=ds["val"], |
|
|
metric="accuracy", |
|
|
num_iterations=20, |
|
|
num_epochs=2, |
|
|
batch_size=16 |
|
|
) |
|
|
trainer.train() |
|
|
print("Eval:", trainer.evaluate()) |
|
|
|
|
|
|
|
|
repo_id = "DelaliScratchwerk/text-period-setfit" |
|
|
trainer.push_to_hub(repo_id) |
|
|
print("Pushed to:", repo_id) |
|
|
|
|
|
|
|
|
with open("labels.json","w") as f: |
|
|
json.dump(LABELS, f) |
|
|
upload_file(path_or_fileobj="labels.json", path_in_repo="labels.json", repo_id=repo_id, repo_type="model") |
|
|
print("Uploaded labels.json") |
|
|
|