fancyzhx/ag_news
Viewer • Updated • 128k • 129k • 190
Fine-tuned DistilBERT model for text classification.
ag_news)distilbert-base-uncasedag_news (quick test run for pipeline validation)mlops-assignment2)| Metric | Score |
|---|---|
| Accuracy | 0.87145 |
| F1 (weighted) | 0.86951 |
| Eval Loss | 0.46378 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
model_name = "YOUR_USERNAME/distilbert-agnews-smoke"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
text = "Stock markets rose today after strong earnings reports."
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
with torch.no_grad():
logits = model(**inputs).logits
pred_id = int(torch.argmax(logits, dim=-1))
print("Predicted class id:", pred_id)