import logging from contextlib import asynccontextmanager from typing import List, Dict, Any from fastapi import FastAPI, HTTPException from fastapi.staticfiles import StaticFiles from pydantic import BaseModel, Field from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Config class NERRequest(BaseModel): text: str = Field(..., title="Input Text", description="Text to analyze") class NEREntity(BaseModel): entity_group: str score: float word: str start: int end: int class NERResponse(BaseModel): entities: List[NEREntity] # Constants SHORT_TEXT_THRESHOLD = 128 MODEL_MAX_LENGTH = 512 WINDOW_OVERLAP = 128 # Core Logic def refine_boundaries(text: str, start: int, end: int) -> (int, int, str): """ Adjusts start/end indices. 1. Expands selection to the end of the word if the model stopped mid-word. 2. Trims leading/trailing whitespace. """ while end < len(text) and text[end].isalnum(): end += 1 # while end < len(text) and (text[end].isalnum() or text[end] == '-'): # end += 1 span = text[start:end] # Shift start index forward if there is leading whitespace while span and span[0].isspace(): start += 1 span = span[1:] # Shift end index backward if there is trailing whitespace while span and span[-1].isspace(): end -= 1 span = span[:-1] return start, end, span def refine_boundaries1(text: str, start: int, end: int) -> (int, int, str): """ Adjusts start/end indices to exclude leading/trailing whitespace. This ensures the HTML highlight is tight around the word. """ # Extract the raw span using original indices span = text[start:end] # Shift start index forward if there is leading whitespace while span and span[0].isspace(): start += 1 span = span[1:] # Shift end index backward if there is trailing whitespace while span and span[-1].isspace(): end -= 1 span = span[:-1] return start, end, span def save_current_entity(entity_parts: List[Dict], full_text: str, aggregated_entities: List[Dict]): """ Finalizes a group of tokens into a single entity. """ if not entity_parts: return # 1. Determine the raw range raw_start = entity_parts[0]['start'] raw_end = entity_parts[-1]['end'] # 2. Refine boundaries (Trim spaces from indices) final_start, final_end, clean_word = refine_boundaries(full_text, raw_start, raw_end) if not clean_word: return # 3. Calculate score avg_score = sum(part['score'] for part in entity_parts) / len(entity_parts) # 4. Determine label (remove B/I prefix) # We take the label from the first token usually, or the most frequent one raw_label = entity_parts[0]['entity'] entity_group = raw_label.split('-')[-1] # e.g., "B-ORG" -> "ORG" aggregated_entities.append({ 'word': clean_word, 'score': float(avg_score), 'entity_group': entity_group, 'start': final_start, 'end': final_end }) def aggregate_entities_manual(ner_results: List[Dict], full_text: str) -> List[Dict]: """ Aggregates subword tokens into whole entities. Handles SentencePiece artifacts and BIO tagging. """ if not ner_results: return [] aggregated_entities = [] current_entity_parts = [] for entity in ner_results: entity_label = entity['entity'] # Skip 'O' (Outside) if entity_label == 'O': if current_entity_parts: save_current_entity(current_entity_parts, full_text, aggregated_entities) current_entity_parts = [] continue # Parse Label (e.g., "B-ORG", "I-ORG") if '-' in entity_label: prefix, label_type = entity_label.split('-', 1) else: prefix, label_type = None, entity_label # Decision logic for merging if not current_entity_parts: # Start new entity current_entity_parts.append(entity) else: prev_label = current_entity_parts[-1]['entity'] prev_type = prev_label.split('-')[-1] if '-' in prev_label else prev_label # Merge condition: # 1. Same Entity Type (ORG == ORG) # 2. Adjacent indices (current start == prev end) # 3. Logic: If it's "I-" tag, it MUST merge. If it's "B-" tag, it usually starts new, # BUT some models are messy. We prioritize adjacency + type match for smoother highlighting. if label_type == prev_type and entity['start'] == current_entity_parts[-1]['end']: current_entity_parts.append(entity) else: # Close previous and start new save_current_entity(current_entity_parts, full_text, aggregated_entities) current_entity_parts = [entity] # Save tail if current_entity_parts: save_current_entity(current_entity_parts, full_text, aggregated_entities) return aggregated_entities # Smart Processing Logic def process_text_smart(text: str, pipe, tokenizer) -> List[Dict]: """ Hybrid strategy: Direct inference for short texts, Sliding Window for long ones. Returns RAW tokens (unaggregated). """ tokenized = tokenizer( text, return_offsets_mapping=True, add_special_tokens=False, verbose=False ) offsets = tokenized["offset_mapping"] total_tokens = len(offsets) # STRATEGY A: Short Text if total_tokens <= SHORT_TEXT_THRESHOLD: return pipe(text) # STRATEGY B: Sliding Window all_raw_tokens = [] step = MODEL_MAX_LENGTH - WINDOW_OVERLAP for start_idx in range(0, total_tokens, step): end_idx = min(start_idx + MODEL_MAX_LENGTH, total_tokens) char_start = offsets[start_idx][0] char_end = offsets[end_idx - 1][1] chunk_text = text[char_start:char_end] if not chunk_text.strip(): continue chunk_results = pipe(chunk_text) for ent in chunk_results: ent["start"] += char_start ent["end"] += char_start all_raw_tokens.append(ent) if end_idx == total_tokens: break # Deduplicate raw tokens based on start index all_raw_tokens.sort(key=lambda x: x['start']) unique_tokens = [] seen_indices = set() for t in all_raw_tokens: idx_key = (t['start'], t['end']) if idx_key not in seen_indices: unique_tokens.append(t) seen_indices.add(idx_key) return unique_tokens # Lifespan ml_models: Dict[str, Any] = {} @asynccontextmanager async def lifespan(app: FastAPI): model_name = "rustemgareev/mdeberta-ner-ontonotes5" logger.info(f"Loading model: {model_name}...") try: tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=MODEL_MAX_LENGTH) model = AutoModelForTokenClassification.from_pretrained(model_name) ner_pipe = pipeline( "ner", model=model, tokenizer=tokenizer, aggregation_strategy="none", device=-1 ) ml_models["ner"] = ner_pipe ml_models["tokenizer"] = tokenizer logger.info("Model loaded.") except Exception as e: logger.error(f"CRITICAL ERROR loading model: {e}") yield ml_models.clear() # App Init app = FastAPI(title="mDeBERTa NER API", version="3.3.0", lifespan=lifespan) # API Endpoints @app.post("/predict", response_model=NERResponse) def predict(request: NERRequest): if "ner" not in ml_models: raise HTTPException(status_code=503, detail="Model not loaded") if not request.text.strip(): return NERResponse(entities=[]) try: # 1. Get Raw Tokens raw_tokens = process_text_smart( request.text, ml_models["ner"], ml_models["tokenizer"] ) # 2. Aggressive Aggregation & Boundary Refinement # We pass request.text to allow precise index trimming aggregated = aggregate_entities_manual(raw_tokens, request.text) return NERResponse(entities=[NEREntity(**item) for item in aggregated]) except Exception as e: logger.error(f"Prediction error: {e}") raise HTTPException(status_code=500, detail=str(e)) # Static Files app.mount("/", StaticFiles(directory="static", html=True), name="static")