import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch # --- Глобальные переменные для модели и токенизатора --- model = None tokenizer = None model_name = "microsoft/DialoGPT-medium" # Можно вынести в константу # --- 1. Функция для загрузки модели и токенизатора --- # Эту функцию можно было бы вызывать один раз при старте, но для простоты # и учитывая, как Gradio работает со Spaces, оставим загрузку в глобальной области. # Однако, обернем в try-except для надежности. try: print(f"Загрузка токенизатора для: {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) print("Токенизатор загружен.") if tokenizer.pad_token_id is None: print(f"Установка tokenizer.pad_token_id = tokenizer.eos_token_id для {model_name}") tokenizer.pad_token_id = tokenizer.eos_token_id print(f"Загрузка модели: {model_name}") model = AutoModelForCausalLM.from_pretrained(model_name) # Для DialoGPT-medium device_map="auto" не сильно нужен, он и так на CPU поместится. # Если бы Space имел GPU, можно было бы model.to("cuda") print(f"Модель {model_name} загружена. Устройство: {model.device if model else 'N/A'}") except Exception as e: print(f"ОШИБКА ЗАГРУЗКИ МОДЕЛИ/ТОКЕНИЗАТОРА: {e}") # В этом случае model и tokenizer останутся None # --- 2. Функция-обработчик для ChatInterface --- # Принимает текущее сообщение и историю чата. # История будет в формате list[dict[str, str]], если Chatbot имеет type="messages". def simple_predict(message: str, history: list[dict[str, str]]): global model, tokenizer # Указываем, что используем глобальные переменные print(f"\nПолучено сообщение: '{message}'") print(f"История: {history}") if not model or not tokenizer: print("Модель или токенизатор не были загружены. Возвращаю ошибку.") return "Ошибка: Модель не загружена. Проверьте логи Space." # Формируем строку для DialoGPT: конкатенация всех предыдущих реплик и текущего сообщения # через eos_token. prompt_parts = [] for turn in history: # turn это {"role": "user/assistant", "content": "текст"} prompt_parts.append(turn["content"]) prompt_parts.append(message) # Добавляем текущее сообщение пользователя # Соединяем все части eos_token. # НЕ добавляем eos_token в самый конец строки здесь. # Токенизатор и/или model.generate сами обработают EOS для генерации. input_text = tokenizer.eos_token.join(prompt_parts) print(f"Сформированный промпт для модели: {repr(input_text)}") try: # Токенизация inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=1024) # Перемещаем на устройство модели (если оно отличается, хотя для CPU это не актуально) # inputs = {k: v.to(model.device) for k, v in inputs.items()} # Можно раскомментировать если используется GPU # Генерация ответа # max_length включает длину входного промпта + количество новых токенов # Для DialoGPT-medium, 150 новых токенов - разумно. print("Начало генерации...") outputs = model.generate( inputs.input_ids.to(model.device), # Убедимся, что тензоры на том же устройстве, что и модель attention_mask=inputs.attention_mask.to(model.device) if "attention_mask" in inputs else None, max_length=inputs.input_ids.shape[-1] + 150, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, do_sample=True, top_k=50, top_p=0.95, temperature=0.8 ) print("Генерация завершена.") # Декодирование только нового сгенерированного текста input_length = inputs.input_ids.shape[1] generated_tokens = outputs[0][input_length:] bot_response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() print(f"Ответ модели: '{bot_response}'") return bot_response except Exception as e: print(f"ОШИБКА ВНУТРИ PREDICT: {e}") import traceback traceback.print_exc() # Печатаем полный трейсбек ошибки в predict return "Произошла ошибка при обработке вашего запроса." # --- 3. Создание и запуск Gradio интерфейса --- # Используем самый простой ChatInterface без явного Chatbot компонента. # Gradio сам создаст Chatbot с type="messages" по умолчанию в новых версиях. demo = gr.ChatInterface( fn=simple_predict, title="Простой Чат с DialoGPT-medium", description="Введите сообщение для начала диалога.", examples=[["Привет!"], ["Как дела?"]], # Примеры должны быть списком списков строк cache_examples=False # Отключаем кэширование примеров для простоты отладки ) if __name__ == "__main__": print("Запуск Gradio приложения...") demo.launch() print("Gradio приложение запущено (или была попытка запуска).")