import gradio as gr from decoderOnly import TransformerRun from transformers import AutoTokenizer import torch import os class ChatBot: def __init__(self, model_path="."): """ Инициализация бота. В Space файлы модели должны находиться в корневой директории. """ print(f"Загрузка модели из: {model_path}") try: # Загружаем токенизатор self.tokenizer = AutoTokenizer.from_pretrained(model_path) print("Токенизатор загружен успешно.") # Если у токенизатора нет pad_token, устанавливаем его if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "[PAD]" print(f"Установлен pad_token: {self.tokenizer.pad_token}") # Создаем модель с параметрами токенизатора self.model = TransformerRun( vocabSize=len(self.tokenizer), maxLong=256, sizeVector=128, block=2 ) # Загружаем веса модели (в Space файл будет в корне) weights_path = f"{model_path}/model_weights.pth" if not os.path.exists(weights_path): # Пробуем найти веса без подпапки weights_path = "model_weights.pth" print(f"Загрузка весов из: {weights_path}") self.model.load_state_dict( torch.load(weights_path, map_location='cpu', weights_only=True) ) # Настраиваем устройство self.device = torch.device("cpu") self.model.to(self.device) self.model.eval() print("Модель загружена и готова к работе!") except Exception as e: print(f"Ошибка при загрузке модели: {e}") raise def generate(self, prompt, max_length=100, temperature=0.5, top_k=50): """ Генерация ответа на промпт пользователя. """ try: if not prompt or prompt.strip() == "": return "Пожалуйста, введите сообщение." print(f"Генерация ответа для промпта: '{prompt[:50]}...'") # Токенизируем промпт inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=200) input_ids = inputs["input_ids"].to(self.device) # Если последовательность пустая после токенизации if input_ids.size(1) == 0: return "Не удалось обработать запрос. Попробуйте другие слова." generated_ids = input_ids.clone() with torch.no_grad(): for _ in range(max_length): # Прямой проход модели outputs = self.model(generated_ids) logits = outputs[0, -1, :] / temperature # учитываем температуру # Top-k sampling if top_k > 0: topk_values, topk_indices = torch.topk(logits, min(top_k, logits.size(-1))) probs = torch.zeros_like(logits).scatter(0, topk_indices, torch.softmax(topk_values, dim=-1)) else: probs = torch.softmax(logits, dim=-1) # Выбираем следующий токен next_token = torch.multinomial(probs, num_samples=1) # Добавляем к сгенерированной последовательности generated_ids = torch.cat([generated_ids, next_token.unsqueeze(0)], dim=1) # Останавливаемся на EOS или PAD stop_tokens = [] if self.tokenizer.eos_token_id is not None: stop_tokens.append(self.tokenizer.eos_token_id) if self.tokenizer.pad_token_id is not None: stop_tokens.append(self.tokenizer.pad_token_id) if next_token.item() in stop_tokens: print(f"Остановка на токене: {next_token.item()}") break # Декодируем обратно в текст response = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) # Убираем оригинальный промпт из ответа if response.startswith(prompt): response = response[len(prompt):].strip() print(f"Сгенерирован ответ длиной {len(response)} символов.") return response except Exception as e: print(f"Ошибка при генерации: {e}") return f"Произошла ошибка: {str(e)}" def create_interface(): """ Создание Gradio интерфейса. """ try: # Инициализируем бота # В Space модель будет находиться в корневой директории bot = ChatBot(model_path=".") print("Интерфейс запускается...") def respond(message, history): """ Функция для обработки сообщений в интерфейсе чата. """ # history содержит предыдущие сообщения в формате [[user1, bot1], [user2, bot2], ...] # Мы будем генерировать ответ только на последнее сообщение response = bot.generate( prompt=message, max_length=100, temperature=0.7, top_k=50 ) return response # Создаем интерфейс чата demo = gr.ChatInterface( fn=respond, title="BasicSmall ChatBot", description="Демонстрация модели MarkProMaster229/BasicSmall. Напишите сообщение и нажмите Submit.", examples=["Привет!", "Расскажи что-нибудь интересное", "Как дела?"], theme="soft" ) return demo except Exception as e: print(f"Критическая ошибка при создании интерфейса: {e}") # Создаем простой интерфейс с сообщением об ошибке def error_response(message, history): return f"Модель не загружена. Ошибка: {str(e)}" return gr.ChatInterface( fn=error_response, title="BasicSmall ChatBot (Ошибка)", description="Не удалось загрузить модель. Проверьте файлы модели." ) # Создаем и запускаем интерфейс if __name__ == "__main__": # Устанавливаем уровень логирования import logging logging.basicConfig(level=logging.INFO) # Создаем интерфейс demo = create_interface() # Запускаем с параметрами для Hugging Face Spaces demo.launch( server_name="0.0.0.0", # Обязательно для Spaces server_port=7860, # Стандартный порт для Spaces share=False # Не создавать публичную ссылку (в Spaces это не нужно) )