Spaces:
Running
Running
| """ | |
| Логика обработки NER: загрузка моделей, обработка текста, история запросов. | |
| """ | |
| import time | |
| import io | |
| from collections import deque | |
| from datetime import datetime | |
| import pandas as pd | |
| from transformers import pipeline | |
| from config import ( | |
| MAX_CHARS, | |
| MAX_BATCH_ROWS, | |
| HISTORY_SIZE, | |
| MODELS, | |
| ENTITY_LABELS | |
| ) | |
| # Глобальные переменные | |
| pipelines_cache = {} # Кэш загруженных моделей | |
| history = deque(maxlen=HISTORY_SIZE) # История запросов | |
| # ============== ЗАГРУЗКА МОДЕЛЕЙ ============== | |
| def load_model(model_key: str): | |
| """Ленивая загрузка модели по ключу.""" | |
| if model_key not in pipelines_cache: | |
| model_name = MODELS[model_key] | |
| pipelines_cache[model_key] = pipeline( | |
| "ner", | |
| model=model_name, | |
| aggregation_strategy="simple" | |
| ) | |
| return pipelines_cache[model_key] | |
| # ============== ВАЛИДАЦИЯ И ОБРАБОТКА ============== | |
| def validate_input(text: str) -> tuple[bool, str]: | |
| """Валидация входного текста.""" | |
| if text is None or not text.strip(): | |
| return False, "Ошибка: введите текст для анализа." | |
| text = text.strip() | |
| if len(text) > MAX_CHARS: | |
| return False, f"Ошибка: текст слишком длинный ({len(text)} символов). Максимум: {MAX_CHARS}." | |
| return True, text | |
| def normalize_entity_type(entity_group: str) -> str: | |
| """Нормализация типа сущности (убираем префиксы B-, I- и т.д.).""" | |
| for prefix in ["B-", "I-", "E-", "S-", "L-", "U-"]: | |
| if entity_group.startswith(prefix): | |
| return entity_group[2:] | |
| return entity_group | |
| def process_entities(entities: list) -> list[dict]: | |
| """Обработка и нормализация списка сущностей.""" | |
| processed = [] | |
| for ent in entities: | |
| entity_type = normalize_entity_type(ent.get("entity_group", ent.get("entity", "UNKNOWN"))) | |
| processed.append({ | |
| "text": ent["word"], | |
| "type": entity_type, | |
| "label": ENTITY_LABELS.get(entity_type, entity_type), | |
| "score": round(ent["score"], 4), | |
| "start": ent["start"], | |
| "end": ent["end"] | |
| }) | |
| return processed | |
| def create_highlighted_text(text: str, entities: list) -> list: | |
| """Создание данных для подсветки текста.""" | |
| if not entities: | |
| return [(text, None)] | |
| sorted_entities = sorted(entities, key=lambda x: x["start"]) | |
| highlighted = [] | |
| last_end = 0 | |
| for ent in sorted_entities: | |
| start, end = ent["start"], ent["end"] | |
| if start > last_end: | |
| highlighted.append((text[last_end:start], None)) | |
| entity_text = text[start:end] | |
| entity_type = ent["type"] | |
| highlighted.append((entity_text, entity_type)) | |
| last_end = end | |
| if last_end < len(text): | |
| highlighted.append((text[last_end:], None)) | |
| return highlighted | |
| def entities_to_dataframe(entities: list) -> pd.DataFrame: | |
| """Преобразование списка сущностей в DataFrame.""" | |
| if not entities: | |
| return pd.DataFrame(columns=["Текст", "Тип", "Описание", "Уверенность"]) | |
| data = [] | |
| for ent in entities: | |
| data.append({ | |
| "Текст": ent["text"], | |
| "Тип": ent["type"], | |
| "Описание": ent["label"], | |
| "Уверенность": f"{ent['score']:.2%}" | |
| }) | |
| return pd.DataFrame(data) | |
| # ============== ИСТОРИЯ ЗАПРОСОВ ============== | |
| def add_to_history(text: str, model: str, entities: list, latency: float): | |
| """Добавление запроса в историю.""" | |
| timestamp = datetime.now().strftime("%H:%M:%S") | |
| entity_count = len(entities) | |
| entity_types = ", ".join(set(e["type"] for e in entities)) if entities else "—" | |
| history.appendleft({ | |
| "Время": timestamp, | |
| "Модель": model.split()[0], | |
| "Текст": text[:50] + "..." if len(text) > 50 else text, | |
| "Найдено": entity_count, | |
| "Типы": entity_types, | |
| "Latency": f"{latency} мс" | |
| }) | |
| def get_history_df(): | |
| """Получение истории запросов как DataFrame.""" | |
| if not history: | |
| return pd.DataFrame(columns=["Время", "Модель", "Текст", "Найдено", "Типы", "Latency"]) | |
| return pd.DataFrame(list(history)) | |
| def clear_history(): | |
| """Очистка истории запросов.""" | |
| history.clear() | |
| return pd.DataFrame(columns=["Время", "Модель", "Текст", "Найдено", "Типы", "Latency"]), "История очищена" | |
| # ============== ОСНОВНЫЕ ФУНКЦИИ ОБРАБОТКИ ============== | |
| def process_single_text(text: str, model_choice: str): | |
| """Обработка одиночного текста.""" | |
| is_valid, result = validate_input(text) | |
| if not is_valid: | |
| return result, None, None, "—" | |
| text = result | |
| try: | |
| pipe = load_model(model_choice) | |
| t0 = time.time() | |
| raw_entities = pipe(text) | |
| latency = round((time.time() - t0) * 1000, 1) | |
| entities = process_entities(raw_entities) | |
| highlighted = create_highlighted_text(text, entities) | |
| df = entities_to_dataframe(entities) | |
| add_to_history(text, model_choice, entities, latency) | |
| status = f"Найдено сущностей: {len(entities)}" | |
| return status, highlighted, df, f"{latency} мс" | |
| except Exception as e: | |
| return f"Ошибка: {type(e).__name__}: {e}", None, None, "—" | |
| def compare_models(text: str): | |
| """Сравнение результатов двух моделей.""" | |
| is_valid, result = validate_input(text) | |
| if not is_valid: | |
| return result, None, None, "—", None, None, "—" | |
| text = result | |
| results = {} | |
| try: | |
| for model_key in MODELS.keys(): | |
| pipe = load_model(model_key) | |
| t0 = time.time() | |
| raw_entities = pipe(text) | |
| latency = round((time.time() - t0) * 1000, 1) | |
| entities = process_entities(raw_entities) | |
| highlighted = create_highlighted_text(text, entities) | |
| df = entities_to_dataframe(entities) | |
| results[model_key] = { | |
| "highlighted": highlighted, | |
| "df": df, | |
| "latency": f"{latency} мс", | |
| "count": len(entities) | |
| } | |
| model_keys = list(MODELS.keys()) | |
| m1, m2 = model_keys[0], model_keys[1] | |
| status = f"Модель 1: {results[m1]['count']} сущностей | Модель 2: {results[m2]['count']} сущностей" | |
| return ( | |
| status, | |
| results[m1]["highlighted"], | |
| results[m1]["df"], | |
| results[m1]["latency"], | |
| results[m2]["highlighted"], | |
| results[m2]["df"], | |
| results[m2]["latency"] | |
| ) | |
| except Exception as e: | |
| error_msg = f"Ошибка: {type(e).__name__}: {e}" | |
| return error_msg, None, None, "—", None, None, "—" | |
| def process_batch(file, model_choice: str): | |
| """Пакетная обработка файла (CSV или TXT).""" | |
| if file is None: | |
| return "Ошибка: загрузите файл.", None, None | |
| try: | |
| file_path = file.name | |
| if file_path.endswith('.csv'): | |
| df_input = pd.read_csv(file_path) | |
| if 'text' not in df_input.columns: | |
| return "Ошибка: CSV должен содержать колонку 'text'.", None, None | |
| texts = df_input['text'].tolist() | |
| else: # TXT | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| texts = [line.strip() for line in f if line.strip()] | |
| if len(texts) > MAX_BATCH_ROWS: | |
| return f"Ошибка: слишком много строк ({len(texts)}). Максимум: {MAX_BATCH_ROWS}.", None, None | |
| if not texts: | |
| return "Ошибка: файл пустой или не содержит текстов.", None, None | |
| pipe = load_model(model_choice) | |
| results = [] | |
| t0 = time.time() | |
| for i, text in enumerate(texts): | |
| if len(text) > MAX_CHARS: | |
| text = text[:MAX_CHARS] | |
| try: | |
| raw_entities = pipe(text) | |
| entities = process_entities(raw_entities) | |
| per_list = [e["text"] for e in entities if e["type"] == "PER"] | |
| org_list = [e["text"] for e in entities if e["type"] == "ORG"] | |
| loc_list = [e["text"] for e in entities if e["type"] == "LOC"] | |
| misc_list = [e["text"] for e in entities if e["type"] == "MISC"] | |
| results.append({ | |
| "№": i + 1, | |
| "Текст": text[:100] + "..." if len(text) > 100 else text, | |
| "PER": ", ".join(per_list) if per_list else "—", | |
| "ORG": ", ".join(org_list) if org_list else "—", | |
| "LOC": ", ".join(loc_list) if loc_list else "—", | |
| "MISC": ", ".join(misc_list) if misc_list else "—", | |
| "Всего": len(entities) | |
| }) | |
| except Exception as e: | |
| results.append({ | |
| "№": i + 1, | |
| "Текст": text[:100] + "...", | |
| "PER": "ОШИБКА", | |
| "ORG": str(e)[:30], | |
| "LOC": "—", | |
| "MISC": "—", | |
| "Всего": 0 | |
| }) | |
| total_latency = round((time.time() - t0) * 1000, 1) | |
| df_results = pd.DataFrame(results) | |
| csv_buffer = io.StringIO() | |
| df_results.to_csv(csv_buffer, index=False, encoding='utf-8') | |
| csv_content = csv_buffer.getvalue() | |
| status = f"Обработано: {len(texts)} текстов за {total_latency} мс" | |
| return status, df_results, csv_content | |
| except Exception as e: | |
| return f"Ошибка: {type(e).__name__}: {e}", None, None | |