Johnade/consumer_complaints_cfpb
Viewer • Updated • 153k • 52 • 1
A Bidirectional LSTM model trained to classify consumer financial complaints into 36 sub-issue categories.
| Parameter | Value |
|---|---|
| Architecture | BiLSTM (2 layers, 256 hidden, bidirectional) |
| Embedding dim | 128 |
| Vocab size | 30,522 (bert-base-uncased tokenizer) |
| Max sequence length | 256 |
| Total parameters | 6,292,772 |
| Dropout | 0.3 |
| Metric | Value |
|---|---|
| Val Accuracy | 0.1700 |
| Val F1 (weighted) | 0.1370 |
| Best Epoch | 1 |
import torch
from huggingface_hub import PyTorchModelHubMixin
import torch.nn as nn
from transformers import AutoTokenizer
import json
# Define the model class (needed for from_pretrained)
class BiLSTMClassifier(nn.Module, PyTorchModelHubMixin):
def __init__(self, vocab_size=30522, embed_dim=128, hidden_dim=256,
num_layers=2, num_classes=36, dropout=0.3, pad_idx=0):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
self.bilstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True,
bidirectional=True, dropout=dropout if num_layers > 1 else 0.0)
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, input_ids):
x = self.dropout(self.embedding(input_ids))
_, (hidden, _) = self.bilstm(x)
hidden_cat = torch.cat([hidden[-2], hidden[-1]], dim=1)
return self.classifier(self.dropout(hidden_cat))
# Load model and tokenizer
model = BiLSTMClassifier.from_pretrained("akashsukhija/complaint-bilstm-classifier")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
# Load label mapping
from huggingface_hub import hf_hub_download
label_map = json.load(open(hf_hub_download("akashsukhija/complaint-bilstm-classifier", "label_mapping.json")))
# Classify a complaint
text = "My credit report shows an account that isn't mine. I've disputed it twice."
inputs = tokenizer(text, padding="max_length", truncation=True, max_length=256, return_tensors="pt")
with torch.no_grad():
logits = model(inputs["input_ids"])
pred = logits.argmax(dim=1).item()
print(f"Predicted: {label_map['id2label'][str(pred)]}")