import gradio as gr from transformers import pipeline import spacy import re import unicodedata import sys import subprocess # Download spaCy model if not present try: nlp = spacy.load("en_core_web_sm") except OSError: print("Downloading spaCy model...") subprocess.run([sys.executable, "-m", "spacy", "download", "en_core_web_sm"], check=True) nlp = spacy.load("en_core_web_sm") nlp.add_pipe("sentencizer") model_id = "Statistical-Impossibility/Feline-NER" ner_pipeline = pipeline("token-classification", model=model_id, aggregation_strategy="simple") def clean_text(text): """Aggressive cleaning for PDF/HTML paste artifacts.""" text = unicodedata.normalize('NFKC', text) text = re.sub(r'<[^>]+>', '', text) text = re.sub(r'(\w+)-\s*\n\s*(\w+)', r'\1\2', text) text = re.sub(r'\n{3,}', '\n\n', text) text = re.sub(r'\s+', ' ', text) text = re.sub(r'-\s+', '', text) return text.strip() def expand_to_word_boundaries(text, start, end): """ Expand entity boundaries to complete words. Prevents highlighting fragments like "itis" from "abnormalities". """ # Expand left until we hit non-alphanumeric while start > 0 and (text[start - 1].isalnum() or text[start - 1] in ['-', "'"]): start -= 1 # Expand right until we hit non-alphanumeric while end < len(text) and (text[end].isalnum() or text[end] in ['-', "'"]): end += 1 return start, end def is_valid_entity(text, start, end): """ Filter out garbage entities. Returns False if entity is: - Too short (< 2 chars) - All punctuation - Just a suffix (starts with ##) """ entity_text = text[start:end].strip() # Too short if len(entity_text) < 2: return False # All punctuation or numbers if not any(c.isalpha() for c in entity_text): return False # Starts with subword marker (shouldn't happen after expansion, but check anyway) if entity_text.startswith('##'): return False # Single letter (likely fragment) if len(entity_text) == 1: return False return True def ner_predict(text): if not text.strip(): return "
No text provided
", "No entities" if len(text) > 100000: return "Text too long (max 100k characters)
", "" # Clean text text = clean_text(text) # spaCy sentence splitting with exact offsets doc = nlp(text) sentences = [] for sent in doc.sents: sentences.append({ "text": sent.text, "start": sent.start_char, "end": sent.end_char }) if not sentences: return "No sentences detected
", "" # Pre-tokenize sentences ONCE (cache token counts) sentence_token_counts = [] for sent in sentences: tokens = ner_pipeline.tokenizer.tokenize(sent["text"]) sentence_token_counts.append(len(tokens)) # Chunking with cached token counts max_tokens = 450 chunks = [] i = 0 while i < len(sentences): chunk_sents = [] token_count = 0 for j in range(i, len(sentences)): sent_token_count = sentence_token_counts[j] # Check if adding this sentence exceeds limit if token_count + sent_token_count > max_tokens and chunk_sents: break chunk_sents.append(sentences[j]) token_count += sent_token_count if chunk_sents: chunk_text = " ".join([s["text"] for s in chunk_sents]) chunks.append({ "text": chunk_text, "offset": chunk_sents[0]["start"], "sentence_count": len(chunk_sents) }) sentences_to_skip = max(1, len(chunk_sents) - 2) i += sentences_to_skip # Predict on chunks (NO CHANGES HERE) all_entities = [] for chunk in chunks: try: results = ner_pipeline(chunk["text"]) for r in results: if r['score'] > 0.50: # Increased threshold to filter noise # Adjust offsets to global position r['start'] += chunk["offset"] r['end'] += chunk["offset"] # CRITICAL FIX: Expand to word boundaries r['start'], r['end'] = expand_to_word_boundaries( text, r['start'], r['end'] ) # Validate entity if is_valid_entity(text, r['start'], r['end']): all_entities.append(r) except Exception as e: print(f"Chunk processing error: {e}") continue # Sort and deduplicate all_entities = sorted(all_entities, key=lambda x: (x['start'], -x['score'])) final_entities = [] for ent in all_entities: # Check overlap with previous entity if not final_entities or ent['start'] >= final_entities[-1]['end']: final_entities.append(ent) elif ent['score'] > final_entities[-1]['score']: # Replace if higher confidence AND different span if ent['end'] > final_entities[-1]['end'] or ent['start'] < final_entities[-1]['start']: final_entities[-1] = ent # Generate highlighted HTML highlighted = "" last_idx = 0 color_map = { "SYMPTOM": "#FFD700", "DISEASE": "#FF6B6B", "MEDICATION": "#90EE90", "PROCEDURE": "#87CEEB", "ANATOMY": "#FFB347" } label_display = { "DISEASE": "disease", "SYMPTOM": "symptom", "MEDICATION": "medication", "PROCEDURE": "procedure", "ANATOMY": "anatomy" } for ent in final_entities: start, end = ent['start'], ent['end'] label = ent['entity_group'] score = ent['score'] # Bounds check if start >= len(text) or end > len(text) or start < 0 or end < 0: continue # Skip if indices are reversed if start >= end: continue highlighted += text[last_idx:start] color = color_map.get(label, "#E0E0E0") display_label = label_display.get(label, label.lower()) entity_text = text[start:end] highlighted += ( f'' f'{entity_text} /{display_label}' f'' ) last_idx = end highlighted += text[last_idx:] highlighted = f'