File size: 5,508 Bytes
15f98fb 8a635ef f76ed2b 8a635ef 15f98fb f76ed2b e2281d1 15f98fb f76ed2b 4c28d4a 8a635ef 4c28d4a 15f98fb f76ed2b 8a635ef f76ed2b 8a635ef 15f98fb f76ed2b cb7bd8d f76ed2b 4c28d4a f76ed2b 8a635ef f76ed2b 8a635ef f76ed2b 8a635ef f76ed2b 8a635ef f76ed2b 8a635ef f76ed2b 8a635ef f76ed2b 8a635ef f76ed2b 8a635ef f76ed2b 8a635ef f76ed2b 8a635ef f76ed2b ead5d53 f76ed2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import gradio as gr
from datasets import load_dataset
import os
from threading import Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer
# Configuración y token
token = os.getenv("HF_TOKEN")
# Modelo de embeddings y dataset para la fase de RAG
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
dataset = load_dataset("not-lain/wikipedia",revision = "embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
# Configuración de quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# Diccionario con modelos disponibles
# Puedes agregar tantos modelos como desees, con su respectiva configuración
MODELOS = {
"Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
"DeepSeek-R1": "deepseek-ai/DeepSeek-R1"
}
# Opcional: Diccionario para almacenar modelos cargados y evitar recargas
MODELOS_CARGADOS = {}
# Función para obtener (o cargar) el modelo y tokenizer según la selección
def get_model(selected_model):
model_id = MODELOS[selected_model]
if selected_model not in MODELOS_CARGADOS:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
quantization_config=bnb_config,
token=token
)
MODELOS_CARGADOS[selected_model] = (model, tokenizer)
return MODELOS_CARGADOS[selected_model]
# Tokens de finalización basados en el tokenizer del modelo, se actualizarán luego
def get_terminators(tokenizer):
return [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
# Prompt del sistema (podrías parametrizarlo si es necesario)
SYS_PROMPT = (
"Tu tarea es analizar un listado de unidades de competencia y devolver únicamente aquellas relacionadas "
"con la profesión de {texto_usuario}. Debes buscar palabras clave o términos relacionados con la profesión "
"en los nombres de las unidades de competencia. If you don't know the answer, just say 'I do not know.' "
"Don't make up an answer."
)
def search(query: str, k: int = 3):
"""Embebe la consulta y retorna los resultados más probables."""
embedded_query = ST.encode(query) # embed de la nueva consulta
scores, retrieved_examples = data.get_nearest_examples(
"embeddings",
embedded_query,
k=k
)
return scores, retrieved_examples
def format_prompt(prompt, retrieved_documents, k):
"""Construye el prompt a partir de los documentos recuperados."""
PROMPT = f"Question: {prompt}\nContext:\n"
for idx in range(k):
PROMPT += f"{retrieved_documents['text'][idx]}\n"
return PROMPT
def talk(prompt, selected_model, history=[]):
# Obtiene (o carga) el modelo y tokenizer basado en la selección
model, tokenizer = get_model(selected_model)
terminators = get_terminators(tokenizer)
# Número de documentos a recuperar
k = 1
scores, retrieved_documents = search(prompt, k)
formatted_prompt = format_prompt(prompt, retrieved_documents, k)
formatted_prompt = formatted_prompt[:2000] # prevenir OOM en GPU
# Construye los mensajes para la generación
messages = [
{"role": "system", "content": SYS_PROMPT},
{"role": "user", "content": formatted_prompt}
]
# Prepara la entrada usando la plantilla de chat del tokenizer
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
# Define el streamer para obtener la salida progresivamente
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=input_ids,
streamer=streamer,
max_new_tokens=1024,
do_sample=True,
top_p=0.95,
temperature=0.75,
eos_token_id=terminators,
)
# Ejecuta la generación en un hilo para ir transmitiendo resultados
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Interfaz de Gradio: se añade un dropdown para la selección del modelo
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# RAG Chatbot")
gr.Markdown("Selecciona el modelo de inteligencia artificial:")
with gr.Row():
# Dropdown para elegir el modelo
modelo_selector = gr.Dropdown(choices=list(MODELOS.keys()), value="Llama-3.2-3B-Instruct", label="Modelo")
chatbot = gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, layout="bubble")
# Caja de entrada para la consulta del usuario
prompt_input = gr.Textbox(lines=2, label="Ingresa tu pregunta")
# Botón para enviar la consulta
send_btn = gr.Button("Enviar")
# Llamada a la función talk, pasando el prompt y la selección de modelo
send_btn.click(fn=talk, inputs=[prompt_input, modelo_selector], outputs=chatbot)
if __name__ == "__main__":
demo.launch(debug=True)
|