raaulcs commited on
Commit
f76ed2b
·
verified ·
1 Parent(s): 2e25512

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -102
app.py CHANGED
@@ -1,93 +1,113 @@
1
  import gradio as gr
2
  from datasets import load_dataset
3
-
4
  import os
5
- import spaces
6
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
7
- import torch
8
  from threading import Thread
 
 
9
  from sentence_transformers import SentenceTransformer
10
- from datasets import load_dataset
11
- import time
12
 
 
13
  token = os.environ["HF_TOKEN"]
14
- ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
15
-
16
- dataset = load_dataset("not-lain/wikipedia",revision = "embedded")
17
 
 
 
 
18
  data = dataset["train"]
19
- data = data.add_faiss_index("embeddings") # column name that has the embeddings of the dataset
20
 
21
-
22
- model_id = "meta-llama/Llama-3.2-3B-Instruct"
23
-
24
- # use quantization to lower GPU usage
25
  bnb_config = BitsAndBytesConfig(
26
- load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
 
 
 
27
  )
28
 
29
- tokenizer = AutoTokenizer.from_pretrained(model_id,token=token)
30
- model = AutoModelForCausalLM.from_pretrained(
31
- model_id,
32
- torch_dtype=torch.bfloat16,
33
- device_map="auto",
34
- quantization_config=bnb_config,
35
- token=token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
- terminators = [
38
- tokenizer.eos_token_id,
39
- tokenizer.convert_tokens_to_ids("<|eot_id|>")
40
- ]
41
-
42
- SYS_PROMPT = """Tu tarea es analizar un listado de unidades de competencia y devolver únicamente aquellas relacionadas con la profesión de {texto_usuario}.
43
- Debes buscar palabras clave o términos relacionados con la profesión en los nombres de las unidades de competencia.
44
- If you don't know the answer, just say "I do not know." Don't make up an answer."""
45
-
46
 
47
-
48
- def search(query: str, k: int = 3 ):
49
- """a function that embeds a new query and returns the most probable results"""
50
- embedded_query = ST.encode(query) # embed new query
51
- scores, retrieved_examples = data.get_nearest_examples( # retrieve results
52
- "embeddings", embedded_query, # compare our new embedded query with the dataset embeddings
53
- k=k # get only top k results
54
  )
55
  return scores, retrieved_examples
56
 
57
- def format_prompt(prompt,retrieved_documents,k):
58
- """using the retrieved documents we will prompt the model to generate our responses"""
59
- PROMPT = f"Question:{prompt}\nContext:"
60
- for idx in range(k) :
61
- PROMPT+= f"{retrieved_documents['text'][idx]}\n"
62
  return PROMPT
63
 
64
-
65
- @spaces.GPU
66
- def talk(prompt,history=[]):
67
- k = 1 # number of retrieved documents
68
- scores , retrieved_documents = search(prompt, k)
69
- formatted_prompt = format_prompt(prompt,retrieved_documents,k)
70
- formatted_prompt = formatted_prompt[:2000] # to avoid GPU OOM
71
- messages = [{"role":"system","content":SYS_PROMPT},{"role":"user","content":formatted_prompt}]
72
- # tell the model to generate
 
 
 
 
 
 
 
 
 
73
  input_ids = tokenizer.apply_chat_template(
74
- messages,
75
- add_generation_prompt=True,
76
- return_tensors="pt"
77
  ).to(model.device)
78
- outputs = model.generate(
79
- input_ids,
80
- max_new_tokens=1024,
81
- eos_token_id=terminators,
82
- do_sample=True,
83
- temperature=0.6,
84
- top_p=0.9,
85
- )
86
- streamer = TextIteratorStreamer(
87
- tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
88
- )
89
  generate_kwargs = dict(
90
- input_ids= input_ids,
91
  streamer=streamer,
92
  max_new_tokens=1024,
93
  do_sample=True,
@@ -95,46 +115,35 @@ def talk(prompt,history=[]):
95
  temperature=0.75,
96
  eos_token_id=terminators,
97
  )
 
 
98
  t = Thread(target=model.generate, kwargs=generate_kwargs)
99
  t.start()
100
-
101
  outputs = []
102
  for text in streamer:
103
  outputs.append(text)
104
  yield "".join(outputs)
105
 
106
-
107
- TITLE = "# RAG"
108
-
109
- DESCRIPTION = """
110
- A rag pipeline with a chatbot feature
111
-
112
- Resources used to build this project :
113
-
114
- * embedding model : https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1
115
- * dataset : https://huggingface.co/datasets/not-lain/wikipedia
116
- * faiss docs : https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index
117
- * chatbot : https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct
118
- * Full documentation : https://huggingface.co/blog/not-lain/rag-chatbot-using-llama3
119
- """
120
-
121
-
122
- demo = gr.ChatInterface(
123
- fn=talk,
124
- chatbot=gr.Chatbot(
125
- show_label=True,
126
- show_share_button=True,
127
- show_copy_button=True,
128
- layout="bubble",
129
- bubble_full_width=False,
130
- ),
131
- theme="Soft",
132
- type="tuples",
133
- examples=["what's anarchy ? "],
134
- title=TITLE,
135
- description=DESCRIPTION,
136
- autofocus=False,
137
- autoscroll = False,
138
 
139
- )
140
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from datasets import load_dataset
 
3
  import os
 
 
 
4
  from threading import Thread
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
7
  from sentence_transformers import SentenceTransformer
 
 
8
 
9
+ # Configuración y token
10
  token = os.environ["HF_TOKEN"]
 
 
 
11
 
12
+ # Modelo de embeddings y dataset para la fase de RAG
13
+ ST = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
14
+ dataset = load_dataset("not-lain/wikipedia", revision="embedded")
15
  data = dataset["train"]
16
+ data = data.add_faiss_index("embeddings") # columna con los embeddings
17
 
18
+ # Configuración de quantization
 
 
 
19
  bnb_config = BitsAndBytesConfig(
20
+ load_in_4bit=True,
21
+ bnb_4bit_use_double_quant=True,
22
+ bnb_4bit_quant_type="nf4",
23
+ bnb_4bit_compute_dtype=torch.bfloat16
24
  )
25
 
26
+ # Diccionario con modelos disponibles
27
+ # Puedes agregar tantos modelos como desees, con su respectiva configuración
28
+ MODELOS = {
29
+ "Llama-3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
30
+ "Otro Modelo": "ruta/o/identificador-del-otro-modelo"
31
+ }
32
+
33
+ # Opcional: Diccionario para almacenar modelos cargados y evitar recargas
34
+ MODELOS_CARGADOS = {}
35
+
36
+ # Función para obtener (o cargar) el modelo y tokenizer según la selección
37
+ def get_model(selected_model):
38
+ model_id = MODELOS[selected_model]
39
+ if selected_model not in MODELOS_CARGADOS:
40
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ model_id,
43
+ torch_dtype=torch.bfloat16,
44
+ device_map="auto",
45
+ quantization_config=bnb_config,
46
+ token=token
47
+ )
48
+ MODELOS_CARGADOS[selected_model] = (model, tokenizer)
49
+ return MODELOS_CARGADOS[selected_model]
50
+
51
+ # Tokens de finalización basados en el tokenizer del modelo, se actualizarán luego
52
+ def get_terminators(tokenizer):
53
+ return [
54
+ tokenizer.eos_token_id,
55
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
56
+ ]
57
+
58
+ # Prompt del sistema (podrías parametrizarlo si es necesario)
59
+ SYS_PROMPT = (
60
+ "Tu tarea es analizar un listado de unidades de competencia y devolver únicamente aquellas relacionadas "
61
+ "con la profesión de {texto_usuario}. Debes buscar palabras clave o términos relacionados con la profesión "
62
+ "en los nombres de las unidades de competencia. If you don't know the answer, just say 'I do not know.' "
63
+ "Don't make up an answer."
64
  )
 
 
 
 
 
 
 
 
 
65
 
66
+ def search(query: str, k: int = 3):
67
+ """Embebe la consulta y retorna los resultados más probables."""
68
+ embedded_query = ST.encode(query) # embed de la nueva consulta
69
+ scores, retrieved_examples = data.get_nearest_examples(
70
+ "embeddings",
71
+ embedded_query,
72
+ k=k
73
  )
74
  return scores, retrieved_examples
75
 
76
+ def format_prompt(prompt, retrieved_documents, k):
77
+ """Construye el prompt a partir de los documentos recuperados."""
78
+ PROMPT = f"Question: {prompt}\nContext:\n"
79
+ for idx in range(k):
80
+ PROMPT += f"{retrieved_documents['text'][idx]}\n"
81
  return PROMPT
82
 
83
+ def talk(prompt, selected_model, history=[]):
84
+ # Obtiene (o carga) el modelo y tokenizer basado en la selección
85
+ model, tokenizer = get_model(selected_model)
86
+ terminators = get_terminators(tokenizer)
87
+
88
+ # Número de documentos a recuperar
89
+ k = 1
90
+ scores, retrieved_documents = search(prompt, k)
91
+ formatted_prompt = format_prompt(prompt, retrieved_documents, k)
92
+ formatted_prompt = formatted_prompt[:2000] # prevenir OOM en GPU
93
+
94
+ # Construye los mensajes para la generación
95
+ messages = [
96
+ {"role": "system", "content": SYS_PROMPT},
97
+ {"role": "user", "content": formatted_prompt}
98
+ ]
99
+
100
+ # Prepara la entrada usando la plantilla de chat del tokenizer
101
  input_ids = tokenizer.apply_chat_template(
102
+ messages,
103
+ add_generation_prompt=True,
104
+ return_tensors="pt"
105
  ).to(model.device)
106
+
107
+ # Define el streamer para obtener la salida progresivamente
108
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
 
 
 
 
 
 
 
 
109
  generate_kwargs = dict(
110
+ input_ids=input_ids,
111
  streamer=streamer,
112
  max_new_tokens=1024,
113
  do_sample=True,
 
115
  temperature=0.75,
116
  eos_token_id=terminators,
117
  )
118
+
119
+ # Ejecuta la generación en un hilo para ir transmitiendo resultados
120
  t = Thread(target=model.generate, kwargs=generate_kwargs)
121
  t.start()
122
+
123
  outputs = []
124
  for text in streamer:
125
  outputs.append(text)
126
  yield "".join(outputs)
127
 
128
+ # Interfaz de Gradio: se añade un dropdown para la selección del modelo
129
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
130
+ gr.Markdown("# RAG Chatbot")
131
+ gr.Markdown("Selecciona el modelo de inteligencia artificial:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
+ with gr.Row():
134
+ # Dropdown para elegir el modelo
135
+ modelo_selector = gr.Dropdown(choices=list(MODELOS.keys()), value="Llama-3.2-3B-Instruct", label="Modelo")
136
+
137
+ chatbot = gr.Chatbot(show_label=True, show_share_button=True, show_copy_button=True, layout="bubble")
138
+
139
+ # Caja de entrada para la consulta del usuario
140
+ prompt_input = gr.Textbox(lines=2, label="Ingresa tu pregunta")
141
+
142
+ # Botón para enviar la consulta
143
+ send_btn = gr.Button("Enviar")
144
+
145
+ # Llamada a la función talk, pasando el prompt y la selección de modelo
146
+ send_btn.click(fn=talk, inputs=[prompt_input, modelo_selector], outputs=chatbot, api_name="chatbot")
147
+
148
+ if __name__ == "__main__":
149
+ demo.launch(debug=True)