Kenan023214 commited on
Commit
41ec8f3
·
verified ·
1 Parent(s): d17b162

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -40
app.py CHANGED
@@ -1,35 +1,32 @@
1
  import gradio as gr
2
  import torch
3
- import os
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  from huggingface_hub import hf_hub_download
6
  from functools import lru_cache
7
 
8
- # --- Конфигурация Hugging Face Space ---
9
- # Загрузка модели и токенизатора один раз при запуске приложения
10
  MODEL_NAME = "Kenan023214/PyroNet-mini"
11
- DEVICE = "cpu" # Используем CPU, как указано для Basic Space
12
  MAX_NEW_TOKENS = 256
13
  MAX_CONTEXT_TOKENS = 2048
14
 
15
- # Загрузка модели и токенизатора
16
  @lru_cache(maxsize=1)
17
  def load_model():
18
- """Загружает модель и токенайзер, кешируя их для производительности."""
19
  print("Loading model and tokenizer...")
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
  model = AutoModelForCausalLM.from_pretrained(
22
  MODEL_NAME,
23
  device_map=DEVICE,
24
- torch_dtype=torch.float32 # Используем float32 для совместимости с CPU
25
  )
26
  print("Model loaded.")
27
  return tokenizer, model
28
 
29
- # Загрузка файлов шаблонов из репозитория
30
  @lru_cache(maxsize=1)
31
  def download_templates():
32
- """Скачивает файлы шаблонов из репозитория модели."""
33
  print("Downloading chat templates...")
34
  for lang in ["ru", "en", "uk"]:
35
  hf_hub_download(
@@ -43,13 +40,13 @@ def download_templates():
43
  tokenizer, model = load_model()
44
  download_templates()
45
 
46
- # --- Утилиты ---
47
  def num_tokens_of_text(text: str) -> int:
48
- """Приближённое количество токенов."""
49
  return len(tokenizer.encode(text, add_special_tokens=False))
50
 
51
  def trim_history_to_max_tokens(messages, max_tokens):
52
- """Обрезает историю сообщений."""
53
  rev = list(reversed(messages))
54
  total = 0
55
  kept = []
@@ -62,7 +59,7 @@ def trim_history_to_max_tokens(messages, max_tokens):
62
  return list(reversed(kept))
63
 
64
  def build_messages_for_template(history_messages, reasoning: bool, language: str):
65
- """Подготавливает сообщения для шаблона."""
66
  if language == 'ru':
67
  system_message = "Ты — дружелюбный ассистент, который говорит на русском. Отвечай кратко, но по делу."
68
  reasoning_instruction = ("[REASONING MODE]\n"
@@ -87,7 +84,7 @@ def build_messages_for_template(history_messages, reasoning: bool, language: str
87
  return messages
88
 
89
  def extract_assistant_reply(raw_generated_text: str) -> str:
90
- """Убирает лишние токены и оставляет только ответ ассистента."""
91
  text = raw_generated_text
92
  if "<|assistant|>" in text:
93
  text = text.split("<|assistant|>")[-1]
@@ -95,23 +92,19 @@ def extract_assistant_reply(raw_generated_text: str) -> str:
95
  text = text.replace(tag, "")
96
  return text.strip()
97
 
98
- # --- Основная функция для Gradio ---
99
  def generate_response(user_text: str, history, reasoning: bool, language: str):
100
- """Обрабатывает пользовательский запрос и генерирует ответ."""
101
 
102
- # Добавляем user-сообщение в историю
103
  history.append({"role": "user", "content": user_text})
104
 
105
- # Подрезаем историю, чтобы вход не стал слишком большим
106
  trimmed_history = trim_history_to_max_tokens(history, MAX_CONTEXT_TOKENS)
107
 
108
- # Собираем messages с возможной инструкцией reasoning
109
  messages_for_template = build_messages_for_template(trimmed_history, reasoning, language)
110
 
111
- # Выбираем шаблон из локальных файлов
112
  template_file = f"chat_template_{language}.jinja"
113
 
114
- # Применяем шаблон и токенизируем
115
  text = tokenizer.apply_chat_template(
116
  messages_for_template,
117
  template_path=template_file,
@@ -121,7 +114,6 @@ def generate_response(user_text: str, history, reasoning: bool, language: str):
121
 
122
  inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
123
 
124
- # Генерация
125
  with torch.no_grad():
126
  outputs = model.generate(
127
  **inputs,
@@ -132,48 +124,40 @@ def generate_response(user_text: str, history, reasoning: bool, language: str):
132
  pad_token_id=tokenizer.eos_token_id
133
  )
134
 
135
- # Декодируем и очищаем ответ
136
  raw = tokenizer.decode(outputs[0], skip_special_tokens=False)
137
  reply = extract_assistant_reply(raw)
138
 
139
- # Добавляем ассистента в историю
140
  history.append({"role": "assistant", "content": reply})
141
 
142
- # Gradio ожидает возвращение списка [пользователь, ассистент]
143
- # Мы возвращаем всю историю для корректного отображения
144
  return "", history
145
 
146
- # --- Интерфейс Gradio ---
147
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
148
  gr.Markdown("# PyroNet-mini Chat")
149
- gr.Markdown("Демонстрация работы PyroNet-mini (на базе Phi-4-mini-instruct) с кастомными шаблонами и режимом рассуждения.")
150
 
151
  chatbot = gr.Chatbot(height=500)
152
 
153
  with gr.Row():
154
  with gr.Column(scale=4):
155
  msg = gr.Textbox(
156
- label="Ваш запрос",
157
- placeholder="Напишите здесь...",
158
  container=False
159
  )
160
  with gr.Column(scale=1, min_width=100):
161
  language_dropdown = gr.Dropdown(
162
  choices=["ru", "en", "uk"],
163
- value="ru",
164
- label="Язык",
165
  container=False
166
  )
167
  reasoning_checkbox = gr.Checkbox(
168
- label="Включить режим рассуждения"
169
  )
170
 
171
- btn_send = gr.Button("Отправить")
172
- btn_clear = gr.Button("Очистить")
173
-
174
- # Обработчики событий
175
- def reset_history():
176
- return [], None
177
 
178
  btn_send.click(
179
  fn=generate_response,
@@ -193,4 +177,3 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
193
 
194
  if __name__ == "__main__":
195
  demo.launch()
196
-
 
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  from huggingface_hub import hf_hub_download
5
  from functools import lru_cache
6
 
7
+ # --- Hugging Face Space Configuration ---
8
+ # Load the model and tokenizer only once when the app starts
9
  MODEL_NAME = "Kenan023214/PyroNet-mini"
10
+ DEVICE = "cpu" # Use CPU for basic Space
11
  MAX_NEW_TOKENS = 256
12
  MAX_CONTEXT_TOKENS = 2048
13
 
 
14
  @lru_cache(maxsize=1)
15
  def load_model():
16
+ """Loads the model and tokenizer, caching them for performance."""
17
  print("Loading model and tokenizer...")
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
  model = AutoModelForCausalLM.from_pretrained(
20
  MODEL_NAME,
21
  device_map=DEVICE,
22
+ torch_dtype=torch.float32 # Use float32 for CPU compatibility
23
  )
24
  print("Model loaded.")
25
  return tokenizer, model
26
 
 
27
  @lru_cache(maxsize=1)
28
  def download_templates():
29
+ """Downloads template files from the model repository."""
30
  print("Downloading chat templates...")
31
  for lang in ["ru", "en", "uk"]:
32
  hf_hub_download(
 
40
  tokenizer, model = load_model()
41
  download_templates()
42
 
43
+ # --- Utilities ---
44
  def num_tokens_of_text(text: str) -> int:
45
+ """Approximate number of tokens for a given text."""
46
  return len(tokenizer.encode(text, add_special_tokens=False))
47
 
48
  def trim_history_to_max_tokens(messages, max_tokens):
49
+ """Trims the message history to fit within a token limit."""
50
  rev = list(reversed(messages))
51
  total = 0
52
  kept = []
 
59
  return list(reversed(kept))
60
 
61
  def build_messages_for_template(history_messages, reasoning: bool, language: str):
62
+ """Prepares messages for the chat template."""
63
  if language == 'ru':
64
  system_message = "Ты — дружелюбный ассистент, который говорит на русском. Отвечай кратко, но по делу."
65
  reasoning_instruction = ("[REASONING MODE]\n"
 
84
  return messages
85
 
86
  def extract_assistant_reply(raw_generated_text: str) -> str:
87
+ """Removes extra tokens and returns only the assistant's reply."""
88
  text = raw_generated_text
89
  if "<|assistant|>" in text:
90
  text = text.split("<|assistant|>")[-1]
 
92
  text = text.replace(tag, "")
93
  return text.strip()
94
 
95
+ # --- Main function for Gradio ---
96
  def generate_response(user_text: str, history, reasoning: bool, language: str):
97
+ """Processes user input and generates a response."""
98
 
 
99
  history.append({"role": "user", "content": user_text})
100
 
 
101
  trimmed_history = trim_history_to_max_tokens(history, MAX_CONTEXT_TOKENS)
102
 
 
103
  messages_for_template = build_messages_for_template(trimmed_history, reasoning, language)
104
 
105
+ # Select the template file from the local files
106
  template_file = f"chat_template_{language}.jinja"
107
 
 
108
  text = tokenizer.apply_chat_template(
109
  messages_for_template,
110
  template_path=template_file,
 
114
 
115
  inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
116
 
 
117
  with torch.no_grad():
118
  outputs = model.generate(
119
  **inputs,
 
124
  pad_token_id=tokenizer.eos_token_id
125
  )
126
 
 
127
  raw = tokenizer.decode(outputs[0], skip_special_tokens=False)
128
  reply = extract_assistant_reply(raw)
129
 
 
130
  history.append({"role": "assistant", "content": reply})
131
 
 
 
132
  return "", history
133
 
134
+ # --- Gradio Interface ---
135
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
136
  gr.Markdown("# PyroNet-mini Chat")
137
+ gr.Markdown("A demonstration of PyroNet-mini (based on a custom model) with multilingual templates and a reasoning mode.")
138
 
139
  chatbot = gr.Chatbot(height=500)
140
 
141
  with gr.Row():
142
  with gr.Column(scale=4):
143
  msg = gr.Textbox(
144
+ label="Your Prompt",
145
+ placeholder="Write your message here...",
146
  container=False
147
  )
148
  with gr.Column(scale=1, min_width=100):
149
  language_dropdown = gr.Dropdown(
150
  choices=["ru", "en", "uk"],
151
+ value="en",
152
+ label="Language",
153
  container=False
154
  )
155
  reasoning_checkbox = gr.Checkbox(
156
+ label="Enable Reasoning Mode"
157
  )
158
 
159
+ btn_send = gr.Button("Send")
160
+ btn_clear = gr.Button("Clear")
 
 
 
 
161
 
162
  btn_send.click(
163
  fn=generate_response,
 
177
 
178
  if __name__ == "__main__":
179
  demo.launch()