File size: 4,871 Bytes
5c6fcdb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba54b37
5c6fcdb
 
 
 
ba54b37
5c6fcdb
ba54b37
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# app.py — veureu/schat (Salamandra 7B Instruct · ZeroGPU) — compatible con ENGINE
from __future__ import annotations
import os, json
from typing import List, Dict, Optional, Tuple

import gradio as gr
import spaces
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer,
)

# ===== Config =====
MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-instruct")
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

_tok = None
_model = None

def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
    global _tok, _model
    if _tok is None or _model is None:
        _tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
        _model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            torch_dtype=DTYPE,
            low_cpu_mem_usage=True,
            use_safetensors=True,
            trust_remote_code=True,
            device_map=None,
        ).to(DEVICE)
    return _tok, _model

def _build_prompt(prompt: str, system: Optional[str]) -> str:
    """

    Si el tokenizer posee 'chat_template', lo usamos con mensajes [system?, user].

    Si no, hacemos un prompt plano con system arriba.

    """
    tok, _ = _lazy_load()
    messages = []
    if system and system.strip():
        messages.append({"role": "system", "content": system.strip()})
    messages.append({"role": "user", "content": prompt})

    chat_template = getattr(tok, "chat_template", None)
    if chat_template:
        return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    # Fallback sin chat template
    sys_part = (f"<<SYS>>\n{system.strip()}\n<</SYS>>\n\n" if system and system.strip() else "")
    return sys_part + f"### Instrucción\n{prompt}\n\n### Respuesta\n"

@spaces.GPU  # usa GPU si está disponible (ZeroGPU)
def _generate(

    prompt: str,

    system: str = "",

    max_new_tokens: int = 512,

    temperature: float = 0.7,

    top_p: float = 0.95,

) -> str:
    tok, model = _lazy_load()
    text = _build_prompt(prompt, system or "")
    inputs = tok(text, return_tensors="pt").to(DEVICE)

    with torch.inference_mode():
        out = model.generate(
            **inputs,
            max_new_tokens=int(max_new_tokens),
            temperature=float(temperature),
            top_p=float(top_p),
            do_sample=True if temperature > 0 else False,
            pad_token_id=tok.eos_token_id,
            eos_token_id=tok.eos_token_id,
        )
    return tok.decode(out[0], skip_special_tokens=True).strip()

# ------------------- Gradio Endpoints -------------------
# 1) /predict — lo que espera el ENGINE (solo 'prompt' → string)
def predict_for_engine(prompt: str) -> str:
    return _generate(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)

# 2) /generate — más controles (prompt + system + params)
def generate_advanced(prompt: str, system: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
    return _generate(prompt=prompt, system=system, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)

# ------------------- HTTP (opcional, clientes puros) -------------------
# Si quieres, puedes añadir un endpoint HTTP POST /generate (FastAPI),
# pero con Gradio Client es suficiente para engine/local.

# ------------------- UI -------------------
with gr.Blocks(title="Salamandra 7B Instruct · ZeroGPU") as demo:
    gr.Markdown("## Salamandra-7B-Instruct · ZeroGPU\nTexto → respuesta instruccional.")
    with gr.Row():
        with gr.Column(scale=1):
            in_system = gr.Textbox(label="System (opcional)", value="")
            in_prompt = gr.Textbox(label="Prompt", placeholder="Escribe tu instrucción…", lines=6)
            max_new = gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens")
            temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
            top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p")
            btn = gr.Button("Generar", variant="primary")
        with gr.Column(scale=1):
            out = gr.Textbox(label="Respuesta", lines=18)

    btn.click(generate_advanced, [in_prompt, in_system, max_new, temp, top_p], out, api_name="generate", concurrency_limit=1)

    # Endpoint minimalista compatible con el ENGINE (/predict: solo prompt)
    in_prompt_engine = gr.Textbox(label="Prompt (ENGINE)", value="Di hola en una frase.")
    out_engine = gr.Textbox(label="Respuesta (ENGINE)")
    gr.Button("Probar /predict").click(predict_for_engine, [in_prompt_engine], out_engine, api_name="predict", concurrency_limit=1)

demo.queue(max_size=16).launch()