Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- README.md +6 -4
- app.py +48 -0
- asr.py +22 -0
- final_model/added_tokens.json +3 -0
- final_model/config.json +32 -0
- final_model/label_spaces.json +44 -0
- final_model/pytorch_model.bin +3 -0
- final_model/special_tokens_map.json +15 -0
- final_model/spm.model +3 -0
- final_model/thresholds.json +15 -0
- final_model/tokenizer.json +0 -0
- final_model/tokenizer_config.json +59 -0
- inference.py +91 -0
- requirements.txt +9 -0
README.md
CHANGED
|
@@ -1,12 +1,14 @@
|
|
| 1 |
---
|
| 2 |
title: 911 Urgency Prototype
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
|
|
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
| 1 |
---
|
| 2 |
title: 911 Urgency Prototype
|
| 3 |
+
emoji: π¨
|
| 4 |
+
colorFrom: indigo
|
| 5 |
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.0.0
|
| 8 |
+
python_version: 3.10
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
+
license: mit
|
| 12 |
---
|
| 13 |
|
| 14 |
+
Urgency-only prototype for 911 decision support (voice β transcript β score).
|
app.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from inference import UrgencyModel
|
| 4 |
+
from asr import Transcriber
|
| 5 |
+
|
| 6 |
+
urg = UrgencyModel()
|
| 7 |
+
asr = Transcriber()
|
| 8 |
+
|
| 9 |
+
def transcribe_then_score(audio_file, thr=0.5):
|
| 10 |
+
urg.threshold = float(thr)
|
| 11 |
+
if audio_file is None:
|
| 12 |
+
return "", 0.0, "Non-Urgent", "No audio provided."
|
| 13 |
+
text = asr.transcribe_file(audio_file)
|
| 14 |
+
res = urg.predict(text)
|
| 15 |
+
return text, res["urgency_score"], res["urgent_label"], res["rationale"]
|
| 16 |
+
|
| 17 |
+
def score_text(text, thr=0.5):
|
| 18 |
+
urg.threshold = float(thr)
|
| 19 |
+
res = urg.predict(text or "")
|
| 20 |
+
return res["urgency_score"], res["urgent_label"], res["rationale"]
|
| 21 |
+
|
| 22 |
+
with gr.Blocks(title="911 Urgency Prototype") as demo:
|
| 23 |
+
gr.Markdown("# 911 Urgency Prototype\nDecision support (not dispatch).")
|
| 24 |
+
thr = gr.Slider(0, 1, value=0.5, step=0.01, label="Decision threshold")
|
| 25 |
+
|
| 26 |
+
with gr.Tab("Voice β Urgency"):
|
| 27 |
+
gr.Markdown("Record or upload a short clip (WAV/MP3).")
|
| 28 |
+
audio_in = gr.Audio(sources=["microphone","upload"], type="filepath")
|
| 29 |
+
btn_v = gr.Button("Transcribe & Score")
|
| 30 |
+
text_out = gr.Textbox(label="Transcript", lines=8)
|
| 31 |
+
score_out = gr.Number(label="Urgency Score (0β1)")
|
| 32 |
+
label_out = gr.Textbox(label="Urgent / Non-Urgent")
|
| 33 |
+
rationale_out = gr.Textbox(label="Rationale")
|
| 34 |
+
btn_v.click(transcribe_then_score, inputs=[audio_in, thr],
|
| 35 |
+
outputs=[text_out, score_out, label_out, rationale_out])
|
| 36 |
+
|
| 37 |
+
with gr.Tab("Text β Urgency"):
|
| 38 |
+
txt_in = gr.Textbox(label="Paste transcript", lines=8, placeholder="Paste a transcriptβ¦")
|
| 39 |
+
btn_t = gr.Button("Score Text")
|
| 40 |
+
score_out2 = gr.Number(label="Urgency Score (0β1)")
|
| 41 |
+
label_out2 = gr.Textbox(label="Urgent / Non-Urgent")
|
| 42 |
+
rationale_out2 = gr.Textbox(label="Rationale")
|
| 43 |
+
btn_t.click(score_text, inputs=[txt_in, thr], outputs=[score_out2, label_out2, rationale_out2])
|
| 44 |
+
|
| 45 |
+
gr.Markdown("**Notes:** Prototype for QA/training. No PII stored; processing is in-memory.")
|
| 46 |
+
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
demo.launch() # Spaces handles networking
|
asr.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
from faster_whisper import WhisperModel
|
| 4 |
+
|
| 5 |
+
def _pick():
|
| 6 |
+
if torch.cuda.is_available():
|
| 7 |
+
return dict(model_id="Systran/faster-whisper-small.en", device="cuda", compute_type="float16")
|
| 8 |
+
else:
|
| 9 |
+
return dict(model_id="Systran/faster-whisper-base.en", device="cpu", compute_type="int8")
|
| 10 |
+
|
| 11 |
+
class Transcriber:
|
| 12 |
+
def __init__(self):
|
| 13 |
+
cfg = _pick()
|
| 14 |
+
try:
|
| 15 |
+
self.model = WhisperModel(cfg["model_id"], device=cfg["device"], compute_type=cfg["compute_type"])
|
| 16 |
+
except ValueError:
|
| 17 |
+
# Fallbacks if int8/float16 not supported
|
| 18 |
+
self.model = WhisperModel(cfg["model_id"], device="cpu", compute_type="float32")
|
| 19 |
+
|
| 20 |
+
def transcribe_file(self, audio_path: str) -> str:
|
| 21 |
+
segments, _ = self.model.transcribe(audio_path, beam_size=1, vad_filter=True, temperature=0.0)
|
| 22 |
+
return " ".join(s.text.strip() for s in segments).strip()
|
final_model/added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"[MASK]": 128000
|
| 3 |
+
}
|
final_model/config.json
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"attention_probs_dropout_prob": 0.1,
|
| 3 |
+
"hidden_act": "gelu",
|
| 4 |
+
"hidden_dropout_prob": 0.1,
|
| 5 |
+
"hidden_size": 768,
|
| 6 |
+
"initializer_range": 0.02,
|
| 7 |
+
"intermediate_size": 3072,
|
| 8 |
+
"layer_norm_eps": 1e-07,
|
| 9 |
+
"legacy": true,
|
| 10 |
+
"max_position_embeddings": 512,
|
| 11 |
+
"max_relative_positions": -1,
|
| 12 |
+
"model_type": "deberta-v2",
|
| 13 |
+
"norm_rel_ebd": "layer_norm",
|
| 14 |
+
"num_attention_heads": 12,
|
| 15 |
+
"num_hidden_layers": 6,
|
| 16 |
+
"pad_token_id": 0,
|
| 17 |
+
"pooler_dropout": 0,
|
| 18 |
+
"pooler_hidden_act": "gelu",
|
| 19 |
+
"pooler_hidden_size": 768,
|
| 20 |
+
"pos_att_type": [
|
| 21 |
+
"p2c",
|
| 22 |
+
"c2p"
|
| 23 |
+
],
|
| 24 |
+
"position_biased_input": false,
|
| 25 |
+
"position_buckets": 256,
|
| 26 |
+
"relative_attention": true,
|
| 27 |
+
"share_att_key": true,
|
| 28 |
+
"torch_dtype": "float32",
|
| 29 |
+
"transformers_version": "4.52.4",
|
| 30 |
+
"type_vocab_size": 0,
|
| 31 |
+
"vocab_size": 128100
|
| 32 |
+
}
|
final_model/label_spaces.json
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"id2label": {
|
| 3 |
+
"urgency": {
|
| 4 |
+
"0": "Non-Urgent",
|
| 5 |
+
"1": "Urgent"
|
| 6 |
+
},
|
| 7 |
+
"call_type": {
|
| 8 |
+
"0": "Active Shooter",
|
| 9 |
+
"1": "Aggravated Assault",
|
| 10 |
+
"2": "Armed Robbery",
|
| 11 |
+
"3": "Disturbance/Nuisance",
|
| 12 |
+
"4": "Domestic Violence",
|
| 13 |
+
"5": "EMS Assist",
|
| 14 |
+
"6": "Homicide",
|
| 15 |
+
"7": "Major Trauma",
|
| 16 |
+
"8": "Other",
|
| 17 |
+
"9": "Suspicious Person/Vehicle",
|
| 18 |
+
"10": "Theft/Larceny",
|
| 19 |
+
"11": "Traffic Crash",
|
| 20 |
+
"12": "Welfare Check"
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"label2id": {
|
| 24 |
+
"urgency": {
|
| 25 |
+
"Non-Urgent": 0,
|
| 26 |
+
"Urgent": 1
|
| 27 |
+
},
|
| 28 |
+
"call_type": {
|
| 29 |
+
"Active Shooter": 0,
|
| 30 |
+
"Aggravated Assault": 1,
|
| 31 |
+
"Armed Robbery": 2,
|
| 32 |
+
"Disturbance/Nuisance": 3,
|
| 33 |
+
"Domestic Violence": 4,
|
| 34 |
+
"EMS Assist": 5,
|
| 35 |
+
"Homicide": 6,
|
| 36 |
+
"Major Trauma": 7,
|
| 37 |
+
"Other": 8,
|
| 38 |
+
"Suspicious Person/Vehicle": 9,
|
| 39 |
+
"Theft/Larceny": 10,
|
| 40 |
+
"Traffic Crash": 11,
|
| 41 |
+
"Welfare Check": 12
|
| 42 |
+
}
|
| 43 |
+
}
|
| 44 |
+
}
|
final_model/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5e535ce3ad77d5c158510e343c2d419b6a9d2643f91d2567e399aa8de37063e6
|
| 3 |
+
size 565308274
|
final_model/special_tokens_map.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": "[CLS]",
|
| 3 |
+
"cls_token": "[CLS]",
|
| 4 |
+
"eos_token": "[SEP]",
|
| 5 |
+
"mask_token": "[MASK]",
|
| 6 |
+
"pad_token": "[PAD]",
|
| 7 |
+
"sep_token": "[SEP]",
|
| 8 |
+
"unk_token": {
|
| 9 |
+
"content": "[UNK]",
|
| 10 |
+
"lstrip": false,
|
| 11 |
+
"normalized": true,
|
| 12 |
+
"rstrip": false,
|
| 13 |
+
"single_word": false
|
| 14 |
+
}
|
| 15 |
+
}
|
final_model/spm.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd
|
| 3 |
+
size 2464616
|
final_model/thresholds.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"Active Shooter": 0.49999999999999994,
|
| 3 |
+
"Aggravated Assault": 0.49999999999999994,
|
| 4 |
+
"Armed Robbery": 0.5499999999999999,
|
| 5 |
+
"Disturbance/Nuisance": 0.39999999999999997,
|
| 6 |
+
"Domestic Violence": 0.5,
|
| 7 |
+
"EMS Assist": 0.49999999999999994,
|
| 8 |
+
"Homicide": 0.44999999999999996,
|
| 9 |
+
"Major Trauma": 0.49999999999999994,
|
| 10 |
+
"Other": 0.49999999999999994,
|
| 11 |
+
"Suspicious Person/Vehicle": 0.49999999999999994,
|
| 12 |
+
"Theft/Larceny": 0.44999999999999996,
|
| 13 |
+
"Traffic Crash": 0.5499999999999999,
|
| 14 |
+
"Welfare Check": 0.5
|
| 15 |
+
}
|
final_model/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
final_model/tokenizer_config.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"added_tokens_decoder": {
|
| 3 |
+
"0": {
|
| 4 |
+
"content": "[PAD]",
|
| 5 |
+
"lstrip": false,
|
| 6 |
+
"normalized": false,
|
| 7 |
+
"rstrip": false,
|
| 8 |
+
"single_word": false,
|
| 9 |
+
"special": true
|
| 10 |
+
},
|
| 11 |
+
"1": {
|
| 12 |
+
"content": "[CLS]",
|
| 13 |
+
"lstrip": false,
|
| 14 |
+
"normalized": false,
|
| 15 |
+
"rstrip": false,
|
| 16 |
+
"single_word": false,
|
| 17 |
+
"special": true
|
| 18 |
+
},
|
| 19 |
+
"2": {
|
| 20 |
+
"content": "[SEP]",
|
| 21 |
+
"lstrip": false,
|
| 22 |
+
"normalized": false,
|
| 23 |
+
"rstrip": false,
|
| 24 |
+
"single_word": false,
|
| 25 |
+
"special": true
|
| 26 |
+
},
|
| 27 |
+
"3": {
|
| 28 |
+
"content": "[UNK]",
|
| 29 |
+
"lstrip": false,
|
| 30 |
+
"normalized": true,
|
| 31 |
+
"rstrip": false,
|
| 32 |
+
"single_word": false,
|
| 33 |
+
"special": true
|
| 34 |
+
},
|
| 35 |
+
"128000": {
|
| 36 |
+
"content": "[MASK]",
|
| 37 |
+
"lstrip": false,
|
| 38 |
+
"normalized": false,
|
| 39 |
+
"rstrip": false,
|
| 40 |
+
"single_word": false,
|
| 41 |
+
"special": true
|
| 42 |
+
}
|
| 43 |
+
},
|
| 44 |
+
"bos_token": "[CLS]",
|
| 45 |
+
"clean_up_tokenization_spaces": false,
|
| 46 |
+
"cls_token": "[CLS]",
|
| 47 |
+
"do_lower_case": false,
|
| 48 |
+
"eos_token": "[SEP]",
|
| 49 |
+
"extra_special_tokens": {},
|
| 50 |
+
"mask_token": "[MASK]",
|
| 51 |
+
"model_max_length": 1000000000000000019884624838656,
|
| 52 |
+
"pad_token": "[PAD]",
|
| 53 |
+
"sep_token": "[SEP]",
|
| 54 |
+
"sp_model_kwargs": {},
|
| 55 |
+
"split_by_punct": false,
|
| 56 |
+
"tokenizer_class": "DebertaV2Tokenizer",
|
| 57 |
+
"unk_token": "[UNK]",
|
| 58 |
+
"vocab_type": "spm"
|
| 59 |
+
}
|
inference.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 4 |
+
|
| 5 |
+
import torch, json
|
| 6 |
+
from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification
|
| 7 |
+
from transformers.models.deberta_v2 import DebertaV2ForSequenceClassification
|
| 8 |
+
|
| 9 |
+
MODEL_DIR_DEFAULT = os.path.join(os.path.dirname(__file__), "final_model")
|
| 10 |
+
|
| 11 |
+
def _strip_wrappers(k: str) -> str:
|
| 12 |
+
for p in ("model.", "module.", "net."):
|
| 13 |
+
if k.startswith(p): return k[len(p):]
|
| 14 |
+
return k
|
| 15 |
+
|
| 16 |
+
def _remap_keys(sd: dict) -> dict:
|
| 17 |
+
new = {}
|
| 18 |
+
for k, v in sd.items():
|
| 19 |
+
k = _strip_wrappers(k)
|
| 20 |
+
if k.startswith("backbone."):
|
| 21 |
+
k = "deberta." + k[len("backbone."):]
|
| 22 |
+
elif k.startswith(("head.", "heads.", "cls.", "fc.")):
|
| 23 |
+
k = "classifier." + k.split(".", 1)[1]
|
| 24 |
+
elif k.startswith("encoder."):
|
| 25 |
+
k = "deberta." + k
|
| 26 |
+
new[k] = v
|
| 27 |
+
return new
|
| 28 |
+
|
| 29 |
+
class UrgencyModel:
|
| 30 |
+
def __init__(self, model_dir=MODEL_DIR_DEFAULT, device=None, threshold=0.5):
|
| 31 |
+
self.model_dir = model_dir
|
| 32 |
+
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 33 |
+
|
| 34 |
+
thr_path = os.path.join(model_dir, "thresholds.json")
|
| 35 |
+
if os.path.exists(thr_path):
|
| 36 |
+
try:
|
| 37 |
+
threshold = float(json.load(open(thr_path, encoding="utf-8")).get("urgency", threshold))
|
| 38 |
+
except Exception:
|
| 39 |
+
pass
|
| 40 |
+
self.threshold = threshold
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
spaces = json.load(open(os.path.join(model_dir, "label_spaces.json"), encoding="utf-8"))
|
| 44 |
+
self.id2label = {int(k): v for k, v in spaces.get("id2label", {}).get("urgency", {}).items()}
|
| 45 |
+
except Exception:
|
| 46 |
+
self.id2label = {0: "Non-Urgent", 1: "Urgent"}
|
| 47 |
+
|
| 48 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_dir, local_files_only=True)
|
| 49 |
+
cfg = AutoConfig.from_pretrained(model_dir, local_files_only=True)
|
| 50 |
+
if getattr(cfg, "model_type", None) == "deberta-v2":
|
| 51 |
+
self.model = DebertaV2ForSequenceClassification(cfg)
|
| 52 |
+
else:
|
| 53 |
+
self.model = AutoModelForSequenceClassification.from_config(cfg)
|
| 54 |
+
|
| 55 |
+
sd = None
|
| 56 |
+
binp = os.path.join(model_dir, "pytorch_model.bin")
|
| 57 |
+
safep = os.path.join(model_dir, "model.safetensors")
|
| 58 |
+
if os.path.exists(binp):
|
| 59 |
+
sd = torch.load(binp, map_location="cpu")
|
| 60 |
+
if isinstance(sd, dict) and "state_dict" in sd and isinstance(sd["state_dict"], dict):
|
| 61 |
+
sd = sd["state_dict"]
|
| 62 |
+
elif os.path.exists(safep):
|
| 63 |
+
from safetensors.torch import load_file
|
| 64 |
+
sd = load_file(safep)
|
| 65 |
+
else:
|
| 66 |
+
raise FileNotFoundError("No model weights found.")
|
| 67 |
+
|
| 68 |
+
sd = _remap_keys(sd)
|
| 69 |
+
self.model.load_state_dict(sd, strict=False)
|
| 70 |
+
self.model.to(self.device).eval()
|
| 71 |
+
|
| 72 |
+
@torch.inference_mode()
|
| 73 |
+
def predict(self, text: str):
|
| 74 |
+
if not text or not text.strip():
|
| 75 |
+
return {"urgency_score": 0.0, "urgent_label": "Non-Urgent", "rationale": "Empty input."}
|
| 76 |
+
inputs = self.tokenizer(text, truncation=True, max_length=1024, return_tensors="pt").to(self.device)
|
| 77 |
+
logits = self.model(**inputs).logits
|
| 78 |
+
if logits.shape[-1] == 1:
|
| 79 |
+
score = torch.sigmoid(logits.squeeze(-1)).item()
|
| 80 |
+
else:
|
| 81 |
+
score = torch.softmax(logits, dim=-1).squeeze(0)[1].item()
|
| 82 |
+
label = self.id2label.get(int(score >= self.threshold), "Urgent" if score >= self.threshold else "Non-Urgent")
|
| 83 |
+
return {"urgency_score": round(float(score), 4), "urgent_label": label, "rationale": self._cheap_rationale(text)}
|
| 84 |
+
|
| 85 |
+
def _cheap_rationale(self, text: str, top_n: int = 3):
|
| 86 |
+
KEYS = ["shot","shooting","gun","stabbing","blood","not breathing","unconscious",
|
| 87 |
+
"heart","chest pain","stroke","seizure","screaming","help now","immediate",
|
| 88 |
+
"fire","trapped","domestic","assault","weapon"]
|
| 89 |
+
t = text.lower()
|
| 90 |
+
hits = [k for k in KEYS if k in t][:top_n]
|
| 91 |
+
return "Keywords: " + (", ".join(hits) if hits else "none detected")
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers>=4.39
|
| 3 |
+
accelerate
|
| 4 |
+
gradio>=4.0
|
| 5 |
+
faster-whisper==1.0.3
|
| 6 |
+
soundfile
|
| 7 |
+
ffmpeg-python
|
| 8 |
+
safetensors
|
| 9 |
+
huggingface_hub
|