# app.py — veureu/schat (Salamandra 7B Instruct · ZeroGPU) — compatible con ENGINE from __future__ import annotations import os, json from typing import List, Dict, Optional, Tuple import gradio as gr import spaces import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, ) # ===== Config ===== MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-instruct") DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" _tok = None _model = None def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]: global _tok, _model if _tok is None or _model is None: _tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) _model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=DTYPE, low_cpu_mem_usage=True, use_safetensors=True, trust_remote_code=True, device_map=None, ).to(DEVICE) return _tok, _model def _build_prompt(prompt: str, system: Optional[str]) -> str: """ Si el tokenizer posee 'chat_template', lo usamos con mensajes [system?, user]. Si no, hacemos un prompt plano con system arriba. """ tok, _ = _lazy_load() messages = [] if system and system.strip(): messages.append({"role": "system", "content": system.strip()}) messages.append({"role": "user", "content": prompt}) chat_template = getattr(tok, "chat_template", None) if chat_template: return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) # Fallback sin chat template sys_part = (f"<>\n{system.strip()}\n<>\n\n" if system and system.strip() else "") return sys_part + f"### Instrucción\n{prompt}\n\n### Respuesta\n" @spaces.GPU # usa GPU si está disponible (ZeroGPU) def _generate( prompt: str, system: str = "", max_new_tokens: int = 512, temperature: float = 0.7, top_p: float = 0.95, ) -> str: tok, model = _lazy_load() text = _build_prompt(prompt, system or "") inputs = tok(text, return_tensors="pt").to(DEVICE) with torch.inference_mode(): out = model.generate( **inputs, max_new_tokens=int(max_new_tokens), temperature=float(temperature), top_p=float(top_p), do_sample=True if temperature > 0 else False, pad_token_id=tok.eos_token_id, eos_token_id=tok.eos_token_id, ) return tok.decode(out[0], skip_special_tokens=True).strip() # ------------------- Gradio Endpoints ------------------- # 1) /predict — lo que espera el ENGINE (solo 'prompt' → string) def predict_for_engine(prompt: str) -> str: return _generate(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95) # 2) /generate — más controles (prompt + system + params) def generate_advanced(prompt: str, system: str, max_new_tokens: int, temperature: float, top_p: float) -> str: return _generate(prompt=prompt, system=system, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p) # ------------------- HTTP (opcional, clientes puros) ------------------- # Si quieres, puedes añadir un endpoint HTTP POST /generate (FastAPI), # pero con Gradio Client es suficiente para engine/local. # ------------------- UI ------------------- with gr.Blocks(title="Salamandra 7B Instruct · ZeroGPU") as demo: gr.Markdown("## Salamandra-7B-Instruct · ZeroGPU\nTexto → respuesta instruccional.") with gr.Row(): with gr.Column(scale=1): in_system = gr.Textbox(label="System (opcional)", value="") in_prompt = gr.Textbox(label="Prompt", placeholder="Escribe tu instrucción…", lines=6) max_new = gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens") temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature") top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p") btn = gr.Button("Generar", variant="primary") with gr.Column(scale=1): out = gr.Textbox(label="Respuesta", lines=18) btn.click(generate_advanced, [in_prompt, in_system, max_new, temp, top_p], out, api_name="generate") # Endpoint minimalista compatible con el ENGINE (/predict: solo prompt) in_prompt_engine = gr.Textbox(label="Prompt (ENGINE)", value="Di hola en una frase.") out_engine = gr.Textbox(label="Respuesta (ENGINE)") gr.Button("Probar /predict").click(predict_for_engine, [in_prompt_engine], out_engine, api_name="predict") demo.queue(concurrency_count=1, max_size=16).launch()