model-embeddng / app.py
NikitaMY's picture
Update app.py
d75604e verified
import gradio as gr
import pandas as pd
import torch
import re
import numpy as np
from sentence_transformers import SentenceTransformer
from typing import List, Tuple, Optional
from functools import lru_cache
#ЗАДАНИЕ 2: модели
MODEL_CHOICES = {
"all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
"paraphrase-multilingual-MiniLM-L12-v2": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
}
DEFAULT_MODEL_NAME = "all-MiniLM-L6-v2"
@lru_cache(maxsize=2)
def load_model(model_key: str):
"""Кэшируем модели в памяти"""
model_path = MODEL_CHOICES[model_key]
return SentenceTransformer(model_path)
# Инициализация начальной модели
model = load_model(DEFAULT_MODEL_NAME)
# Ограничения
MAX_TEXTS = 500
MAX_CHARS_PER_TEXT = 2000
MAX_QUERY_CHARS = 2000
# Модуль А - нормализация текста
def clean_text(s: str) -> str:
s = "" if s is None else str(s)
s = re.sub(r'\s+', ' ', s).strip()
return s
# Модуль B - парсинг текста при ручном вводе
def parser_manual_texts(raw: str) -> List[str]:
raw = "" if raw is None else str(raw)
lines = [clean_text(x) for x in raw.splitlines()]
lines = [x for x in lines if x]
return lines[:MAX_TEXTS]
# Модуль C - парсинг текста для файла
def parser_file(file_obj) -> List[str]:
if file_obj is None:
return []
path = file_obj.name
if path.lower().endswith(".txt"):
with open(path, "r", encoding="utf-8", errors="ignore") as f:
lines = [clean_text(x) for x in f.read().splitlines()]
lines = [x for x in lines if x]
return lines[:MAX_TEXTS]
if path.lower().endswith(".csv"):
df = pd.read_csv(path)
if "text" in df.columns:
col = "text"
else:
col = df.columns[0]
texts = [clean_text(x) for x in df[col].astype(str).tolist()]
texts = [x for x in texts if x]
return texts[:MAX_TEXTS]
return []
# Модуль D - косинусное сходство
def cosine_sim_matrix(query_emb: np.ndarray, docs_emb: np.ndarray) -> np.ndarray:
q = query_emb / (np.linalg.norm(query_emb) + 1e-12)
d = docs_emb / (np.linalg.norm(docs_emb, axis=1, keepdims=True) + 1e-12)
return d @ q
# Модуль E - Построение индекса (с учетом выбранной модели)
def build_index(texts: List[str], model_name: str) -> Tuple[List[str], Optional[np.ndarray], str]:
if not texts:
return [], None, "База пуста, добавьте текст"
texts = [t[:MAX_CHARS_PER_TEXT] for t in texts]
try:
current_model = load_model(model_name)
emb = current_model.encode(texts, convert_to_numpy=True, show_progress_bar=False)
return texts, emb, f"Индекс построен: {len(texts)} текстов (модель: {model_name})"
except Exception as e:
return [], None, f"Ошибка построения индекса: {type(e).__name__}: {e}"
# Модуль F - Основной обработчик кнопки
def search_similar(
query: str,
manual_texts: str,
file_obj,
top_k: int,
min_sim: float, # ЗАДАНИЕ A: параметр минимальной похожести
model_name: str, # ЗАДАНИЕ B: выбранная модель
state_texts,
state_emb,
state_model # Новое состояние для хранения модели индекса
):
query = clean_text(query)[:MAX_QUERY_CHARS]
if not query:
return None, "Введите запрос", state_texts, state_emb, state_model
# Парсинг текстов
texts = parser_manual_texts(manual_texts)
texts_from_file = parser_file(file_obj)
texts.extend(texts_from_file)
# Удаление дубликатов
uniq = []
seen = set()
for t in texts:
if t not in seen:
uniq.append(t)
seen.add(t)
texts = uniq[:MAX_TEXTS]
status_msgs = []
# Проверка необходимости перестроения индекса
needs_rebuild = (
state_texts is None or
texts != state_texts or
state_emb is None or
state_model != model_name # ЗАДАНИЕ B: перестраиваем если сменили модель
)
if needs_rebuild:
idx_texts, idx_emb, msg = build_index(texts, model_name)
status_msgs.append(msg)
state_texts, state_emb = idx_texts, idx_emb
state_model = model_name # Сохраняем модель индекса
else:
status_msgs.append(f"Используем готовый индекс: {len(state_texts)} текстов")
if state_emb is None or not state_texts:
return None, "\n".join(status_msgs), state_texts, state_emb, state_model
try:
current_model = load_model(model_name)
q_emb = current_model.encode([query], convert_to_numpy=True, show_progress_bar=False)[0]
except Exception as e:
return None, f"Ошибка эмбеддинга: {type(e).__name__}: {e}", state_texts, state_emb, state_model
# Расчет сходства
sims = cosine_sim_matrix(q_emb, state_emb)
# ЗАДАНИЕ 1: фильтрация по минимальному порогу
above_threshold = sims >= min_sim
if not above_threshold.any():
return None, "Ничего не найдено (похожесть ниже порога)", state_texts, state_emb, state_model
# Получаем индексы, удовлетворяющие порогу
valid_indices = np.where(above_threshold)[0]
valid_sims = sims[valid_indices]
# Сортировка и выбор top_k
top_k = int(top_k)
top_k = max(1, min(top_k, len(valid_indices)))
# Сортируем по убыванию сходства среди отфильтрованных
top_indices = valid_indices[np.argsort(-valid_sims)[:top_k]]
# Формирование результатов
rows = []
for rank, i in enumerate(top_indices, start=1):
rows.append({
"rank": rank,
"similarity": float(sims[i]),
"text": state_texts[i]
})
df = pd.DataFrame(rows)
status_msgs.append(f"Найдено {len(valid_indices)} текстов с похожестью ≥ {min_sim}. Показано топ-{top_k}")
return df, "\n".join(status_msgs), state_texts, state_emb, state_model
# Модуль G - Интерфейс Gradio
with gr.Blocks(title="Поиск похожих текстов") as demo:
gr.Markdown("""
# Поиск похожих текстов с использованием эмбеддингов
**Инструкция:**
1. Введите тексты в поле или загрузите файл
2. Выберите модель и настройте параметры
3. Введите запрос и нажмите "Найти похожее"
""")
with gr.Row():
with gr.Column(scale=2):
query = gr.Textbox(
label="Что ищем",
lines=4,
placeholder="Клиент жалуется на задержку"
)
with gr.Column(scale=1):
# ЗАДАНИЕ B: выбор модели
model_dropdown = gr.Dropdown(
choices=list(MODEL_CHOICES.keys()),
value=DEFAULT_MODEL_NAME,
label="Модель эмбеддингов"
)
# ЗАДАНИЕ 1: ползунок минимальной похожести
min_sim = gr.Slider(
0.0, 1.0,
value=0.3,
step=0.01,
label="Минимальная похожесть",
info="Показывать только результаты с похожестью выше этого значения"
)
top_k = gr.Slider(
1, 20,
value=5,
step=1,
label="Количество результатов"
)
with gr.Row():
manual_texts = gr.Textbox(
label="База текстов (каждая строка - отдельный документ)",
lines=10,
placeholder="Текст 1\nТекст 2\nТекст 3"
)
file_obj = gr.File(
label="Или загрузите файл (txt, csv)",
file_types=[".txt", ".csv"]
)
run_btn = gr.Button(
"Найти похожее",
variant="primary"
)
with gr.Row():
out_table = gr.Dataframe(
label="Результаты поиска",
interactive=False,
wrap=True
)
status = gr.Textbox(
label="Статус выполнения",
lines=4
)
# Состояния
state_texts = gr.State(None)
state_emb = gr.State(None)
state_model = gr.State(DEFAULT_MODEL_NAME) # Новое состояние для модели
# Обработчик кнопки
run_btn.click(
search_similar,
inputs=[
query, manual_texts, file_obj, top_k,
min_sim, model_dropdown, # Добавлены новые параметры
state_texts, state_emb, state_model
],
outputs=[out_table, status, state_texts, state_emb, state_model]
)
# Примеры
gr.Examples(
examples=[
[
"У меня списали деньги дважды за заказ.",
"Деньги списались два раза за один и тот же заказ.\n"
"Не могу войти в личный кабинет, пишет неверный пароль.\n"
"Доставка задерживается уже на 5 дней, где мой заказ?\n"
"Хочу вернуть товар, он не подошел по размеру.\n"
"Поддержка не отвечает, жду ответа третий день.\n"
"Оплата не проходит, ошибка на этапе подтверждения."
],
[
"Не могу оплатить, постоянно ошибка.",
"Оплата не проходит, ошибка на этапе подтверждения.\n"
"Платеж отклоняется банком, хотя карта рабочая.\n"
"Хочу отменить заказ и вернуть деньги.\n"
"Курьер не пришел, доставка переносится.\n"
"Личный кабинет не открывается."
],
],
inputs=[query, manual_texts],
label="Примеры запросов"
)
# Информация о моделях
with gr.Accordion("Информация о моделях", open=False):
gr.Markdown("""
**all-MiniLM-L6-v2:**
- Размер: 80 МБ
- Размерность эмбеддингов: 384
- Язык: английский
**paraphrase-multilingual-MiniLM-L12-v2:**
- Размер: 420 МБ
- Размерность эмбеддингов: 384
- Языки: мультиязычная (поддерживает русский)
""")
if __name__ == "__main__":
demo.launch(share=False)