Spaces:
Sleeping
Sleeping
File size: 8,460 Bytes
7136781 d53a92a 7136781 d53a92a 7136781 d53a92a 7136781 d53a92a 7136781 d53a92a 7136781 d53a92a 7136781 d53a92a 7136781 d53a92a 7136781 d53a92a 7136781 d53a92a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import gradio as gr
from decoderOnly import TransformerRun
from transformers import AutoTokenizer
import torch
import os
class ChatBot:
def __init__(self, model_path="."):
"""
Инициализация бота.
В Space файлы модели должны находиться в корневой директории.
"""
print(f"Загрузка модели из: {model_path}")
try:
# Загружаем токенизатор
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
print("Токенизатор загружен успешно.")
# Если у токенизатора нет pad_token, устанавливаем его
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "[PAD]"
print(f"Установлен pad_token: {self.tokenizer.pad_token}")
# Создаем модель с параметрами токенизатора
self.model = TransformerRun(
vocabSize=len(self.tokenizer),
maxLong=256,
sizeVector=128,
block=2
)
# Загружаем веса модели (в Space файл будет в корне)
weights_path = f"{model_path}/model_weights.pth"
if not os.path.exists(weights_path):
# Пробуем найти веса без подпапки
weights_path = "model_weights.pth"
print(f"Загрузка весов из: {weights_path}")
self.model.load_state_dict(
torch.load(weights_path, map_location='cpu', weights_only=True)
)
# Настраиваем устройство
self.device = torch.device("cpu")
self.model.to(self.device)
self.model.eval()
print("Модель загружена и готова к работе!")
except Exception as e:
print(f"Ошибка при загрузке модели: {e}")
raise
def generate(self, prompt, max_length=100, temperature=0.5, top_k=50):
"""
Генерация ответа на промпт пользователя.
"""
try:
if not prompt or prompt.strip() == "":
return "Пожалуйста, введите сообщение."
print(f"Генерация ответа для промпта: '{prompt[:50]}...'")
# Токенизируем промпт
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=200)
input_ids = inputs["input_ids"].to(self.device)
# Если последовательность пустая после токенизации
if input_ids.size(1) == 0:
return "Не удалось обработать запрос. Попробуйте другие слова."
generated_ids = input_ids.clone()
with torch.no_grad():
for _ in range(max_length):
# Прямой проход модели
outputs = self.model(generated_ids)
logits = outputs[0, -1, :] / temperature # учитываем температуру
# Top-k sampling
if top_k > 0:
topk_values, topk_indices = torch.topk(logits, min(top_k, logits.size(-1)))
probs = torch.zeros_like(logits).scatter(0, topk_indices, torch.softmax(topk_values, dim=-1))
else:
probs = torch.softmax(logits, dim=-1)
# Выбираем следующий токен
next_token = torch.multinomial(probs, num_samples=1)
# Добавляем к сгенерированной последовательности
generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1)
# Останавливаемся на EOS или PAD
stop_tokens = []
if self.tokenizer.eos_token_id is not None:
stop_tokens.append(self.tokenizer.eos_token_id)
if self.tokenizer.pad_token_id is not None:
stop_tokens.append(self.tokenizer.pad_token_id)
if next_token.item() in stop_tokens:
print(f"Остановка на токене: {next_token.item()}")
break
# Декодируем обратно в текст
response = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# Убираем оригинальный промпт из ответа
if response.startswith(prompt):
response = response[len(prompt):].strip()
print(f"Сгенерирован ответ длиной {len(response)} символов.")
return response
except Exception as e:
print(f"Ошибка при генерации: {e}")
return f"Произошла ошибка: {str(e)}"
def create_interface():
"""
Создание Gradio интерфейса.
"""
try:
# Инициализируем бота
# В Space модель будет находиться в корневой директории
bot = ChatBot(model_path=".")
print("Интерфейс запускается...")
def respond(message, history):
"""
Функция для обработки сообщений в интерфейсе чата.
"""
# history содержит предыдущие сообщения в формате [[user1, bot1], [user2, bot2], ...]
# Мы будем генерировать ответ только на последнее сообщение
response = bot.generate(
prompt=message,
max_length=100,
temperature=0.7,
top_k=50
)
return response
# Создаем интерфейс чата
demo = gr.ChatInterface(
fn=respond,
title="BasicSmall ChatBot",
description="Демонстрация модели MarkProMaster229/BasicSmall. Напишите сообщение и нажмите Submit.",
examples=["Привет!", "Расскажи что-нибудь интересное", "Как дела?"],
theme="soft"
)
return demo
except Exception as e:
print(f"Критическая ошибка при создании интерфейса: {e}")
# Создаем простой интерфейс с сообщением об ошибке
def error_response(message, history):
return f"Модель не загружена. Ошибка: {str(e)}"
return gr.ChatInterface(
fn=error_response,
title="BasicSmall ChatBot (Ошибка)",
description="Не удалось загрузить модель. Проверьте файлы модели."
)
# Создаем и запускаем интерфейс
if __name__ == "__main__":
# Устанавливаем уровень логирования
import logging
logging.basicConfig(level=logging.INFO)
# Создаем интерфейс
demo = create_interface()
# Запускаем с параметрами для Hugging Face Spaces
demo.launch(
server_name="0.0.0.0", # Обязательно для Spaces
server_port=7860, # Стандартный порт для Spaces
share=False # Не создавать публичную ссылку (в Spaces это не нужно)
) |