Kenan023214 commited on
Commit
67c514f
·
verified ·
1 Parent(s): 6694a15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -34
app.py CHANGED
@@ -1,55 +1,38 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
- from huggingface_hub import hf_hub_download
5
  from functools import lru_cache
6
 
7
- # --- Hugging Face Space Configuration ---
 
8
  MODEL_NAME = "Kenan023214/PyroNet-mini"
9
- DEVICE = "cpu" # Use CPU for basic Space
10
- MAX_NEW_TOKENS = 1024
11
  MAX_CONTEXT_TOKENS = 2048
12
 
13
- # Dictionary to store the full paths of downloaded templates
14
- TEMPLATE_PATHS = {}
15
-
16
  @lru_cache(maxsize=1)
17
  def load_model():
18
- """Loads the model and tokenizer, caching them for performance."""
19
  print("Loading model and tokenizer...")
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
  model = AutoModelForCausalLM.from_pretrained(
22
  MODEL_NAME,
23
  device_map=DEVICE,
24
- torch_dtype=torch.float32 # Use float32 for CPU compatibility
25
  )
26
  print("Model loaded.")
27
  return tokenizer, model
28
 
29
- def download_templates():
30
- """Downloads template files from the model repository and stores their paths."""
31
- print("Downloading chat templates...")
32
- for lang in ["ru", "en", "uk"]:
33
- filename = f"chat_template_{lang}.jinja"
34
- file_path = hf_hub_download(
35
- repo_id=MODEL_NAME,
36
- filename=filename,
37
- local_dir=".",
38
- local_dir_use_symlinks=False
39
- )
40
- TEMPLATE_PATHS[lang] = file_path
41
- print("Templates downloaded.")
42
-
43
  tokenizer, model = load_model()
44
- download_templates()
45
 
46
- # --- Utilities ---
47
  def num_tokens_of_text(text: str) -> int:
48
- """Approximate number of tokens for a given text."""
49
  return len(tokenizer.encode(text, add_special_tokens=False))
50
 
51
  def trim_history_to_max_tokens(messages, max_tokens):
52
- """Trims the message history to fit within a token limit."""
53
  rev = list(reversed(messages))
54
  total = 0
55
  kept = []
@@ -62,7 +45,7 @@ def trim_history_to_max_tokens(messages, max_tokens):
62
  return list(reversed(kept))
63
 
64
  def build_messages_for_template(history_messages, reasoning: bool, language: str):
65
- """Prepares messages for the chat template."""
66
  if language == 'ru':
67
  system_message = "Ты — дружелюбный ассистент, который говорит на русском. Отвечай кратко, но по делу."
68
  reasoning_instruction = ("[REASONING MODE]\n"
@@ -87,7 +70,7 @@ def build_messages_for_template(history_messages, reasoning: bool, language: str
87
  return messages
88
 
89
  def extract_assistant_reply(raw_generated_text: str) -> str:
90
- """Removes extra tokens and returns only the assistant's reply."""
91
  text = raw_generated_text
92
  if "<|assistant|>" in text:
93
  text = text.split("<|assistant|>")[-1]
@@ -95,9 +78,9 @@ def extract_assistant_reply(raw_generated_text: str) -> str:
95
  text = text.replace(tag, "")
96
  return text.strip()
97
 
98
- # --- Main function for Gradio ---
99
  def generate_response(user_text: str, history, reasoning: bool, language: str):
100
- """Processes user input and generates a response."""
101
 
102
  history.append({"role": "user", "content": user_text})
103
 
@@ -105,8 +88,8 @@ def generate_response(user_text: str, history, reasoning: bool, language: str):
105
 
106
  messages_for_template = build_messages_for_template(trimmed_history, reasoning, language)
107
 
108
- # Use the full path from the TEMPLATE_PATHS dictionary
109
- template_file = TEMPLATE_PATHS.get(language, TEMPLATE_PATHS["en"])
110
 
111
  text = tokenizer.apply_chat_template(
112
  messages_for_template,
@@ -134,7 +117,7 @@ def generate_response(user_text: str, history, reasoning: bool, language: str):
134
 
135
  return "", history
136
 
137
- # --- Gradio Interface ---
138
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
139
  gr.Markdown("# PyroNet-mini Chat")
140
  gr.Markdown("A demonstration of PyroNet-mini with multilingual templates and a reasoning mode.")
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  from functools import lru_cache
5
 
6
+ # --- Конфигурация Hugging Face Space ---
7
+ # Загрузка модели и токенизатора один раз при запуске приложения
8
  MODEL_NAME = "Kenan023214/PyroNet-mini"
9
+ DEVICE = "cpu" # Используем CPU, как указано для Basic Space
10
+ MAX_NEW_TOKENS = 256
11
  MAX_CONTEXT_TOKENS = 2048
12
 
13
+ # Загрузка модели и токенизатора
 
 
14
  @lru_cache(maxsize=1)
15
  def load_model():
16
+ """Загружает модель и токенайзер, кешируя их для производительности."""
17
  print("Loading model and tokenizer...")
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
  model = AutoModelForCausalLM.from_pretrained(
20
  MODEL_NAME,
21
  device_map=DEVICE,
22
+ torch_dtype=torch.float32 # Используем float32 для совместимости с CPU
23
  )
24
  print("Model loaded.")
25
  return tokenizer, model
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  tokenizer, model = load_model()
 
28
 
29
+ # --- Утилиты ---
30
  def num_tokens_of_text(text: str) -> int:
31
+ """Приближённое количество токенов для заданного текста."""
32
  return len(tokenizer.encode(text, add_special_tokens=False))
33
 
34
  def trim_history_to_max_tokens(messages, max_tokens):
35
+ """Обрезает историю сообщений, чтобы она соответствовала лимиту токенов."""
36
  rev = list(reversed(messages))
37
  total = 0
38
  kept = []
 
45
  return list(reversed(kept))
46
 
47
  def build_messages_for_template(history_messages, reasoning: bool, language: str):
48
+ """Подготавливает сообщения для шаблона, включая системное сообщение."""
49
  if language == 'ru':
50
  system_message = "Ты — дружелюбный ассистент, который говорит на русском. Отвечай кратко, но по делу."
51
  reasoning_instruction = ("[REASONING MODE]\n"
 
70
  return messages
71
 
72
  def extract_assistant_reply(raw_generated_text: str) -> str:
73
+ """Убирает лишние токены и возвращает только ответ ассистента."""
74
  text = raw_generated_text
75
  if "<|assistant|>" in text:
76
  text = text.split("<|assistant|>")[-1]
 
78
  text = text.replace(tag, "")
79
  return text.strip()
80
 
81
+ # --- Основная функция для Gradio ---
82
  def generate_response(user_text: str, history, reasoning: bool, language: str):
83
+ """Обрабатывает пользовательский запрос и генерирует ответ."""
84
 
85
  history.append({"role": "user", "content": user_text})
86
 
 
88
 
89
  messages_for_template = build_messages_for_template(trimmed_history, reasoning, language)
90
 
91
+ # Выбираем шаблон из файлов в репозитории
92
+ template_file = f"chat_template_{language}.jinja"
93
 
94
  text = tokenizer.apply_chat_template(
95
  messages_for_template,
 
117
 
118
  return "", history
119
 
120
+ # --- Интерфейс Gradio ---
121
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
122
  gr.Markdown("# PyroNet-mini Chat")
123
  gr.Markdown("A demonstration of PyroNet-mini with multilingual templates and a reasoning mode.")