Feline-NER-Demo / app.py
Statistical-Impossibility's picture
Update app.py
3ebde54 verified
import gradio as gr
from transformers import pipeline
import spacy
import re
import unicodedata
import sys
import subprocess
# Download spaCy model if not present
try:
nlp = spacy.load("en_core_web_sm")
except OSError:
print("Downloading spaCy model...")
subprocess.run([sys.executable, "-m", "spacy", "download", "en_core_web_sm"], check=True)
nlp = spacy.load("en_core_web_sm")
nlp.add_pipe("sentencizer")
model_id = "Statistical-Impossibility/Feline-NER"
ner_pipeline = pipeline("token-classification", model=model_id, aggregation_strategy="simple")
def clean_text(text):
"""Aggressive cleaning for PDF/HTML paste artifacts."""
text = unicodedata.normalize('NFKC', text)
text = re.sub(r'<[^>]+>', '', text)
text = re.sub(r'(\w+)-\s*\n\s*(\w+)', r'\1\2', text)
text = re.sub(r'\n{3,}', '\n\n', text)
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'-\s+', '', text)
return text.strip()
def expand_to_word_boundaries(text, start, end):
"""
Expand entity boundaries to complete words.
Prevents highlighting fragments like "itis" from "abnormalities".
"""
# Expand left until we hit non-alphanumeric
while start > 0 and (text[start - 1].isalnum() or text[start - 1] in ['-', "'"]):
start -= 1
# Expand right until we hit non-alphanumeric
while end < len(text) and (text[end].isalnum() or text[end] in ['-', "'"]):
end += 1
return start, end
def is_valid_entity(text, start, end):
"""
Filter out garbage entities.
Returns False if entity is:
- Too short (< 2 chars)
- All punctuation
- Just a suffix (starts with ##)
"""
entity_text = text[start:end].strip()
# Too short
if len(entity_text) < 2:
return False
# All punctuation or numbers
if not any(c.isalpha() for c in entity_text):
return False
# Starts with subword marker (shouldn't happen after expansion, but check anyway)
if entity_text.startswith('##'):
return False
# Single letter (likely fragment)
if len(entity_text) == 1:
return False
return True
def ner_predict(text):
if not text.strip():
return "<p>No text provided</p>", "No entities"
if len(text) > 100000:
return "<p style='color:red;'>Text too long (max 100k characters)</p>", ""
# Clean text
text = clean_text(text)
# spaCy sentence splitting with exact offsets
doc = nlp(text)
sentences = []
for sent in doc.sents:
sentences.append({
"text": sent.text,
"start": sent.start_char,
"end": sent.end_char
})
if not sentences:
return "<p>No sentences detected</p>", ""
# Pre-tokenize sentences ONCE (cache token counts)
sentence_token_counts = []
for sent in sentences:
tokens = ner_pipeline.tokenizer.tokenize(sent["text"])
sentence_token_counts.append(len(tokens))
# Chunking with cached token counts
max_tokens = 450
chunks = []
i = 0
while i < len(sentences):
chunk_sents = []
token_count = 0
for j in range(i, len(sentences)):
sent_token_count = sentence_token_counts[j]
# Check if adding this sentence exceeds limit
if token_count + sent_token_count > max_tokens and chunk_sents:
break
chunk_sents.append(sentences[j])
token_count += sent_token_count
if chunk_sents:
chunk_text = " ".join([s["text"] for s in chunk_sents])
chunks.append({
"text": chunk_text,
"offset": chunk_sents[0]["start"],
"sentence_count": len(chunk_sents)
})
sentences_to_skip = max(1, len(chunk_sents) - 2)
i += sentences_to_skip
# Predict on chunks (NO CHANGES HERE)
all_entities = []
for chunk in chunks:
try:
results = ner_pipeline(chunk["text"])
for r in results:
if r['score'] > 0.50: # Increased threshold to filter noise
# Adjust offsets to global position
r['start'] += chunk["offset"]
r['end'] += chunk["offset"]
# CRITICAL FIX: Expand to word boundaries
r['start'], r['end'] = expand_to_word_boundaries(
text, r['start'], r['end']
)
# Validate entity
if is_valid_entity(text, r['start'], r['end']):
all_entities.append(r)
except Exception as e:
print(f"Chunk processing error: {e}")
continue
# Sort and deduplicate
all_entities = sorted(all_entities, key=lambda x: (x['start'], -x['score']))
final_entities = []
for ent in all_entities:
# Check overlap with previous entity
if not final_entities or ent['start'] >= final_entities[-1]['end']:
final_entities.append(ent)
elif ent['score'] > final_entities[-1]['score']:
# Replace if higher confidence AND different span
if ent['end'] > final_entities[-1]['end'] or ent['start'] < final_entities[-1]['start']:
final_entities[-1] = ent
# Generate highlighted HTML
highlighted = ""
last_idx = 0
color_map = {
"SYMPTOM": "#FFD700",
"DISEASE": "#FF6B6B",
"MEDICATION": "#90EE90",
"PROCEDURE": "#87CEEB",
"ANATOMY": "#FFB347"
}
label_display = {
"DISEASE": "disease",
"SYMPTOM": "symptom",
"MEDICATION": "medication",
"PROCEDURE": "procedure",
"ANATOMY": "anatomy"
}
for ent in final_entities:
start, end = ent['start'], ent['end']
label = ent['entity_group']
score = ent['score']
# Bounds check
if start >= len(text) or end > len(text) or start < 0 or end < 0:
continue
# Skip if indices are reversed
if start >= end:
continue
highlighted += text[last_idx:start]
color = color_map.get(label, "#E0E0E0")
display_label = label_display.get(label, label.lower())
entity_text = text[start:end]
highlighted += (
f'<mark style="background-color:{color}; padding:2px 4px; '
f'border-radius:3px; font-weight:500;" '
f'title="{display_label} ({score:.2f})">'
f'{entity_text} <sup style="font-size:0.65em; color:#666;">/{display_label}</sup>'
f'</mark>'
)
last_idx = end
highlighted += text[last_idx:]
highlighted = f'<div style="line-height:1.8; font-family:sans-serif; white-space:pre-wrap;">{highlighted}</div>'
# Entity list
if final_entities:
entity_list = "\n".join([
f"{label_display.get(e['entity_group'], e['entity_group'])}: "
f"{text[e['start']:e['end']]} ({e['score']:.2f})"
for e in final_entities
])
else:
entity_list = "No entities detected"
return highlighted, entity_list
with gr.Blocks(title="Feline Veterinary NER (Educational Demo)") as demo:
gr.Markdown("# 🐱 Feline Veterinary NER System")
gr.Markdown(
"**Educational and research demo only β€” NOT for clinical use.**\n\n"
"Extracts **disease**, **symptom**, **medication**, **procedure**, "
"and **anatomy** from feline veterinary literature. "
"Handles PDF/HTML paste artifacts."
)
input_text = gr.Textbox(
label="Input Text",
lines=15,
placeholder="Paste article text here (handles complex scientific formatting)..."
)
analyze_btn = gr.Button("πŸ”¬ Analyze", variant="primary", size="lg")
output_html = gr.HTML(label="πŸ“„ Annotated Text")
output_list = gr.Textbox(label="πŸ“‹ Detected Entities", lines=10)
analyze_btn.click(ner_predict, input_text, [output_html, output_list])
gr.Examples(
examples=[
["Chronic kidney disease was diagnosed. The cat received meloxicam and subcutaneous fluids."],
["Ultrasound revealed a renal mass. FIV infection was confirmed by PCR in blood samples."],
["The patient presented with vomiting, lethargy, and dehydration. Blood work showed elevated creatinine."]
],
inputs=input_text
)
demo.launch()