|
|
--- |
|
|
language: en |
|
|
tags: |
|
|
- roberta |
|
|
- multilabel-classification |
|
|
- policy-analysis |
|
|
- huggingface |
|
|
datasets: |
|
|
- custom |
|
|
license: apache-2.0 |
|
|
--- |
|
|
|
|
|
# RoBERTa for Multi-label Classification of Policy Instruments |
|
|
|
|
|
This model fine-tunes `roberta-base` for multilabel classification of policies, targets, and themes. |
|
|
|
|
|
## Model Details |
|
|
- Base model: roberta-base |
|
|
- Max length: 512 |
|
|
- Output: 67 multilabel classes (PI - Policy Instrument, TG - Target Group, TH - Theme). There are three main classes that have further sub-categories in them. |
|
|
- Threshold: 0.25 |
|
|
|
|
|
## Intended Use |
|
|
Classify policy documents descriptions into thematic categories. |
|
|
|
|
|
## How to Use |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
import numpy as np |
|
|
import joblib |
|
|
import requests |
|
|
|
|
|
model_path = "toqeerehsan/multilabel-indicator-classification" |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_path) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
|
|
|
mlb_url = "https://huggingface.co/toqeerehsan/multilabel-indicator-classification/resolve/main/mlb.pkl" |
|
|
mlb_path = "mlb.pkl" |
|
|
|
|
|
with open(mlb_path, "wb") as f: |
|
|
f.write(requests.get(mlb_url).content) |
|
|
mlb = joblib.load(mlb_path) |
|
|
|
|
|
text = "This program supports clean technology and sustainable development in industries." |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
probs = torch.sigmoid(logits).squeeze().numpy() |
|
|
|
|
|
# Threshold |
|
|
binary_preds = (probs > 0.25).astype(int) |
|
|
predicted_labels = [label for i, label in enumerate(mlb.classes_) if binary_preds[i] == 1] |
|
|
|
|
|
print("Predicted Labels:", predicted_labels) |
|
|
|
|
|
# Predicted Labels: ['PI007', 'PI008', 'TG20', 'TG21', 'TG22', 'TG25', 'TG29', 'TG31', 'TH31'] |
|
|
|