|
|
import logging |
|
|
from contextlib import asynccontextmanager |
|
|
from typing import List, Dict, Any |
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from pydantic import BaseModel, Field |
|
|
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class NERRequest(BaseModel): |
|
|
text: str = Field(..., title="Input Text", description="Text to analyze") |
|
|
|
|
|
class NEREntity(BaseModel): |
|
|
entity_group: str |
|
|
score: float |
|
|
word: str |
|
|
start: int |
|
|
end: int |
|
|
|
|
|
class NERResponse(BaseModel): |
|
|
entities: List[NEREntity] |
|
|
|
|
|
|
|
|
SHORT_TEXT_THRESHOLD = 128 |
|
|
MODEL_MAX_LENGTH = 512 |
|
|
WINDOW_OVERLAP = 128 |
|
|
|
|
|
|
|
|
def refine_boundaries(text: str, start: int, end: int) -> (int, int, str): |
|
|
""" |
|
|
Adjusts start/end indices. |
|
|
1. Expands selection to the end of the word if the model stopped mid-word. |
|
|
2. Trims leading/trailing whitespace. |
|
|
""" |
|
|
|
|
|
while end < len(text) and text[end].isalnum(): |
|
|
end += 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
span = text[start:end] |
|
|
|
|
|
|
|
|
while span and span[0].isspace(): |
|
|
start += 1 |
|
|
span = span[1:] |
|
|
|
|
|
|
|
|
while span and span[-1].isspace(): |
|
|
end -= 1 |
|
|
span = span[:-1] |
|
|
|
|
|
return start, end, span |
|
|
|
|
|
def refine_boundaries1(text: str, start: int, end: int) -> (int, int, str): |
|
|
""" |
|
|
Adjusts start/end indices to exclude leading/trailing whitespace. |
|
|
This ensures the HTML highlight is tight around the word. |
|
|
""" |
|
|
|
|
|
span = text[start:end] |
|
|
|
|
|
|
|
|
while span and span[0].isspace(): |
|
|
start += 1 |
|
|
span = span[1:] |
|
|
|
|
|
|
|
|
while span and span[-1].isspace(): |
|
|
end -= 1 |
|
|
span = span[:-1] |
|
|
|
|
|
return start, end, span |
|
|
|
|
|
def save_current_entity(entity_parts: List[Dict], full_text: str, aggregated_entities: List[Dict]): |
|
|
""" |
|
|
Finalizes a group of tokens into a single entity. |
|
|
""" |
|
|
if not entity_parts: |
|
|
return |
|
|
|
|
|
|
|
|
raw_start = entity_parts[0]['start'] |
|
|
raw_end = entity_parts[-1]['end'] |
|
|
|
|
|
|
|
|
final_start, final_end, clean_word = refine_boundaries(full_text, raw_start, raw_end) |
|
|
|
|
|
if not clean_word: |
|
|
return |
|
|
|
|
|
|
|
|
avg_score = sum(part['score'] for part in entity_parts) / len(entity_parts) |
|
|
|
|
|
|
|
|
|
|
|
raw_label = entity_parts[0]['entity'] |
|
|
entity_group = raw_label.split('-')[-1] |
|
|
|
|
|
aggregated_entities.append({ |
|
|
'word': clean_word, |
|
|
'score': float(avg_score), |
|
|
'entity_group': entity_group, |
|
|
'start': final_start, |
|
|
'end': final_end |
|
|
}) |
|
|
|
|
|
def aggregate_entities_manual(ner_results: List[Dict], full_text: str) -> List[Dict]: |
|
|
""" |
|
|
Aggregates subword tokens into whole entities. |
|
|
Handles SentencePiece artifacts and BIO tagging. |
|
|
""" |
|
|
if not ner_results: |
|
|
return [] |
|
|
|
|
|
aggregated_entities = [] |
|
|
current_entity_parts = [] |
|
|
|
|
|
for entity in ner_results: |
|
|
entity_label = entity['entity'] |
|
|
|
|
|
|
|
|
if entity_label == 'O': |
|
|
if current_entity_parts: |
|
|
save_current_entity(current_entity_parts, full_text, aggregated_entities) |
|
|
current_entity_parts = [] |
|
|
continue |
|
|
|
|
|
|
|
|
if '-' in entity_label: |
|
|
prefix, label_type = entity_label.split('-', 1) |
|
|
else: |
|
|
prefix, label_type = None, entity_label |
|
|
|
|
|
|
|
|
if not current_entity_parts: |
|
|
|
|
|
current_entity_parts.append(entity) |
|
|
else: |
|
|
prev_label = current_entity_parts[-1]['entity'] |
|
|
prev_type = prev_label.split('-')[-1] if '-' in prev_label else prev_label |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if label_type == prev_type and entity['start'] == current_entity_parts[-1]['end']: |
|
|
current_entity_parts.append(entity) |
|
|
else: |
|
|
|
|
|
save_current_entity(current_entity_parts, full_text, aggregated_entities) |
|
|
current_entity_parts = [entity] |
|
|
|
|
|
|
|
|
if current_entity_parts: |
|
|
save_current_entity(current_entity_parts, full_text, aggregated_entities) |
|
|
|
|
|
return aggregated_entities |
|
|
|
|
|
|
|
|
def process_text_smart(text: str, pipe, tokenizer) -> List[Dict]: |
|
|
""" |
|
|
Hybrid strategy: Direct inference for short texts, Sliding Window for long ones. |
|
|
Returns RAW tokens (unaggregated). |
|
|
""" |
|
|
tokenized = tokenizer( |
|
|
text, |
|
|
return_offsets_mapping=True, |
|
|
add_special_tokens=False, |
|
|
verbose=False |
|
|
) |
|
|
offsets = tokenized["offset_mapping"] |
|
|
total_tokens = len(offsets) |
|
|
|
|
|
|
|
|
if total_tokens <= SHORT_TEXT_THRESHOLD: |
|
|
return pipe(text) |
|
|
|
|
|
|
|
|
all_raw_tokens = [] |
|
|
step = MODEL_MAX_LENGTH - WINDOW_OVERLAP |
|
|
|
|
|
for start_idx in range(0, total_tokens, step): |
|
|
end_idx = min(start_idx + MODEL_MAX_LENGTH, total_tokens) |
|
|
|
|
|
char_start = offsets[start_idx][0] |
|
|
char_end = offsets[end_idx - 1][1] |
|
|
|
|
|
chunk_text = text[char_start:char_end] |
|
|
if not chunk_text.strip(): |
|
|
continue |
|
|
|
|
|
chunk_results = pipe(chunk_text) |
|
|
|
|
|
for ent in chunk_results: |
|
|
ent["start"] += char_start |
|
|
ent["end"] += char_start |
|
|
all_raw_tokens.append(ent) |
|
|
|
|
|
if end_idx == total_tokens: |
|
|
break |
|
|
|
|
|
|
|
|
all_raw_tokens.sort(key=lambda x: x['start']) |
|
|
unique_tokens = [] |
|
|
seen_indices = set() |
|
|
|
|
|
for t in all_raw_tokens: |
|
|
idx_key = (t['start'], t['end']) |
|
|
if idx_key not in seen_indices: |
|
|
unique_tokens.append(t) |
|
|
seen_indices.add(idx_key) |
|
|
|
|
|
return unique_tokens |
|
|
|
|
|
|
|
|
ml_models: Dict[str, Any] = {} |
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
model_name = "rustemgareev/mdeberta-ner-ontonotes5" |
|
|
logger.info(f"Loading model: {model_name}...") |
|
|
|
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=MODEL_MAX_LENGTH) |
|
|
model = AutoModelForTokenClassification.from_pretrained(model_name) |
|
|
|
|
|
ner_pipe = pipeline( |
|
|
"ner", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
aggregation_strategy="none", |
|
|
device=-1 |
|
|
) |
|
|
|
|
|
ml_models["ner"] = ner_pipe |
|
|
ml_models["tokenizer"] = tokenizer |
|
|
logger.info("Model loaded.") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"CRITICAL ERROR loading model: {e}") |
|
|
|
|
|
yield |
|
|
ml_models.clear() |
|
|
|
|
|
|
|
|
app = FastAPI(title="mDeBERTa NER API", version="3.3.0", lifespan=lifespan) |
|
|
|
|
|
|
|
|
@app.post("/predict", response_model=NERResponse) |
|
|
def predict(request: NERRequest): |
|
|
if "ner" not in ml_models: |
|
|
raise HTTPException(status_code=503, detail="Model not loaded") |
|
|
|
|
|
if not request.text.strip(): |
|
|
return NERResponse(entities=[]) |
|
|
|
|
|
try: |
|
|
|
|
|
raw_tokens = process_text_smart( |
|
|
request.text, |
|
|
ml_models["ner"], |
|
|
ml_models["tokenizer"] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
aggregated = aggregate_entities_manual(raw_tokens, request.text) |
|
|
|
|
|
return NERResponse(entities=[NEREntity(**item) for item in aggregated]) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Prediction error: {e}") |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
|
|
|
app.mount("/", StaticFiles(directory="static", html=True), name="static") |