chatbot / app.py
raaulcs's picture
Update app.py
ead5d53 verified
import gradio as gr
from datasets import load_dataset
import os
from threading import Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
# Configuración y token
token = os.getenv("HF_TOKEN")
# Modelo de embeddings y dataset para la fase de RAG
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia",revision = "embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
# Configuración de quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# Diccionario con modelos disponibles
# Puedes agregar tantos modelos como desees, con su respectiva configuración
MODELOS = {
"Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
"DeepSeek-R1": "deepseek-ai/DeepSeek-R1"
}
# Opcional: Diccionario para almacenar modelos cargados y evitar recargas
MODELOS_CARGADOS = {}
# Función para obtener (o cargar) el modelo y tokenizer según la selección
def get_model(selected_model):
model_id = MODELOS[selected_model]
if selected_model not in MODELOS_CARGADOS:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config,
token=token
)
MODELOS_CARGADOS[selected_model] = (model, tokenizer)
return MODELOS_CARGADOS[selected_model]
# Tokens de finalización basados en el tokenizer del modelo, se actualizarán luego
def get_terminators(tokenizer):
return [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
# Prompt del sistema (podrías parametrizarlo si es necesario)
SYS_PROMPT = (
"Tu tarea es analizar un listado de unidades de competencia y devolver únicamente aquellas relacionadas "
"con la profesión de {texto_usuario}. Debes buscar palabras clave o términos relacionados con la profesión "
"en los nombres de las unidades de competencia. If you don't know the answer, just say 'I do not know.' "
"Don't make up an answer."
)
def search(query: str, k: int = 3):
"""Embebe la consulta y retorna los resultados más probables."""
embedded_query = ST.encode(query) # embed de la nueva consulta
scores, retrieved_examples = data.get_nearest_examples(
"embeddings",
embedded_query,
k=k
)
return scores, retrieved_examples
def format_prompt(prompt, retrieved_documents, k):
"""Construye el prompt a partir de los documentos recuperados."""
PROMPT = f"Question: {prompt}\nContext:\n"
for idx in range(k):
PROMPT += f"{retrieved_documents['text'][idx]}\n"
return PROMPT
def talk(prompt, selected_model, history=[]):
# Obtiene (o carga) el modelo y tokenizer basado en la selección
model, tokenizer = get_model(selected_model)
terminators = get_terminators(tokenizer)
# Número de documentos a recuperar
k = 1
scores, retrieved_documents = search(prompt, k)
formatted_prompt = format_prompt(prompt, retrieved_documents, k)
formatted_prompt = formatted_prompt[:2000] # prevenir OOM en GPU
# Construye los mensajes para la generación
messages = [
{"role": "system", "content": SYS_PROMPT},
{"role": "user", "content": formatted_prompt}
]
# Prepara la entrada usando la plantilla de chat del tokenizer
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Define el streamer para obtener la salida progresivamente
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
temperature=0.75,
eos_token_id=terminators,
)
# Ejecuta la generación en un hilo para ir transmitiendo resultados
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Interfaz de Gradio: se añade un dropdown para la selección del modelo
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# RAG Chatbot")
gr.Markdown("Selecciona el modelo de inteligencia artificial:")
with gr.Row():
# Dropdown para elegir el modelo
modelo_selector = gr.Dropdown(choices=list(MODELOS.keys()), value="Llama-3.2-3B-Instruct", label="Modelo")
chatbot = gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, layout="bubble")
# Caja de entrada para la consulta del usuario
prompt_input = gr.Textbox(lines=2, label="Ingresa tu pregunta")
# Botón para enviar la consulta
send_btn = gr.Button("Enviar")
# Llamada a la función talk, pasando el prompt y la selección de modelo
send_btn.click(fn=talk, inputs=[prompt_input, modelo_selector], outputs=chatbot)
if __name__ == "__main__":
demo.launch(debug=True)