stools / app.py
VeuReu's picture
Update app.py
277c5a2 verified
# app.py — veureu/stools (Salamandra 7B Tools · ZeroGPU) — compatible con ENGINE
from __future__ import annotations
import os, json, re
from typing import List, Dict, Any, Optional, Tuple
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from salamandra_tools import SalamandraClient
# ================= Config =================
MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-tools")
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
_tok = None
_model = None
def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
global _tok, _model
if _tok is None or _model is None:
_tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
low_cpu_mem_usage=True,
use_safetensors=True,
trust_remote_code=True,
device_map=None,
).to(DEVICE)
return _tok, _model
# =============== Helpers ===============
def _render_tools_md(tools: List[Dict[str, Any]]) -> str:
"""Convierte la especificación OpenAI-style de tools a un bloque breve markdown para el prompt."""
if not tools:
return ""
lines = ["Herramientas disponibles (formato JSON):"]
for t in tools:
name = t.get("function", {}).get("name") or t.get("name") or "tool"
desc = t.get("function", {}).get("description") or t.get("description") or ""
params = t.get("function", {}).get("parameters") or t.get("parameters") or {}
lines.append(f"- **{name}**: {desc} | parámetros: {json.dumps(params)[:600]}")
return "\n".join(lines)
def _compose_chat_prompt(messages: List[Dict[str, str]], tools_md: str) -> str:
"""
Soporta mensajes estilo OpenAI: [{"role":"system|user|assistant", "content":"..."}]
Usa chat_template si está disponible.
"""
tok, _ = _lazy_load()
sys_text = ""
usr_msgs: List[Dict[str, str]] = []
for m in messages:
role = m.get("role", "")
content = (m.get("content") or "").strip()
if role == "system":
sys_text += ("\n" + content) if sys_text else content
else:
usr_msgs.append({"role": role, "content": content})
# injerta descripción de tools en el system
if tools_md:
sys_text = (sys_text + "\n\n" if sys_text else "") + tools_md + \
"\n\nSi decides llamar a una herramienta, devuelve un objeto JSON con la clave 'tool_calls' " \
"y describe tus razonamientos de forma concisa en 'thought' (opcional)."
# reconstruimos la conversación con system delante
conv: List[Dict[str, str]] = []
if sys_text:
conv.append({"role":"system", "content": sys_text})
conv.extend(usr_msgs)
chat_template = getattr(tok, "chat_template", None)
if chat_template:
return tok.apply_chat_template(conv, tokenize=False, add_generation_prompt=True)
# Fallback sin plantilla
rendered = ""
if sys_text:
rendered += f"<<SYS>>\n{sys_text}\n<</SYS>>\n\n"
for m in usr_msgs:
if m["role"] == "user":
rendered += f"### Usuario\n{m['content']}\n\n"
elif m["role"] == "assistant":
rendered += f"### Asistente\n{m['content']}\n\n"
rendered += "### Asistente\n"
return rendered
# =============== (Opcional) Mini-ejecutor local de herramientas seguras ===============
# Si el LLM devuelve {"tool_calls":[{"name":"calculator","arguments":{"expr":"2+2"}}]}
# podemos ejecutar algunas herramientas inofensivas de ejemplo.
# Nota: mantén esto muy simple/seguro. Puedes desactivarlo poniendo EXECUTE_TOOLS=False.
EXECUTE_TOOLS = True
def _safe_calculator(expr: str) -> str:
# Permite solo dígitos, espacios, (), y +-*/.%**
if not re.fullmatch(r"[0-9\.\s\+\-\*\/\%\(\)\^eE]+", expr.replace("**","^")):
return "Rejected expression."
# soporta ^ como potencia -> **
expr = expr.replace("^", "**")
try:
return str(eval(expr, {"__builtins__":{}}, {}))
except Exception as e:
return f"Error: {e}"
LOCAL_TOOLBOX = {
"calculator": lambda args: _safe_calculator(str(args.get("expr",""))),
}
def maybe_execute_tool_calls(tool_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
if not EXECUTE_TOOLS:
return []
results = []
for call in tool_calls:
name = call.get("name")
args = call.get("arguments", {})
fn = LOCAL_TOOLBOX.get(name)
if fn is None:
results.append({"name": name, "error": "tool_not_available"})
continue
try:
out = fn(args)
results.append({"name": name, "output": out})
except Exception as e:
results.append({"name": name, "error": str(e)})
return results
# =============== Core generation ===============
@spaces.GPU # usa GPU si está disponible (ZeroGPU)
def _generate_with_tools(
messages: List[Dict[str, str]],
tools: List[Dict[str, Any]],
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.95,
) -> Dict[str, Any]:
tok, model = _lazy_load()
tools_md = _render_tools_md(tools)
prompt = _compose_chat_prompt(messages, tools_md)
inputs = tok(prompt, return_tensors="pt").to(DEVICE)
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=True if temperature > 0 else False,
pad_token_id=tok.eos_token_id,
eos_token_id=tok.eos_token_id,
)
text = tok.decode(out[0], skip_special_tokens=True).strip()
# Si el modelo devuelve un bloque JSON con 'tool_calls', lo intentamos extraer.
tool_calls: List[Dict[str, Any]] = []
try:
# busca el último {...} que contenga "tool_calls"
matches = list(re.finditer(r"\{.*?\"tool_calls\".*?\}", text, flags=re.S))
if matches:
block = text[matches[-1].start():matches[-1].end()]
obj = json.loads(block)
tc = obj.get("tool_calls", [])
if isinstance(tc, list):
tool_calls = tc
except Exception:
pass
tool_results = maybe_execute_tool_calls(tool_calls) if tool_calls else []
return {"text": text, "tool_calls": tool_calls, "tool_results": tool_results}
# =================== Gradio Endpoints ===================
def predict_for_engine(messages_json: str, tools_json: str) -> Dict[str, Any]:
"""
Endpoint esperado por ENGINE (ToolsClient.chat):
- messages_json: JSON de [{"role":"user|assistant|system","content":"..."}]
- tools_json: JSON OpenAI-like de herramientas (opcional)
Devuelve: {"text": "...", "tool_calls": [...], "tool_results": [...]}
"""
try:
messages = json.loads(messages_json) if messages_json else []
except Exception:
messages = []
try:
tools = json.loads(tools_json) if tools_json else []
except Exception:
tools = []
return _generate_with_tools(messages, tools, max_new_tokens=512, temperature=0.7, top_p=0.95)
def chat_advanced(messages_json: str, tools_json: str, max_new_tokens: int, temperature: float, top_p: float) -> Dict[str, Any]:
try:
messages = json.loads(messages_json) if messages_json else []
except Exception:
messages = []
try:
tools = json.loads(tools_json) if tools_json else []
except Exception:
tools = []
return _generate_with_tools(messages, tools, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p))
_salamandra = None
def salamandra_chat_endpoint(prompt: str) -> Dict[str, Any]:
global _salamandra
if _salamandra is None:
_salamandra = SalamandraClient() # usa tu clase
try:
text = _salamandra.chat(prompt)
except Exception as e:
text = f"Error ejecutando SalamandraClient: {str(e)}"
return {"text": text}
# =================== UI ===================
custom_css = """
h2 {
background: #e3e4e6 !important;
padding: 14px 22px !important;
border-radius: 14px !important;
box-shadow: 0 4px 12px rgba(0,0,0,0.08) !important;
display: block !important; /* ocupa tot l'ample */
width: 100% !important; /* assegura 100% */
margin: 20px auto !important;
text-align:center;
}
"""
# Main interface for Salamandra 7B Tools: supports tool specification (function calling)
with gr.Blocks(title="Salamandra 7B Tools · ZeroGPU", css=custom_css, theme=gr.themes.Soft()) as demo:
# Header description for the UI
gr.Markdown("## Salamandra-7B-Tools · ZeroGPU\nXat amb especificació d'eines (function-calling).")
with gr.Row():
with gr.Column():
# JSON array of messages passed to the model
messages = gr.Textbox(
label="Missatges (JSON)",
value='[{"role":"user","content":"Quant és (2+2)^3?"}]',
lines=6
)
# Tool definitions in JSON schema format
tools = gr.Textbox(
label="Eines (JSON, opcional)",
value='[{"type":"function","function":{"name":"calculator","description":"Avalua expressions aritmètiques bàsiques.","parameters":{"type":"object","properties":{"expr":{"type":"string"}},"required":["expr"]}}}]',
lines=6
)
# Maximum generation length
max_new = gr.Slider(16, 2048, value=512, step=16, label="Màxim de tokens nous")
# Temperature for randomness
temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperatura")
# Nucleus sampling threshold
topp = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p")
# Button to generate a response
btn = gr.Button("Generar", variant="primary")
with gr.Column():
# JSON output from the model
out = gr.JSON(label="Sortida")
# Bind chat-with-tools generation
btn.click(
chat_advanced,
[messages, tools, max_new, temp, topp],
out,
api_name="chat",
concurrency_limit=1
)
# --------------------------------------------------------------
gr.Markdown("---")
# --------------------------------------------------------------
# Minimal /predict endpoint for ENGINE compatibility
# Accepts messages + tool definitions
gr.Button("Provar /predict").click(
predict_for_engine,
[messages, tools],
out,
api_name="predict",
concurrency_limit=1
)
# --------------------------------------------------------------
gr.Markdown("---")
# --------------------------------------------------------------
# Endpoint: raw prompt → model output (JSON)
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=10)
with gr.Row():
btn2 = gr.Button("Generar", variant="primary")
with gr.Row():
out2 = gr.JSON(label="Sortida")
btn2.click(
salamandra_chat_endpoint,
[prompt],
out2,
api_name="generate_out_from_prompt",
concurrency_limit=1
)
# --------------------------------------------------------------
gr.Markdown("---")
# --------------------------------------------------------------
# Enable request queue for concurrency safety
demo.queue(max_size=16).launch()