mipatov commited on
Commit
6bd1823
·
1 Parent(s): 0ec252e
Files changed (15) hide show
  1. app.py +191 -0
  2. images/0.svg +27 -0
  3. images/1.svg +27 -0
  4. images/10.svg +18 -0
  5. images/2.svg +66 -0
  6. images/3.svg +24 -0
  7. images/4.svg +30 -0
  8. images/5.svg +23 -0
  9. images/6.svg +33 -0
  10. images/7.svg +30 -0
  11. images/8.svg +51 -0
  12. images/9.svg +29 -0
  13. images/base.svg +43 -0
  14. requirements.txt +0 -0
  15. runtime.txt +1 -0
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import torch
4
+ import gradio as gr
5
+ import pandas as pd
6
+ from datetime import datetime
7
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
+
9
+ APP_CSS = """
10
+ /* Контейнер пошире под 2 колонки */
11
+ #app-wrap {
12
+ max-width: 1100px;
13
+ margin: 0 auto !important;
14
+ padding: 0 16px;
15
+ }
16
+ #left-col, #right-col { gap: 12px; }
17
+ """
18
+
19
+ MODEL_PATH = os.getenv("MODEL_ID", "mipatov/tech_support_intent_classifier")
20
+ MAX_LEN = 512
21
+
22
+ ID2LABEL = {
23
+ 0: "0. Нет подключения",
24
+ 1: "1. Низкая скорость",
25
+ 2: "2. Смена пароля Wi‑Fi",
26
+ 3: "3. Обрыв кабеля",
27
+ 4: "4. Нестабильный интернет",
28
+ 5: "5. Узнать пароль Wi‑Fi",
29
+ 6: "6. Высокий пинг",
30
+ 7: "7. Настройка роутера",
31
+ 8: "8. Замена роутера",
32
+ 9: "9. Вызов мастера",
33
+ 10: "10. Другое",
34
+ }
35
+ LABEL2ID = {v: k for k, v in ID2LABEL.items()}
36
+
37
+ EXAMPLE_TICKETS = [
38
+ 'День добрый! Не подскажете как сменить пароль на роутере?',
39
+ 'Мне необходимо восстановить логин и пароль в админку. Чтобы изменить имя и пароль WiFi',
40
+ 'Здравствуйте, у нас собака перегрызла провод интернет и сломался роутер, как нам поступить какие есть варианты?',
41
+ 'как оформить заявку на ремонт? не работает ни телефон .ни интернет',
42
+ 'Добрый день. У меня на роутере время от времени мигает лампочка wi fi. Интернет не работает. В чем дело? Подскажите, пожалуйста',
43
+ 'Не работает ни телевизор ни интернет, пишет ошибку. Проверьте оборудование пожалуйста',
44
+ 'Здравствуйте низкая скорость и высокий пинг невозможно пользоваться поточными сервисами с этим можно что-то сделать?',
45
+ 'Роутэр на новый можно поменять? Или он от старого не отличается',
46
+ 'Нет инета'
47
+ ]
48
+
49
+ # Загрузка модели
50
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
51
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
52
+ model.config.id2label = ID2LABEL
53
+ model.config.label2id = LABEL2ID
54
+ model.eval()
55
+
56
+ labels = [model.config.id2label[i] for i in range(model.config.num_labels)]
57
+ os.makedirs("logs", exist_ok=True)
58
+ FEEDBACK_CSV = "logs/feedback.csv"
59
+
60
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
61
+ IMAGES_DIR = os.path.join(BASE_DIR, "images")
62
+
63
+ def svg_file_to_html(file_path: str) -> str:
64
+ """Читает SVG и заворачивает в адаптивный контейнер для gr.HTML."""
65
+ try:
66
+ with open(file_path, "r", encoding="utf-8") as f:
67
+ svg = f.read()
68
+ return f'''
69
+ <div style="display:flex;justify-content:center;">
70
+ <div style="max-width: 360px; width: 100%">{svg}</div>
71
+ </div>
72
+ '''
73
+ except Exception:
74
+ return ""
75
+
76
+ BASE_SVG_HTML = svg_file_to_html(os.path.join(IMAGES_DIR, "base.svg"))
77
+
78
+ def label_to_image_path(label: str):
79
+ if not label:
80
+ return None
81
+ idx = LABEL2ID.get(label)
82
+ if idx is None:
83
+ return None
84
+ path = os.path.join(IMAGES_DIR, f"{idx}.svg")
85
+ return path if os.path.exists(path) else None
86
+
87
+ def label_to_image_html(label: str) -> str:
88
+ """Возвращает HTML со встроенным SVG (чтобы не зависеть от поддержки SVG в gr.Image)."""
89
+ path = label_to_image_path(label)
90
+ if not path:
91
+ return ""
92
+ try:
93
+ with open(path, "r", encoding="utf-8") as f:
94
+ svg = f.read()
95
+ # Оборачиваем для адаптивности
96
+ return f'''
97
+ <div style="display:flex;justify-content:center;">
98
+ <div style="max-width: 360px; width: 100%">{svg}</div>
99
+ </div>
100
+ '''
101
+ except Exception:
102
+ return ""
103
+
104
+ def predict_with_probs(text: str):
105
+ if not text or not text.strip():
106
+ return "", {}
107
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=MAX_LEN)
108
+ with torch.no_grad():
109
+ logits = model(**inputs).logits[0]
110
+ probs = torch.softmax(logits, dim=-1).tolist()
111
+ label_probs = {labels[i]: float(probs[i]) for i in range(len(labels))}
112
+ label_probs = {k: round(v, 4) for k, v in sorted(label_probs.items(), key=lambda x: x[1], reverse=True)}
113
+ top_label = next(iter(label_probs)) if label_probs else ""
114
+ return top_label, label_probs
115
+
116
+ def save_feedback(text, pred_label, correct_label):
117
+ if not text or not text.strip():
118
+ return "Введите текст обращения."
119
+ final_label = correct_label if correct_label else pred_label
120
+ row = {
121
+ "timestamp": datetime.utcnow().isoformat(),
122
+ "text": text,
123
+ "pred_label": pred_label,
124
+ "correct_label": correct_label if correct_label else "",
125
+ "final_label": final_label,
126
+ }
127
+ df = pd.DataFrame([row])
128
+ header = not os.path.exists(FEEDBACK_CSV)
129
+ df.to_csv(FEEDBACK_CSV, mode="a", header=header, index=False, encoding="utf-8")
130
+ return "Фидбек сохранён ✅"
131
+
132
+ def on_classify(t: str):
133
+ top_label, label_probs = predict_with_probs(t)
134
+ img_html = label_to_image_html(top_label)
135
+ correct_reset = gr.update(value=None)
136
+ caption_md = f"Предсказанный класс: <b>{top_label}</b>" if top_label else ""
137
+ caption_md = ""
138
+ # Возвращаем: HTML со svg, вероятности, сброс dropdown, скрытый textbox с меткой, подпись
139
+ return img_html, label_probs, correct_reset, top_label, caption_md
140
+
141
+ with gr.Blocks(css=APP_CSS) as demo:
142
+ with gr.Column(elem_id="app-wrap"):
143
+ gr.Markdown("## Классификация обращений в техподдержку")
144
+
145
+ with gr.Row():
146
+ # Левая колонка
147
+ with gr.Column(scale=1, min_width=400, elem_id="left-col"):
148
+ text = gr.Textbox(label="Текст обращения", lines=6, placeholder="Вставь сюда текст тикета")
149
+ gr.Examples(
150
+ examples=[[e] for e in EXAMPLE_TICKETS],
151
+ inputs=[text],
152
+ label="Примеры обращений",
153
+ cache_examples=False,
154
+ examples_per_page=3,
155
+ )
156
+ # btn = gr.Button("Классифицировать", variant="primary")
157
+
158
+ # Правая колонка
159
+ with gr.Column(scale=1, min_width=400, elem_id="right-col"):
160
+ pred_img = gr.HTML(label="Предсказанный класс", value=BASE_SVG_HTML) # HTML со встроенным SVG
161
+ pred_caption = gr.Markdown()
162
+
163
+ probs = gr.Label(label="Вероятности по классам")
164
+ # correct = gr.Dropdown(choices=labels, label="Если неверно — выбери правильный класс")
165
+ # save = gr.Button("Сохранить фидбек")
166
+ # msg = gr.Markdown()
167
+
168
+ # Скрытая «переменная» без gr.State (во избежание бага с JSON Schema)
169
+ pred_label_hidden = gr.Textbox(visible=False, interactive=False)
170
+
171
+ # События
172
+ # btn.click(
173
+ # on_classify,
174
+ # inputs=text,
175
+ # outputs=[pred_img, probs, correct, pred_label_hidden, pred_caption],
176
+ # )
177
+ text.change(
178
+ on_classify,
179
+ inputs=text,
180
+ outputs=[pred_img, probs, pred_label_hidden, pred_caption],
181
+ )
182
+
183
+ # save.click(
184
+ # save_feedback,
185
+ # inputs=[text, pred_label_hidden, correct],
186
+ # outputs=msg,
187
+ # )
188
+
189
+ if __name__ == "__main__":
190
+ gr.close_all() # опционально
191
+ demo.queue().launch()
images/0.svg ADDED
images/1.svg ADDED
images/10.svg ADDED
images/2.svg ADDED
images/3.svg ADDED
images/4.svg ADDED
images/5.svg ADDED
images/6.svg ADDED
images/7.svg ADDED
images/8.svg ADDED
images/9.svg ADDED
images/base.svg ADDED
requirements.txt ADDED
File without changes
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.12