Richard / app.py
Walter130209's picture
Update app.py
0e3009c verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
MODEL_NAME = "bertin-project/bertin-gpt-j-6B"
# Cargar modelo
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model.eval()
if torch.cuda.is_available():
model = model.to("cuda")
# Función para generar respuesta
def generar_respuesta(user_input, chat_history):
# Preparar prompt
prompt = "Usuario: " + user_input + "\nRichard: "
inputs = tokenizer(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.to("cuda") for k, v in inputs.items()}
generation_config = GenerationConfig(
temperature=0.7,
top_p=0.9,
max_new_tokens=256,
num_beams=1
)
# Mostrar mensaje pensando
chat_history.append({"role": "user", "content": user_input})
chat_history.append({"role": "assistant", "content": "[Richard está pensando…]"})
# Generar respuesta
output = model.generate(
**inputs,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=False
)
decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
# Reemplazar mensaje de pensando
chat_history[-1]["content"] = decoded.split("Richard:")[-1].strip()
return chat_history, chat_history
# Interfaz Gradio
with gr.Blocks(css="""
.gradio-container {border-radius: 20px; background-color: #f7f7f8; padding: 15px;}
.gradio-button {border-radius: 12px; background-color: #10a37f; color: white; font-weight: bold;}
.gradio-textbox {border-radius: 12px;}
""") as demo:
# Inicializar el chat con el saludo
initial_history = [{"role": "assistant", "content": "¡Hola! Soy Richard. ¿En qué te puedo ayudar hoy?"}]
chatbot = gr.Chatbot(value=initial_history, type="messages")
user_input = gr.Textbox(label="Escribe aquí...", placeholder="Hola, soy Richard", lines=2)
enviar = gr.Button("Enviar")
# Conectar función
enviar.click(fn=generar_respuesta, inputs=[user_input, chatbot], outputs=[chatbot, chatbot])
demo.launch()