Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
|
@@ -1,3 +1,387 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
tags:
|
| 6 |
+
- text-classification
|
| 7 |
+
- token-classification
|
| 8 |
+
- multi-label-classification
|
| 9 |
+
- named-entity-recognition
|
| 10 |
+
- multitask
|
| 11 |
+
- finance
|
| 12 |
+
- energy
|
| 13 |
+
- news
|
| 14 |
+
- distilbert
|
| 15 |
+
- quantbridge
|
| 16 |
+
pipeline_tag: token-classification
|
| 17 |
+
base_model: QuantBridge/energy-intelligence-multitask-custom-ner
|
| 18 |
+
datasets:
|
| 19 |
+
- ag_news
|
| 20 |
+
- reuters-21578
|
| 21 |
+
- rmisra/news-category-dataset
|
| 22 |
+
---
|
| 23 |
+
|
| 24 |
+
# Energy Intelligence Multitask Model
|
| 25 |
+
|
| 26 |
+
**QuantBridge / energy-intelligence-multitask**
|
| 27 |
+
|
| 28 |
+
A single DistilBERT model with a **shared encoder** and **two task heads** for energy and financial news analysis. One forward pass returns both named entities **and** topic labels simultaneously.
|
| 29 |
+
|
| 30 |
+
| Head | Task | Output shape |
|
| 31 |
+
|------|------|-------------|
|
| 32 |
+
| **NER** | Named entity recognition (BIO) | `(batch, seq_len, 19)` |
|
| 33 |
+
| **CLS** | Multi-label topic classification | `(batch, 10)` |
|
| 34 |
+
|
| 35 |
+
---
|
| 36 |
+
|
| 37 |
+
## Architecture
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
Input headline
|
| 41 |
+
|
|
| 42 |
+
BertTokenizer (do_lower_case=True, max_length=128)
|
| 43 |
+
|
|
| 44 |
+
DistilBERT encoder (6 layers Β· 768 dim Β· 12 heads Β· ~67M params)
|
| 45 |
+
[weights from QuantBridge/energy-intelligence-multitask-custom-ner]
|
| 46 |
+
|
|
| 47 |
+
+βββββββββββββββββββββββββββββββββββββββββββ
|
| 48 |
+
| |
|
| 49 |
+
all token hidden states [CLS] hidden state
|
| 50 |
+
| |
|
| 51 |
+
Dropout(0.1) Linear(768β768) + ReLU
|
| 52 |
+
| Dropout(0.2)
|
| 53 |
+
Linear(768β19) Linear(768β10)
|
| 54 |
+
| |
|
| 55 |
+
NER logits CLS logits
|
| 56 |
+
argmax β BIO entity tags sigmoid β topic probabilities
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
---
|
| 60 |
+
|
| 61 |
+
## NER Label Space β 19 BIO tags
|
| 62 |
+
|
| 63 |
+
| Entity Type | Example extractions from test set |
|
| 64 |
+
|---|---|
|
| 65 |
+
| `COMPANY` | ExxonMobil, Gazprom, Maersk, Shell, Chevron, BP, Equinor |
|
| 66 |
+
| `ORGANIZATION` | OPEC+, US Treasury, Federal Reserve, IMF, FERC, IAEA |
|
| 67 |
+
| `COUNTRY` | Saudi Arabia, Russia, China, Iran, Venezuela, Germany |
|
| 68 |
+
| `COMMODITY` | crude oil, natural gas, LNG, methane, aluminum, hydrogen |
|
| 69 |
+
| `LOCATION` | Strait of Hormuz, Red Sea, Gulf of Mexico, North Sea, Kollsnes |
|
| 70 |
+
| `MARKET` | S&P 500, Brent, WTI |
|
| 71 |
+
| `EVENT` | Hurricane Ida, Houthi attacks |
|
| 72 |
+
| `PERSON` | Elon Musk, Jerome Powell |
|
| 73 |
+
| `INFRASTRUCTURE` | pipelines, refineries, terminals |
|
| 74 |
+
|
| 75 |
+
Each type uses standard BIO tagging: `B-<TYPE>` starts a span, `I-<TYPE>` continues it, `O` marks non-entities.
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
## Classification Label Space β 10 topic labels
|
| 80 |
+
|
| 81 |
+
| Label | Description | Avg score (test set) |
|
| 82 |
+
|-------|-------------|---------------------|
|
| 83 |
+
| `macro` | GDP, inflation, central bank policy | **0.323** |
|
| 84 |
+
| `politics` | Government policy, sanctions, diplomacy | **0.307** |
|
| 85 |
+
| `business` | Corporate earnings, M&A, operations | 0.219 |
|
| 86 |
+
| `technology` | Tech, innovation, clean-tech | 0.155 |
|
| 87 |
+
| `energy` | Oil, gas, renewables, power grid | 0.070 |
|
| 88 |
+
| `trade` | Tariffs, import/export, agreements | 0.046 |
|
| 89 |
+
| `shipping` | Maritime logistics, ports | 0.038 |
|
| 90 |
+
| `stocks` | Equity markets, share prices | 0.015 |
|
| 91 |
+
| `regulation` | Compliance, legislation, rules | 0.013 |
|
| 92 |
+
| `risk` | Crises, geopolitical tension | 0.013 |
|
| 93 |
+
|
| 94 |
+
> **Note on classification scores:** The classification head was trained on AG News + Reuters + Kaggle β datasets dominated by general `business` and `macro` content. Domain-specific labels (`energy`, `shipping`, `risk`, `regulation`, `stocks`) score lower as a result. The relative ranking of scores is semantically meaningful even when raw values are low. See [Limitations](#limitations).
|
| 95 |
+
|
| 96 |
+
---
|
| 97 |
+
|
| 98 |
+
## Test Results
|
| 99 |
+
|
| 100 |
+
Evaluated on **40 real-world energy & financial news headlines** across 9 domain groups (ENERGY, GEOPOLITICAL, SHIPPING, TRADE, MACRO, CORPORATE, REGULATION, TECHNOLOGY, STOCKS, RISK).
|
| 101 |
+
|
| 102 |
+
### NER Results
|
| 103 |
+
|
| 104 |
+
| Metric | Value |
|
| 105 |
+
|--------|-------|
|
| 106 |
+
| Total entities detected | **86** across 40 headlines |
|
| 107 |
+
| Average entities per headline | **2.1** |
|
| 108 |
+
| Entity types fired | **7 / 9** |
|
| 109 |
+
|
| 110 |
+
**Entity type frequency:**
|
| 111 |
+
|
| 112 |
+
| Entity Type | Detections | Example extractions |
|
| 113 |
+
|---|---|---|
|
| 114 |
+
| COMMODITY | 20 | oil production, crude, LNG, natural gas, aluminum, methane, hydrogen |
|
| 115 |
+
| COUNTRY | 19 | Saudi Arabia, Russia, China, Iran, Venezuela, Poland, Bulgaria, UK |
|
| 116 |
+
| ORGANIZATION | 15 | OPEC+, US Treasury, Federal Reserve, IMF, G7, FERC, IAEA, SEC |
|
| 117 |
+
| COMPANY | 15 | ExxonMobil, Gazprom, Maersk, Shell, Chevron, Equinor, BP, Tesla, Vestas |
|
| 118 |
+
| LOCATION | 14 | Kollsnes, Strait of Hormuz, Red Sea, Panama Canal, North Sea, Gulf of Mexico |
|
| 119 |
+
| EVENT | 2 | Hurricane Ida, Houthi (attacks) |
|
| 120 |
+
| MARKET | 1 | S&P 500 |
|
| 121 |
+
| PERSON | 0 | β (not fired on this test set) |
|
| 122 |
+
| INFRASTRUCTURE | 0 | β (not fired on this test set) |
|
| 123 |
+
|
| 124 |
+
**Key NER observations:**
|
| 125 |
+
- COMMODITY is the top entity type β the model reliably extracts energy goods (oil, crude, LNG, natural gas, hydrogen) and commodities (aluminum, solar panels)
|
| 126 |
+
- COUNTRY and ORGANIZATION fire consistently across all domain groups
|
| 127 |
+
- COMPANY detection is accurate: correctly identifies both energy majors (ExxonMobil, Shell, BP) and non-energy companies (Tesla, Maersk, Vestas)
|
| 128 |
+
- LOCATION captures geopolitically important hotspots correctly (Red Sea, Strait of Hormuz, Gulf of Mexico, North Sea)
|
| 129 |
+
- MARKET fires on "S&P 500" but misses "Brent" and "WTI" β likely a tokenisation artefact where these are split sub-words during BertTokenizer processing
|
| 130 |
+
- PERSON and INFRASTRUCTURE did not fire on this specific test set; these types are present in the model's label vocabulary and will activate on appropriate inputs
|
| 131 |
+
|
| 132 |
+
### Classification Results (threshold = 0.20)
|
| 133 |
+
|
| 134 |
+
**Label activation frequency across 40 headlines:**
|
| 135 |
+
|
| 136 |
+
| Label | Active headlines | % | Avg score |
|
| 137 |
+
|-------|-----------------|---|-----------|
|
| 138 |
+
| macro | 14 / 40 | 35% | 0.323 |
|
| 139 |
+
| politics | 9 / 40 | 22% | 0.307 |
|
| 140 |
+
| business | 1 / 40 | 2% | 0.219 |
|
| 141 |
+
| technology | 0 / 40 | 0% | 0.155 |
|
| 142 |
+
| energy | 0 / 40 | 0% | 0.070 |
|
| 143 |
+
| trade | 0 / 40 | 0% | 0.046 |
|
| 144 |
+
| shipping | 0 / 40 | 0% | 0.038 |
|
| 145 |
+
| stocks | 0 / 40 | 0% | 0.015 |
|
| 146 |
+
| regulation | 0 / 40 | 0% | 0.013 |
|
| 147 |
+
| risk | 0 / 40 | 0% | 0.013 |
|
| 148 |
+
|
| 149 |
+
**Domain-group heatmap** (`>>>` = group average score β₯ 0.35):
|
| 150 |
+
|
| 151 |
+
| Domain group | energy | politics | trade | stocks | regulation | shipping | macro | business | technology | risk |
|
| 152 |
+
|---|---|---|---|---|---|---|---|---|---|---|
|
| 153 |
+
| ENERGY | 0.09 | 0.28 | 0.07 | 0.02 | 0.02 | 0.05 | **>>>** | 0.26 | 0.17 | 0.02 |
|
| 154 |
+
| GEOPOLITICAL | 0.06 | 0.30 | 0.04 | 0.01 | 0.01 | 0.03 | 0.30 | 0.19 | 0.12 | 0.01 |
|
| 155 |
+
| SHIPPING | 0.07 | 0.31 | 0.04 | 0.01 | 0.01 | 0.05 | 0.28 | 0.23 | 0.14 | 0.01 |
|
| 156 |
+
| TRADE | 0.06 | 0.30 | 0.05 | 0.01 | 0.01 | 0.04 | 0.31 | 0.23 | 0.17 | 0.01 |
|
| 157 |
+
| MACRO | 0.07 | 0.30 | 0.06 | 0.02 | 0.02 | 0.04 | **>>>** | 0.22 | 0.17 | 0.02 |
|
| 158 |
+
| CORPORATE | 0.09 | 0.33 | 0.05 | 0.02 | 0.02 | 0.04 | **>>>** | 0.23 | 0.16 | 0.02 |
|
| 159 |
+
| REGULATION | 0.04 | 0.26 | 0.03 | 0.01 | 0.01 | 0.02 | 0.32 | 0.26 | 0.18 | 0.01 |
|
| 160 |
+
| TECHNOLOGY | 0.07 | **>>>** | 0.04 | 0.01 | 0.01 | 0.04 | 0.28 | 0.17 | 0.14 | 0.01 |
|
| 161 |
+
| STOCKS | 0.08 | 0.30 | 0.04 | 0.01 | 0.01 | 0.03 | 0.32 | 0.21 | 0.16 | 0.01 |
|
| 162 |
+
| RISK | 0.08 | 0.32 | 0.04 | 0.01 | 0.01 | 0.04 | 0.34 | 0.19 | 0.14 | 0.01 |
|
| 163 |
+
|
| 164 |
+
**Key classification observations:**
|
| 165 |
+
- `macro` is the dominant label across all domain groups β a direct consequence of training data composition (AG News World category and Kaggle both map heavily to `macro`)
|
| 166 |
+
- `politics` fires on TECHNOLOGY and GEOPOLITICAL groups, which is semantically reasonable (government energy policy, sanctions)
|
| 167 |
+
- Domain-specific labels (`energy`, `shipping`, `risk`, `regulation`, `stocks`) score consistently low β these categories are underrepresented in training data
|
| 168 |
+
- The score **ranking** is meaningful: for ENERGY headlines, `energy` consistently ranks above `trade`, `shipping`, and `regulation` even when below threshold β the model has learned the correct relative associations
|
| 169 |
+
|
| 170 |
+
---
|
| 171 |
+
|
| 172 |
+
## Usage
|
| 173 |
+
|
| 174 |
+
> **Important:** This model uses custom architecture files. Always pass `trust_remote_code=True`.
|
| 175 |
+
|
| 176 |
+
### Installation
|
| 177 |
+
|
| 178 |
+
```bash
|
| 179 |
+
pip install transformers torch
|
| 180 |
+
```
|
| 181 |
+
|
| 182 |
+
### Full inference β NER + Classification
|
| 183 |
+
|
| 184 |
+
```python
|
| 185 |
+
import torch
|
| 186 |
+
import numpy as np
|
| 187 |
+
from transformers import AutoTokenizer, AutoConfig, AutoModel
|
| 188 |
+
|
| 189 |
+
MODEL_ID = "QuantBridge/energy-intelligence-multitask"
|
| 190 |
+
|
| 191 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
| 192 |
+
model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True)
|
| 193 |
+
model.eval()
|
| 194 |
+
|
| 195 |
+
def sigmoid(x):
|
| 196 |
+
return 1 / (1 + np.exp(-x))
|
| 197 |
+
|
| 198 |
+
def predict(text: str, cls_threshold: float = 0.20):
|
| 199 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
|
| 200 |
+
inputs.pop("token_type_ids", None) # DistilBERT does not use these
|
| 201 |
+
|
| 202 |
+
with torch.no_grad():
|
| 203 |
+
output = model(**inputs)
|
| 204 |
+
|
| 205 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
| 206 |
+
|
| 207 |
+
# ββ Named Entity Recognition ββββββββββββββββββββββββββββββββββββββββββ
|
| 208 |
+
ner_id2label = {int(k): v for k, v in model.config.ner_id2label.items()}
|
| 209 |
+
tag_ids = output.ner_logits[0].argmax(-1).tolist()
|
| 210 |
+
|
| 211 |
+
entities = []
|
| 212 |
+
current = None
|
| 213 |
+
for token, tag_id in zip(tokens, tag_ids):
|
| 214 |
+
if token in ("[CLS]", "[SEP]", "[PAD]"):
|
| 215 |
+
if current: entities.append(current); current = None
|
| 216 |
+
continue
|
| 217 |
+
tag = ner_id2label[tag_id]
|
| 218 |
+
if tag.startswith("B-"):
|
| 219 |
+
if current: entities.append(current)
|
| 220 |
+
current = {"text": token.replace("##", ""), "type": tag[2:]}
|
| 221 |
+
elif tag.startswith("I-") and current:
|
| 222 |
+
current["text"] += token[2:] if token.startswith("##") else f" {token}"
|
| 223 |
+
else:
|
| 224 |
+
if current: entities.append(current); current = None
|
| 225 |
+
if current: entities.append(current)
|
| 226 |
+
|
| 227 |
+
# ββ Topic Classification ββββββββββββββββββοΏ½οΏ½βββββββββββββββββββββββββββ
|
| 228 |
+
cls_id2label = {int(k): v for k, v in model.config.cls_id2label.items()}
|
| 229 |
+
probs = sigmoid(output.cls_logits[0].numpy())
|
| 230 |
+
topics = {cls_id2label[i]: float(probs[i]) for i in range(len(probs))}
|
| 231 |
+
active = {lbl: p for lbl, p in topics.items() if p >= cls_threshold}
|
| 232 |
+
|
| 233 |
+
return entities, active
|
| 234 |
+
|
| 235 |
+
# Example
|
| 236 |
+
headline = "Russia cuts natural gas flows to Poland and Bulgaria following payment dispute"
|
| 237 |
+
entities, topics = predict(headline)
|
| 238 |
+
|
| 239 |
+
print("Entities found:")
|
| 240 |
+
for e in entities:
|
| 241 |
+
print(f" [{e['type']}] {e['text']}")
|
| 242 |
+
|
| 243 |
+
print("\nActive topic labels:")
|
| 244 |
+
for topic, score in sorted(topics.items(), key=lambda x: -x[1]):
|
| 245 |
+
print(f" {topic}: {score:.3f}")
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
**Expected output:**
|
| 249 |
+
```
|
| 250 |
+
Entities found:
|
| 251 |
+
[COUNTRY] Russia
|
| 252 |
+
[COUNTRY] Poland
|
| 253 |
+
[COUNTRY] Bulgaria
|
| 254 |
+
[COMMODITY] natural gas
|
| 255 |
+
|
| 256 |
+
Active topic labels:
|
| 257 |
+
politics: 0.362
|
| 258 |
+
macro: 0.357
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
### NER only β decode all entity spans
|
| 262 |
+
|
| 263 |
+
```python
|
| 264 |
+
def get_entities(text: str) -> list[dict]:
|
| 265 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
|
| 266 |
+
inputs.pop("token_type_ids", None)
|
| 267 |
+
with torch.no_grad():
|
| 268 |
+
output = model(**inputs)
|
| 269 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
| 270 |
+
ner_id2label = {int(k): v for k, v in model.config.ner_id2label.items()}
|
| 271 |
+
tag_ids = output.ner_logits[0].argmax(-1).tolist()
|
| 272 |
+
# ... (decode as shown above)
|
| 273 |
+
```
|
| 274 |
+
|
| 275 |
+
### Classification only β get all label scores
|
| 276 |
+
|
| 277 |
+
```python
|
| 278 |
+
def get_topic_scores(text: str) -> dict[str, float]:
|
| 279 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
|
| 280 |
+
inputs.pop("token_type_ids", None)
|
| 281 |
+
with torch.no_grad():
|
| 282 |
+
output = model(**inputs)
|
| 283 |
+
cls_id2label = {int(k): v for k, v in model.config.cls_id2label.items()}
|
| 284 |
+
probs = sigmoid(output.cls_logits[0].numpy())
|
| 285 |
+
return {cls_id2label[i]: float(probs[i]) for i in range(len(probs))}
|
| 286 |
+
```
|
| 287 |
+
|
| 288 |
+
---
|
| 289 |
+
|
| 290 |
+
## Training Details
|
| 291 |
+
|
| 292 |
+
### Encoder
|
| 293 |
+
|
| 294 |
+
Transferred from [`QuantBridge/energy-intelligence-multitask-custom-ner`](https://huggingface.co/QuantBridge/energy-intelligence-multitask-custom-ner) β DistilBERT fine-tuned on energy and financial news for BIO entity recognition.
|
| 295 |
+
|
| 296 |
+
### NER Head
|
| 297 |
+
|
| 298 |
+
Weights transferred directly from the NER backbone (`classifier.*` β `ner_classifier.*`). No additional NER training was performed.
|
| 299 |
+
|
| 300 |
+
### Classification Head
|
| 301 |
+
|
| 302 |
+
Trained separately from scratch on a merged corpus:
|
| 303 |
+
|
| 304 |
+
| Source | HF / NLTK id | Categories used | Mapped to |
|
| 305 |
+
|--------|-------------|-----------------|-----------|
|
| 306 |
+
| AG News | `ag_news` | World (0), Business (2), Sci/Tech (3) | `macro`, `business`, `technology` |
|
| 307 |
+
| Reuters-21578 | `nltk.corpus.reuters` | crude, gas, ship, trade, money-fx, interest, earn, acq | `energy`, `shipping`, `trade`, `macro`, `business` |
|
| 308 |
+
| Kaggle News Category | `rmisra/news-category-dataset` | POLITICS, BUSINESS, TECH, WORLD NEWS | `politics`, `business`, `technology`, `macro` |
|
| 309 |
+
|
| 310 |
+
Training split: 80% train / 10% validation / 10% test, seed 42.
|
| 311 |
+
|
| 312 |
+
**Hyperparameters:**
|
| 313 |
+
|
| 314 |
+
| Parameter | Value |
|
| 315 |
+
|---|---|
|
| 316 |
+
| Epochs | 10 |
|
| 317 |
+
| Train batch size | 32 |
|
| 318 |
+
| Learning rate | 2e-5 |
|
| 319 |
+
| Warmup steps | 500 |
|
| 320 |
+
| Weight decay | 0.01 |
|
| 321 |
+
| Max sequence length | 128 tokens |
|
| 322 |
+
| Loss | BCEWithLogitsLoss |
|
| 323 |
+
| Best checkpoint selected by | micro-F1 on validation set |
|
| 324 |
+
| Hardware | NVIDIA T4 16 GB |
|
| 325 |
+
|
| 326 |
+
---
|
| 327 |
+
|
| 328 |
+
## Model Files
|
| 329 |
+
|
| 330 |
+
```
|
| 331 |
+
energy-intelligence-multitask/
|
| 332 |
+
configuration_energy_multitask.py # EnergyMultitaskConfig (DistilBertConfig subclass)
|
| 333 |
+
modeling_energy_multitask.py # EnergyMultitaskModel (two-head architecture)
|
| 334 |
+
config.json # Serialised config with auto_map
|
| 335 |
+
model.safetensors # Combined weights (~256 MB)
|
| 336 |
+
tokenizer.json # Fast tokenizer
|
| 337 |
+
tokenizer_config.json # Tokenizer settings
|
| 338 |
+
```
|
| 339 |
+
|
| 340 |
+
---
|
| 341 |
+
|
| 342 |
+
## Limitations
|
| 343 |
+
|
| 344 |
+
- **English only** β trained exclusively on English-language news text.
|
| 345 |
+
- **Classification data bias** β training corpora (AG News, Kaggle) are dominated by `business` and `macro` content. Domain-specific labels (`energy`, `shipping`, `risk`, `regulation`, `stocks`) score lower across the board and may not cross common thresholds even when semantically correct. A recommended threshold for this model is **0.20** rather than the default 0.50.
|
| 346 |
+
- **NER on headlines** β the NER head was fine-tuned on short news headlines; performance may be lower on long-form documents.
|
| 347 |
+
- **Max length** β inputs are truncated to 128 tokens. Longer texts should be chunked.
|
| 348 |
+
- **PERSON / INFRASTRUCTURE** β these entity types exist in the label vocabulary but fired less frequently on financial news headlines compared to COMPANY, COUNTRY, and COMMODITY.
|
| 349 |
+
- **Not for trading** β this model is intended as an intelligence tagging layer, not for real-time trading or financial decision-making.
|
| 350 |
+
|
| 351 |
+
---
|
| 352 |
+
|
| 353 |
+
## Intended Use
|
| 354 |
+
|
| 355 |
+
This model is the **tagging layer** in an energy intelligence pipeline:
|
| 356 |
+
|
| 357 |
+
```
|
| 358 |
+
Raw news headline
|
| 359 |
+
β
|
| 360 |
+
EnergyMultitaskModel (this model)
|
| 361 |
+
β
|
| 362 |
+
entities βββββββββββββββββββ who / what / where
|
| 363 |
+
topic labels βββββββββββββββ energy / risk / trade / macro / ...
|
| 364 |
+
β
|
| 365 |
+
Structured intelligence signal for downstream analysis
|
| 366 |
+
```
|
| 367 |
+
|
| 368 |
+
---
|
| 369 |
+
|
| 370 |
+
## Related Models
|
| 371 |
+
|
| 372 |
+
- [`QuantBridge/energy-intelligence-multitask-custom-ner`](https://huggingface.co/QuantBridge/energy-intelligence-multitask-custom-ner) β NER backbone (encoder source)
|
| 373 |
+
- [`QuantBridge/energy-news-classifier-ner-multitask`](https://huggingface.co/QuantBridge/energy-news-classifier-ner-multitask) β Classification-only model (single head)
|
| 374 |
+
|
| 375 |
+
---
|
| 376 |
+
|
| 377 |
+
## Citation
|
| 378 |
+
|
| 379 |
+
```bibtex
|
| 380 |
+
@misc{quantbridge2025multitask,
|
| 381 |
+
author = {QuantBridge},
|
| 382 |
+
title = {Energy Intelligence Multitask Model (NER + Classification)},
|
| 383 |
+
year = {2025},
|
| 384 |
+
publisher = {Hugging Face},
|
| 385 |
+
howpublished = {\url{https://huggingface.co/QuantBridge/energy-intelligence-multitask}},
|
| 386 |
+
}
|
| 387 |
+
```
|