schat / app.py
VeuReu's picture
Update app.py
db72e25 verified
raw
history blame
7.11 kB
# app.py — veureu/schat (Salamandra 7B Instruct · ZeroGPU) — compatible con ENGINE
from __future__ import annotations
import os, json
from typing import List, Dict, Any, Optional, Tuple
import gradio as gr
import spaces
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
from transformers import AutoTokenizer, AutoModelForCausalLM
from moe_tools import SalamandraClient
# ===== Config =====
MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-instruct")
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
_tok = None
_model = None
_salamandra = None
def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
global _tok, _model
if _tok is None or _model is None:
_tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
low_cpu_mem_usage=True,
use_safetensors=True,
trust_remote_code=True,
device_map=None,
).to(DEVICE)
return _tok, _model
def _build_prompt(prompt: str, system: Optional[str]) -> str:
"""
Si el tokenizer posee 'chat_template', lo usamos con mensajes [system?, user].
Si no, hacemos un prompt plano con system arriba.
"""
tok, _ = _lazy_load()
messages = []
if system and system.strip():
messages.append({"role": "system", "content": system.strip()})
messages.append({"role": "user", "content": prompt})
chat_template = getattr(tok, "chat_template", None)
if chat_template:
return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Fallback sin chat template
sys_part = (f"<<SYS>>\n{system.strip()}\n<</SYS>>\n\n" if system and system.strip() else "")
return sys_part + f"### Instrucción\n{prompt}\n\n### Respuesta\n"
@spaces.GPU # usa GPU si está disponible (ZeroGPU)
def _generate_with_tools(
messages: List[Dict[str, str]],
tools: List[Dict[str, Any]],
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.95,
) -> Dict[str, Any]:
tok, model = _lazy_load()
tools_md = _render_tools_md(tools)
prompt = _compose_chat_prompt(messages, tools_md)
inputs = tok(prompt, return_tensors="pt").to(DEVICE)
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=True if temperature > 0 else False,
pad_token_id=tok.eos_token_id,
eos_token_id=tok.eos_token_id,
)
text = tok.decode(out[0], skip_special_tokens=True).strip()
# Si el modelo devuelve un bloque JSON con 'tool_calls', lo intentamos extraer.
tool_calls: List[Dict[str, Any]] = []
try:
# busca el último {...} que contenga "tool_calls"
matches = list(re.finditer(r"\{.*?\"tool_calls\".*?\}", text, flags=re.S))
if matches:
block = text[matches[-1].start():matches[-1].end()]
obj = json.loads(block)
tc = obj.get("tool_calls", [])
if isinstance(tc, list):
tool_calls = tc
except Exception:
pass
tool_results = maybe_execute_tool_calls(tool_calls) if tool_calls else []
return {"text": text, "tool_calls": tool_calls, "tool_results": tool_results}
@spaces.GPU # usa GPU si está disponible (ZeroGPU)
def _generate(
prompt: str,
system: str = "",
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.95,
) -> str:
tok, model = _lazy_load()
text = _build_prompt(prompt, system or "")
inputs = tok(text, return_tensors="pt").to(DEVICE)
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=True if temperature > 0 else False,
pad_token_id=tok.eos_token_id,
eos_token_id=tok.eos_token_id,
)
return tok.decode(out[0], skip_special_tokens=True).strip()
# ------------------- Gradio Endpoints -------------------
# 1) /predict — lo que espera el ENGINE (solo 'prompt' → string)
def predict_for_engine(prompt: str) -> str:
return _generate(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
# 2) /generate — más controles (prompt + system + params)
def generate_advanced(prompt: str, system: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
return _generate(prompt=prompt, system=system, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
def salamandra_chat_endpoint(prompt: str) -> Dict[str, Any]:
global _salamandra
if _salamandra is None:
_salamandra = SalamandraClient() # usa tu clase
try:
text = _salamandra.chat(prompt)
except Exception as e:
text = f"Error ejecutando SalamandraClient: {str(e)}"
return {"text": text}
# ------------------- HTTP (opcional, clientes puros) -------------------
# Si quieres, puedes añadir un endpoint HTTP POST /generate (FastAPI),
# pero con Gradio Client es suficiente para engine/local.
# ------------------- UI -------------------
with gr.Blocks(title="Salamandra 7B Instruct · ZeroGPU") as demo:
gr.Markdown("## Salamandra-7B-Instruct · ZeroGPU\nTexto → respuesta instruccional.")
with gr.Row():
with gr.Column(scale=1):
in_system = gr.Textbox(label="System (opcional)", value="")
in_prompt = gr.Textbox(label="Prompt", placeholder="Escribe tu instrucción…", lines=6)
max_new = gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens")
temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p")
btn = gr.Button("Generar", variant="primary")
with gr.Column(scale=1):
out = gr.Textbox(label="Respuesta", lines=18)
btn.click(generate_advanced, [in_prompt, in_system, max_new, temp, top_p], out, api_name="generate", concurrency_limit=1)
# Endpoint minimalista compatible con el ENGINE (/predict: solo prompt)
in_prompt_engine = gr.Textbox(label="Prompt (ENGINE)", value="Di hola en una frase.")
out_engine = gr.Textbox(label="Respuesta (ENGINE)")
gr.Button("Probar /predict").click(predict_for_engine, [in_prompt_engine], out_engine, api_name="predict", concurrency_limit=1)
with gr.Row():
prompt = gr.Textbox(label="prompt", lines=10)
with gr.Row():
btn2 = gr.Button("Generar", variant="primary")
with gr.Row():
out2 = gr.JSON(label="Salida")
btn2.click(salamandra_chat_endpoint, [prompt], out2, api_name="generate_out_from_prompt", concurrency_limit=1)
demo.queue(max_size=16).launch()