VeuReu commited on
Commit
e09a32e
·
verified ·
1 Parent(s): 4abe767

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +255 -231
app.py CHANGED
@@ -1,231 +1,255 @@
1
- # app.py — veureu/stools (Salamandra 7B Tools · ZeroGPU) — compatible con ENGINE
2
- from __future__ import annotations
3
- import os, json, re
4
- from typing import List, Dict, Any, Optional, Tuple
5
-
6
- import gradio as gr
7
- import spaces
8
- import torch
9
- from transformers import AutoTokenizer, AutoModelForCausalLM
10
-
11
- # ================= Config =================
12
- MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-tools")
13
- DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
14
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
-
16
- _tok = None
17
- _model = None
18
-
19
- def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
20
- global _tok, _model
21
- if _tok is None or _model is None:
22
- _tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
23
- _model = AutoModelForCausalLM.from_pretrained(
24
- MODEL_ID,
25
- torch_dtype=DTYPE,
26
- low_cpu_mem_usage=True,
27
- use_safetensors=True,
28
- trust_remote_code=True,
29
- device_map=None,
30
- ).to(DEVICE)
31
- return _tok, _model
32
-
33
-
34
- # =============== Helpers ===============
35
-
36
- def _render_tools_md(tools: List[Dict[str, Any]]) -> str:
37
- """Convierte la especificación OpenAI-style de tools a un bloque breve markdown para el prompt."""
38
- if not tools:
39
- return ""
40
- lines = ["Herramientas disponibles (formato JSON):"]
41
- for t in tools:
42
- name = t.get("function", {}).get("name") or t.get("name") or "tool"
43
- desc = t.get("function", {}).get("description") or t.get("description") or ""
44
- params = t.get("function", {}).get("parameters") or t.get("parameters") or {}
45
- lines.append(f"- **{name}**: {desc} | parámetros: {json.dumps(params)[:600]}")
46
- return "\n".join(lines)
47
-
48
- def _compose_chat_prompt(messages: List[Dict[str, str]], tools_md: str) -> str:
49
- """
50
- Soporta mensajes estilo OpenAI: [{"role":"system|user|assistant", "content":"..."}]
51
- Usa chat_template si está disponible.
52
- """
53
- tok, _ = _lazy_load()
54
- sys_text = ""
55
- usr_msgs: List[Dict[str, str]] = []
56
- for m in messages:
57
- role = m.get("role", "")
58
- content = (m.get("content") or "").strip()
59
- if role == "system":
60
- sys_text += ("\n" + content) if sys_text else content
61
- else:
62
- usr_msgs.append({"role": role, "content": content})
63
-
64
- # injerta descripción de tools en el system
65
- if tools_md:
66
- sys_text = (sys_text + "\n\n" if sys_text else "") + tools_md + \
67
- "\n\nSi decides llamar a una herramienta, devuelve un objeto JSON con la clave 'tool_calls' " \
68
- "y describe tus razonamientos de forma concisa en 'thought' (opcional)."
69
-
70
- # reconstruimos la conversación con system delante
71
- conv: List[Dict[str, str]] = []
72
- if sys_text:
73
- conv.append({"role":"system", "content": sys_text})
74
- conv.extend(usr_msgs)
75
-
76
- chat_template = getattr(tok, "chat_template", None)
77
- if chat_template:
78
- return tok.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
79
-
80
- # Fallback sin plantilla
81
- rendered = ""
82
- if sys_text:
83
- rendered += f"<<SYS>>\n{sys_text}\n<</SYS>>\n\n"
84
- for m in usr_msgs:
85
- if m["role"] == "user":
86
- rendered += f"### Usuario\n{m['content']}\n\n"
87
- elif m["role"] == "assistant":
88
- rendered += f"### Asistente\n{m['content']}\n\n"
89
- rendered += "### Asistente\n"
90
- return rendered
91
-
92
-
93
- # =============== (Opcional) Mini-ejecutor local de herramientas seguras ===============
94
- # Si el LLM devuelve {"tool_calls":[{"name":"calculator","arguments":{"expr":"2+2"}}]}
95
- # podemos ejecutar algunas herramientas inofensivas de ejemplo.
96
- # Nota: mantén esto muy simple/seguro. Puedes desactivarlo poniendo EXECUTE_TOOLS=False.
97
- EXECUTE_TOOLS = True
98
-
99
- def _safe_calculator(expr: str) -> str:
100
- # Permite solo dígitos, espacios, (), y +-*/.%**
101
- if not re.fullmatch(r"[0-9\.\s\+\-\*\/\%\(\)\^eE]+", expr.replace("**","^")):
102
- return "Rejected expression."
103
- # soporta ^ como potencia -> **
104
- expr = expr.replace("^", "**")
105
- try:
106
- return str(eval(expr, {"__builtins__":{}}, {}))
107
- except Exception as e:
108
- return f"Error: {e}"
109
-
110
- LOCAL_TOOLBOX = {
111
- "calculator": lambda args: _safe_calculator(str(args.get("expr",""))),
112
- }
113
-
114
- def maybe_execute_tool_calls(tool_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
115
- if not EXECUTE_TOOLS:
116
- return []
117
- results = []
118
- for call in tool_calls:
119
- name = call.get("name")
120
- args = call.get("arguments", {})
121
- fn = LOCAL_TOOLBOX.get(name)
122
- if fn is None:
123
- results.append({"name": name, "error": "tool_not_available"})
124
- continue
125
- try:
126
- out = fn(args)
127
- results.append({"name": name, "output": out})
128
- except Exception as e:
129
- results.append({"name": name, "error": str(e)})
130
- return results
131
-
132
-
133
- # =============== Core generation ===============
134
-
135
- @spaces.GPU # usa GPU si está disponible (ZeroGPU)
136
- def _generate_with_tools(
137
- messages: List[Dict[str, str]],
138
- tools: List[Dict[str, Any]],
139
- max_new_tokens: int = 512,
140
- temperature: float = 0.7,
141
- top_p: float = 0.95,
142
- ) -> Dict[str, Any]:
143
- tok, model = _lazy_load()
144
- tools_md = _render_tools_md(tools)
145
- prompt = _compose_chat_prompt(messages, tools_md)
146
-
147
- inputs = tok(prompt, return_tensors="pt").to(DEVICE)
148
- with torch.inference_mode():
149
- out = model.generate(
150
- **inputs,
151
- max_new_tokens=int(max_new_tokens),
152
- temperature=float(temperature),
153
- top_p=float(top_p),
154
- do_sample=True if temperature > 0 else False,
155
- pad_token_id=tok.eos_token_id,
156
- eos_token_id=tok.eos_token_id,
157
- )
158
- text = tok.decode(out[0], skip_special_tokens=True).strip()
159
-
160
- # Si el modelo devuelve un bloque JSON con 'tool_calls', lo intentamos extraer.
161
- tool_calls: List[Dict[str, Any]] = []
162
- try:
163
- # busca el último {...} que contenga "tool_calls"
164
- matches = list(re.finditer(r"\{.*?\"tool_calls\".*?\}", text, flags=re.S))
165
- if matches:
166
- block = text[matches[-1].start():matches[-1].end()]
167
- obj = json.loads(block)
168
- tc = obj.get("tool_calls", [])
169
- if isinstance(tc, list):
170
- tool_calls = tc
171
- except Exception:
172
- pass
173
-
174
- tool_results = maybe_execute_tool_calls(tool_calls) if tool_calls else []
175
-
176
- return {"text": text, "tool_calls": tool_calls, "tool_results": tool_results}
177
-
178
-
179
- # =================== Gradio Endpoints ===================
180
-
181
- def predict_for_engine(messages_json: str, tools_json: str) -> Dict[str, Any]:
182
- """
183
- Endpoint esperado por ENGINE (ToolsClient.chat):
184
- - messages_json: JSON de [{"role":"user|assistant|system","content":"..."}]
185
- - tools_json: JSON OpenAI-like de herramientas (opcional)
186
- Devuelve: {"text": "...", "tool_calls": [...], "tool_results": [...]}
187
- """
188
- try:
189
- messages = json.loads(messages_json) if messages_json else []
190
- except Exception:
191
- messages = []
192
- try:
193
- tools = json.loads(tools_json) if tools_json else []
194
- except Exception:
195
- tools = []
196
- return _generate_with_tools(messages, tools, max_new_tokens=512, temperature=0.7, top_p=0.95)
197
-
198
- def chat_advanced(messages_json: str, tools_json: str, max_new_tokens: int, temperature: float, top_p: float) -> Dict[str, Any]:
199
- try:
200
- messages = json.loads(messages_json) if messages_json else []
201
- except Exception:
202
- messages = []
203
- try:
204
- tools = json.loads(tools_json) if tools_json else []
205
- except Exception:
206
- tools = []
207
- return _generate_with_tools(messages, tools, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p))
208
-
209
-
210
- # =================== UI ===================
211
-
212
- with gr.Blocks(title="Salamandra 7B Tools · ZeroGPU") as demo:
213
- gr.Markdown("## Salamandra-7B-Tools · ZeroGPU\nChat con especificación de herramientas (function-calling).")
214
-
215
- with gr.Row():
216
- with gr.Column():
217
- messages = gr.Textbox(label="messages_json", value='[{"role":"user","content":"¿Cuánto es (2+2)^3?"}]', lines=6)
218
- tools = gr.Textbox(label="tools_json (opcional)", value='[{"type":"function","function":{"name":"calculator","description":"Evalúa expresiones aritméticas básicas.","parameters":{"type":"object","properties":{"expr":{"type":"string"}},"required":["expr"]}}}]', lines=6)
219
- max_new = gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens")
220
- temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
221
- topp = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p")
222
- btn = gr.Button("Generar", variant="primary")
223
- with gr.Column():
224
- out = gr.JSON(label="Salida")
225
-
226
- btn.click(chat_advanced, [messages, tools, max_new, temp, topp], out, api_name="chat", concurrency_limit=1)
227
-
228
- # Endpoint minimalista /predict para ENGINE (mensajes + tools)
229
- gr.Button("Probar /predict").click(predict_for_engine, [messages, tools], out, api_name="predict", concurrency_limit=1)
230
-
231
- demo.queue(max_size=16).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py — veureu/stools (Salamandra 7B Tools · ZeroGPU) — compatible con ENGINE
2
+ from __future__ import annotations
3
+ import os, json, re
4
+ from typing import List, Dict, Any, Optional, Tuple
5
+
6
+ import gradio as gr
7
+ import spaces
8
+ import torch
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM
10
+ from moe_tools import SalamandraClient
11
+
12
+ # ================= Config =================
13
+ MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-tools")
14
+ DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ _tok = None
18
+ _model = None
19
+
20
+ def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
21
+ global _tok, _model
22
+ if _tok is None or _model is None:
23
+ _tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
24
+ _model = AutoModelForCausalLM.from_pretrained(
25
+ MODEL_ID,
26
+ torch_dtype=DTYPE,
27
+ low_cpu_mem_usage=True,
28
+ use_safetensors=True,
29
+ trust_remote_code=True,
30
+ device_map=None,
31
+ ).to(DEVICE)
32
+ return _tok, _model
33
+
34
+
35
+ # =============== Helpers ===============
36
+
37
+ def _render_tools_md(tools: List[Dict[str, Any]]) -> str:
38
+ """Convierte la especificación OpenAI-style de tools a un bloque breve markdown para el prompt."""
39
+ if not tools:
40
+ return ""
41
+ lines = ["Herramientas disponibles (formato JSON):"]
42
+ for t in tools:
43
+ name = t.get("function", {}).get("name") or t.get("name") or "tool"
44
+ desc = t.get("function", {}).get("description") or t.get("description") or ""
45
+ params = t.get("function", {}).get("parameters") or t.get("parameters") or {}
46
+ lines.append(f"- **{name}**: {desc} | parámetros: {json.dumps(params)[:600]}")
47
+ return "\n".join(lines)
48
+
49
+ def _compose_chat_prompt(messages: List[Dict[str, str]], tools_md: str) -> str:
50
+ """
51
+ Soporta mensajes estilo OpenAI: [{"role":"system|user|assistant", "content":"..."}]
52
+ Usa chat_template si está disponible.
53
+ """
54
+ tok, _ = _lazy_load()
55
+ sys_text = ""
56
+ usr_msgs: List[Dict[str, str]] = []
57
+ for m in messages:
58
+ role = m.get("role", "")
59
+ content = (m.get("content") or "").strip()
60
+ if role == "system":
61
+ sys_text += ("\n" + content) if sys_text else content
62
+ else:
63
+ usr_msgs.append({"role": role, "content": content})
64
+
65
+ # injerta descripción de tools en el system
66
+ if tools_md:
67
+ sys_text = (sys_text + "\n\n" if sys_text else "") + tools_md + \
68
+ "\n\nSi decides llamar a una herramienta, devuelve un objeto JSON con la clave 'tool_calls' " \
69
+ "y describe tus razonamientos de forma concisa en 'thought' (opcional)."
70
+
71
+ # reconstruimos la conversación con system delante
72
+ conv: List[Dict[str, str]] = []
73
+ if sys_text:
74
+ conv.append({"role":"system", "content": sys_text})
75
+ conv.extend(usr_msgs)
76
+
77
+ chat_template = getattr(tok, "chat_template", None)
78
+ if chat_template:
79
+ return tok.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
80
+
81
+ # Fallback sin plantilla
82
+ rendered = ""
83
+ if sys_text:
84
+ rendered += f"<<SYS>>\n{sys_text}\n<</SYS>>\n\n"
85
+ for m in usr_msgs:
86
+ if m["role"] == "user":
87
+ rendered += f"### Usuario\n{m['content']}\n\n"
88
+ elif m["role"] == "assistant":
89
+ rendered += f"### Asistente\n{m['content']}\n\n"
90
+ rendered += "### Asistente\n"
91
+ return rendered
92
+
93
+
94
+ # =============== (Opcional) Mini-ejecutor local de herramientas seguras ===============
95
+ # Si el LLM devuelve {"tool_calls":[{"name":"calculator","arguments":{"expr":"2+2"}}]}
96
+ # podemos ejecutar algunas herramientas inofensivas de ejemplo.
97
+ # Nota: mantén esto muy simple/seguro. Puedes desactivarlo poniendo EXECUTE_TOOLS=False.
98
+ EXECUTE_TOOLS = True
99
+
100
+ def _safe_calculator(expr: str) -> str:
101
+ # Permite solo dígitos, espacios, (), y +-*/.%**
102
+ if not re.fullmatch(r"[0-9\.\s\+\-\*\/\%\(\)\^eE]+", expr.replace("**","^")):
103
+ return "Rejected expression."
104
+ # soporta ^ como potencia -> **
105
+ expr = expr.replace("^", "**")
106
+ try:
107
+ return str(eval(expr, {"__builtins__":{}}, {}))
108
+ except Exception as e:
109
+ return f"Error: {e}"
110
+
111
+ LOCAL_TOOLBOX = {
112
+ "calculator": lambda args: _safe_calculator(str(args.get("expr",""))),
113
+ }
114
+
115
+ def maybe_execute_tool_calls(tool_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
116
+ if not EXECUTE_TOOLS:
117
+ return []
118
+ results = []
119
+ for call in tool_calls:
120
+ name = call.get("name")
121
+ args = call.get("arguments", {})
122
+ fn = LOCAL_TOOLBOX.get(name)
123
+ if fn is None:
124
+ results.append({"name": name, "error": "tool_not_available"})
125
+ continue
126
+ try:
127
+ out = fn(args)
128
+ results.append({"name": name, "output": out})
129
+ except Exception as e:
130
+ results.append({"name": name, "error": str(e)})
131
+ return results
132
+
133
+
134
+ # =============== Core generation ===============
135
+
136
+ @spaces.GPU # usa GPU si está disponible (ZeroGPU)
137
+ def _generate_with_tools(
138
+ messages: List[Dict[str, str]],
139
+ tools: List[Dict[str, Any]],
140
+ max_new_tokens: int = 512,
141
+ temperature: float = 0.7,
142
+ top_p: float = 0.95,
143
+ ) -> Dict[str, Any]:
144
+ tok, model = _lazy_load()
145
+ tools_md = _render_tools_md(tools)
146
+ prompt = _compose_chat_prompt(messages, tools_md)
147
+
148
+ inputs = tok(prompt, return_tensors="pt").to(DEVICE)
149
+ with torch.inference_mode():
150
+ out = model.generate(
151
+ **inputs,
152
+ max_new_tokens=int(max_new_tokens),
153
+ temperature=float(temperature),
154
+ top_p=float(top_p),
155
+ do_sample=True if temperature > 0 else False,
156
+ pad_token_id=tok.eos_token_id,
157
+ eos_token_id=tok.eos_token_id,
158
+ )
159
+ text = tok.decode(out[0], skip_special_tokens=True).strip()
160
+
161
+ # Si el modelo devuelve un bloque JSON con 'tool_calls', lo intentamos extraer.
162
+ tool_calls: List[Dict[str, Any]] = []
163
+ try:
164
+ # busca el último {...} que contenga "tool_calls"
165
+ matches = list(re.finditer(r"\{.*?\"tool_calls\".*?\}", text, flags=re.S))
166
+ if matches:
167
+ block = text[matches[-1].start():matches[-1].end()]
168
+ obj = json.loads(block)
169
+ tc = obj.get("tool_calls", [])
170
+ if isinstance(tc, list):
171
+ tool_calls = tc
172
+ except Exception:
173
+ pass
174
+
175
+ tool_results = maybe_execute_tool_calls(tool_calls) if tool_calls else []
176
+
177
+ return {"text": text, "tool_calls": tool_calls, "tool_results": tool_results}
178
+
179
+
180
+ # =================== Gradio Endpoints ===================
181
+
182
+ def predict_for_engine(messages_json: str, tools_json: str) -> Dict[str, Any]:
183
+ """
184
+ Endpoint esperado por ENGINE (ToolsClient.chat):
185
+ - messages_json: JSON de [{"role":"user|assistant|system","content":"..."}]
186
+ - tools_json: JSON OpenAI-like de herramientas (opcional)
187
+ Devuelve: {"text": "...", "tool_calls": [...], "tool_results": [...]}
188
+ """
189
+ try:
190
+ messages = json.loads(messages_json) if messages_json else []
191
+ except Exception:
192
+ messages = []
193
+ try:
194
+ tools = json.loads(tools_json) if tools_json else []
195
+ except Exception:
196
+ tools = []
197
+ return _generate_with_tools(messages, tools, max_new_tokens=512, temperature=0.7, top_p=0.95)
198
+
199
+ def chat_advanced(messages_json: str, tools_json: str, max_new_tokens: int, temperature: float, top_p: float) -> Dict[str, Any]:
200
+ try:
201
+ messages = json.loads(messages_json) if messages_json else []
202
+ except Exception:
203
+ messages = []
204
+ try:
205
+ tools = json.loads(tools_json) if tools_json else []
206
+ except Exception:
207
+ tools = []
208
+ return _generate_with_tools(messages, tools, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p))
209
+
210
+
211
+ _salamandra = None
212
+
213
+ def salamandra_chat_endpoint(prompt: str) -> Dict[str, Any]:
214
+ global _salamandra
215
+ if _salamandra is None:
216
+ _salamandra = SalamandraClient() # usa tu clase
217
+
218
+ try:
219
+ text = _salamandra.chat(prompt)
220
+ except Exception as e:
221
+ text = f"Error ejecutando SalamandraClient: {str(e)}"
222
+
223
+ return {"text": text}
224
+
225
+ # =================== UI ===================
226
+
227
+ with gr.Blocks(title="Salamandra 7B Tools · ZeroGPU") as demo:
228
+ gr.Markdown("## Salamandra-7B-Tools · ZeroGPU\nChat con especificación de herramientas (function-calling).")
229
+
230
+ with gr.Row():
231
+ with gr.Column():
232
+ messages = gr.Textbox(label="messages_json", value='[{"role":"user","content":"¿Cuánto es (2+2)^3?"}]', lines=6)
233
+ tools = gr.Textbox(label="tools_json (opcional)", value='[{"type":"function","function":{"name":"calculator","description":"Evalúa expresiones aritméticas básicas.","parameters":{"type":"object","properties":{"expr":{"type":"string"}},"required":["expr"]}}}]', lines=6)
234
+ max_new = gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens")
235
+ temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
236
+ topp = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p")
237
+ btn = gr.Button("Generar", variant="primary")
238
+ with gr.Column():
239
+ out = gr.JSON(label="Salida")
240
+
241
+ btn.click(chat_advanced, [messages, tools, max_new, temp, topp], out, api_name="chat", concurrency_limit=1)
242
+
243
+ # Endpoint minimalista /predict para ENGINE (mensajes + tools)
244
+ gr.Button("Probar /predict").click(predict_for_engine, [messages, tools], out, api_name="predict", concurrency_limit=1)
245
+
246
+ with gr.Row():
247
+ prompt = gr.Textbox(label="prompt", lines=10)
248
+ with gr.Row():
249
+ btn2 = gr.Button("Generar", variant="primary")
250
+ with gr.Row():
251
+ out2 = gr.JSON(label="Salida")
252
+
253
+ btn2.click(salamandra_chat_endpoint, [prompt], out2, api_name="generate_out_from_prompt", concurrency_limit=1)
254
+
255
+ demo.queue(max_size=16).launch()