vic35get/nhtsa_complaints_dataset
Viewer • Updated • 12.5k • 7
How to use vic35get/nhtsa_complaints_classifier with Transformers:
# Use a pipeline as a high-level helper
from transformers import pipeline
pipe = pipeline("text-classification", model="vic35get/nhtsa_complaints_classifier") # Load model directly
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("vic35get/nhtsa_complaints_classifier")
model = AutoModelForSequenceClassification.from_pretrained("vic35get/nhtsa_complaints_classifier")Este modelo foi treinado para classificar reclamações de veículos registradas no banco de dados da NHTSA (National Highway Traffic Safety Administration) entre 2014 e 2024. Ele classifica textos em cinco categorias de componentes veiculares:
Os dados foram extraídos da API oficial da NHTSA e passaram por um pipeline de processamento de linguagem natural (NLP), incluindo:
bert-base-uncased para transformar o texto em tensores compatíveis com o modelo.📊 Divisão dos Dados:
| Conjunto | Amostras |
|---|---|
| Treinamento | 8.357 |
| Validação | 2.090 |
| Teste | 2.090 |
| Parâmetro | Valor |
|---|---|
| Modelo base | bert-base-uncased |
| Batch size | 4 |
| Taxa de aprendizado | 1e-5 |
| Épocas | 30 (com early stopping de 3 épocas sem melhora) |
| Otimizador | AdamW |
| Métrica | Valor |
|---|---|
| Acurácia | 86.40% |
| F1-Score | 85.78% |
| Precisão | 85.96% |
| Recall | 86.40% |
| Métrica | Valor |
|---|---|
| Acurácia | 69.94% |
| F1-Score | 75.69% |
| Precisão | 87.96% |
| Recall | 69.94% |
A diferença de desempenho entre os conjuntos de validação e teste pode ser explicada pelo desbalanceamento e pela natureza ampla da classe OTHER, que agrupa diferentes tipos de reclamações.
Para carregar e utilizar o modelo:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Carregar modelo e tokenizer
model_name = "vic35get/nhtsa_complaints_classifier"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
# Função de inferência
def predict(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
with torch.no_grad():
outputs = model(**inputs)
return torch.argmax(outputs.logits, dim=1).item()
# Exemplo de uso
text = "The airbag did not deploy during the accident."
print(predict(text))
Base model
google-bert/bert-base-uncased