|
|
|
|
|
from __future__ import annotations
|
|
|
import os, json
|
|
|
from typing import List, Dict, Optional, Tuple
|
|
|
|
|
|
import gradio as gr
|
|
|
import spaces
|
|
|
import torch
|
|
|
from transformers import (
|
|
|
AutoTokenizer,
|
|
|
AutoModelForCausalLM,
|
|
|
TextIteratorStreamer,
|
|
|
)
|
|
|
|
|
|
|
|
|
MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-instruct")
|
|
|
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
_tok = None
|
|
|
_model = None
|
|
|
|
|
|
def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
|
|
|
global _tok, _model
|
|
|
if _tok is None or _model is None:
|
|
|
_tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
|
|
|
_model = AutoModelForCausalLM.from_pretrained(
|
|
|
MODEL_ID,
|
|
|
torch_dtype=DTYPE,
|
|
|
low_cpu_mem_usage=True,
|
|
|
use_safetensors=True,
|
|
|
trust_remote_code=True,
|
|
|
device_map=None,
|
|
|
).to(DEVICE)
|
|
|
return _tok, _model
|
|
|
|
|
|
def _build_prompt(prompt: str, system: Optional[str]) -> str:
|
|
|
"""
|
|
|
Si el tokenizer posee 'chat_template', lo usamos con mensajes [system?, user].
|
|
|
Si no, hacemos un prompt plano con system arriba.
|
|
|
"""
|
|
|
tok, _ = _lazy_load()
|
|
|
messages = []
|
|
|
if system and system.strip():
|
|
|
messages.append({"role": "system", "content": system.strip()})
|
|
|
messages.append({"role": "user", "content": prompt})
|
|
|
|
|
|
chat_template = getattr(tok, "chat_template", None)
|
|
|
if chat_template:
|
|
|
return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
|
|
|
|
sys_part = (f"<<SYS>>\n{system.strip()}\n<</SYS>>\n\n" if system and system.strip() else "")
|
|
|
return sys_part + f"### Instrucción\n{prompt}\n\n### Respuesta\n"
|
|
|
|
|
|
@spaces.GPU
|
|
|
def _generate(
|
|
|
prompt: str,
|
|
|
system: str = "",
|
|
|
max_new_tokens: int = 512,
|
|
|
temperature: float = 0.7,
|
|
|
top_p: float = 0.95,
|
|
|
) -> str:
|
|
|
tok, model = _lazy_load()
|
|
|
text = _build_prompt(prompt, system or "")
|
|
|
inputs = tok(text, return_tensors="pt").to(DEVICE)
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
out = model.generate(
|
|
|
**inputs,
|
|
|
max_new_tokens=int(max_new_tokens),
|
|
|
temperature=float(temperature),
|
|
|
top_p=float(top_p),
|
|
|
do_sample=True if temperature > 0 else False,
|
|
|
pad_token_id=tok.eos_token_id,
|
|
|
eos_token_id=tok.eos_token_id,
|
|
|
)
|
|
|
return tok.decode(out[0], skip_special_tokens=True).strip()
|
|
|
|
|
|
|
|
|
|
|
|
def predict_for_engine(prompt: str) -> str:
|
|
|
return _generate(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
|
|
|
|
|
|
|
|
|
def generate_advanced(prompt: str, system: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
|
|
|
return _generate(prompt=prompt, system=system, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Salamandra 7B Instruct · ZeroGPU") as demo:
|
|
|
gr.Markdown("## Salamandra-7B-Instruct · ZeroGPU\nTexto → respuesta instruccional.")
|
|
|
with gr.Row():
|
|
|
with gr.Column(scale=1):
|
|
|
in_system = gr.Textbox(label="System (opcional)", value="")
|
|
|
in_prompt = gr.Textbox(label="Prompt", placeholder="Escribe tu instrucción…", lines=6)
|
|
|
max_new = gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens")
|
|
|
temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
|
|
|
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p")
|
|
|
btn = gr.Button("Generar", variant="primary")
|
|
|
with gr.Column(scale=1):
|
|
|
out = gr.Textbox(label="Respuesta", lines=18)
|
|
|
|
|
|
btn.click(generate_advanced, [in_prompt, in_system, max_new, temp, top_p], out, api_name="generate")
|
|
|
|
|
|
|
|
|
in_prompt_engine = gr.Textbox(label="Prompt (ENGINE)", value="Di hola en una frase.")
|
|
|
out_engine = gr.Textbox(label="Respuesta (ENGINE)")
|
|
|
gr.Button("Probar /predict").click(predict_for_engine, [in_prompt_engine], out_engine, api_name="predict")
|
|
|
|
|
|
demo.queue(concurrency_count=1, max_size=16).launch()
|
|
|
|