yagnik12's picture
Update train.py
c847ada verified
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import os
# ✅ Dataset
dataset = load_dataset("HanxiGuo/BiScope_Data")
# ✅ Base model
BASE_MODEL = "distilbert-base-uncased"
MODEL_REPO = "yagnik12/AI_Text_Detecter_HanxiGuo_BiScope-Data"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
def tokenize(batch):
return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=256)
tokenized = dataset.map(tokenize, batched=True)
# ✅ Model
model = AutoModelForSequenceClassification.from_pretrained(BASE_MODEL, num_labels=2)
# ✅ Training setup
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
save_strategy="epoch",
num_train_epochs=1, # start small for demo
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
push_to_hub=True,
hub_model_id=MODEL_REPO,
hub_token=os.getenv("HF_TOKEN"),
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["test"],
tokenizer=tokenizer,
)
# ✅ Train & push
trainer.train()
trainer.push_to_hub()
print(f"✅ Model pushed to https://huggingface.co/{MODEL_REPO}")