|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
from huggingface_hub import hf_hub_download |
|
|
from transformers import GPTJForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
|
|
|
MODEL_REPO = "EleutherAI/gpt-j-6B" |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO) |
|
|
|
|
|
|
|
|
model = GPTJForCausalLM.from_pretrained( |
|
|
MODEL_REPO, |
|
|
revision="float16", |
|
|
torch_dtype=torch.float16, |
|
|
low_cpu_mem_usage=True |
|
|
).to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
def generate_response(prompt, history=[]): |
|
|
|
|
|
input_text = "".join([f"User: {item[0]}\nBot: {item[1]}\n" for item in history]) |
|
|
input_text += f"User: {prompt}\nBot:" |
|
|
|
|
|
|
|
|
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(model.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = model.generate( |
|
|
input_ids, |
|
|
max_length=1024, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
|
|
|
response = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
response = response.split("User:")[-1].split("Bot:")[-1].strip() |
|
|
|
|
|
|
|
|
history.append((prompt, response)) |
|
|
return history, history |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
body { |
|
|
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; |
|
|
background-color: #f0f2f5; |
|
|
margin: 0; |
|
|
padding: 0; |
|
|
} |
|
|
#chatbot-container { |
|
|
width: 80%; |
|
|
max-width: 800px; |
|
|
margin: 50px auto; |
|
|
background: #fff; |
|
|
border-radius: 10px; |
|
|
box-shadow: 0 4px 20px rgba(0, 0, 0, 0.1); |
|
|
overflow: hidden; |
|
|
animation: fadeIn 0.5s ease-in-out; |
|
|
} |
|
|
@keyframes fadeIn { |
|
|
from { opacity: 0; transform: translateY(20px); } |
|
|
to { opacity: 1; transform: translateY(0); } |
|
|
} |
|
|
.chat-message { |
|
|
display: flex; |
|
|
align-items: flex-start; |
|
|
padding: 15px; |
|
|
border-bottom: 1px solid #f0f0f0; |
|
|
animation: slideIn 0.3s ease-in-out; |
|
|
} |
|
|
@keyframes slideIn { |
|
|
from { opacity: 0; transform: translateX(-20px); } |
|
|
to { opacity: 1; transform: translateX(0); } |
|
|
} |
|
|
.chat-message.user .message-content { |
|
|
background-color: #0078d7; |
|
|
color: #fff; |
|
|
border-radius: 15px 15px 0 15px; |
|
|
padding: 10px 15px; |
|
|
max-width: 70%; |
|
|
margin-left: auto; |
|
|
} |
|
|
.chat-message.bot .message-content { |
|
|
background-color: #e1e1e1; |
|
|
color: #333; |
|
|
border-radius: 15px 15px 15px 0; |
|
|
padding: 10px 15px; |
|
|
max-width: 70%; |
|
|
margin-right: auto; |
|
|
} |
|
|
input[type="text"] { |
|
|
width: calc(100% - 20px); |
|
|
margin: 10px; |
|
|
padding: 10px; |
|
|
border: 1px solid #ccc; |
|
|
border-radius: 5px; |
|
|
font-size: 16px; |
|
|
} |
|
|
button { |
|
|
background-color: #0078d7; |
|
|
color: #fff; |
|
|
border: none; |
|
|
padding: 10px 20px; |
|
|
margin: 10px; |
|
|
border-radius: 5px; |
|
|
cursor: pointer; |
|
|
font-size: 16px; |
|
|
transition: background-color 0.3s; |
|
|
} |
|
|
button:hover { |
|
|
background-color: #005bb5; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(css=custom_css) as chat_interface: |
|
|
gr.Markdown("# 🧠 Умный Чат с GPT-J-6B") |
|
|
with gr.Box(elem_id="chatbot-container"): |
|
|
chatbot = gr.Chatbot() |
|
|
user_input = gr.Textbox(placeholder="Введите сообщение...") |
|
|
clear_btn = gr.Button("Очистить историю") |
|
|
|
|
|
user_input.submit(generate_response, [user_input, chatbot], [chatbot, user_input]) |
|
|
clear_btn.click(lambda: None, None, chatbot) |
|
|
|
|
|
|
|
|
chat_interface.launch(debug=True) |
|
|
|