import gradio as gr from datasets import load_dataset import os from threading import Thread import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig from sentence_transformers import SentenceTransformer # Configuración y token token = os.getenv("HF_TOKEN") # Modelo de embeddings y dataset para la fase de RAG ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") dataset = load_dataset("not-lain/wikipedia",revision = "embedded") data = dataset["train"] data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset # Configuración de quantization bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 ) # Diccionario con modelos disponibles # Puedes agregar tantos modelos como desees, con su respectiva configuración MODELOS = { "Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct", "DeepSeek-R1": "deepseek-ai/DeepSeek-R1" } # Opcional: Diccionario para almacenar modelos cargados y evitar recargas MODELOS_CARGADOS = {} # Función para obtener (o cargar) el modelo y tokenizer según la selección def get_model(selected_model): model_id = MODELOS[selected_model] if selected_model not in MODELOS_CARGADOS: tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", quantization_config=bnb_config, token=token ) MODELOS_CARGADOS[selected_model] = (model, tokenizer) return MODELOS_CARGADOS[selected_model] # Tokens de finalización basados en el tokenizer del modelo, se actualizarán luego def get_terminators(tokenizer): return [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>") ] # Prompt del sistema (podrías parametrizarlo si es necesario) SYS_PROMPT = ( "Tu tarea es analizar un listado de unidades de competencia y devolver únicamente aquellas relacionadas " "con la profesión de {texto_usuario}. Debes buscar palabras clave o términos relacionados con la profesión " "en los nombres de las unidades de competencia. If you don't know the answer, just say 'I do not know.' " "Don't make up an answer." ) def search(query: str, k: int = 3): """Embebe la consulta y retorna los resultados más probables.""" embedded_query = ST.encode(query) # embed de la nueva consulta scores, retrieved_examples = data.get_nearest_examples( "embeddings", embedded_query, k=k ) return scores, retrieved_examples def format_prompt(prompt, retrieved_documents, k): """Construye el prompt a partir de los documentos recuperados.""" PROMPT = f"Question: {prompt}\nContext:\n" for idx in range(k): PROMPT += f"{retrieved_documents['text'][idx]}\n" return PROMPT def talk(prompt, selected_model, history=[]): # Obtiene (o carga) el modelo y tokenizer basado en la selección model, tokenizer = get_model(selected_model) terminators = get_terminators(tokenizer) # Número de documentos a recuperar k = 1 scores, retrieved_documents = search(prompt, k) formatted_prompt = format_prompt(prompt, retrieved_documents, k) formatted_prompt = formatted_prompt[:2000] # prevenir OOM en GPU # Construye los mensajes para la generación messages = [ {"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt} ] # Prepara la entrada usando la plantilla de chat del tokenizer input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(model.device) # Define el streamer para obtener la salida progresivamente streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=input_ids, streamer=streamer, max_new_tokens=1024, do_sample=True, top_p=0.95, temperature=0.75, eos_token_id=terminators, ) # Ejecuta la generación en un hilo para ir transmitiendo resultados t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) # Interfaz de Gradio: se añade un dropdown para la selección del modelo with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# RAG Chatbot") gr.Markdown("Selecciona el modelo de inteligencia artificial:") with gr.Row(): # Dropdown para elegir el modelo modelo_selector = gr.Dropdown(choices=list(MODELOS.keys()), value="Llama-3.2-3B-Instruct", label="Modelo") chatbot = gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, layout="bubble") # Caja de entrada para la consulta del usuario prompt_input = gr.Textbox(lines=2, label="Ingresa tu pregunta") # Botón para enviar la consulta send_btn = gr.Button("Enviar") # Llamada a la función talk, pasando el prompt y la selección de modelo send_btn.click(fn=talk, inputs=[prompt_input, modelo_selector], outputs=chatbot) if __name__ == "__main__": demo.launch(debug=True)