sashadd commited on
Commit
c9548f6
·
verified ·
1 Parent(s): 7879a11

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +186 -0
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import time
5
+ import re
6
+ from typing import Tuple, Dict
7
+
8
+ # ------------------------------------------------------------
9
+ # Конфигурация
10
+ # ------------------------------------------------------------
11
+ MODEL_NAMES = [
12
+ "tinkoff-ai/ruDialoGPT-small",
13
+ "tinkoff-ai/ruDialoGPT-medium"
14
+ ]
15
+ DEFAULT_MODEL = MODEL_NAMES[0]
16
+
17
+ # Лимиты на длину ввода (в символах)
18
+ MAX_DOCUMENT_CHARS = 2000
19
+ MAX_QUESTION_CHARS = 1000
20
+ MAX_TOTAL_CHARS = MAX_DOCUMENT_CHARS + MAX_QUESTION_CHARS
21
+
22
+ # Кэш для моделей и токенизаторов
23
+ model_cache: Dict[str, Tuple] = {} # имя -> (tokenizer, model)
24
+
25
+ def load_model(model_name: str):
26
+ """Загружает токенизатор и модель, если ещё не загружены."""
27
+ if model_name not in model_cache:
28
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
29
+ model = AutoModelForCausalLM.from_pretrained(model_name)
30
+ model_cache[model_name] = (tokenizer, model)
31
+ return model_cache[model_name]
32
+
33
+ def truncate_text(text: str, max_chars: int) -> str:
34
+ """Обрезает текст до указанного количества символов (грубо, по символам)."""
35
+ if len(text) > max_chars:
36
+ return text[:max_chars] + "..."
37
+ return text
38
+
39
+ def generate_response(
40
+ document: str,
41
+ question: str,
42
+ model_name: str,
43
+ max_new_tokens: int,
44
+ temperature: float
45
+ ) -> Tuple[str, float]:
46
+ """
47
+ Генерирует ответ модели на основе документа и вопроса.
48
+ Возвращает (ответ, время_генерации_сек).
49
+ """
50
+ # Проверка на пустые входные данные
51
+ if not document.strip():
52
+ return "Ошибка: документ не может быть пустым.", 0.0
53
+ if not question.strip():
54
+ return "Ошибка: вопрос не может быть пустым.", 0.0
55
+
56
+ # Обрезка по длине
57
+ document = truncate_text(document, MAX_DOCUMENT_CHARS)
58
+ question = truncate_text(question, MAX_QUESTION_CHARS)
59
+
60
+ # Формирование промпта (простая инструкция)
61
+ prompt = f"Документ: {document}\nВопрос: {question}\nОтвет:"
62
+
63
+ # Загрузка модели
64
+ try:
65
+ tokenizer, model = load_model(model_name)
66
+ except Exception as e:
67
+ return f"Ошибка загрузки модели: {type(e).__name__}: {e}", 0.0
68
+
69
+ # Токенизация с учётом максимальной длины модели
70
+ try:
71
+ inputs = tokenizer(
72
+ prompt,
73
+ return_tensors="pt",
74
+ truncation=True,
75
+ max_length=tokenizer.model_max_length
76
+ )
77
+ except Exception as e:
78
+ return f"Ошибка токенизации: {type(e).__name__}: {e}", 0.0
79
+
80
+ # Генерация
81
+ start_time = time.time()
82
+ try:
83
+ with torch.no_grad():
84
+ outputs = model.generate(
85
+ inputs.input_ids,
86
+ max_new_tokens=max_new_tokens,
87
+ temperature=temperature,
88
+ do_sample=True,
89
+ top_p=0.95,
90
+ pad_token_id=tokenizer.eos_token_id
91
+ )
92
+ latency = time.time() - start_time
93
+ except Exception as e:
94
+ return f"Ошибка генерации: {type(e).__name__}: {e}", time.time() - start_time
95
+
96
+ # Декодирование ответа
97
+ response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
98
+ if not response.strip():
99
+ response = "[модель не дала ответа]"
100
+
101
+ return response.strip(), latency
102
+
103
+ # ------------------------------------------------------------
104
+ # Интерфейс Gradio
105
+ # ------------------------------------------------------------
106
+ with gr.Blocks(title="Мини-чат по документу (русский язык)") as demo:
107
+ gr.Markdown("""
108
+ ## Чат с моделью на основе одного документа
109
+ Задайте вопрос по предоставленному тексту. Модель ответит, используя только информацию из документа.
110
+ """)
111
+
112
+ with gr.Row():
113
+ with gr.Column(scale=2):
114
+ document_input = gr.Textbox(
115
+ label="Документ (контекст)",
116
+ lines=6,
117
+ placeholder="Вставьте текст документа здесь..."
118
+ )
119
+ question_input = gr.Textbox(
120
+ label="Ваш вопрос",
121
+ lines=2,
122
+ placeholder="Например: О чём говорится в документе?"
123
+ )
124
+ with gr.Row():
125
+ model_selector = gr.Dropdown(
126
+ choices=MODEL_NAMES,
127
+ value=DEFAULT_MODEL,
128
+ label="Модель"
129
+ )
130
+ max_tokens_slider = gr.Slider(
131
+ 10, 200, value=50, step=5,
132
+ label="Макс. новых токенов"
133
+ )
134
+ temperature_slider = gr.Slider(
135
+ 0.1, 2.0, value=0.7, step=0.1,
136
+ label="Температура"
137
+ )
138
+ submit_btn = gr.Button("Спросить", variant="primary")
139
+
140
+ with gr.Column(scale=1):
141
+ answer_output = gr.Textbox(
142
+ label="Ответ модели",
143
+ lines=6,
144
+ interactive=False
145
+ )
146
+ latency_output = gr.Textbox(
147
+ label="Время генерации (сек)",
148
+ lines=1,
149
+ interactive=False
150
+ )
151
+
152
+ # Примеры (заполняют документ и вопрос, остальные параметры остаются текущими)
153
+ gr.Examples(
154
+ examples=[
155
+ [
156
+ "Кофе эспрессо готовится путём пропускания горячей воды под давлением через молотые зёрна. Температура воды 90-96°C, давление 9 бар. Выход напитка 25-35 мл.",
157
+ "Как приготовить эспрессо?"
158
+ ],
159
+ [
160
+ "Солнечная система состоит из Солнца и планет: Меркурий, Венера, Земля, Марс, Юпитер, Сатурн, Уран, Нептун. Земля — третья планета от Солнца, единственная известная планета с жизнью.",
161
+ "Какая планета третья от Солнца?"
162
+ ],
163
+ [
164
+ "Для сборки стола необходимо: столешница, 4 ножки, 8 шурупов, отвёртка. Сначала прикрутить ножки к столешнице, затянув шурупы крест-накрест.",
165
+ "Какие инструменты нужны для сборки стола?"
166
+ ]
167
+ ],
168
+ inputs=[document_input, question_input],
169
+ label="Примеры запросов"
170
+ )
171
+
172
+ # Функция обработки
173
+ def process(document, question, model_name, max_tokens, temperature):
174
+ answer, latency = generate_response(
175
+ document, question, model_name,
176
+ max_tokens, temperature
177
+ )
178
+ return answer, f"{latency:.3f}"
179
+
180
+ submit_btn.click(
181
+ fn=process,
182
+ inputs=[document_input, question_input, model_selector, max_tokens_slider, temperature_slider],
183
+ outputs=[answer_output, latency_output]
184
+ )
185
+
186
+ demo.launch()