|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
import spacy |
|
|
import re |
|
|
import unicodedata |
|
|
import sys |
|
|
import subprocess |
|
|
|
|
|
|
|
|
try: |
|
|
nlp = spacy.load("en_core_web_sm") |
|
|
except OSError: |
|
|
print("Downloading spaCy model...") |
|
|
subprocess.run([sys.executable, "-m", "spacy", "download", "en_core_web_sm"], check=True) |
|
|
nlp = spacy.load("en_core_web_sm") |
|
|
|
|
|
nlp.add_pipe("sentencizer") |
|
|
|
|
|
model_id = "Statistical-Impossibility/Feline-NER" |
|
|
ner_pipeline = pipeline("token-classification", model=model_id, aggregation_strategy="simple") |
|
|
|
|
|
def clean_text(text): |
|
|
"""Aggressive cleaning for PDF/HTML paste artifacts.""" |
|
|
text = unicodedata.normalize('NFKC', text) |
|
|
text = re.sub(r'<[^>]+>', '', text) |
|
|
text = re.sub(r'(\w+)-\s*\n\s*(\w+)', r'\1\2', text) |
|
|
text = re.sub(r'\n{3,}', '\n\n', text) |
|
|
text = re.sub(r'\s+', ' ', text) |
|
|
text = re.sub(r'-\s+', '', text) |
|
|
return text.strip() |
|
|
|
|
|
def expand_to_word_boundaries(text, start, end): |
|
|
""" |
|
|
Expand entity boundaries to complete words. |
|
|
Prevents highlighting fragments like "itis" from "abnormalities". |
|
|
""" |
|
|
|
|
|
while start > 0 and (text[start - 1].isalnum() or text[start - 1] in ['-', "'"]): |
|
|
start -= 1 |
|
|
|
|
|
|
|
|
while end < len(text) and (text[end].isalnum() or text[end] in ['-', "'"]): |
|
|
end += 1 |
|
|
|
|
|
return start, end |
|
|
|
|
|
def is_valid_entity(text, start, end): |
|
|
""" |
|
|
Filter out garbage entities. |
|
|
Returns False if entity is: |
|
|
- Too short (< 2 chars) |
|
|
- All punctuation |
|
|
- Just a suffix (starts with ##) |
|
|
""" |
|
|
entity_text = text[start:end].strip() |
|
|
|
|
|
|
|
|
if len(entity_text) < 2: |
|
|
return False |
|
|
|
|
|
|
|
|
if not any(c.isalpha() for c in entity_text): |
|
|
return False |
|
|
|
|
|
|
|
|
if entity_text.startswith('##'): |
|
|
return False |
|
|
|
|
|
|
|
|
if len(entity_text) == 1: |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
def ner_predict(text): |
|
|
if not text.strip(): |
|
|
return "<p>No text provided</p>", "No entities" |
|
|
|
|
|
if len(text) > 100000: |
|
|
return "<p style='color:red;'>Text too long (max 100k characters)</p>", "" |
|
|
|
|
|
|
|
|
text = clean_text(text) |
|
|
|
|
|
|
|
|
doc = nlp(text) |
|
|
sentences = [] |
|
|
for sent in doc.sents: |
|
|
sentences.append({ |
|
|
"text": sent.text, |
|
|
"start": sent.start_char, |
|
|
"end": sent.end_char |
|
|
}) |
|
|
|
|
|
if not sentences: |
|
|
return "<p>No sentences detected</p>", "" |
|
|
|
|
|
|
|
|
sentence_token_counts = [] |
|
|
for sent in sentences: |
|
|
tokens = ner_pipeline.tokenizer.tokenize(sent["text"]) |
|
|
sentence_token_counts.append(len(tokens)) |
|
|
|
|
|
|
|
|
max_tokens = 450 |
|
|
chunks = [] |
|
|
i = 0 |
|
|
|
|
|
while i < len(sentences): |
|
|
chunk_sents = [] |
|
|
token_count = 0 |
|
|
|
|
|
for j in range(i, len(sentences)): |
|
|
sent_token_count = sentence_token_counts[j] |
|
|
|
|
|
|
|
|
if token_count + sent_token_count > max_tokens and chunk_sents: |
|
|
break |
|
|
|
|
|
chunk_sents.append(sentences[j]) |
|
|
token_count += sent_token_count |
|
|
|
|
|
if chunk_sents: |
|
|
chunk_text = " ".join([s["text"] for s in chunk_sents]) |
|
|
chunks.append({ |
|
|
"text": chunk_text, |
|
|
"offset": chunk_sents[0]["start"], |
|
|
"sentence_count": len(chunk_sents) |
|
|
}) |
|
|
|
|
|
sentences_to_skip = max(1, len(chunk_sents) - 2) |
|
|
i += sentences_to_skip |
|
|
|
|
|
|
|
|
all_entities = [] |
|
|
|
|
|
for chunk in chunks: |
|
|
try: |
|
|
results = ner_pipeline(chunk["text"]) |
|
|
|
|
|
for r in results: |
|
|
if r['score'] > 0.50: |
|
|
|
|
|
r['start'] += chunk["offset"] |
|
|
r['end'] += chunk["offset"] |
|
|
|
|
|
|
|
|
r['start'], r['end'] = expand_to_word_boundaries( |
|
|
text, r['start'], r['end'] |
|
|
) |
|
|
|
|
|
|
|
|
if is_valid_entity(text, r['start'], r['end']): |
|
|
all_entities.append(r) |
|
|
except Exception as e: |
|
|
print(f"Chunk processing error: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
all_entities = sorted(all_entities, key=lambda x: (x['start'], -x['score'])) |
|
|
|
|
|
final_entities = [] |
|
|
for ent in all_entities: |
|
|
|
|
|
if not final_entities or ent['start'] >= final_entities[-1]['end']: |
|
|
final_entities.append(ent) |
|
|
elif ent['score'] > final_entities[-1]['score']: |
|
|
|
|
|
if ent['end'] > final_entities[-1]['end'] or ent['start'] < final_entities[-1]['start']: |
|
|
final_entities[-1] = ent |
|
|
|
|
|
|
|
|
highlighted = "" |
|
|
last_idx = 0 |
|
|
|
|
|
color_map = { |
|
|
"SYMPTOM": "#FFD700", |
|
|
"DISEASE": "#FF6B6B", |
|
|
"MEDICATION": "#90EE90", |
|
|
"PROCEDURE": "#87CEEB", |
|
|
"ANATOMY": "#FFB347" |
|
|
} |
|
|
|
|
|
label_display = { |
|
|
"DISEASE": "disease", |
|
|
"SYMPTOM": "symptom", |
|
|
"MEDICATION": "medication", |
|
|
"PROCEDURE": "procedure", |
|
|
"ANATOMY": "anatomy" |
|
|
} |
|
|
|
|
|
for ent in final_entities: |
|
|
start, end = ent['start'], ent['end'] |
|
|
label = ent['entity_group'] |
|
|
score = ent['score'] |
|
|
|
|
|
|
|
|
if start >= len(text) or end > len(text) or start < 0 or end < 0: |
|
|
continue |
|
|
|
|
|
|
|
|
if start >= end: |
|
|
continue |
|
|
|
|
|
highlighted += text[last_idx:start] |
|
|
|
|
|
color = color_map.get(label, "#E0E0E0") |
|
|
display_label = label_display.get(label, label.lower()) |
|
|
entity_text = text[start:end] |
|
|
|
|
|
highlighted += ( |
|
|
f'<mark style="background-color:{color}; padding:2px 4px; ' |
|
|
f'border-radius:3px; font-weight:500;" ' |
|
|
f'title="{display_label} ({score:.2f})">' |
|
|
f'{entity_text} <sup style="font-size:0.65em; color:#666;">/{display_label}</sup>' |
|
|
f'</mark>' |
|
|
) |
|
|
|
|
|
last_idx = end |
|
|
|
|
|
highlighted += text[last_idx:] |
|
|
highlighted = f'<div style="line-height:1.8; font-family:sans-serif; white-space:pre-wrap;">{highlighted}</div>' |
|
|
|
|
|
|
|
|
if final_entities: |
|
|
entity_list = "\n".join([ |
|
|
f"{label_display.get(e['entity_group'], e['entity_group'])}: " |
|
|
f"{text[e['start']:e['end']]} ({e['score']:.2f})" |
|
|
for e in final_entities |
|
|
]) |
|
|
else: |
|
|
entity_list = "No entities detected" |
|
|
|
|
|
return highlighted, entity_list |
|
|
|
|
|
with gr.Blocks(title="Feline Veterinary NER (Educational Demo)") as demo: |
|
|
gr.Markdown("# π± Feline Veterinary NER System") |
|
|
gr.Markdown( |
|
|
"**Educational and research demo only β NOT for clinical use.**\n\n" |
|
|
"Extracts **disease**, **symptom**, **medication**, **procedure**, " |
|
|
"and **anatomy** from feline veterinary literature. " |
|
|
"Handles PDF/HTML paste artifacts." |
|
|
) |
|
|
|
|
|
|
|
|
input_text = gr.Textbox( |
|
|
label="Input Text", |
|
|
lines=15, |
|
|
placeholder="Paste article text here (handles complex scientific formatting)..." |
|
|
) |
|
|
|
|
|
analyze_btn = gr.Button("π¬ Analyze", variant="primary", size="lg") |
|
|
|
|
|
output_html = gr.HTML(label="π Annotated Text") |
|
|
output_list = gr.Textbox(label="π Detected Entities", lines=10) |
|
|
|
|
|
analyze_btn.click(ner_predict, input_text, [output_html, output_list]) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
["Chronic kidney disease was diagnosed. The cat received meloxicam and subcutaneous fluids."], |
|
|
["Ultrasound revealed a renal mass. FIV infection was confirmed by PCR in blood samples."], |
|
|
["The patient presented with vomiting, lethargy, and dehydration. Blood work showed elevated creatinine."] |
|
|
], |
|
|
inputs=input_text |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|