my-chat-model / app.py
hypo69's picture
Update app.py
32be262 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# --- Глобальные переменные для модели и токенизатора ---
model = None
tokenizer = None
model_name = "microsoft/DialoGPT-medium" # Можно вынести в константу
# --- 1. Функция для загрузки модели и токенизатора ---
# Эту функцию можно было бы вызывать один раз при старте, но для простоты
# и учитывая, как Gradio работает со Spaces, оставим загрузку в глобальной области.
# Однако, обернем в try-except для надежности.
try:
print(f"Загрузка токенизатора для: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Токенизатор загружен.")
if tokenizer.pad_token_id is None:
print(f"Установка tokenizer.pad_token_id = tokenizer.eos_token_id для {model_name}")
tokenizer.pad_token_id = tokenizer.eos_token_id
print(f"Загрузка модели: {model_name}")
model = AutoModelForCausalLM.from_pretrained(model_name)
# Для DialoGPT-medium device_map="auto" не сильно нужен, он и так на CPU поместится.
# Если бы Space имел GPU, можно было бы model.to("cuda")
print(f"Модель {model_name} загружена. Устройство: {model.device if model else 'N/A'}")
except Exception as e:
print(f"ОШИБКА ЗАГРУЗКИ МОДЕЛИ/ТОКЕНИЗАТОРА: {e}")
# В этом случае model и tokenizer останутся None
# --- 2. Функция-обработчик для ChatInterface ---
# Принимает текущее сообщение и историю чата.
# История будет в формате list[dict[str, str]], если Chatbot имеет type="messages".
def simple_predict(message: str, history: list[dict[str, str]]):
global model, tokenizer # Указываем, что используем глобальные переменные
print(f"\nПолучено сообщение: '{message}'")
print(f"История: {history}")
if not model or not tokenizer:
print("Модель или токенизатор не были загружены. Возвращаю ошибку.")
return "Ошибка: Модель не загружена. Проверьте логи Space."
# Формируем строку для DialoGPT: конкатенация всех предыдущих реплик и текущего сообщения
# через eos_token.
prompt_parts = []
for turn in history: # turn это {"role": "user/assistant", "content": "текст"}
prompt_parts.append(turn["content"])
prompt_parts.append(message) # Добавляем текущее сообщение пользователя
# Соединяем все части eos_token.
# НЕ добавляем eos_token в самый конец строки здесь.
# Токенизатор и/или model.generate сами обработают EOS для генерации.
input_text = tokenizer.eos_token.join(prompt_parts)
print(f"Сформированный промпт для модели: {repr(input_text)}")
try:
# Токенизация
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=1024)
# Перемещаем на устройство модели (если оно отличается, хотя для CPU это не актуально)
# inputs = {k: v.to(model.device) for k, v in inputs.items()} # Можно раскомментировать если используется GPU
# Генерация ответа
# max_length включает длину входного промпта + количество новых токенов
# Для DialoGPT-medium, 150 новых токенов - разумно.
print("Начало генерации...")
outputs = model.generate(
inputs.input_ids.to(model.device), # Убедимся, что тензоры на том же устройстве, что и модель
attention_mask=inputs.attention_mask.to(model.device) if "attention_mask" in inputs else None,
max_length=inputs.input_ids.shape[-1] + 150,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.8
)
print("Генерация завершена.")
# Декодирование только нового сгенерированного текста
input_length = inputs.input_ids.shape[1]
generated_tokens = outputs[0][input_length:]
bot_response = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
print(f"Ответ модели: '{bot_response}'")
return bot_response
except Exception as e:
print(f"ОШИБКА ВНУТРИ PREDICT: {e}")
import traceback
traceback.print_exc() # Печатаем полный трейсбек ошибки в predict
return "Произошла ошибка при обработке вашего запроса."
# --- 3. Создание и запуск Gradio интерфейса ---
# Используем самый простой ChatInterface без явного Chatbot компонента.
# Gradio сам создаст Chatbot с type="messages" по умолчанию в новых версиях.
demo = gr.ChatInterface(
fn=simple_predict,
title="Простой Чат с DialoGPT-medium",
description="Введите сообщение для начала диалога.",
examples=[["Привет!"], ["Как дела?"]], # Примеры должны быть списком списков строк
cache_examples=False # Отключаем кэширование примеров для простоты отладки
)
if __name__ == "__main__":
print("Запуск Gradio приложения...")
demo.launch()
print("Gradio приложение запущено (или была попытка запуска).")