File size: 4,208 Bytes
20f3412 60686b2 20f3412 60686b2 20f3412 60686b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
---
language: en
license: apache-2.0
tags:
- text-classification
- multi-label-classification
- tinybert
- pytorch
datasets:
- JayShah07/multi_label_reporting
metrics:
- accuracy
- f1
widget:
- text: "Show me my current holdings"
- text: "What are my capital gains for this year?"
- text: "Give me monthly scheme-wise returns"
---
# TinyBERT Dual Classifier for Investment Reporting
This model is a fine-tuned TinyBERT with two classification heads for multi-label classification of investment reporting queries.
## Model Description
- **Base Model**: TinyBERT (huawei-noah/TinyBERT_General_4L_312D)
- **Parameters**: ~14-15M
- **Architecture**: Single encoder with two independent classification heads
- **Task**: Multi-label classification (Module + Date)
## Labels
**Module Labels (6 classes)**:
- holdings
- capital_gains
- scheme_wise_returns
- investment_account_wise_returns
- portfolio_update
- None_module
**Date Labels (7 classes)**:
- Current Year
- Previous Year
- Daily
- Monthly
- Weekly
- Yearly
- None_date
## Performance
**Test Set Results**:
- Module Accuracy: 1.0000
- Module F1 Score: 1.0000
- Date Accuracy: 1.0000
- Date F1 Score: 1.0000
## Usage
```python
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn as nn
# Define model class
class TinyBERTDualClassifier(nn.Module):
def __init__(self, num_module_labels, num_date_labels, dropout_rate=0.1):
super(TinyBERTDualClassifier, self).__init__()
self.encoder = AutoModel.from_pretrained("JayShah07/tinybert-dual-classifier")
self.hidden_size = self.encoder.config.hidden_size
self.dropout = nn.Dropout(p=dropout_rate)
self.module_classifier = nn.Linear(self.hidden_size, num_module_labels)
self.date_classifier = nn.Linear(self.hidden_size, num_date_labels)
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :]
cls_output = self.dropout(cls_output)
module_logits = self.module_classifier(cls_output)
date_logits = self.date_classifier(cls_output)
return module_logits, date_logits
# Load model
classifier_config = torch.hub.load_state_dict_from_url(
f"https://huggingface.co/JayShah07/tinybert-dual-classifier/resolve/main/classifier_heads.pt"
)
model = TinyBERTDualClassifier(
num_module_labels=6,
num_date_labels=7
)
model.module_classifier.load_state_dict(classifier_config['module_classifier'])
model.date_classifier.load_state_dict(classifier_config['date_classifier'])
tokenizer = AutoTokenizer.from_pretrained("JayShah07/tinybert-dual-classifier")
# Inference
model.eval()
text = "Show my holdings for this month"
inputs = tokenizer(text, return_tensors='pt', padding='max_length',
truncation=True, max_length=128)
with torch.no_grad():
module_logits, date_logits = model(inputs['input_ids'], inputs['attention_mask'])
module_pred = torch.argmax(module_logits, dim=1).item()
date_pred = torch.argmax(date_logits, dim=1).item()
module_labels = ['holdings', 'capital_gains', 'scheme_wise_returns', 'investment_account_wise_returns', 'portfolio_update', 'None_module']
date_labels = ['Current Year', 'Previous Year', 'Daily', 'Monthly', 'Weekly', 'Yearly', 'None_date']
print(f"Module: {module_labels[module_pred]}")
print(f"Date: {date_labels[date_pred]}")
```
## Training Details
- **Dataset**: JayShah07/multi_label_reporting
- **Training Samples**: 3097
- **Validation Samples**: 387
- **Test Samples**: 388
- **Epochs**: 10
- **Batch Size**: 16
- **Learning Rate**: 2e-05
- **Optimizer**: AdamW
- **Loss Function**: CrossEntropyLoss (separate for each head)
## Latency
Average inference latency on sample queries (mean ± std):
- See notebook for detailed latency analysis
## Citation
If you use this model, please cite:
```bibtex
@misc{tinybert-dual-classifier,
author = {Jay Shah},
title = {TinyBERT Dual Classifier for Investment Reporting},
year = {2025},
publisher = {Hugging Face},
howpublished = {\url{https://huggingface.co/JayShah07/tinybert-dual-classifier}}
}
```
## License
Apache 2.0
|