Kolyadual commited on
Commit
25777f3
·
verified ·
1 Parent(s): 8e402c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -250
app.py CHANGED
@@ -1,259 +1,127 @@
1
- import torch
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
 
3
  import gradio as gr
4
- import time
5
- from typing import Tuple
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- # Настройки модели
8
- MODEL_NAME = "Kolyadual/MIXdevAI-llama"
9
- DEFAULT_MAX_LENGTH = 512
10
- DEFAULT_TEMPERATURE = 0.7
11
- DEFAULT_TOP_P = 0.9
12
 
13
- class ChatBot:
14
- def __init__(self):
15
- self.model = None
16
- self.tokenizer = None
17
- self.pipe = None
18
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
19
- self.is_loaded = False
20
-
21
- def load_model(self):
22
- """Загрузка модели"""
23
- if self.is_loaded:
24
- return True
25
-
26
- try:
27
- print("⏳ Загрузка токенизатора...")
28
- self.tokenizer = AutoTokenizer.from_pretrained(
29
- MODEL_NAME,
30
- trust_remote_code=True
31
- )
32
-
33
- print("⏳ Загрузка модели...")
34
- self.model = AutoModelForCausalLM.from_pretrained(
35
- MODEL_NAME,
36
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
37
- device_map="auto" if self.device == "cuda" else None,
38
- trust_remote_code=True,
39
- low_cpu_mem_usage=True
40
- )
41
-
42
- if self.device == "cpu":
43
- self.model = self.model.to(self.device)
44
-
45
- print("⏳ Создание пайплайна...")
46
- self.pipe = pipeline(
47
- "text-generation",
48
- model=self.model,
49
- tokenizer=self.tokenizer,
50
- device=0 if self.device == "cuda" else -1
51
- )
52
-
53
- self.is_loaded = True
54
- print("✅ Модель успешно загружена!")
55
- return True
56
-
57
- except Exception as e:
58
- print(f"❌ Ошибка загрузки модели: {e}")
59
- return False
60
-
61
- def generate_response(self,
62
- message: str,
63
- history: list,
64
- max_length: int,
65
- temperature: float,
66
- top_p: float) -> Tuple[str, list]:
67
- """Генерация ответа"""
68
- if not self.is_loaded:
69
- if not self.load_model():
70
- return "Ошибка: модель не загружена", history
71
-
72
- try:
73
- # Форматируем историю для модели
74
- prompt = self._format_chat_prompt(message, history)
75
-
76
- # Генерируем ответ
77
- with torch.no_grad():
78
- outputs = self.pipe(
79
- prompt,
80
- max_new_tokens=max_length,
81
- temperature=temperature,
82
- top_p=top_p,
83
- do_sample=True,
84
- pad_token_id=self.tokenizer.eos_token_id,
85
- eos_token_id=self.tokenizer.eos_token_id,
86
- repetition_penalty=1.1
87
- )
88
-
89
- # Извлекаем ответ
90
- full_response = outputs[0]['generated_text']
91
- response = full_response[len(prompt):].strip()
92
-
93
- # Добавляем в историю
94
- history.append((message, response))
95
-
96
- return "", history
97
-
98
- except Exception as e:
99
- print(f"❌ Ошибка генерации: {e}")
100
- return f"Ошибка: {str(e)}", history
101
-
102
- def _format_chat_prompt(self, message: str, history: list) -> str:
103
- """Форматирование промпта из истории чата"""
104
- prompt = ""
105
-
106
- # Добавляем историю
107
- for user_msg, assistant_msg in history:
108
- prompt += f"### User: {user_msg}\n### Assistant: {assistant_msg}\n"
109
-
110
- # Добавляем текущее сообщение
111
- prompt += f"### User: {message}\n### Assistant: "
112
-
113
- return prompt
114
-
115
- def clear_history(self):
116
- """Очистка истории"""
117
- return []
118
 
119
- # Создаем экземпляр бота
120
- chatbot = ChatBot()
121
 
122
- def predict(message: str,
123
- history: list,
124
- max_length: int,
125
- temperature: float,
126
- top_p: float) -> Tuple[str, list]:
127
- """Функция для Gradio"""
128
- if not message.strip():
129
- return "", history
130
-
131
- # Генерируем ответ
132
- _, updated_history = chatbot.generate_response(
133
- message, history, max_length, temperature, top_p
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  )
135
-
136
- return "", updated_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
- def ui():
139
- """Создание интерфейса"""
140
- with gr.Blocks(
141
- title="MIXdevAI-llama Chat",
142
- css="""
143
- .gradio-container {
144
- max-width: 900px;
145
- margin: auto;
146
- }
147
- .chatbot {
148
- min-height: 500px;
149
- }
150
- .footer {
151
- text-align: center;
152
- margin-top: 20px;
153
- color: #666;
154
- font-size: 0.9em;
155
- }
156
- """
157
- ) as demo:
158
- # Заголовок
159
- gr.Markdown("""
160
- # 🤖 MIXdevAI-llama Chat Assistant
161
- Модель: [Kolyadual/MIXdevAI-llama](https://huggingface.co/Kolyadual/MIXdevAI-llama)
162
- """)
163
-
164
- # Информация о загрузке
165
- status = gr.Markdown("Статус: ⏳ Загрузка модели...")
166
-
167
- # Чат
168
- with gr.Row():
169
- with gr.Column(scale=3):
170
- chatbot_ui = gr.Chatbot(
171
- label="Диалог",
172
- height=500
173
- )
174
-
175
- msg = gr.Textbox(
176
- label="Ваше сообщение",
177
- placeholder="Введите сообщение...",
178
- lines=2,
179
- max_lines=5
180
- )
181
-
182
- with gr.Row():
183
- submit_btn = gr.Button("📤 Отправить", variant="primary")
184
- clear_btn = gr.Button("🗑️ Очистить историю")
185
-
186
- # Настройки
187
- with gr.Accordion("⚙️ Настройки генерации", open=False):
188
- with gr.Row():
189
- max_length = gr.Slider(
190
- minimum=64,
191
- maximum=2048,
192
- value=DEFAULT_MAX_LENGTH,
193
- step=64,
194
- label="Максимальная длина ответа"
195
- )
196
- temperature = gr.Slider(
197
- minimum=0.1,
198
- maximum=2.0,
199
- value=DEFAULT_TEMPERATURE,
200
- step=0.1,
201
- label="Температура (креативность)"
202
- )
203
- top_p = gr.Slider(
204
- minimum=0.1,
205
- maximum=1.0,
206
- value=DEFAULT_TOP_P,
207
- step=0.05,
208
- label="Top-p (разнообразие)"
209
- )
210
-
211
- # Футер
212
- gr.Markdown("""
213
- <div class="footer">
214
- <p>Модель автоматически загружается при первшем запросе</p>
215
- <p>Для работы на CPU потребуется время на загрузку (~5-10 минут)</p>
216
- </div>
217
- """)
218
-
219
- # Обработчики событий
220
- submit_event = msg.submit(
221
- predict,
222
- [msg, chatbot_ui, max_length, temperature, top_p],
223
- [msg, chatbot_ui]
224
- )
225
-
226
- submit_btn.click(
227
- predict,
228
- [msg, chatbot_ui, max_length, temperature, top_p],
229
- [msg, chatbot_ui]
230
- )
231
-
232
- clear_btn.click(
233
- chatbot.clear_history,
234
- outputs=chatbot_ui
235
- )
236
-
237
- # Загружаем модель при старте
238
- demo.load(
239
- chatbot.load_model,
240
- outputs=status
241
- ).then(
242
- lambda: "Статус: ✅ Модель готова к использованию!",
243
- outputs=status
244
- )
245
-
246
- return demo
247
 
248
  if __name__ == "__main__":
249
- # Запускаем приложение
250
- demo = ui()
251
- demo.launch(
252
- server_name="0.0.0.0",
253
- server_port=7860,
254
- share=False,
255
- theme=gr.themes.Soft(
256
- primary_hue="blue",
257
- secondary_hue="gray"
258
- )
259
- )
 
1
+ import os
2
+ from collections.abc import Iterator
3
+ from threading import Thread
4
+
5
  import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """\
11
+ # MIXdevAI Llama
12
+
13
+ MIXdevAI-llama is fine-tuned Russian model based on Llama 3.2 1B Instruct. Model for chating, coding and other! Created by Kolyadual
14
+ """
15
+
16
+ MAX_MAX_NEW_TOKENS = 2048
17
+ DEFAULT_MAX_NEW_TOKENS = 1024
18
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
19
 
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
21
 
22
+ model_id = "Kolyadual/MIXdevAI-llama"
23
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_id,
26
+ device_map="auto",
27
+ torch_dtype=torch.bfloat16,
28
+ )
29
+ model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
 
 
31
 
32
+ @spaces.GPU(duration=90)
33
+ def generate(
34
+ message: str,
35
+ chat_history: list[dict],
36
+ max_new_tokens: int = 1024,
37
+ temperature: float = 0.6,
38
+ top_p: float = 0.9,
39
+ top_k: int = 50,
40
+ repetition_penalty: float = 1.2,
41
+ ) -> Iterator[str]:
42
+ conversation = [*chat_history, {"role": "user", "content": message}]
43
+
44
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
45
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
46
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
47
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
48
+ input_ids = input_ids.to(model.device)
49
+
50
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
51
+ generate_kwargs = dict(
52
+ {"input_ids": input_ids},
53
+ streamer=streamer,
54
+ max_new_tokens=max_new_tokens,
55
+ do_sample=True,
56
+ top_p=top_p,
57
+ top_k=top_k,
58
+ temperature=temperature,
59
+ num_beams=1,
60
+ repetition_penalty=repetition_penalty,
61
  )
62
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
63
+ t.start()
64
+
65
+ outputs = []
66
+ for text in streamer:
67
+ outputs.append(text)
68
+ yield "".join(outputs)
69
+
70
+
71
+ demo = gr.ChatInterface(
72
+ fn=generate,
73
+ additional_inputs=[
74
+ gr.Slider(
75
+ label="Max new tokens",
76
+ minimum=1,
77
+ maximum=MAX_MAX_NEW_TOKENS,
78
+ step=1,
79
+ value=DEFAULT_MAX_NEW_TOKENS,
80
+ ),
81
+ gr.Slider(
82
+ label="Temperature",
83
+ minimum=0.1,
84
+ maximum=4.0,
85
+ step=0.1,
86
+ value=0.6,
87
+ ),
88
+ gr.Slider(
89
+ label="Top-p (nucleus sampling)",
90
+ minimum=0.05,
91
+ maximum=1.0,
92
+ step=0.05,
93
+ value=0.9,
94
+ ),
95
+ gr.Slider(
96
+ label="Top-k",
97
+ minimum=1,
98
+ maximum=1000,
99
+ step=1,
100
+ value=50,
101
+ ),
102
+ gr.Slider(
103
+ label="Repetition penalty",
104
+ minimum=1.0,
105
+ maximum=2.0,
106
+ step=0.05,
107
+ value=1.2,
108
+ ),
109
+ ],
110
+ stop_btn=None,
111
+ examples=[
112
+ ["Привет! Кто ты и кто тебя создал?"],
113
+ ["Можете вкратце объяснить, что такое язык программирования Python?"],
114
+ ["Объясните сюжет «Золушки» одним предложением."],
115
+ ["Сколько часов потребуется человеку, чтобы съесть вертолет?"],
116
+ ["Напишите статью объемом 100 слов на тему «Преимущества открытого исходного кода в исследованиях в области искусственного интеллекта»."],
117
+ ],
118
+ cache_examples=False,
119
+ type="messages",
120
+ description=DESCRIPTION,
121
+ css_paths="style.css",
122
+ fill_height=True,
123
+ )
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  if __name__ == "__main__":
127
+ demo.launch()