hello / hello.py
chinmaygarde's picture
Update model.
9cacd9e unverified
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
TrainingArguments,
Trainer,
)
from datasets import load_dataset
# Load a small subset of IMDB reviews
dataset = load_dataset("stanfordnlp/imdb", split="train[:500]")
dataset = dataset.train_test_split(test_size=0.2, seed=42)
# Use DistilBERT — small, fast, good enough for a demo
model_name = "distilbert/distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
def tokenize(batch):
return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=128)
dataset = dataset.map(tokenize, batched=True)
trainer = Trainer(
model=model,
args=TrainingArguments(
output_dir="./training_output",
num_train_epochs=2,
per_device_train_batch_size=8,
logging_steps=25,
save_strategy="epoch",
),
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
# Train
trainer.train()
# Evaluate
results = trainer.evaluate()
print(f"Eval accuracy proxy (loss): {results['eval_loss']:.4f}")
# Save the model and tokenizer to the repo directory
trainer.save_model(".")
tokenizer.save_pretrained(".")
print("Done! Model and tokenizer saved to current directory.")