rustemgareev's picture
Upload app files
eb59cf9
import logging
from contextlib import asynccontextmanager
from typing import List, Dict, Any
from fastapi import FastAPI, HTTPException
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel, Field
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Config
class NERRequest(BaseModel):
text: str = Field(..., title="Input Text", description="Text to analyze")
class NEREntity(BaseModel):
entity_group: str
score: float
word: str
start: int
end: int
class NERResponse(BaseModel):
entities: List[NEREntity]
# Constants
SHORT_TEXT_THRESHOLD = 128
MODEL_MAX_LENGTH = 512
WINDOW_OVERLAP = 128
# Core Logic
def refine_boundaries(text: str, start: int, end: int) -> (int, int, str):
"""
Adjusts start/end indices.
1. Expands selection to the end of the word if the model stopped mid-word.
2. Trims leading/trailing whitespace.
"""
while end < len(text) and text[end].isalnum():
end += 1
# while end < len(text) and (text[end].isalnum() or text[end] == '-'):
# end += 1
span = text[start:end]
# Shift start index forward if there is leading whitespace
while span and span[0].isspace():
start += 1
span = span[1:]
# Shift end index backward if there is trailing whitespace
while span and span[-1].isspace():
end -= 1
span = span[:-1]
return start, end, span
def refine_boundaries1(text: str, start: int, end: int) -> (int, int, str):
"""
Adjusts start/end indices to exclude leading/trailing whitespace.
This ensures the HTML highlight is tight around the word.
"""
# Extract the raw span using original indices
span = text[start:end]
# Shift start index forward if there is leading whitespace
while span and span[0].isspace():
start += 1
span = span[1:]
# Shift end index backward if there is trailing whitespace
while span and span[-1].isspace():
end -= 1
span = span[:-1]
return start, end, span
def save_current_entity(entity_parts: List[Dict], full_text: str, aggregated_entities: List[Dict]):
"""
Finalizes a group of tokens into a single entity.
"""
if not entity_parts:
return
# 1. Determine the raw range
raw_start = entity_parts[0]['start']
raw_end = entity_parts[-1]['end']
# 2. Refine boundaries (Trim spaces from indices)
final_start, final_end, clean_word = refine_boundaries(full_text, raw_start, raw_end)
if not clean_word:
return
# 3. Calculate score
avg_score = sum(part['score'] for part in entity_parts) / len(entity_parts)
# 4. Determine label (remove B/I prefix)
# We take the label from the first token usually, or the most frequent one
raw_label = entity_parts[0]['entity']
entity_group = raw_label.split('-')[-1] # e.g., "B-ORG" -> "ORG"
aggregated_entities.append({
'word': clean_word,
'score': float(avg_score),
'entity_group': entity_group,
'start': final_start,
'end': final_end
})
def aggregate_entities_manual(ner_results: List[Dict], full_text: str) -> List[Dict]:
"""
Aggregates subword tokens into whole entities.
Handles SentencePiece artifacts and BIO tagging.
"""
if not ner_results:
return []
aggregated_entities = []
current_entity_parts = []
for entity in ner_results:
entity_label = entity['entity']
# Skip 'O' (Outside)
if entity_label == 'O':
if current_entity_parts:
save_current_entity(current_entity_parts, full_text, aggregated_entities)
current_entity_parts = []
continue
# Parse Label (e.g., "B-ORG", "I-ORG")
if '-' in entity_label:
prefix, label_type = entity_label.split('-', 1)
else:
prefix, label_type = None, entity_label
# Decision logic for merging
if not current_entity_parts:
# Start new entity
current_entity_parts.append(entity)
else:
prev_label = current_entity_parts[-1]['entity']
prev_type = prev_label.split('-')[-1] if '-' in prev_label else prev_label
# Merge condition:
# 1. Same Entity Type (ORG == ORG)
# 2. Adjacent indices (current start == prev end)
# 3. Logic: If it's "I-" tag, it MUST merge. If it's "B-" tag, it usually starts new,
# BUT some models are messy. We prioritize adjacency + type match for smoother highlighting.
if label_type == prev_type and entity['start'] == current_entity_parts[-1]['end']:
current_entity_parts.append(entity)
else:
# Close previous and start new
save_current_entity(current_entity_parts, full_text, aggregated_entities)
current_entity_parts = [entity]
# Save tail
if current_entity_parts:
save_current_entity(current_entity_parts, full_text, aggregated_entities)
return aggregated_entities
# Smart Processing Logic
def process_text_smart(text: str, pipe, tokenizer) -> List[Dict]:
"""
Hybrid strategy: Direct inference for short texts, Sliding Window for long ones.
Returns RAW tokens (unaggregated).
"""
tokenized = tokenizer(
text,
return_offsets_mapping=True,
add_special_tokens=False,
verbose=False
)
offsets = tokenized["offset_mapping"]
total_tokens = len(offsets)
# STRATEGY A: Short Text
if total_tokens <= SHORT_TEXT_THRESHOLD:
return pipe(text)
# STRATEGY B: Sliding Window
all_raw_tokens = []
step = MODEL_MAX_LENGTH - WINDOW_OVERLAP
for start_idx in range(0, total_tokens, step):
end_idx = min(start_idx + MODEL_MAX_LENGTH, total_tokens)
char_start = offsets[start_idx][0]
char_end = offsets[end_idx - 1][1]
chunk_text = text[char_start:char_end]
if not chunk_text.strip():
continue
chunk_results = pipe(chunk_text)
for ent in chunk_results:
ent["start"] += char_start
ent["end"] += char_start
all_raw_tokens.append(ent)
if end_idx == total_tokens:
break
# Deduplicate raw tokens based on start index
all_raw_tokens.sort(key=lambda x: x['start'])
unique_tokens = []
seen_indices = set()
for t in all_raw_tokens:
idx_key = (t['start'], t['end'])
if idx_key not in seen_indices:
unique_tokens.append(t)
seen_indices.add(idx_key)
return unique_tokens
# Lifespan
ml_models: Dict[str, Any] = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
model_name = "rustemgareev/mdeberta-ner-ontonotes5"
logger.info(f"Loading model: {model_name}...")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=MODEL_MAX_LENGTH)
model = AutoModelForTokenClassification.from_pretrained(model_name)
ner_pipe = pipeline(
"ner",
model=model,
tokenizer=tokenizer,
aggregation_strategy="none",
device=-1
)
ml_models["ner"] = ner_pipe
ml_models["tokenizer"] = tokenizer
logger.info("Model loaded.")
except Exception as e:
logger.error(f"CRITICAL ERROR loading model: {e}")
yield
ml_models.clear()
# App Init
app = FastAPI(title="mDeBERTa NER API", version="3.3.0", lifespan=lifespan)
# API Endpoints
@app.post("/predict", response_model=NERResponse)
def predict(request: NERRequest):
if "ner" not in ml_models:
raise HTTPException(status_code=503, detail="Model not loaded")
if not request.text.strip():
return NERResponse(entities=[])
try:
# 1. Get Raw Tokens
raw_tokens = process_text_smart(
request.text,
ml_models["ner"],
ml_models["tokenizer"]
)
# 2. Aggressive Aggregation & Boundary Refinement
# We pass request.text to allow precise index trimming
aggregated = aggregate_entities_manual(raw_tokens, request.text)
return NERResponse(entities=[NEREntity(**item) for item in aggregated])
except Exception as e:
logger.error(f"Prediction error: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Static Files
app.mount("/", StaticFiles(directory="static", html=True), name="static")