# app.py — veureu/stools (Salamandra 7B Tools · ZeroGPU) — compatible con ENGINE from __future__ import annotations import os, json, re from typing import List, Dict, Any, Optional, Tuple import gradio as gr import spaces import torch from transformers import AutoTokenizer, AutoModelForCausalLM from salamandra_tools import SalamandraClient # ================= Config ================= MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-tools") DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" _tok = None _model = None def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]: global _tok, _model if _tok is None or _model is None: _tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) _model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=DTYPE, low_cpu_mem_usage=True, use_safetensors=True, trust_remote_code=True, device_map=None, ).to(DEVICE) return _tok, _model # =============== Helpers =============== def _render_tools_md(tools: List[Dict[str, Any]]) -> str: """Convierte la especificación OpenAI-style de tools a un bloque breve markdown para el prompt.""" if not tools: return "" lines = ["Herramientas disponibles (formato JSON):"] for t in tools: name = t.get("function", {}).get("name") or t.get("name") or "tool" desc = t.get("function", {}).get("description") or t.get("description") or "" params = t.get("function", {}).get("parameters") or t.get("parameters") or {} lines.append(f"- **{name}**: {desc} | parámetros: {json.dumps(params)[:600]}") return "\n".join(lines) def _compose_chat_prompt(messages: List[Dict[str, str]], tools_md: str) -> str: """ Soporta mensajes estilo OpenAI: [{"role":"system|user|assistant", "content":"..."}] Usa chat_template si está disponible. """ tok, _ = _lazy_load() sys_text = "" usr_msgs: List[Dict[str, str]] = [] for m in messages: role = m.get("role", "") content = (m.get("content") or "").strip() if role == "system": sys_text += ("\n" + content) if sys_text else content else: usr_msgs.append({"role": role, "content": content}) # injerta descripción de tools en el system if tools_md: sys_text = (sys_text + "\n\n" if sys_text else "") + tools_md + \ "\n\nSi decides llamar a una herramienta, devuelve un objeto JSON con la clave 'tool_calls' " \ "y describe tus razonamientos de forma concisa en 'thought' (opcional)." # reconstruimos la conversación con system delante conv: List[Dict[str, str]] = [] if sys_text: conv.append({"role":"system", "content": sys_text}) conv.extend(usr_msgs) chat_template = getattr(tok, "chat_template", None) if chat_template: return tok.apply_chat_template(conv, tokenize=False, add_generation_prompt=True) # Fallback sin plantilla rendered = "" if sys_text: rendered += f"<>\n{sys_text}\n<>\n\n" for m in usr_msgs: if m["role"] == "user": rendered += f"### Usuario\n{m['content']}\n\n" elif m["role"] == "assistant": rendered += f"### Asistente\n{m['content']}\n\n" rendered += "### Asistente\n" return rendered # =============== (Opcional) Mini-ejecutor local de herramientas seguras =============== # Si el LLM devuelve {"tool_calls":[{"name":"calculator","arguments":{"expr":"2+2"}}]} # podemos ejecutar algunas herramientas inofensivas de ejemplo. # Nota: mantén esto muy simple/seguro. Puedes desactivarlo poniendo EXECUTE_TOOLS=False. EXECUTE_TOOLS = True def _safe_calculator(expr: str) -> str: # Permite solo dígitos, espacios, (), y +-*/.%** if not re.fullmatch(r"[0-9\.\s\+\-\*\/\%\(\)\^eE]+", expr.replace("**","^")): return "Rejected expression." # soporta ^ como potencia -> ** expr = expr.replace("^", "**") try: return str(eval(expr, {"__builtins__":{}}, {})) except Exception as e: return f"Error: {e}" LOCAL_TOOLBOX = { "calculator": lambda args: _safe_calculator(str(args.get("expr",""))), } def maybe_execute_tool_calls(tool_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]: if not EXECUTE_TOOLS: return [] results = [] for call in tool_calls: name = call.get("name") args = call.get("arguments", {}) fn = LOCAL_TOOLBOX.get(name) if fn is None: results.append({"name": name, "error": "tool_not_available"}) continue try: out = fn(args) results.append({"name": name, "output": out}) except Exception as e: results.append({"name": name, "error": str(e)}) return results # =============== Core generation =============== @spaces.GPU # usa GPU si está disponible (ZeroGPU) def _generate_with_tools( messages: List[Dict[str, str]], tools: List[Dict[str, Any]], max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95, ) -> Dict[str, Any]: tok, model = _lazy_load() tools_md = _render_tools_md(tools) prompt = _compose_chat_prompt(messages, tools_md) inputs = tok(prompt, return_tensors="pt").to(DEVICE) with torch.inference_mode(): out = model.generate( **inputs, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p), do_sample=True if temperature > 0 else False, pad_token_id=tok.eos_token_id, eos_token_id=tok.eos_token_id, ) text = tok.decode(out[0], skip_special_tokens=True).strip() # Si el modelo devuelve un bloque JSON con 'tool_calls', lo intentamos extraer. tool_calls: List[Dict[str, Any]] = [] try: # busca el último {...} que contenga "tool_calls" matches = list(re.finditer(r"\{.*?\"tool_calls\".*?\}", text, flags=re.S)) if matches: block = text[matches[-1].start():matches[-1].end()] obj = json.loads(block) tc = obj.get("tool_calls", []) if isinstance(tc, list): tool_calls = tc except Exception: pass tool_results = maybe_execute_tool_calls(tool_calls) if tool_calls else [] return {"text": text, "tool_calls": tool_calls, "tool_results": tool_results} # =================== Gradio Endpoints =================== def predict_for_engine(messages_json: str, tools_json: str) -> Dict[str, Any]: """ Endpoint esperado por ENGINE (ToolsClient.chat): - messages_json: JSON de [{"role":"user|assistant|system","content":"..."}] - tools_json: JSON OpenAI-like de herramientas (opcional) Devuelve: {"text": "...", "tool_calls": [...], "tool_results": [...]} """ try: messages = json.loads(messages_json) if messages_json else [] except Exception: messages = [] try: tools = json.loads(tools_json) if tools_json else [] except Exception: tools = [] return _generate_with_tools(messages, tools, max_new_tokens=512, temperature=0.7, top_p=0.95) def chat_advanced(messages_json: str, tools_json: str, max_new_tokens: int, temperature: float, top_p: float) -> Dict[str, Any]: try: messages = json.loads(messages_json) if messages_json else [] except Exception: messages = [] try: tools = json.loads(tools_json) if tools_json else [] except Exception: tools = [] return _generate_with_tools(messages, tools, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p)) _salamandra = None def salamandra_chat_endpoint(prompt: str) -> Dict[str, Any]: global _salamandra if _salamandra is None: _salamandra = SalamandraClient() # usa tu clase try: text = _salamandra.chat(prompt) except Exception as e: text = f"Error ejecutando SalamandraClient: {str(e)}" return {"text": text} # =================== UI =================== custom_css = """ h2 { background: #e3e4e6 !important; padding: 14px 22px !important; border-radius: 14px !important; box-shadow: 0 4px 12px rgba(0,0,0,0.08) !important; display: block !important; /* ocupa tot l'ample */ width: 100% !important; /* assegura 100% */ margin: 20px auto !important; text-align:center; } """ # Main interface for Salamandra 7B Tools: supports tool specification (function calling) with gr.Blocks(title="Salamandra 7B Tools · ZeroGPU", css=custom_css, theme=gr.themes.Soft()) as demo: # Header description for the UI gr.Markdown("## Salamandra-7B-Tools · ZeroGPU\nXat amb especificació d'eines (function-calling).") with gr.Row(): with gr.Column(): # JSON array of messages passed to the model messages = gr.Textbox( label="Missatges (JSON)", value='[{"role":"user","content":"Quant és (2+2)^3?"}]', lines=6 ) # Tool definitions in JSON schema format tools = gr.Textbox( label="Eines (JSON, opcional)", value='[{"type":"function","function":{"name":"calculator","description":"Avalua expressions aritmètiques bàsiques.","parameters":{"type":"object","properties":{"expr":{"type":"string"}},"required":["expr"]}}}]', lines=6 ) # Maximum generation length max_new = gr.Slider(16, 2048, value=512, step=16, label="Màxim de tokens nous") # Temperature for randomness temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperatura") # Nucleus sampling threshold topp = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p") # Button to generate a response btn = gr.Button("Generar", variant="primary") with gr.Column(): # JSON output from the model out = gr.JSON(label="Sortida") # Bind chat-with-tools generation btn.click( chat_advanced, [messages, tools, max_new, temp, topp], out, api_name="chat", concurrency_limit=1 ) # -------------------------------------------------------------- gr.Markdown("---") # -------------------------------------------------------------- # Minimal /predict endpoint for ENGINE compatibility # Accepts messages + tool definitions gr.Button("Provar /predict").click( predict_for_engine, [messages, tools], out, api_name="predict", concurrency_limit=1 ) # -------------------------------------------------------------- gr.Markdown("---") # -------------------------------------------------------------- # Endpoint: raw prompt → model output (JSON) with gr.Row(): prompt = gr.Textbox(label="Prompt", lines=10) with gr.Row(): btn2 = gr.Button("Generar", variant="primary") with gr.Row(): out2 = gr.JSON(label="Sortida") btn2.click( salamandra_chat_endpoint, [prompt], out2, api_name="generate_out_from_prompt", concurrency_limit=1 ) # -------------------------------------------------------------- gr.Markdown("---") # -------------------------------------------------------------- # Enable request queue for concurrency safety demo.queue(max_size=16).launch()