Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| # Загрузка токенизатора и модели | |
| model_name = "GoidaAlignment/GOIDA-0.5B" # Замените на вашу модель | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| model = model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| # Шаблонная функция для форматирования диалога | |
| def apply_chat_template(chat, add_generation_prompt=True): | |
| formatted_chat = "" | |
| for message in chat: | |
| role = message["role"] | |
| content = message["content"] | |
| if role == "user": | |
| formatted_chat += f"User: {content}\n" | |
| elif role == "assistant": | |
| formatted_chat += f"Assistant: {content}\n" | |
| if add_generation_prompt: | |
| formatted_chat += "Assistant: " | |
| return formatted_chat | |
| # Функция генерации ответа | |
| def generate_response(user_input, chat_history): | |
| chat_history.append({"role": "user", "content": user_input}) | |
| formatted_chat = apply_chat_template(chat_history, add_generation_prompt=True) | |
| # Токенизация | |
| inputs = tokenizer(formatted_chat, return_tensors="pt", add_special_tokens=False) | |
| inputs = {key: tensor.to(model.device) for key, tensor in inputs.items()} | |
| # Генерация | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=64, | |
| temperature=0.7, | |
| top_p=0.9, | |
| do_sample=True | |
| ) | |
| # Декодирование | |
| decoded_output = tokenizer.decode(outputs[0][inputs["input_ids"].size(1):], skip_special_tokens=True) | |
| chat_history.append({"role": "assistant", "content": decoded_output}) | |
| return decoded_output, chat_history | |
| # Интерфейс Gradio | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Chatbot на основе модели ГОЙДАААА\nВзаимодействуйте с языковой моделью.") | |
| chatbot = gr.Chatbot() | |
| user_input = gr.Textbox(placeholder="Введите ваше сообщение...") | |
| clear = gr.Button("Очистить чат") | |
| chat_history = gr.State([]) # Состояние для хранения истории чата | |
| user_input.submit( | |
| generate_response, | |
| [user_input, chat_history], | |
| [chatbot, chat_history] | |
| ) | |
| clear.click(lambda: ([], []), None, [chatbot, chat_history]) | |
| if __name__ == "__main__": | |
| demo.launch() | |