Spaces:
Build error
Build error
| """Simple training script for a text toxicity classifier. | |
| Usage examples: | |
| - Train from a CSV: python train.py --dataset_csv data/toxic_train.csv --text_col text --label_col label --output_dir ./outputs | |
| - Push to Hub: python train.py --dataset_csv data/toxic_train.csv --output_dir ./outputs --push_to_hub --hub_model_id your-username/toxic-detector | |
| Expect CSV with columns: text, label (0/1) for single-label classification. For multi-label adjust the preprocessing. | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| from datasets import load_dataset, Dataset | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| TrainingArguments, | |
| Trainer, | |
| DataCollatorWithPadding, | |
| ) | |
| import numpy as np | |
| import evaluate | |
| def parse_args(): | |
| p = argparse.ArgumentParser() | |
| p.add_argument("--dataset_csv", type=str, default=None, help="Path to CSV dataset with text and label columns") | |
| p.add_argument("--text_col", type=str, default="text") | |
| p.add_argument("--label_col", type=str, default="label") | |
| p.add_argument("--model_name_or_path", type=str, default="distilbert-base-uncased") | |
| p.add_argument("--output_dir", type=str, default="./model_output") | |
| p.add_argument("--push_to_hub", action="store_true") | |
| p.add_argument("--hub_model_id", type=str, default=None) | |
| p.add_argument("--num_train_epochs", type=int, default=1) | |
| p.add_argument("--per_device_train_batch_size", type=int, default=16) | |
| return p.parse_args() | |
| def main(): | |
| args = parse_args() | |
| if args.dataset_csv: | |
| ds = load_dataset("csv", data_files={"train": args.dataset_csv}) | |
| # if no validation split, take 10% for val | |
| ds = ds["train"].train_test_split(test_size=0.1) | |
| dataset = ds | |
| else: | |
| # small built-in fallback: use a tiny subset of imdb for demo (binary sentiment) | |
| dataset = load_dataset("imdb", split={"train": "train[:2000]","test": "test[:500]"}) | |
| dataset = dataset | |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) | |
| def preprocess_function(examples): | |
| texts = examples[args.text_col] if args.dataset_csv else examples["text"] | |
| return tokenizer(texts, truncation=True) | |
| if args.dataset_csv: | |
| tokenized = dataset.map(preprocess_function, batched=True) | |
| else: | |
| # imdb default has 'text' and 'label' | |
| tokenized = dataset.map(lambda x: tokenizer(x['text'], truncation=True), batched=True) | |
| labels = tokenized["train"].features[args.label_col] if args.dataset_csv else None | |
| num_labels = 2 | |
| model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, num_labels=num_labels) | |
| metric_acc = evaluate.load("accuracy") | |
| def compute_metrics(eval_pred): | |
| logits, labels = eval_pred | |
| preds = np.argmax(logits, axis=-1) | |
| return metric_acc.compute(predictions=preds, references=labels) | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| training_args = TrainingArguments( | |
| output_dir=args.output_dir, | |
| evaluation_strategy="epoch", | |
| num_train_epochs=args.num_train_epochs, | |
| per_device_train_batch_size=args.per_device_train_batch_size, | |
| save_total_limit=2, | |
| push_to_hub=args.push_to_hub, | |
| hub_model_id=args.hub_model_id, | |
| ) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized["train"], | |
| eval_dataset=tokenized.get("test", None), | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| ) | |
| trainer.train() | |
| trainer.save_model() | |
| if args.push_to_hub and args.hub_model_id: | |
| trainer.push_to_hub() | |
| if __name__ == "__main__": | |
| main() | |