File size: 5,508 Bytes
15f98fb
8a635ef
 
 
f76ed2b
 
8a635ef
15f98fb
f76ed2b
e2281d1
15f98fb
f76ed2b
 
4c28d4a
 
8a635ef
4c28d4a
15f98fb
f76ed2b
8a635ef
f76ed2b
 
 
 
8a635ef
15f98fb
f76ed2b
 
 
 
cb7bd8d
f76ed2b
 
4c28d4a
 
f76ed2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a635ef
 
f76ed2b
 
 
 
 
 
 
8a635ef
 
 
f76ed2b
 
 
 
 
8a635ef
 
f76ed2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8a635ef
f76ed2b
 
 
8a635ef
f76ed2b
 
 
8a635ef
f76ed2b
8a635ef
 
 
 
 
 
 
f76ed2b
 
8a635ef
 
f76ed2b
8a635ef
 
 
 
 
f76ed2b
 
 
 
8a635ef
f76ed2b
 
 
 
 
 
 
 
 
 
 
 
 
ead5d53
f76ed2b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import gradio as gr
from datasets import load_dataset
import os
from threading import Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
from sentence_transformers import SentenceTransformer

# Configuración y token
token = os.getenv("HF_TOKEN")

# Modelo de embeddings y dataset para la fase de RAG
ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")

dataset = load_dataset("not-lain/wikipedia",revision = "embedded")
data = dataset["train"]
data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset

# Configuración de quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True, 
    bnb_4bit_use_double_quant=True, 
    bnb_4bit_quant_type="nf4", 
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Diccionario con modelos disponibles
# Puedes agregar tantos modelos como desees, con su respectiva configuración
MODELOS = {
    "Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
    "DeepSeek-R1": "deepseek-ai/DeepSeek-R1"
}



# Opcional: Diccionario para almacenar modelos cargados y evitar recargas
MODELOS_CARGADOS = {}

# Función para obtener (o cargar) el modelo y tokenizer según la selección
def get_model(selected_model):
    model_id = MODELOS[selected_model]
    if selected_model not in MODELOS_CARGADOS:
        tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            quantization_config=bnb_config,
            token=token
        )
        MODELOS_CARGADOS[selected_model] = (model, tokenizer)
    return MODELOS_CARGADOS[selected_model]

# Tokens de finalización basados en el tokenizer del modelo, se actualizarán luego
def get_terminators(tokenizer):
    return [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]

# Prompt del sistema (podrías parametrizarlo si es necesario)
SYS_PROMPT = (
    "Tu tarea es analizar un listado de unidades de competencia y devolver únicamente aquellas relacionadas "
    "con la profesión de {texto_usuario}. Debes buscar palabras clave o términos relacionados con la profesión "
    "en los nombres de las unidades de competencia. If you don't know the answer, just say 'I do not know.' "
    "Don't make up an answer."
)

def search(query: str, k: int = 3):
    """Embebe la consulta y retorna los resultados más probables."""
    embedded_query = ST.encode(query)  # embed de la nueva consulta
    scores, retrieved_examples = data.get_nearest_examples(
        "embeddings",
        embedded_query,
        k=k
    )
    return scores, retrieved_examples

def format_prompt(prompt, retrieved_documents, k):
    """Construye el prompt a partir de los documentos recuperados."""
    PROMPT = f"Question: {prompt}\nContext:\n"
    for idx in range(k):
        PROMPT += f"{retrieved_documents['text'][idx]}\n"
    return PROMPT

def talk(prompt, selected_model, history=[]):
    # Obtiene (o carga) el modelo y tokenizer basado en la selección
    model, tokenizer = get_model(selected_model)
    terminators = get_terminators(tokenizer)
    
    # Número de documentos a recuperar
    k = 1
    scores, retrieved_documents = search(prompt, k)
    formatted_prompt = format_prompt(prompt, retrieved_documents, k)
    formatted_prompt = formatted_prompt[:2000]  # prevenir OOM en GPU

    # Construye los mensajes para la generación
    messages = [
        {"role": "system", "content": SYS_PROMPT},
        {"role": "user", "content": formatted_prompt}
    ]
    
    # Prepara la entrada usando la plantilla de chat del tokenizer
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)
    
    # Define el streamer para obtener la salida progresivamente
    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        temperature=0.75,
        eos_token_id=terminators,
    )
    
    # Ejecuta la generación en un hilo para ir transmitiendo resultados
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    
    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)

# Interfaz de Gradio: se añade un dropdown para la selección del modelo
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# RAG Chatbot")
    gr.Markdown("Selecciona el modelo de inteligencia artificial:")
    
    with gr.Row():
        # Dropdown para elegir el modelo
        modelo_selector = gr.Dropdown(choices=list(MODELOS.keys()), value="Llama-3.2-3B-Instruct", label="Modelo")
    
    chatbot = gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, layout="bubble")
    
    # Caja de entrada para la consulta del usuario
    prompt_input = gr.Textbox(lines=2, label="Ingresa tu pregunta")
    
    # Botón para enviar la consulta
    send_btn = gr.Button("Enviar")
    
    # Llamada a la función talk, pasando el prompt y la selección de modelo
    send_btn.click(fn=talk, inputs=[prompt_input, modelo_selector], outputs=chatbot)
    
if __name__ == "__main__":
    demo.launch(debug=True)