|
|
--- |
|
|
language: en |
|
|
tags: |
|
|
- Text Classification |
|
|
- TDAMM |
|
|
- Multi-label Classification |
|
|
- NASA |
|
|
- Astrophysics |
|
|
base_model: |
|
|
- adsabs/astroBERT |
|
|
library_name: transformers |
|
|
license: apache-2.0 |
|
|
--- |
|
|
|
|
|
# TDAMM Multi-Label Classification Model |
|
|
|
|
|
The TDAMM (Time Domain Multi-Messenger Astronomy) model is created to categorize NASA’s time domain multi-messenger resources into one or more of 36 distinct categories identified by subject matter experts (SMEs) |
|
|
|
|
|
## Model Description |
|
|
|
|
|
- **Base Model:** astroBERT, fine-tuned for multi-label classification |
|
|
- **Task:** Multi-label classification |
|
|
- **Training Data:** A collection of 408 NASA and non-NASA documents related to TDAMM topics identified by SMEs |
|
|
|
|
|
## Data Distribution |
|
|
|
|
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/67804a0abd67e99d000342e1/oOZ3PhRsh6TDEfaSTTpxa.png" width="70%" alt="Distribution 1"> |
|
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/67804a0abd67e99d000342e1/kKpL5XWCtgWiXHLAAmGz5.png" width="70%" alt="Distribution 2"> |
|
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/67804a0abd67e99d000342e1/hJQt5iBKYsVPSHQLIH2RG.png" width="50%" alt="Distribution 3"> |
|
|
|
|
|
## Performance Analysis |
|
|
|
|
|
<img src="https://cdn-uploads.huggingface.co/production/uploads/67804a0abd67e99d000342e1/aX8X-b7dehTwaA-opBulN.png" width="70%" alt="Threshold 1"> |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("nasa-impact/tdamm-classification") |
|
|
model = AutoModelForSequenceClassification.from_pretrained("nasa-impact/tdamm-classification") |
|
|
|
|
|
# Prepare input |
|
|
text = "Your astronomical test text here" |
|
|
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
|
|
|
|
# Get predictions |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
predictions = torch.sigmoid(outputs.logits) |
|
|
|
|
|
# Convert to binary predictions (threshold = 0.5) |
|
|
predictions = (predictions > 0.5).int() |
|
|
``` |
|
|
|
|
|
## Label Mapping During Inference |
|
|
|
|
|
After obtaining predictions from the model, we can map the predicted label indices to their actual names using the `model.config.id2label` dictionary |
|
|
|
|
|
```python |
|
|
# Example usage |
|
|
predicted_indices = [0, 2, 5] |
|
|
predicted_labels = [model.config.id2label[idx] for idx in predicted_indices] |
|
|
print(predicted_labels) |
|
|
``` |