Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import torch | |
| import gradio as gr | |
| import pandas as pd | |
| from datetime import datetime | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| APP_CSS = """ | |
| /* Контейнер пошире под 2 колонки */ | |
| #app-wrap { | |
| max-width: 1100px; | |
| margin: 0 auto !important; | |
| padding: 0 16px; | |
| } | |
| #left-col, #right-col { gap: 12px; } | |
| """ | |
| MODEL_PATH = os.getenv("MODEL_ID", "mipatov/tech_support_intent_classifier") | |
| MAX_LEN = 512 | |
| ID2LABEL = { | |
| 0: "0. Нет подключения", | |
| 1: "1. Низкая скорость", | |
| 2: "2. Смена пароля Wi‑Fi", | |
| 3: "3. Обрыв кабеля", | |
| 4: "4. Нестабильный интернет", | |
| 5: "5. Узнать пароль Wi‑Fi", | |
| 6: "6. Высокий пинг", | |
| 7: "7. Настройка роутера", | |
| 8: "8. Замена роутера", | |
| 9: "9. Вызов мастера", | |
| 10: "10. Другое", | |
| } | |
| LABEL2ID = {v: k for k, v in ID2LABEL.items()} | |
| EXAMPLE_TICKETS = [ | |
| 'День добрый! Не подскажете как сменить пароль на роутере?', | |
| 'Очень медленная скорость', | |
| 'Здравствуйте, у нас собака перегрызла провод интернет и сломался роутер, как нам поступить какие есть варианты?', | |
| 'Здравствуйте, каждый вечер отваливается интернет', | |
| 'Добрый день подскажите пожалуйста как узнать пароль от вайфая если документы утеряны', | |
| 'Не работает ни телевизор ни интернет, пишет ошибку. Проверьте оборудование пожалуйста', | |
| 'Здравствуйте низкая скорость и высокий пинг невозможно пользоваться поточными сервисами с этим можно что-то сделать?', | |
| 'Роутэр на новый можно поменять? Или он от старого не отличается', | |
| 'Здравствуйте, необходимо перенести точку доступа внутри дома. К кому обратиться?' | |
| ] | |
| # Загрузка модели | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH) | |
| model.config.id2label = ID2LABEL | |
| model.config.label2id = LABEL2ID | |
| model.eval() | |
| labels = [model.config.id2label[i] for i in range(model.config.num_labels)] | |
| os.makedirs("logs", exist_ok=True) | |
| FEEDBACK_CSV = "logs/feedback.csv" | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| IMAGES_DIR = os.path.join(BASE_DIR, "images") | |
| def svg_file_to_html(file_path: str) -> str: | |
| """Читает SVG и заворачивает в адаптивный контейнер для gr.HTML.""" | |
| try: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| svg = f.read() | |
| return f''' | |
| <div style="display:flex;justify-content:center;"> | |
| <div style="max-width: 360px; width: 100%">{svg}</div> | |
| </div> | |
| ''' | |
| except Exception: | |
| return "" | |
| BASE_SVG_HTML = svg_file_to_html(os.path.join(IMAGES_DIR, "base.svg")) | |
| def label_to_image_path(label: str): | |
| if not label: | |
| return None | |
| idx = LABEL2ID.get(label) | |
| if idx is None: | |
| return None | |
| path = os.path.join(IMAGES_DIR, f"{idx}.svg") | |
| return path if os.path.exists(path) else None | |
| def label_to_image_html(label: str) -> str: | |
| """Возвращает HTML со встроенным SVG (чтобы не зависеть от поддержки SVG в gr.Image).""" | |
| path = label_to_image_path(label) | |
| if not path: | |
| return "" | |
| try: | |
| with open(path, "r", encoding="utf-8") as f: | |
| svg = f.read() | |
| # Оборачиваем для адаптивности | |
| return f''' | |
| <div style="display:flex;justify-content:center;"> | |
| <div style="max-width: 360px; width: 100%">{svg}</div> | |
| </div> | |
| ''' | |
| except Exception: | |
| return "" | |
| def predict_with_probs(text: str): | |
| if not text or not text.strip(): | |
| return "", {} | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_LEN) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits[0] | |
| probs = torch.softmax(logits, dim=-1).tolist() | |
| label_probs = {labels[i]: float(probs[i]) for i in range(len(labels))} | |
| label_probs = {k: round(v, 4) for k, v in sorted(label_probs.items(), key=lambda x: x[1], reverse=True)} | |
| top_label = next(iter(label_probs)) if label_probs else "" | |
| return top_label, label_probs | |
| def on_classify(t: str): | |
| top_label, label_probs = predict_with_probs(t) | |
| img_html = label_to_image_html(top_label) | |
| correct_reset = gr.update(value=None) | |
| # Возвращаем: HTML со svg, вероятности, сброс dropdown, скрытый textbox с меткой | |
| return img_html, label_probs, correct_reset, top_label | |
| with gr.Blocks(css=APP_CSS) as demo: | |
| with gr.Column(elem_id="app-wrap"): | |
| gr.Markdown("## Классификация обращений в техподдержку") | |
| with gr.Row(): | |
| # Левая колонка | |
| with gr.Column(scale=1, min_width=400, elem_id="left-col"): | |
| text = gr.Textbox(label="Текст обращения", lines=6, placeholder="Вставь сюда текст тикета") | |
| gr.Examples( | |
| examples=[[e] for e in EXAMPLE_TICKETS], | |
| inputs=[text], | |
| label="Примеры обращений", | |
| cache_examples=False | |
| ) | |
| # Правая колонка | |
| with gr.Column(scale=1, min_width=400, elem_id="right-col"): | |
| pred_img = gr.HTML(label="Предсказанный класс", value=BASE_SVG_HTML) # HTML со встроенным SVG | |
| probs = gr.Label(label="Вероятности по классам") | |
| # Скрытая «переменная» | |
| pred_label_hidden = gr.Textbox(visible=False, interactive=False) | |
| text.change( | |
| on_classify, | |
| inputs=text, | |
| outputs=[pred_img, probs, pred_label_hidden], | |
| ) | |
| if __name__ == "__main__": | |
| gr.close_all() # опционально | |
| demo.queue().launch() |