Deploy FastAPI NER-ADREP classifier with fixed requirements
Browse files- Dockerfile +13 -0
- README.md +51 -5
- app.py +361 -0
- requirements_api.txt +8 -0
Dockerfile
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install only what the inference API needs
|
| 6 |
+
COPY requirements_api.txt .
|
| 7 |
+
RUN pip install --no-cache-dir -r requirements_api.txt
|
| 8 |
+
|
| 9 |
+
COPY app.py .
|
| 10 |
+
|
| 11 |
+
EXPOSE 7860
|
| 12 |
+
|
| 13 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,10 +1,56 @@
|
|
| 1 |
---
|
| 2 |
title: Aviation Report Classification
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
-
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: Aviation Report Classification
|
| 3 |
+
emoji: ✈️
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
short_description: ADREP classification via NER (SafeAeroBERT)
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# Aviation ADREP Classification API
|
| 12 |
+
|
| 13 |
+
FastAPI endpoint that classifies aviation incident narratives into ICAO ADREP occurrence categories using Named Entity Recognition with NASA SafeAeroBERT.
|
| 14 |
+
|
| 15 |
+
## Pipeline
|
| 16 |
+
|
| 17 |
+
```
|
| 18 |
+
POST /predict
|
| 19 |
+
└── NER inference (theophilusowiti/asn-ner-aerobert)
|
| 20 |
+
└── Multi-word entity merging (B-/I- subword tokens)
|
| 21 |
+
└── ADREP keyword scoring (weighted by entity role)
|
| 22 |
+
└── Confidence normalisation → top-5 response
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## Usage
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
curl -X POST https://rdelyon-aviation-report-classification.hf.space/predict \
|
| 29 |
+
-H "Content-Type: application/json" \
|
| 30 |
+
-d '{"narrative": "The aircraft experienced severe turbulence leading to cabin crew injury."}'
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
## Response Schema
|
| 34 |
+
|
| 35 |
+
```json
|
| 36 |
+
{
|
| 37 |
+
"model_id": "theophilusowiti/asn-ner-aerobert",
|
| 38 |
+
"display_name": "SafeAeroBERT NER + ADREP Classifier",
|
| 39 |
+
"prediction": {
|
| 40 |
+
"top_class": "TURB",
|
| 41 |
+
"confidence": 0.612,
|
| 42 |
+
"top_5": [
|
| 43 |
+
{"class": "TURB", "confidence": 0.612},
|
| 44 |
+
{"class": "CABIN", "confidence": 0.183},
|
| 45 |
+
{"class": "LOC-I", "confidence": 0.091},
|
| 46 |
+
{"class": "WSTRW", "confidence": 0.061},
|
| 47 |
+
{"class": "OTHR", "confidence": 0.011}
|
| 48 |
+
]
|
| 49 |
+
},
|
| 50 |
+
"inference_time_ms": 312
|
| 51 |
+
}
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Part of CMU 18-786 Deep Learning Group Project
|
| 55 |
+
|
| 56 |
+
Model fine-tuned on aviation safety reports for NER-based ADREP classification.
|
app.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FastAPI endpoint for Aviation ADREP Classification
|
| 2 |
+
# Pipeline: narrative -> NER (SafeAeroBERT) -> entity extraction -> ADREP scoring -> API response
|
| 3 |
+
|
| 4 |
+
from fastapi import FastAPI
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
| 8 |
+
import time
|
| 9 |
+
from collections import defaultdict
|
| 10 |
+
|
| 11 |
+
app = FastAPI()
|
| 12 |
+
|
| 13 |
+
MODEL_ID = "theophilusowiti/asn-ner-aerobert"
|
| 14 |
+
MODEL_DISPLAY_NAME = "SafeAeroBERT NER + ADREP Classifier"
|
| 15 |
+
|
| 16 |
+
# Keywords matched by substring against extracted entity text.
|
| 17 |
+
# TRIGGER entities are weighted 3x, so list the most discriminative TRIGGER phrases first.
|
| 18 |
+
ADREP_KEYWORDS: dict = {
|
| 19 |
+
# Codes correspond to ECCAIRS Aviation 7.1.0.0 Attribute 430 "Occurrence category"
|
| 20 |
+
"AMAN": [
|
| 21 |
+
"abrupt manoeuvre", "abrupt maneuver", "evasive manoeuvre", "evasive maneuver",
|
| 22 |
+
"sudden pull", "sudden push", "rapid pitch", "avoidance manoeuvre",
|
| 23 |
+
],
|
| 24 |
+
"ARC": [
|
| 25 |
+
"hard landing", "tail strike", "tailstrike", "nose gear collapse",
|
| 26 |
+
"gear up landing", "bounced", "firm touchdown", "rough landing",
|
| 27 |
+
],
|
| 28 |
+
"ADRM": [
|
| 29 |
+
"aerodrome", "airport obstacle", "runway light", "taxiway sign",
|
| 30 |
+
"apron collision", "airport infrastructure",
|
| 31 |
+
],
|
| 32 |
+
"ATM": [
|
| 33 |
+
"atc error", "air traffic control", "clearance error", "separation instruction",
|
| 34 |
+
"atc instruction", "controller error", "atm failure", "cns failure",
|
| 35 |
+
],
|
| 36 |
+
"BIRD": [
|
| 37 |
+
"bird strike", "bird ingestion", "birdstrike", "bird impact",
|
| 38 |
+
"avian strike", "bird hit",
|
| 39 |
+
],
|
| 40 |
+
"CABIN": [
|
| 41 |
+
"cabin injury", "passenger injury", "turbulence injury", "unsecured",
|
| 42 |
+
"galley fire", "cabin depressurisation", "cabin pressurisation",
|
| 43 |
+
"cabin crew injury", "oxygen mask deployed",
|
| 44 |
+
],
|
| 45 |
+
"CFIT": [
|
| 46 |
+
"controlled flight into terrain", "struck trees", "struck high ground",
|
| 47 |
+
"hit terrain", "terrain impact", "ground impact", "terrain", "mountain",
|
| 48 |
+
"hill", "tree",
|
| 49 |
+
],
|
| 50 |
+
"CTOL": [
|
| 51 |
+
"collision during takeoff", "collision during landing", "obstacle takeoff",
|
| 52 |
+
"obstacle landing", "struck obstacle", "hit obstacle on takeoff",
|
| 53 |
+
"hit obstacle on landing",
|
| 54 |
+
],
|
| 55 |
+
"EVAC": [
|
| 56 |
+
"evacuation", "emergency evacuation", "evacuation slide", "passengers evacuated",
|
| 57 |
+
"rapid disembarkation", "cabin evacuation",
|
| 58 |
+
],
|
| 59 |
+
"EXTL": [
|
| 60 |
+
"external load", "slung load", "underslung", "cargo net", "longline",
|
| 61 |
+
"load release", "load shift",
|
| 62 |
+
],
|
| 63 |
+
"F-NI": [
|
| 64 |
+
"in-flight fire", "engine fire", "electrical fire", "cargo fire",
|
| 65 |
+
"smoke in cockpit", "smoke in cabin", "fumes", "fire warning",
|
| 66 |
+
"fire non-impact", "smoke", "fire",
|
| 67 |
+
],
|
| 68 |
+
"F-POST": [
|
| 69 |
+
"post-crash fire", "post-impact fire", "fuel ignition after impact",
|
| 70 |
+
"post-accident fire",
|
| 71 |
+
],
|
| 72 |
+
"FUEL": [
|
| 73 |
+
"fuel exhaustion", "fuel starvation", "fuel contamination", "fuel imbalance",
|
| 74 |
+
"low fuel", "exhaustion", "starvation", "fuel",
|
| 75 |
+
],
|
| 76 |
+
"RAMP": [
|
| 77 |
+
"ground handling", "ramp incident", "towing accident", "pushback collision",
|
| 78 |
+
"baggage loader", "ground equipment", "jet bridge", "service vehicle",
|
| 79 |
+
"fuelling incident",
|
| 80 |
+
],
|
| 81 |
+
"GCOL": [
|
| 82 |
+
"ground collision", "taxiway collision", "tug", "pushback",
|
| 83 |
+
],
|
| 84 |
+
"GTOW": [
|
| 85 |
+
"glider tow", "aerotow", "tow rope", "glider towing",
|
| 86 |
+
],
|
| 87 |
+
"ICE": [
|
| 88 |
+
"ice accretion", "ice ingestion", "icing", "frost", "frozen",
|
| 89 |
+
"deice", "anti-ice", "ice",
|
| 90 |
+
],
|
| 91 |
+
"ISEC": [
|
| 92 |
+
"cyber attack", "information security", "gps spoofing", "gps jamming",
|
| 93 |
+
"data link compromise", "avionics cyber",
|
| 94 |
+
],
|
| 95 |
+
"LALT": [
|
| 96 |
+
"low altitude", "low level flight", "below minimum altitude", "msaw",
|
| 97 |
+
"controlled flight low", "low flying",
|
| 98 |
+
],
|
| 99 |
+
"LOC-G": [
|
| 100 |
+
"loss of control ground", "ground loss of control", "veer off ground",
|
| 101 |
+
"directional control loss", "ground loop", "skidded", "aquaplaning",
|
| 102 |
+
],
|
| 103 |
+
"LOC-I": [
|
| 104 |
+
"loss of control", "departure from controlled flight", "unusual attitude",
|
| 105 |
+
"uncontrolled descent", "spiral dive", "pitch up", "stall", "upset", "spin",
|
| 106 |
+
],
|
| 107 |
+
"LOLI": [
|
| 108 |
+
"loss of lifting conditions", "downdraft", "loss of lift en route",
|
| 109 |
+
"density altitude", "helicopter settling with power", "vortex ring",
|
| 110 |
+
],
|
| 111 |
+
"MAC": [
|
| 112 |
+
"mid-air collision", "midair collision", "airprox", "near miss", "tcas",
|
| 113 |
+
"traffic alert", "loss of separation", "acas alert",
|
| 114 |
+
],
|
| 115 |
+
"MED": [
|
| 116 |
+
"medical", "pilot incapacitation", "crew incapacitation", "heart attack",
|
| 117 |
+
"stroke", "medical emergency", "unconscious pilot",
|
| 118 |
+
],
|
| 119 |
+
"NAV": [
|
| 120 |
+
"navigation error", "wrong runway", "off course", "navigational error",
|
| 121 |
+
"rnav error", "gps error", "position error", "wrong approach",
|
| 122 |
+
"flew to wrong airport",
|
| 123 |
+
],
|
| 124 |
+
"RE": [
|
| 125 |
+
"runway excursion", "runway overrun", "overran runway", "veered off runway",
|
| 126 |
+
"skidded off runway", "overrun", "excursion",
|
| 127 |
+
],
|
| 128 |
+
"RI": [
|
| 129 |
+
"runway incursion", "unauthorised runway entry", "runway occupied",
|
| 130 |
+
"aircraft on runway", "vehicle on runway", "person on runway",
|
| 131 |
+
],
|
| 132 |
+
"SCF-NP": [
|
| 133 |
+
"gear failure", "gear collapse", "nose gear", "main gear", "landing gear",
|
| 134 |
+
"hydraulic failure", "hydraulic leak", "avionics failure", "flap failure",
|
| 135 |
+
"flight control failure", "structural failure", "electrical failure",
|
| 136 |
+
"hydraulic", "avionics", "flap", "rudder", "elevator", "aileron",
|
| 137 |
+
],
|
| 138 |
+
"SCF-PP": [
|
| 139 |
+
"engine failure", "engine malfunction", "engine problem",
|
| 140 |
+
"engine separation", "engine shutdown", "engine surge", "power loss",
|
| 141 |
+
"flameout", "oil leak", "fuel leak", "compressor stall", "turbine failure",
|
| 142 |
+
"propeller failure", "rpm rollback", "powerplant", "turbine", "compressor",
|
| 143 |
+
"propeller", "engine",
|
| 144 |
+
],
|
| 145 |
+
"SEC": [
|
| 146 |
+
"hijack", "hijacking", "air piracy", "security threat", "bomb threat",
|
| 147 |
+
"weapon", "unruly passenger", "assault", "attack",
|
| 148 |
+
],
|
| 149 |
+
"TURB": [
|
| 150 |
+
"severe turbulence", "clear air turbulence", "wake turbulence",
|
| 151 |
+
"turbulence", "chop", "jolt",
|
| 152 |
+
],
|
| 153 |
+
"UIMC": [
|
| 154 |
+
"inadvertent imc", "vfr into imc", "flew into cloud",
|
| 155 |
+
"instrument meteorological conditions", "unintended flight in imc", "imc",
|
| 156 |
+
],
|
| 157 |
+
"UNK": [],
|
| 158 |
+
"USOS": [
|
| 159 |
+
"undershoot", "overshoot", "short landing", "long landing",
|
| 160 |
+
"landed short", "landed long", "threshold undershoot",
|
| 161 |
+
],
|
| 162 |
+
"WILD": [
|
| 163 |
+
"wildlife strike", "animal strike", "wildlife", "animal",
|
| 164 |
+
],
|
| 165 |
+
"WSTRW": [
|
| 166 |
+
"microburst", "downburst", "windshear", "wind shear", "thunderstorm",
|
| 167 |
+
"shear", "microburst encounter",
|
| 168 |
+
],
|
| 169 |
+
"OTHR": [],
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
# Entity type scoring weights — TRIGGERs are primary accident drivers
|
| 173 |
+
ENTITY_WEIGHTS = {
|
| 174 |
+
"TRIGGER": 3.0,
|
| 175 |
+
"OUTCOME": 2.0,
|
| 176 |
+
"SYSTEM": 1.5,
|
| 177 |
+
"PHASE": 0.5,
|
| 178 |
+
"ACTOR": 0.5,
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
ALL_ADREP_CODES = list(ADREP_KEYWORDS.keys())
|
| 182 |
+
|
| 183 |
+
# Full descriptions from ECCAIRS Aviation 7.1.0.0 Attribute 430 "Occurrence category"
|
| 184 |
+
ADREP_DESCRIPTIONS: dict = {
|
| 185 |
+
"AMAN": "Abrupt Manoeuvre",
|
| 186 |
+
"ARC": "Abnormal Runway Contact",
|
| 187 |
+
"ADRM": "Aerodrome",
|
| 188 |
+
"ATM": "ATM/CNS",
|
| 189 |
+
"BIRD": "Birdstrike",
|
| 190 |
+
"CABIN": "Cabin Safety Events",
|
| 191 |
+
"CFIT": "Controlled Flight Into or Toward Terrain",
|
| 192 |
+
"CTOL": "Collision with Obstacle(s) During Take-off and Landing",
|
| 193 |
+
"EVAC": "Evacuation",
|
| 194 |
+
"EXTL": "External Load Related Occurrences",
|
| 195 |
+
"F-NI": "Fire/Smoke (Non-Impact)",
|
| 196 |
+
"F-POST": "Fire/Smoke (Post-Impact)",
|
| 197 |
+
"FUEL": "Fuel Related",
|
| 198 |
+
"RAMP": "Ground Handling",
|
| 199 |
+
"GCOL": "Ground Collision",
|
| 200 |
+
"GTOW": "Glider Towing Related Events",
|
| 201 |
+
"ICE": "Icing",
|
| 202 |
+
"ISEC": "Information Security Related",
|
| 203 |
+
"LALT": "Low Altitude Operations",
|
| 204 |
+
"LOC-G": "Loss of Control - Ground",
|
| 205 |
+
"LOC-I": "Loss of Control - Inflight",
|
| 206 |
+
"LOLI": "Loss of Lifting Conditions En-Route",
|
| 207 |
+
"MAC": "Airprox/ACAS Alert/Loss of Separation/(Near) Midair Collision",
|
| 208 |
+
"MED": "Medical",
|
| 209 |
+
"NAV": "Navigation Error",
|
| 210 |
+
"RE": "Runway Excursion",
|
| 211 |
+
"RI": "Runway Incursion - Vehicle, Aircraft or Person",
|
| 212 |
+
"SCF-NP": "System/Component Failure or Malfunction (Non-Powerplant)",
|
| 213 |
+
"SCF-PP": "Powerplant Failure or Malfunction",
|
| 214 |
+
"SEC": "Security Related",
|
| 215 |
+
"TURB": "Turbulence Encounter",
|
| 216 |
+
"UIMC": "Unintended Flight in IMC",
|
| 217 |
+
"UNK": "Unknown or Undetermined",
|
| 218 |
+
"USOS": "Undershoot/Overshoot",
|
| 219 |
+
"WILD": "Collision with Wildlife",
|
| 220 |
+
"WSTRW": "Windshear or Thunderstorm",
|
| 221 |
+
"OTHR": "Other",
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
print("Loading NER model...")
|
| 225 |
+
_tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 226 |
+
_model = AutoModelForTokenClassification.from_pretrained(MODEL_ID)
|
| 227 |
+
_model.eval()
|
| 228 |
+
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 229 |
+
_model.to(_device)
|
| 230 |
+
print(f"Model loaded on {_device}.")
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class IncidentRequest(BaseModel):
|
| 234 |
+
narrative: str
|
| 235 |
+
event_id: str = None
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def extract_entities(text: str) -> list:
|
| 239 |
+
"""Run NER inference; return [(token, label), ...] with subwords merged."""
|
| 240 |
+
inputs = _tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
| 241 |
+
inputs = {k: v.to(_device) for k, v in inputs.items()}
|
| 242 |
+
|
| 243 |
+
with torch.no_grad():
|
| 244 |
+
logits = _model(**inputs).logits
|
| 245 |
+
|
| 246 |
+
preds = torch.argmax(logits, dim=-1)[0].cpu().numpy()
|
| 247 |
+
tokens = _tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
|
| 248 |
+
|
| 249 |
+
results = []
|
| 250 |
+
for token, label_id in zip(tokens, preds):
|
| 251 |
+
if token in _tokenizer.all_special_tokens:
|
| 252 |
+
continue
|
| 253 |
+
label = _model.config.id2label[label_id]
|
| 254 |
+
if token.startswith("##"):
|
| 255 |
+
if results:
|
| 256 |
+
results[-1] = (results[-1][0] + token[2:], results[-1][1])
|
| 257 |
+
else:
|
| 258 |
+
results.append((token, label))
|
| 259 |
+
|
| 260 |
+
return results
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def build_event_dict(entities: list) -> dict:
|
| 264 |
+
"""
|
| 265 |
+
Merge consecutive B-/I- tokens into multi-word phrases.
|
| 266 |
+
Returns {"ACTOR": [...], "SYSTEM": [...], "PHASE": [...],
|
| 267 |
+
"TRIGGER": [...], "OUTCOME": [...]}
|
| 268 |
+
"""
|
| 269 |
+
event = {role: [] for role in ENTITY_WEIGHTS}
|
| 270 |
+
current_tokens = []
|
| 271 |
+
current_role = None
|
| 272 |
+
|
| 273 |
+
for token, label in entities:
|
| 274 |
+
if label == "O":
|
| 275 |
+
if current_tokens and current_role:
|
| 276 |
+
event[current_role].append(" ".join(current_tokens))
|
| 277 |
+
current_tokens, current_role = [], None
|
| 278 |
+
continue
|
| 279 |
+
|
| 280 |
+
prefix, role = label.split("-", 1)
|
| 281 |
+
if role not in event:
|
| 282 |
+
continue
|
| 283 |
+
|
| 284 |
+
if prefix == "B":
|
| 285 |
+
if current_tokens and current_role:
|
| 286 |
+
event[current_role].append(" ".join(current_tokens))
|
| 287 |
+
current_tokens = [token]
|
| 288 |
+
current_role = role
|
| 289 |
+
elif prefix == "I" and role == current_role:
|
| 290 |
+
current_tokens.append(token)
|
| 291 |
+
else:
|
| 292 |
+
if current_tokens and current_role:
|
| 293 |
+
event[current_role].append(" ".join(current_tokens))
|
| 294 |
+
current_tokens = [token]
|
| 295 |
+
current_role = role
|
| 296 |
+
|
| 297 |
+
if current_tokens and current_role:
|
| 298 |
+
event[current_role].append(" ".join(current_tokens))
|
| 299 |
+
|
| 300 |
+
return event
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def score_adrep(event: dict) -> dict:
|
| 304 |
+
"""
|
| 305 |
+
For each entity role, check which ADREP codes have a keyword that is a
|
| 306 |
+
substring of the extracted phrase. Accumulate weighted scores.
|
| 307 |
+
"""
|
| 308 |
+
scores: dict = defaultdict(float)
|
| 309 |
+
|
| 310 |
+
for role, phrases in event.items():
|
| 311 |
+
weight = ENTITY_WEIGHTS.get(role, 1.0)
|
| 312 |
+
combined = " ".join(phrases).lower()
|
| 313 |
+
for code, keywords in ADREP_KEYWORDS.items():
|
| 314 |
+
for kw in keywords:
|
| 315 |
+
if kw in combined:
|
| 316 |
+
scores[code] += weight
|
| 317 |
+
break # count each code once per entity type
|
| 318 |
+
|
| 319 |
+
if not scores:
|
| 320 |
+
scores["OTHR"] = 1.0
|
| 321 |
+
|
| 322 |
+
return dict(scores)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def scores_to_top5(scores: dict) -> tuple:
|
| 326 |
+
"""Normalise scores → confidences; return (top_class, confidence, top_5)."""
|
| 327 |
+
total = sum(scores.values())
|
| 328 |
+
normalised = {k: v / total for k, v in scores.items()}
|
| 329 |
+
|
| 330 |
+
# Assign a tiny residual to every code not already scored
|
| 331 |
+
unscored = [c for c in ALL_ADREP_CODES if c not in normalised]
|
| 332 |
+
residual = max((1.0 - sum(normalised.values())) / max(len(unscored), 1), 0.001)
|
| 333 |
+
for code in unscored:
|
| 334 |
+
normalised[code] = residual
|
| 335 |
+
|
| 336 |
+
sorted_codes = sorted(normalised.items(), key=lambda x: x[1], reverse=True)
|
| 337 |
+
top_code, top_conf = sorted_codes[0]
|
| 338 |
+
top_5 = [{"class": c, "confidence": round(conf, 4)} for c, conf in sorted_codes[:5]]
|
| 339 |
+
|
| 340 |
+
return top_code, round(top_conf, 4), top_5
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
@app.post("/predict")
|
| 344 |
+
async def predict(request: IncidentRequest):
|
| 345 |
+
start = time.time()
|
| 346 |
+
|
| 347 |
+
entities = extract_entities(request.narrative)
|
| 348 |
+
event = build_event_dict(entities)
|
| 349 |
+
scores = score_adrep(event)
|
| 350 |
+
top_code, confidence, top_5 = scores_to_top5(scores)
|
| 351 |
+
|
| 352 |
+
return {
|
| 353 |
+
"model_id": MODEL_ID,
|
| 354 |
+
"display_name": MODEL_DISPLAY_NAME,
|
| 355 |
+
"prediction": {
|
| 356 |
+
"top_class": ADREP_DESCRIPTIONS.get(top_code, top_code),
|
| 357 |
+
"confidence": confidence,
|
| 358 |
+
"top_5": top_5,
|
| 359 |
+
},
|
| 360 |
+
"inference_time_ms": int((time.time() - start) * 1000),
|
| 361 |
+
}
|
requirements_api.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.0
|
| 2 |
+
uvicorn==0.34.0
|
| 3 |
+
transformers==4.47.0
|
| 4 |
+
torch==2.6.0
|
| 5 |
+
pydantic==2.7.0
|
| 6 |
+
huggingface_hub==0.27.0
|
| 7 |
+
safetensors==0.5.3
|
| 8 |
+
tokenizers==0.21.0
|