russian-ner / processing.py
dnkdm's picture
Refactor: split app.py into modules
a48698f
"""
Логика обработки NER: загрузка моделей, обработка текста, история запросов.
"""
import time
import io
from collections import deque
from datetime import datetime
import pandas as pd
from transformers import pipeline
from config import (
MAX_CHARS,
MAX_BATCH_ROWS,
HISTORY_SIZE,
MODELS,
ENTITY_LABELS
)
# Глобальные переменные
pipelines_cache = {} # Кэш загруженных моделей
history = deque(maxlen=HISTORY_SIZE) # История запросов
# ============== ЗАГРУЗКА МОДЕЛЕЙ ==============
def load_model(model_key: str):
"""Ленивая загрузка модели по ключу."""
if model_key not in pipelines_cache:
model_name = MODELS[model_key]
pipelines_cache[model_key] = pipeline(
"ner",
model=model_name,
aggregation_strategy="simple"
)
return pipelines_cache[model_key]
# ============== ВАЛИДАЦИЯ И ОБРАБОТКА ==============
def validate_input(text: str) -> tuple[bool, str]:
"""Валидация входного текста."""
if text is None or not text.strip():
return False, "Ошибка: введите текст для анализа."
text = text.strip()
if len(text) > MAX_CHARS:
return False, f"Ошибка: текст слишком длинный ({len(text)} символов). Максимум: {MAX_CHARS}."
return True, text
def normalize_entity_type(entity_group: str) -> str:
"""Нормализация типа сущности (убираем префиксы B-, I- и т.д.)."""
for prefix in ["B-", "I-", "E-", "S-", "L-", "U-"]:
if entity_group.startswith(prefix):
return entity_group[2:]
return entity_group
def process_entities(entities: list) -> list[dict]:
"""Обработка и нормализация списка сущностей."""
processed = []
for ent in entities:
entity_type = normalize_entity_type(ent.get("entity_group", ent.get("entity", "UNKNOWN")))
processed.append({
"text": ent["word"],
"type": entity_type,
"label": ENTITY_LABELS.get(entity_type, entity_type),
"score": round(ent["score"], 4),
"start": ent["start"],
"end": ent["end"]
})
return processed
def create_highlighted_text(text: str, entities: list) -> list:
"""Создание данных для подсветки текста."""
if not entities:
return [(text, None)]
sorted_entities = sorted(entities, key=lambda x: x["start"])
highlighted = []
last_end = 0
for ent in sorted_entities:
start, end = ent["start"], ent["end"]
if start > last_end:
highlighted.append((text[last_end:start], None))
entity_text = text[start:end]
entity_type = ent["type"]
highlighted.append((entity_text, entity_type))
last_end = end
if last_end < len(text):
highlighted.append((text[last_end:], None))
return highlighted
def entities_to_dataframe(entities: list) -> pd.DataFrame:
"""Преобразование списка сущностей в DataFrame."""
if not entities:
return pd.DataFrame(columns=["Текст", "Тип", "Описание", "Уверенность"])
data = []
for ent in entities:
data.append({
"Текст": ent["text"],
"Тип": ent["type"],
"Описание": ent["label"],
"Уверенность": f"{ent['score']:.2%}"
})
return pd.DataFrame(data)
# ============== ИСТОРИЯ ЗАПРОСОВ ==============
def add_to_history(text: str, model: str, entities: list, latency: float):
"""Добавление запроса в историю."""
timestamp = datetime.now().strftime("%H:%M:%S")
entity_count = len(entities)
entity_types = ", ".join(set(e["type"] for e in entities)) if entities else "—"
history.appendleft({
"Время": timestamp,
"Модель": model.split()[0],
"Текст": text[:50] + "..." if len(text) > 50 else text,
"Найдено": entity_count,
"Типы": entity_types,
"Latency": f"{latency} мс"
})
def get_history_df():
"""Получение истории запросов как DataFrame."""
if not history:
return pd.DataFrame(columns=["Время", "Модель", "Текст", "Найдено", "Типы", "Latency"])
return pd.DataFrame(list(history))
def clear_history():
"""Очистка истории запросов."""
history.clear()
return pd.DataFrame(columns=["Время", "Модель", "Текст", "Найдено", "Типы", "Latency"]), "История очищена"
# ============== ОСНОВНЫЕ ФУНКЦИИ ОБРАБОТКИ ==============
def process_single_text(text: str, model_choice: str):
"""Обработка одиночного текста."""
is_valid, result = validate_input(text)
if not is_valid:
return result, None, None, "—"
text = result
try:
pipe = load_model(model_choice)
t0 = time.time()
raw_entities = pipe(text)
latency = round((time.time() - t0) * 1000, 1)
entities = process_entities(raw_entities)
highlighted = create_highlighted_text(text, entities)
df = entities_to_dataframe(entities)
add_to_history(text, model_choice, entities, latency)
status = f"Найдено сущностей: {len(entities)}"
return status, highlighted, df, f"{latency} мс"
except Exception as e:
return f"Ошибка: {type(e).__name__}: {e}", None, None, "—"
def compare_models(text: str):
"""Сравнение результатов двух моделей."""
is_valid, result = validate_input(text)
if not is_valid:
return result, None, None, "—", None, None, "—"
text = result
results = {}
try:
for model_key in MODELS.keys():
pipe = load_model(model_key)
t0 = time.time()
raw_entities = pipe(text)
latency = round((time.time() - t0) * 1000, 1)
entities = process_entities(raw_entities)
highlighted = create_highlighted_text(text, entities)
df = entities_to_dataframe(entities)
results[model_key] = {
"highlighted": highlighted,
"df": df,
"latency": f"{latency} мс",
"count": len(entities)
}
model_keys = list(MODELS.keys())
m1, m2 = model_keys[0], model_keys[1]
status = f"Модель 1: {results[m1]['count']} сущностей | Модель 2: {results[m2]['count']} сущностей"
return (
status,
results[m1]["highlighted"],
results[m1]["df"],
results[m1]["latency"],
results[m2]["highlighted"],
results[m2]["df"],
results[m2]["latency"]
)
except Exception as e:
error_msg = f"Ошибка: {type(e).__name__}: {e}"
return error_msg, None, None, "—", None, None, "—"
def process_batch(file, model_choice: str):
"""Пакетная обработка файла (CSV или TXT)."""
if file is None:
return "Ошибка: загрузите файл.", None, None
try:
file_path = file.name
if file_path.endswith('.csv'):
df_input = pd.read_csv(file_path)
if 'text' not in df_input.columns:
return "Ошибка: CSV должен содержать колонку 'text'.", None, None
texts = df_input['text'].tolist()
else: # TXT
with open(file_path, 'r', encoding='utf-8') as f:
texts = [line.strip() for line in f if line.strip()]
if len(texts) > MAX_BATCH_ROWS:
return f"Ошибка: слишком много строк ({len(texts)}). Максимум: {MAX_BATCH_ROWS}.", None, None
if not texts:
return "Ошибка: файл пустой или не содержит текстов.", None, None
pipe = load_model(model_choice)
results = []
t0 = time.time()
for i, text in enumerate(texts):
if len(text) > MAX_CHARS:
text = text[:MAX_CHARS]
try:
raw_entities = pipe(text)
entities = process_entities(raw_entities)
per_list = [e["text"] for e in entities if e["type"] == "PER"]
org_list = [e["text"] for e in entities if e["type"] == "ORG"]
loc_list = [e["text"] for e in entities if e["type"] == "LOC"]
misc_list = [e["text"] for e in entities if e["type"] == "MISC"]
results.append({
"№": i + 1,
"Текст": text[:100] + "..." if len(text) > 100 else text,
"PER": ", ".join(per_list) if per_list else "—",
"ORG": ", ".join(org_list) if org_list else "—",
"LOC": ", ".join(loc_list) if loc_list else "—",
"MISC": ", ".join(misc_list) if misc_list else "—",
"Всего": len(entities)
})
except Exception as e:
results.append({
"№": i + 1,
"Текст": text[:100] + "...",
"PER": "ОШИБКА",
"ORG": str(e)[:30],
"LOC": "—",
"MISC": "—",
"Всего": 0
})
total_latency = round((time.time() - t0) * 1000, 1)
df_results = pd.DataFrame(results)
csv_buffer = io.StringIO()
df_results.to_csv(csv_buffer, index=False, encoding='utf-8')
csv_content = csv_buffer.getvalue()
status = f"Обработано: {len(texts)} текстов за {total_latency} мс"
return status, df_results, csv_content
except Exception as e:
return f"Ошибка: {type(e).__name__}: {e}", None, None