Spaces:
Sleeping
Sleeping
File size: 5,073 Bytes
2272891 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from transformers import (
pipeline,
BertTokenizerFast,
BertModel,
)
import gensim.downloader as api
import nltk
from nltk.tokenize import word_tokenize
# -------------------------------------------------
# Initialization
# -------------------------------------------------
app = FastAPI(title="Clinical NER Comparison Demo")
app.mount("/static", StaticFiles(directory="static"), name="static")
# -------------------------------------------------
# Models loaded ONCE (important)
# -------------------------------------------------
clinical_ner = pipeline(
"token-classification",
model="samrawal/bert-base-uncased_clinical-ner",
aggregation_strategy="simple"
)
bert_tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
bert_model = BertModel.from_pretrained("bert-base-uncased")
bert_model.eval()
w2v = api.load("word2vec-google-news-300")
nltk.download("punkt", quiet=True)
nltk.download("punkt_tab", quiet=True)
# -------------------------------------------------
# API schema
# -------------------------------------------------
class NERRequest(BaseModel):
text: str
prototypes: list[str]
# -------------------------------------------------
# Utility functions
# -------------------------------------------------
def build_prototypes_bert(words, embeddings, word_ids, prototype_words):
prototypes = {}
for pw in prototype_words:
idxs = [
i for i, wid in enumerate(word_ids)
if wid is not None and words[wid] == pw
]
if idxs:
prototypes[pw] = embeddings[idxs].mean(axis=0)
return prototypes
def bert_similarity_ner(text, prototype_words):
words = text.lower().split()
encoded = bert_tokenizer(
words,
is_split_into_words=True,
return_tensors="pt",
return_offsets_mapping=True
)
encoded.pop("offset_mapping")
with torch.no_grad():
outputs = bert_model(**encoded)
embeddings = outputs.last_hidden_state.squeeze(0).numpy()
tokens = bert_tokenizer.convert_ids_to_tokens(encoded["input_ids"][0])
word_ids = encoded.word_ids()
prototypes = build_prototypes_bert(words, embeddings, word_ids, prototype_words)
results = []
for token, emb, wid in zip(tokens, embeddings, word_ids):
if wid is None or token.startswith("##"):
continue
sims = {
pw: cosine_similarity(
emb.reshape(1, -1),
proto.reshape(1, -1)
)[0][0]
for pw, proto in prototypes.items()
}
if sims:
best = max(sims, key=sims.get)
if sims[best] > 0.75:
results.append({
"text": token,
"label": best,
"score": float(sims[best])
})
return results
def w2v_similarity_ner(text, prototype_words):
tokens = word_tokenize(text.lower())
results = []
for t in tokens:
if t in w2v:
sims = {
pw: cosine_similarity(
w2v[t].reshape(1, -1),
w2v[pw].reshape(1, -1)
)[0][0]
for pw in prototype_words if pw in w2v
}
if sims:
best = max(sims, key=sims.get)
if sims[best] > 0.65:
results.append({
"text": t,
"label": best,
"score": float(sims[best])
})
return results
def make_json_safe(obj):
"""
Recursively convert NumPy types to native Python types
so FastAPI can serialize them.
"""
if isinstance(obj, dict):
return {k: make_json_safe(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [make_json_safe(v) for v in obj]
elif isinstance(obj, np.generic):
return obj.item()
else:
return obj
# -------------------------------------------------
# Routes
# -------------------------------------------------
@app.get("/", response_class=HTMLResponse)
def root():
with open("static/index.html") as f:
return f.read()
@app.post("/run")
def run_ner(req: NERRequest):
text = req.text
prototype_words = [p.strip().lower() for p in req.prototypes if p.strip()]
log = []
log.append("Running Pipeline 1 (fine-tuned clinical BERT)")
p1 = make_json_safe(clinical_ner(text))
log.append("Running Pipeline 2 (vanilla BERT + similarity)")
p2 = make_json_safe(bert_similarity_ner(text, prototype_words))
log.append("Running Pipeline 3 (Word2Vec + similarity)")
p3 = make_json_safe(w2v_similarity_ner(text, prototype_words))
return {
"pipeline_1": p1,
"pipeline_2": p2,
"pipeline_3": p3,
"log": log
}
|