BasicSmall / app.py
MarkProMaster229's picture
Update app.py
d53a92a verified
import gradio as gr
from decoderOnly import TransformerRun
from transformers import AutoTokenizer
import torch
import os
class ChatBot:
def __init__(self, model_path="."):
"""
Инициализация бота.
В Space файлы модели должны находиться в корневой директории.
"""
print(f"Загрузка модели из: {model_path}")
try:
# Загружаем токенизатор
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
print("Токенизатор загружен успешно.")
# Если у токенизатора нет pad_token, устанавливаем его
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "[PAD]"
print(f"Установлен pad_token: {self.tokenizer.pad_token}")
# Создаем модель с параметрами токенизатора
self.model = TransformerRun(
vocabSize=len(self.tokenizer),
maxLong=256,
sizeVector=128,
block=2
)
# Загружаем веса модели (в Space файл будет в корне)
weights_path = f"{model_path}/model_weights.pth"
if not os.path.exists(weights_path):
# Пробуем найти веса без подпапки
weights_path = "model_weights.pth"
print(f"Загрузка весов из: {weights_path}")
self.model.load_state_dict(
torch.load(weights_path, map_location='cpu', weights_only=True)
)
# Настраиваем устройство
self.device = torch.device("cpu")
self.model.to(self.device)
self.model.eval()
print("Модель загружена и готова к работе!")
except Exception as e:
print(f"Ошибка при загрузке модели: {e}")
raise
def generate(self, prompt, max_length=100, temperature=0.5, top_k=50):
"""
Генерация ответа на промпт пользователя.
"""
try:
if not prompt or prompt.strip() == "":
return "Пожалуйста, введите сообщение."
print(f"Генерация ответа для промпта: '{prompt[:50]}...'")
# Токенизируем промпт
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=200)
input_ids = inputs["input_ids"].to(self.device)
# Если последовательность пустая после токенизации
if input_ids.size(1) == 0:
return "Не удалось обработать запрос. Попробуйте другие слова."
generated_ids = input_ids.clone()
with torch.no_grad():
for _ in range(max_length):
# Прямой проход модели
outputs = self.model(generated_ids)
logits = outputs[0, -1, :] / temperature # учитываем температуру
# Top-k sampling
if top_k > 0:
topk_values, topk_indices = torch.topk(logits, min(top_k, logits.size(-1)))
probs = torch.zeros_like(logits).scatter(0, topk_indices, torch.softmax(topk_values, dim=-1))
else:
probs = torch.softmax(logits, dim=-1)
# Выбираем следующий токен
next_token = torch.multinomial(probs, num_samples=1)
# Добавляем к сгенерированной последовательности
generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1)
# Останавливаемся на EOS или PAD
stop_tokens = []
if self.tokenizer.eos_token_id is not None:
stop_tokens.append(self.tokenizer.eos_token_id)
if self.tokenizer.pad_token_id is not None:
stop_tokens.append(self.tokenizer.pad_token_id)
if next_token.item() in stop_tokens:
print(f"Остановка на токене: {next_token.item()}")
break
# Декодируем обратно в текст
response = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# Убираем оригинальный промпт из ответа
if response.startswith(prompt):
response = response[len(prompt):].strip()
print(f"Сгенерирован ответ длиной {len(response)} символов.")
return response
except Exception as e:
print(f"Ошибка при генерации: {e}")
return f"Произошла ошибка: {str(e)}"
def create_interface():
"""
Создание Gradio интерфейса.
"""
try:
# Инициализируем бота
# В Space модель будет находиться в корневой директории
bot = ChatBot(model_path=".")
print("Интерфейс запускается...")
def respond(message, history):
"""
Функция для обработки сообщений в интерфейсе чата.
"""
# history содержит предыдущие сообщения в формате [[user1, bot1], [user2, bot2], ...]
# Мы будем генерировать ответ только на последнее сообщение
response = bot.generate(
prompt=message,
max_length=100,
temperature=0.7,
top_k=50
)
return response
# Создаем интерфейс чата
demo = gr.ChatInterface(
fn=respond,
title="BasicSmall ChatBot",
description="Демонстрация модели MarkProMaster229/BasicSmall. Напишите сообщение и нажмите Submit.",
examples=["Привет!", "Расскажи что-нибудь интересное", "Как дела?"],
theme="soft"
)
return demo
except Exception as e:
print(f"Критическая ошибка при создании интерфейса: {e}")
# Создаем простой интерфейс с сообщением об ошибке
def error_response(message, history):
return f"Модель не загружена. Ошибка: {str(e)}"
return gr.ChatInterface(
fn=error_response,
title="BasicSmall ChatBot (Ошибка)",
description="Не удалось загрузить модель. Проверьте файлы модели."
)
# Создаем и запускаем интерфейс
if __name__ == "__main__":
# Устанавливаем уровень логирования
import logging
logging.basicConfig(level=logging.INFO)
# Создаем интерфейс
demo = create_interface()
# Запускаем с параметрами для Hugging Face Spaces
demo.launch(
server_name="0.0.0.0", # Обязательно для Spaces
server_port=7860, # Стандартный порт для Spaces
share=False # Не создавать публичную ссылку (в Spaces это не нужно)
)