|
|
|
|
|
from __future__ import annotations |
|
|
import os, json, re |
|
|
from typing import List, Dict, Any, Optional, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from salamandra_tools import SalamandraClient |
|
|
|
|
|
|
|
|
MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-tools") |
|
|
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
_tok = None |
|
|
_model = None |
|
|
|
|
|
def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]: |
|
|
global _tok, _model |
|
|
if _tok is None or _model is None: |
|
|
_tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) |
|
|
_model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=DTYPE, |
|
|
low_cpu_mem_usage=True, |
|
|
use_safetensors=True, |
|
|
trust_remote_code=True, |
|
|
device_map=None, |
|
|
).to(DEVICE) |
|
|
return _tok, _model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _render_tools_md(tools: List[Dict[str, Any]]) -> str: |
|
|
"""Convierte la especificación OpenAI-style de tools a un bloque breve markdown para el prompt.""" |
|
|
if not tools: |
|
|
return "" |
|
|
lines = ["Herramientas disponibles (formato JSON):"] |
|
|
for t in tools: |
|
|
name = t.get("function", {}).get("name") or t.get("name") or "tool" |
|
|
desc = t.get("function", {}).get("description") or t.get("description") or "" |
|
|
params = t.get("function", {}).get("parameters") or t.get("parameters") or {} |
|
|
lines.append(f"- **{name}**: {desc} | parámetros: {json.dumps(params)[:600]}") |
|
|
return "\n".join(lines) |
|
|
|
|
|
def _compose_chat_prompt(messages: List[Dict[str, str]], tools_md: str) -> str: |
|
|
""" |
|
|
Soporta mensajes estilo OpenAI: [{"role":"system|user|assistant", "content":"..."}] |
|
|
Usa chat_template si está disponible. |
|
|
""" |
|
|
tok, _ = _lazy_load() |
|
|
sys_text = "" |
|
|
usr_msgs: List[Dict[str, str]] = [] |
|
|
for m in messages: |
|
|
role = m.get("role", "") |
|
|
content = (m.get("content") or "").strip() |
|
|
if role == "system": |
|
|
sys_text += ("\n" + content) if sys_text else content |
|
|
else: |
|
|
usr_msgs.append({"role": role, "content": content}) |
|
|
|
|
|
|
|
|
if tools_md: |
|
|
sys_text = (sys_text + "\n\n" if sys_text else "") + tools_md + \ |
|
|
"\n\nSi decides llamar a una herramienta, devuelve un objeto JSON con la clave 'tool_calls' " \ |
|
|
"y describe tus razonamientos de forma concisa en 'thought' (opcional)." |
|
|
|
|
|
|
|
|
conv: List[Dict[str, str]] = [] |
|
|
if sys_text: |
|
|
conv.append({"role":"system", "content": sys_text}) |
|
|
conv.extend(usr_msgs) |
|
|
|
|
|
chat_template = getattr(tok, "chat_template", None) |
|
|
if chat_template: |
|
|
return tok.apply_chat_template(conv, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
|
|
|
rendered = "" |
|
|
if sys_text: |
|
|
rendered += f"<<SYS>>\n{sys_text}\n<</SYS>>\n\n" |
|
|
for m in usr_msgs: |
|
|
if m["role"] == "user": |
|
|
rendered += f"### Usuario\n{m['content']}\n\n" |
|
|
elif m["role"] == "assistant": |
|
|
rendered += f"### Asistente\n{m['content']}\n\n" |
|
|
rendered += "### Asistente\n" |
|
|
return rendered |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EXECUTE_TOOLS = True |
|
|
|
|
|
def _safe_calculator(expr: str) -> str: |
|
|
|
|
|
if not re.fullmatch(r"[0-9\.\s\+\-\*\/\%\(\)\^eE]+", expr.replace("**","^")): |
|
|
return "Rejected expression." |
|
|
|
|
|
expr = expr.replace("^", "**") |
|
|
try: |
|
|
return str(eval(expr, {"__builtins__":{}}, {})) |
|
|
except Exception as e: |
|
|
return f"Error: {e}" |
|
|
|
|
|
LOCAL_TOOLBOX = { |
|
|
"calculator": lambda args: _safe_calculator(str(args.get("expr",""))), |
|
|
} |
|
|
|
|
|
def maybe_execute_tool_calls(tool_calls: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
|
if not EXECUTE_TOOLS: |
|
|
return [] |
|
|
results = [] |
|
|
for call in tool_calls: |
|
|
name = call.get("name") |
|
|
args = call.get("arguments", {}) |
|
|
fn = LOCAL_TOOLBOX.get(name) |
|
|
if fn is None: |
|
|
results.append({"name": name, "error": "tool_not_available"}) |
|
|
continue |
|
|
try: |
|
|
out = fn(args) |
|
|
results.append({"name": name, "output": out}) |
|
|
except Exception as e: |
|
|
results.append({"name": name, "error": str(e)}) |
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def _generate_with_tools( |
|
|
messages: List[Dict[str, str]], |
|
|
tools: List[Dict[str, Any]], |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
top_p: float = 0.95, |
|
|
) -> Dict[str, Any]: |
|
|
tok, model = _lazy_load() |
|
|
tools_md = _render_tools_md(tools) |
|
|
prompt = _compose_chat_prompt(messages, tools_md) |
|
|
|
|
|
inputs = tok(prompt, return_tensors="pt").to(DEVICE) |
|
|
with torch.inference_mode(): |
|
|
out = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=int(max_new_tokens), |
|
|
temperature=float(temperature), |
|
|
top_p=float(top_p), |
|
|
do_sample=True if temperature > 0 else False, |
|
|
pad_token_id=tok.eos_token_id, |
|
|
eos_token_id=tok.eos_token_id, |
|
|
) |
|
|
text = tok.decode(out[0], skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
|
tool_calls: List[Dict[str, Any]] = [] |
|
|
try: |
|
|
|
|
|
matches = list(re.finditer(r"\{.*?\"tool_calls\".*?\}", text, flags=re.S)) |
|
|
if matches: |
|
|
block = text[matches[-1].start():matches[-1].end()] |
|
|
obj = json.loads(block) |
|
|
tc = obj.get("tool_calls", []) |
|
|
if isinstance(tc, list): |
|
|
tool_calls = tc |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
tool_results = maybe_execute_tool_calls(tool_calls) if tool_calls else [] |
|
|
|
|
|
return {"text": text, "tool_calls": tool_calls, "tool_results": tool_results} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_for_engine(messages_json: str, tools_json: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Endpoint esperado por ENGINE (ToolsClient.chat): |
|
|
- messages_json: JSON de [{"role":"user|assistant|system","content":"..."}] |
|
|
- tools_json: JSON OpenAI-like de herramientas (opcional) |
|
|
Devuelve: {"text": "...", "tool_calls": [...], "tool_results": [...]} |
|
|
""" |
|
|
try: |
|
|
messages = json.loads(messages_json) if messages_json else [] |
|
|
except Exception: |
|
|
messages = [] |
|
|
try: |
|
|
tools = json.loads(tools_json) if tools_json else [] |
|
|
except Exception: |
|
|
tools = [] |
|
|
return _generate_with_tools(messages, tools, max_new_tokens=512, temperature=0.7, top_p=0.95) |
|
|
|
|
|
def chat_advanced(messages_json: str, tools_json: str, max_new_tokens: int, temperature: float, top_p: float) -> Dict[str, Any]: |
|
|
try: |
|
|
messages = json.loads(messages_json) if messages_json else [] |
|
|
except Exception: |
|
|
messages = [] |
|
|
try: |
|
|
tools = json.loads(tools_json) if tools_json else [] |
|
|
except Exception: |
|
|
tools = [] |
|
|
return _generate_with_tools(messages, tools, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p)) |
|
|
|
|
|
|
|
|
_salamandra = None |
|
|
|
|
|
def salamandra_chat_endpoint(prompt: str) -> Dict[str, Any]: |
|
|
global _salamandra |
|
|
if _salamandra is None: |
|
|
_salamandra = SalamandraClient() |
|
|
|
|
|
try: |
|
|
text = _salamandra.chat(prompt) |
|
|
except Exception as e: |
|
|
text = f"Error ejecutando SalamandraClient: {str(e)}" |
|
|
|
|
|
return {"text": text} |
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
h2 { |
|
|
background: #e3e4e6 !important; |
|
|
padding: 14px 22px !important; |
|
|
border-radius: 14px !important; |
|
|
box-shadow: 0 4px 12px rgba(0,0,0,0.08) !important; |
|
|
display: block !important; /* ocupa tot l'ample */ |
|
|
width: 100% !important; /* assegura 100% */ |
|
|
margin: 20px auto !important; |
|
|
text-align:center; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Salamandra 7B Tools · ZeroGPU", css=custom_css, theme=gr.themes.Soft()) as demo: |
|
|
|
|
|
|
|
|
gr.Markdown("## Salamandra-7B-Tools · ZeroGPU\nXat amb especificació d'eines (function-calling).") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
|
|
|
|
|
|
messages = gr.Textbox( |
|
|
label="Missatges (JSON)", |
|
|
value='[{"role":"user","content":"Quant és (2+2)^3?"}]', |
|
|
lines=6 |
|
|
) |
|
|
|
|
|
|
|
|
tools = gr.Textbox( |
|
|
label="Eines (JSON, opcional)", |
|
|
value='[{"type":"function","function":{"name":"calculator","description":"Avalua expressions aritmètiques bàsiques.","parameters":{"type":"object","properties":{"expr":{"type":"string"}},"required":["expr"]}}}]', |
|
|
lines=6 |
|
|
) |
|
|
|
|
|
|
|
|
max_new = gr.Slider(16, 2048, value=512, step=16, label="Màxim de tokens nous") |
|
|
|
|
|
|
|
|
temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperatura") |
|
|
|
|
|
|
|
|
topp = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p") |
|
|
|
|
|
|
|
|
btn = gr.Button("Generar", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
|
|
|
out = gr.JSON(label="Sortida") |
|
|
|
|
|
|
|
|
btn.click( |
|
|
chat_advanced, |
|
|
[messages, tools, max_new, temp, topp], |
|
|
out, |
|
|
api_name="chat", |
|
|
concurrency_limit=1 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gr.Button("Provar /predict").click( |
|
|
predict_for_engine, |
|
|
[messages, tools], |
|
|
out, |
|
|
api_name="predict", |
|
|
concurrency_limit=1 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
prompt = gr.Textbox(label="Prompt", lines=10) |
|
|
|
|
|
with gr.Row(): |
|
|
btn2 = gr.Button("Generar", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
|
out2 = gr.JSON(label="Sortida") |
|
|
|
|
|
btn2.click( |
|
|
salamandra_chat_endpoint, |
|
|
[prompt], |
|
|
out2, |
|
|
api_name="generate_out_from_prompt", |
|
|
concurrency_limit=1 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
|
|
|
|
|
|
demo.queue(max_size=16).launch() |
|
|
|