MarkProMaster229 commited on
Commit
d53a92a
·
verified ·
1 Parent(s): 807907a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -60
app.py CHANGED
@@ -1,70 +1,180 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
 
25
- response = ""
 
 
 
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
-
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
67
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from decoderOnly import TransformerRun
3
+ from transformers import AutoTokenizer
4
+ import torch
5
+ import os
6
 
7
+ class ChatBot:
8
+ def __init__(self, model_path="."):
9
+ """
10
+ Инициализация бота.
11
+ В Space файлы модели должны находиться в корневой директории.
12
+ """
13
+ print(f"Загрузка модели из: {model_path}")
14
+
15
+ try:
16
+ # Загружаем токенизатор
17
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
18
+ print("Токенизатор загружен успешно.")
19
+
20
+ # Если у токенизатора нет pad_token, устанавливаем его
21
+ if self.tokenizer.pad_token is None:
22
+ self.tokenizer.pad_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "[PAD]"
23
+ print(f"Установлен pad_token: {self.tokenizer.pad_token}")
24
+
25
+ # Создаем модель с параметрами токенизатора
26
+ self.model = TransformerRun(
27
+ vocabSize=len(self.tokenizer),
28
+ maxLong=256,
29
+ sizeVector=128,
30
+ block=2
31
+ )
32
+
33
+ # Загружаем веса модели (в Space файл будет в корне)
34
+ weights_path = f"{model_path}/model_weights.pth"
35
+ if not os.path.exists(weights_path):
36
+ # Пробуем найти веса без подпапки
37
+ weights_path = "model_weights.pth"
38
+
39
+ print(f"Загрузка весов из: {weights_path}")
40
+ self.model.load_state_dict(
41
+ torch.load(weights_path, map_location='cpu', weights_only=True)
42
+ )
43
+
44
+ # Настраиваем устройство
45
+ self.device = torch.device("cpu")
46
+ self.model.to(self.device)
47
+ self.model.eval()
48
+ print("Модель загружена и готова к работе!")
49
+
50
+ except Exception as e:
51
+ print(f"Ошибка при загрузке модели: {e}")
52
+ raise
53
 
54
+ def generate(self, prompt, max_length=100, temperature=0.5, top_k=50):
55
+ """
56
+ Генерация ответа на промпт пользователя.
57
+ """
58
+ try:
59
+ if not prompt or prompt.strip() == "":
60
+ return "Пожалуйста, введите сообщение."
61
+
62
+ print(f"Генерация ответа для промпта: '{prompt[:50]}...'")
63
+
64
+ # Токенизируем промпт
65
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=200)
66
+ input_ids = inputs["input_ids"].to(self.device)
67
+
68
+ # Если последовательность пустая после токенизации
69
+ if input_ids.size(1) == 0:
70
+ return "Не удалось обработать запрос. Попробуйте другие слова."
71
+
72
+ generated_ids = input_ids.clone()
73
 
74
+ with torch.no_grad():
75
+ for _ in range(max_length):
76
+ # Прямой проход модели
77
+ outputs = self.model(generated_ids)
78
+ logits = outputs[0, -1, :] / temperature # учитываем температуру
79
 
80
+ # Top-k sampling
81
+ if top_k > 0:
82
+ topk_values, topk_indices = torch.topk(logits, min(top_k, logits.size(-1)))
83
+ probs = torch.zeros_like(logits).scatter(0, topk_indices, torch.softmax(topk_values, dim=-1))
84
+ else:
85
+ probs = torch.softmax(logits, dim=-1)
 
 
 
 
 
86
 
87
+ # Выбираем следующий токен
88
+ next_token = torch.multinomial(probs, num_samples=1)
89
+
90
+ # Добавляем к сгенерированной последовательности
91
+ generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1)
92
+
93
+ # Останавливаемся на EOS или PAD
94
+ stop_tokens = []
95
+ if self.tokenizer.eos_token_id is not None:
96
+ stop_tokens.append(self.tokenizer.eos_token_id)
97
+ if self.tokenizer.pad_token_id is not None:
98
+ stop_tokens.append(self.tokenizer.pad_token_id)
99
+
100
+ if next_token.item() in stop_tokens:
101
+ print(f"Остановка на токене: {next_token.item()}")
102
+ break
103
 
104
+ # Декодируем обратно в текст
105
+ response = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
106
+
107
+ # Убираем оригинальный промпт из ответа
108
+ if response.startswith(prompt):
109
+ response = response[len(prompt):].strip()
110
+
111
+ print(f"Сгенерирован ответ длиной {len(response)} символов.")
112
+ return response
113
+
114
+ except Exception as e:
115
+ print(f"Ошибка при генерации: {e}")
116
+ return f"Произошла ошибка: {str(e)}"
117
 
118
+ def create_interface():
119
+ """
120
+ Создание Gradio интерфейса.
121
+ """
122
+ try:
123
+ # Инициализируем бота
124
+ # В Space модель будет находиться в корневой директории
125
+ bot = ChatBot(model_path=".")
126
+ print("Интерфейс запускается...")
127
+
128
+ def respond(message, history):
129
+ """
130
+ Функция для обработки сообщений в интерфейсе чата.
131
+ """
132
+ # history содержит предыдущие сообщения в формате [[user1, bot1], [user2, bot2], ...]
133
+ # Мы будем генерировать ответ только на последнее сообщение
134
+ response = bot.generate(
135
+ prompt=message,
136
+ max_length=100,
137
+ temperature=0.7,
138
+ top_k=50
139
+ )
140
+ return response
141
+
142
+ # Создаем интерфейс чата
143
+ demo = gr.ChatInterface(
144
+ fn=respond,
145
+ title="BasicSmall ChatBot",
146
+ description="Демонстрация модели MarkProMaster229/BasicSmall. Напишите сообщение и нажмите Submit.",
147
+ examples=["Привет!", "Расскажи что-нибудь интересное", "Как дела?"],
148
+ theme="soft"
149
+ )
150
+
151
+ return demo
152
+
153
+ except Exception as e:
154
+ print(f"Критическая ошибка при создании интерфейса: {e}")
155
+
156
+ # Создаем простой интерфейс с сообщением об ошибке
157
+ def error_response(message, history):
158
+ return f"Модель не загружена. Ошибка: {str(e)}"
159
+
160
+ return gr.ChatInterface(
161
+ fn=error_response,
162
+ title="BasicSmall ChatBot (Ошибка)",
163
+ description="Не удалось загрузить модель. Проверьте файлы модели."
164
+ )
165
 
166
+ # Создаем и запускаем интерфейс
167
  if __name__ == "__main__":
168
+ # Устанавливаем уровень логирования
169
+ import logging
170
+ logging.basicConfig(level=logging.INFO)
171
+
172
+ # Создаем интерфейс
173
+ demo = create_interface()
174
+
175
+ # Запускаем с параметрами для Hugging Face Spaces
176
+ demo.launch(
177
+ server_name="0.0.0.0", # Обязательно для Spaces
178
+ server_port=7860, # Стандартный порт для Spaces
179
+ share=False # Не создавать публичную ссылку (в Spaces это не нужно)
180
+ )