JMAA00 commited on
Commit
9dfbce3
·
1 Parent(s): 9fa6053
Files changed (1) hide show
  1. app.py +133 -83
app.py CHANGED
@@ -1,30 +1,61 @@
1
  import os
2
- import torch
3
  import gradio as gr
4
- from transformers import (
5
- AutoTokenizer,
6
- AutoModelForCausalLM,
7
- TextIteratorStreamer,
8
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- # URL al modelo validado:
11
- MODEL_URL = "https://llama3-1.llamameta.net/*?Policy=eyJTdGF0ZW1lbnQiOlt7InVuaXF1ZV9oYXNoIjoicTNqOHYzcTJyZ3B0eWl6ZTd6dTFkNXN1IiwiUmVzb3VyY2UiOiJodHRwczpcL1wvbGxhbWEzLTEubGxhbWFtZXRhLm5ldFwvKiIsIkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0MzQ0MDM3Nn19fV19&Signature=bFkdKBkrmrAe6vKEmYlEblQ5a7O9UB09rcCrEYKTo%7EL-d5bY3qgR8TWzGp3WrzdcIm9lK1srSp5t4Oz%7EctElYCeLwYTlmrV-DmPm6cvwTpW75yDMnKHkZOWw2eETT7w6TkX1HqtMU2rKmN1Yx9vfz0guaKrgtIrVr4sq0pY-6DZqr0G6wkFDAFcok6qEK%7ExzqQms8zHjvJuEUTzWRpcJ2zwL6pO1GBDB8OYPzu%7EwSDEqmLMMLn3AFfQKpFlkGUQDlF0-9RePPecRtldBK-AaJMgoQpzsxcrmP3PblAJgVR3ujoJj2MVS7RzwUOOX3yrxir0en7GK-BAxiT8QGRPDSA__&Key-Pair-Id=K15QRJLYKIFSLZ&Download-Request-ID=1214046803621383"
12
-
13
- # 1) Cargamos el tokenizer y el modelo desde la URL firmada:
14
- print("Cargando tokenizer...")
15
- tokenizer = AutoTokenizer.from_pretrained(MODEL_URL, trust_remote_code=True)
16
-
17
- print("Cargando modelo (puede tardar varios minutos)...")
18
- # device_map="auto" intenta usar GPU si está disponible;
19
- # si no hay GPU, lo cargará en CPU (podría requerir mucha RAM).
20
- # Ajusta "torch_dtype" a float16 si dispones de GPU con FP16.
21
- model = AutoModelForCausalLM.from_pretrained(
22
- MODEL_URL,
23
- trust_remote_code=True,
24
- device_map="auto",
25
- torch_dtype=torch.float16 # Si tienes GPU. Si solo CPU, usa float32
26
- )
27
- model.eval()
28
 
29
  def respond(
30
  message,
@@ -33,62 +64,75 @@ def respond(
33
  max_tokens,
34
  temperature,
35
  top_p,
 
36
  ):
37
  """
38
- Mantenemos la estructura anterior:
39
- - history: [(usuario, asistente), ...]
40
- - system_message: texto con rol 'system'
41
- - message: el mensaje más reciente del usuario
42
  """
43
- # Preparamos el "prompt" reconstruyendo la conversación en un formato simple.
44
- # Podrías mejorarlo usando un formateo estilo "ChatGPT" con roles y saltos de línea.
45
- full_prompt = f"[SYSTEM] {system_message}\n"
46
- for user_msg, assistant_msg in history:
47
- if user_msg:
48
- full_prompt += f"[USER] {user_msg}\n"
49
- if assistant_msg:
50
- full_prompt += f"[ASSISTANT] {assistant_msg}\n"
51
- full_prompt += f"[USER] {message}\n[ASSISTANT]"
52
-
53
- # Preparamos la generación con streaming usando TextIteratorStreamer
54
- # (similar a la API de chat_completion con stream=True)
55
- streamer = TextIteratorStreamer(
56
- tokenizer=tokenizer,
57
- skip_special_tokens=True
58
- )
59
-
60
- # Ajustamos parámetros de decodificación (lo que antes hacíamos con pipeline)
61
- generation_kwargs = dict(
62
- inputs=tokenizer(full_prompt, return_tensors="pt").to(model.device),
63
- streamer=streamer,
64
- max_new_tokens=max_tokens,
 
 
 
 
 
 
 
 
65
  temperature=temperature,
66
  top_p=top_p,
67
- do_sample=True,
68
- # Recomendado no usar 'repetition_penalty=1.0' en Llama3 si no se sugiere.
69
- )
70
-
71
- # Disparamos la generación en un hilo:
72
- # streamer irá soltando tokens a medida que se generen.
73
- generation_thread = torch.Thread(
74
- target=model.generate,
75
- kwargs=generation_kwargs
76
- )
77
- generation_thread.start()
78
-
79
- # Ahora leemos tokens a medida que se generen y los enviamos a Gradio (yield)
80
- output_tokens = ""
81
- for new_token in streamer:
82
- output_tokens += new_token
83
- yield output_tokens
84
-
85
-
86
- # Interfaz usando ChatInterface de Gradio
 
 
 
 
 
87
  demo = gr.ChatInterface(
88
- respond,
89
  additional_inputs=[
90
  gr.Textbox(
91
- label="Mensaje del sistema",
92
  value=(
93
  "Eres Juan, un asistente virtual en español. "
94
  "Debes responder con mucha paciencia y empatía a usuarios que "
@@ -96,20 +140,21 @@ demo = gr.ChatInterface(
96
  "Provee explicaciones simples, procura entender la intención del usuario "
97
  "aunque la frase esté mal escrita, y mantén siempre un tono amable."
98
  ),
 
99
  ),
100
  gr.Slider(
101
- minimum=1,
102
- maximum=2048,
103
- value=512,
104
- step=1,
105
- label="Máxima cantidad de tokens",
106
  ),
107
  gr.Slider(
108
- minimum=0.1,
109
- maximum=4.0,
110
- value=0.7,
111
- step=0.1,
112
- label="Temperatura",
113
  ),
114
  gr.Slider(
115
  minimum=0.1,
@@ -118,9 +163,14 @@ demo = gr.ChatInterface(
118
  step=0.05,
119
  label="Top-p (muestreo por núcleo)",
120
  ),
 
 
 
 
 
 
121
  ],
122
  )
123
 
124
  if __name__ == "__main__":
125
- print("Iniciando servidor Gradio...")
126
  demo.launch()
 
1
  import os
 
2
  import gradio as gr
3
+ import requests
4
+ from huggingface_hub import InferenceClient
5
+
6
+ """
7
+ For more information on `huggingface_hub` Inference API support,
8
+ please check the docs:
9
+ https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
10
+ """
11
+
12
+ # ----------------------------------------------------------------
13
+ # CONFIGURACIÓN DE SERPER (búsqueda web)
14
+ # ----------------------------------------------------------------
15
+ SERPER_API_KEY = os.getenv("SERPER_API_KEY")
16
+
17
+ def do_websearch(query: str) -> str:
18
+ """
19
+ Llama a serper.dev para hacer la búsqueda en Google y devolver
20
+ un texto resumido de los resultados.
21
+ """
22
+ if not SERPER_API_KEY:
23
+ return "(SERPER_API_KEY no está configurado)"
24
+
25
+ url = "https://google.serper.dev/search"
26
+ headers = {
27
+ "X-API-KEY": SERPER_API_KEY,
28
+ "Content-Type": "application/json",
29
+ }
30
+ payload = {"q": query}
31
+
32
+ try:
33
+ resp = requests.post(url, json=payload, headers=headers, timeout=10)
34
+ data = resp.json()
35
+ except Exception as e:
36
+ return f"(Error al llamar a serper.dev: {e})"
37
+
38
+ # Se espera un campo 'organic' con resultados
39
+ if "organic" not in data:
40
+ return "No se encontraron resultados en serper.dev."
41
+
42
+ results = data["organic"]
43
+ if not results:
44
+ return "No hay resultados relevantes."
45
+
46
+ text = []
47
+ for i, item in enumerate(results, start=1):
48
+ title = item.get("title", "Sin título")
49
+ link = item.get("link", "Sin enlace")
50
+ text.append(f"{i}. {title}\n {link}")
51
+
52
+ return "\n".join(text)
53
 
54
+
55
+ # ----------------------------------------------------------------
56
+ # CONFIGURACIÓN DEL MODELO
57
+ # ----------------------------------------------------------------
58
+ client = InferenceClient("meta-llama/Llama-3.1-8B-Instruct")
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def respond(
61
  message,
 
64
  max_tokens,
65
  temperature,
66
  top_p,
67
+ use_search # <-- Nuevo parámetro: si está "activado" el botón
68
  ):
69
  """
70
+ - system_message: Texto del rol "system"
71
+ - history: lista de (user_msg, assistant_msg)
72
+ - message: Mensaje actual del usuario
73
+ - use_search: booleano que indica si se habilita la búsqueda en serper
74
  """
75
+
76
+ # ----------------------------------------------------------------
77
+ # 1) Si el toggle está activo, hacemos búsqueda y la agregamos al prompt
78
+ # ----------------------------------------------------------------
79
+ if use_search:
80
+ web_info = do_websearch(message)
81
+ # Agregamos info al final del texto del usuario
82
+ message = f"{message}\nInformación de la web:\n{web_info}"
83
+
84
+ # ----------------------------------------------------------------
85
+ # 2) Construimos la lista de mensajes para la API de chat
86
+ # ----------------------------------------------------------------
87
+ messages = [{"role": "system", "content": system_message}]
88
+ for val in history:
89
+ if val[0]:
90
+ messages.append({"role": "user", "content": val[0]})
91
+ if val[1]:
92
+ messages.append({"role": "assistant", "content": val[1]})
93
+
94
+ # Añadimos el mensaje nuevo del usuario (posiblemente complementado con la info web)
95
+ messages.append({"role": "user", "content": message})
96
+
97
+ # ----------------------------------------------------------------
98
+ # 3) Llamamos a la API con streaming de tokens
99
+ # ----------------------------------------------------------------
100
+ response = ""
101
+ for chunk in client.chat_completion(
102
+ messages,
103
+ max_tokens=max_tokens,
104
+ stream=True,
105
  temperature=temperature,
106
  top_p=top_p,
107
+ ):
108
+ token = chunk.choices[0].delta.get("content", "")
109
+ response += token
110
+ yield response
111
+
112
+
113
+ # ----------------------------------------------------------------
114
+ # CONFIGURACIÓN DE LA INTERFAZ
115
+ # ----------------------------------------------------------------
116
+ # Para usar Tailwind, podemos asignar clases en "elem_classes".
117
+ # Ejemplo de clases genéricas (puedes cambiarlas a tu gusto):
118
+ tailwind_toggle_classes = [
119
+ "inline-flex",
120
+ "items-center",
121
+ "bg-blue-500",
122
+ "hover:bg-blue-700",
123
+ "text-white",
124
+ "font-bold",
125
+ "py-1",
126
+ "px-2",
127
+ "rounded",
128
+ "cursor-pointer"
129
+ ]
130
+
131
+ # ChatInterface, con un input Checkbox para "🌐 Búsqueda"
132
  demo = gr.ChatInterface(
133
+ fn=respond,
134
  additional_inputs=[
135
  gr.Textbox(
 
136
  value=(
137
  "Eres Juan, un asistente virtual en español. "
138
  "Debes responder con mucha paciencia y empatía a usuarios que "
 
140
  "Provee explicaciones simples, procura entender la intención del usuario "
141
  "aunque la frase esté mal escrita, y mantén siempre un tono amable."
142
  ),
143
+ label="Mensaje del sistema",
144
  ),
145
  gr.Slider(
146
+ minimum=1,
147
+ maximum=2048,
148
+ value=512,
149
+ step=1,
150
+ label="Máxima cantidad de tokens"
151
  ),
152
  gr.Slider(
153
+ minimum=0.1,
154
+ maximum=4.0,
155
+ value=0.7,
156
+ step=0.1,
157
+ label="Temperatura"
158
  ),
159
  gr.Slider(
160
  minimum=0.1,
 
163
  step=0.05,
164
  label="Top-p (muestreo por núcleo)",
165
  ),
166
+ # Un checkbox que hace de "toggle" para la búsqueda
167
+ gr.Checkbox(
168
+ value=False, # Por defecto desactivado
169
+ label="🌐 Búsqueda", # Etiqueta
170
+ elem_classes=tailwind_toggle_classes
171
+ ),
172
  ],
173
  )
174
 
175
  if __name__ == "__main__":
 
176
  demo.launch()