Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from decoderOnly import TransformerRun | |
| from transformers import AutoTokenizer | |
| import torch | |
| import os | |
| class ChatBot: | |
| def __init__(self, model_path="."): | |
| """ | |
| Инициализация бота. | |
| В Space файлы модели должны находиться в корневой директории. | |
| """ | |
| print(f"Загрузка модели из: {model_path}") | |
| try: | |
| # Загружаем токенизатор | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| print("Токенизатор загружен успешно.") | |
| # Если у токенизатора нет pad_token, устанавливаем его | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "[PAD]" | |
| print(f"Установлен pad_token: {self.tokenizer.pad_token}") | |
| # Создаем модель с параметрами токенизатора | |
| self.model = TransformerRun( | |
| vocabSize=len(self.tokenizer), | |
| maxLong=256, | |
| sizeVector=128, | |
| block=2 | |
| ) | |
| # Загружаем веса модели (в Space файл будет в корне) | |
| weights_path = f"{model_path}/model_weights.pth" | |
| if not os.path.exists(weights_path): | |
| # Пробуем найти веса без подпапки | |
| weights_path = "model_weights.pth" | |
| print(f"Загрузка весов из: {weights_path}") | |
| self.model.load_state_dict( | |
| torch.load(weights_path, map_location='cpu', weights_only=True) | |
| ) | |
| # Настраиваем устройство | |
| self.device = torch.device("cpu") | |
| self.model.to(self.device) | |
| self.model.eval() | |
| print("Модель загружена и готова к работе!") | |
| except Exception as e: | |
| print(f"Ошибка при загрузке модели: {e}") | |
| raise | |
| def generate(self, prompt, max_length=100, temperature=0.5, top_k=50): | |
| """ | |
| Генерация ответа на промпт пользователя. | |
| """ | |
| try: | |
| if not prompt or prompt.strip() == "": | |
| return "Пожалуйста, введите сообщение." | |
| print(f"Генерация ответа для промпта: '{prompt[:50]}...'") | |
| # Токенизируем промпт | |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=200) | |
| input_ids = inputs["input_ids"].to(self.device) | |
| # Если последовательность пустая после токенизации | |
| if input_ids.size(1) == 0: | |
| return "Не удалось обработать запрос. Попробуйте другие слова." | |
| generated_ids = input_ids.clone() | |
| with torch.no_grad(): | |
| for _ in range(max_length): | |
| # Прямой проход модели | |
| outputs = self.model(generated_ids) | |
| logits = outputs[0, -1, :] / temperature # учитываем температуру | |
| # Top-k sampling | |
| if top_k > 0: | |
| topk_values, topk_indices = torch.topk(logits, min(top_k, logits.size(-1))) | |
| probs = torch.zeros_like(logits).scatter(0, topk_indices, torch.softmax(topk_values, dim=-1)) | |
| else: | |
| probs = torch.softmax(logits, dim=-1) | |
| # Выбираем следующий токен | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| # Добавляем к сгенерированной последовательности | |
| generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1) | |
| # Останавливаемся на EOS или PAD | |
| stop_tokens = [] | |
| if self.tokenizer.eos_token_id is not None: | |
| stop_tokens.append(self.tokenizer.eos_token_id) | |
| if self.tokenizer.pad_token_id is not None: | |
| stop_tokens.append(self.tokenizer.pad_token_id) | |
| if next_token.item() in stop_tokens: | |
| print(f"Остановка на токене: {next_token.item()}") | |
| break | |
| # Декодируем обратно в текст | |
| response = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
| # Убираем оригинальный промпт из ответа | |
| if response.startswith(prompt): | |
| response = response[len(prompt):].strip() | |
| print(f"Сгенерирован ответ длиной {len(response)} символов.") | |
| return response | |
| except Exception as e: | |
| print(f"Ошибка при генерации: {e}") | |
| return f"Произошла ошибка: {str(e)}" | |
| def create_interface(): | |
| """ | |
| Создание Gradio интерфейса. | |
| """ | |
| try: | |
| # Инициализируем бота | |
| # В Space модель будет находиться в корневой директории | |
| bot = ChatBot(model_path=".") | |
| print("Интерфейс запускается...") | |
| def respond(message, history): | |
| """ | |
| Функция для обработки сообщений в интерфейсе чата. | |
| """ | |
| # history содержит предыдущие сообщения в формате [[user1, bot1], [user2, bot2], ...] | |
| # Мы будем генерировать ответ только на последнее сообщение | |
| response = bot.generate( | |
| prompt=message, | |
| max_length=100, | |
| temperature=0.7, | |
| top_k=50 | |
| ) | |
| return response | |
| # Создаем интерфейс чата | |
| demo = gr.ChatInterface( | |
| fn=respond, | |
| title="BasicSmall ChatBot", | |
| description="Демонстрация модели MarkProMaster229/BasicSmall. Напишите сообщение и нажмите Submit.", | |
| examples=["Привет!", "Расскажи что-нибудь интересное", "Как дела?"], | |
| theme="soft" | |
| ) | |
| return demo | |
| except Exception as e: | |
| print(f"Критическая ошибка при создании интерфейса: {e}") | |
| # Создаем простой интерфейс с сообщением об ошибке | |
| def error_response(message, history): | |
| return f"Модель не загружена. Ошибка: {str(e)}" | |
| return gr.ChatInterface( | |
| fn=error_response, | |
| title="BasicSmall ChatBot (Ошибка)", | |
| description="Не удалось загрузить модель. Проверьте файлы модели." | |
| ) | |
| # Создаем и запускаем интерфейс | |
| if __name__ == "__main__": | |
| # Устанавливаем уровень логирования | |
| import logging | |
| logging.basicConfig(level=logging.INFO) | |
| # Создаем интерфейс | |
| demo = create_interface() | |
| # Запускаем с параметрами для Hugging Face Spaces | |
| demo.launch( | |
| server_name="0.0.0.0", # Обязательно для Spaces | |
| server_port=7860, # Стандартный порт для Spaces | |
| share=False # Не создавать публичную ссылку (в Spaces это не нужно) | |
| ) |