VeuReu commited on
Commit
7b4bcba
·
verified ·
1 Parent(s): 77da823

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -113
app.py CHANGED
@@ -1,113 +1,116 @@
1
- # app.py — veureu/schat (Salamandra 7B Instruct · ZeroGPU) — compatible con ENGINE
2
- from __future__ import annotations
3
- import os, json
4
- from typing import List, Dict, Optional, Tuple
5
-
6
- import gradio as gr
7
- import spaces
8
- import torch
9
- from transformers import (
10
- AutoTokenizer,
11
- AutoModelForCausalLM,
12
- TextIteratorStreamer,
13
- )
14
-
15
- # ===== Config =====
16
- MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-instruct")
17
- DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
18
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
-
20
- _tok = None
21
- _model = None
22
-
23
- def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
24
- global _tok, _model
25
- if _tok is None or _model is None:
26
- _tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
27
- _model = AutoModelForCausalLM.from_pretrained(
28
- MODEL_ID,
29
- torch_dtype=DTYPE,
30
- low_cpu_mem_usage=True,
31
- use_safetensors=True,
32
- trust_remote_code=True,
33
- device_map=None,
34
- ).to(DEVICE)
35
- return _tok, _model
36
-
37
- def _build_prompt(prompt: str, system: Optional[str]) -> str:
38
- """
39
- Si el tokenizer posee 'chat_template', lo usamos con mensajes [system?, user].
40
- Si no, hacemos un prompt plano con system arriba.
41
- """
42
- tok, _ = _lazy_load()
43
- messages = []
44
- if system and system.strip():
45
- messages.append({"role": "system", "content": system.strip()})
46
- messages.append({"role": "user", "content": prompt})
47
-
48
- chat_template = getattr(tok, "chat_template", None)
49
- if chat_template:
50
- return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
51
- # Fallback sin chat template
52
- sys_part = (f"<<SYS>>\n{system.strip()}\n<</SYS>>\n\n" if system and system.strip() else "")
53
- return sys_part + f"### Instrucción\n{prompt}\n\n### Respuesta\n"
54
-
55
- @spaces.GPU # usa GPU si está disponible (ZeroGPU)
56
- def _generate(
57
- prompt: str,
58
- system: str = "",
59
- max_new_tokens: int = 512,
60
- temperature: float = 0.7,
61
- top_p: float = 0.95,
62
- ) -> str:
63
- tok, model = _lazy_load()
64
- text = _build_prompt(prompt, system or "")
65
- inputs = tok(text, return_tensors="pt").to(DEVICE)
66
-
67
- with torch.inference_mode():
68
- out = model.generate(
69
- **inputs,
70
- max_new_tokens=int(max_new_tokens),
71
- temperature=float(temperature),
72
- top_p=float(top_p),
73
- do_sample=True if temperature > 0 else False,
74
- pad_token_id=tok.eos_token_id,
75
- eos_token_id=tok.eos_token_id,
76
- )
77
- return tok.decode(out[0], skip_special_tokens=True).strip()
78
-
79
- # ------------------- Gradio Endpoints -------------------
80
- # 1) /predict — lo que espera el ENGINE (solo 'prompt' → string)
81
- def predict_for_engine(prompt: str) -> str:
82
- return _generate(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
83
-
84
- # 2) /generate — más controles (prompt + system + params)
85
- def generate_advanced(prompt: str, system: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
86
- return _generate(prompt=prompt, system=system, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
87
-
88
- # ------------------- HTTP (opcional, clientes puros) -------------------
89
- # Si quieres, puedes añadir un endpoint HTTP POST /generate (FastAPI),
90
- # pero con Gradio Client es suficiente para engine/local.
91
-
92
- # ------------------- UI -------------------
93
- with gr.Blocks(title="Salamandra 7B Instruct · ZeroGPU") as demo:
94
- gr.Markdown("## Salamandra-7B-Instruct · ZeroGPU\nTexto → respuesta instruccional.")
95
- with gr.Row():
96
- with gr.Column(scale=1):
97
- in_system = gr.Textbox(label="System (opcional)", value="")
98
- in_prompt = gr.Textbox(label="Prompt", placeholder="Escribe tu instrucción…", lines=6)
99
- max_new = gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens")
100
- temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
101
- top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p")
102
- btn = gr.Button("Generar", variant="primary")
103
- with gr.Column(scale=1):
104
- out = gr.Textbox(label="Respuesta", lines=18)
105
-
106
- btn.click(generate_advanced, [in_prompt, in_system, max_new, temp, top_p], out, api_name="generate", concurrency_limit=1)
107
-
108
- # Endpoint minimalista compatible con el ENGINE (/predict: solo prompt)
109
- in_prompt_engine = gr.Textbox(label="Prompt (ENGINE)", value="Di hola en una frase.")
110
- out_engine = gr.Textbox(label="Respuesta (ENGINE)")
111
- gr.Button("Probar /predict").click(predict_for_engine, [in_prompt_engine], out_engine, api_name="predict", concurrency_limit=1)
112
-
113
- demo.queue(max_size=16).launch()
 
 
 
 
1
+ # app.py — veureu/schat (Salamandra 7B Instruct · ZeroGPU) — compatible con ENGINE
2
+ from __future__ import annotations
3
+ import os, json
4
+ from typing import List, Dict, Optional, Tuple
5
+
6
+ import gradio as gr
7
+ import spaces
8
+ import torch
9
+ from transformers import (
10
+ AutoTokenizer,
11
+ AutoModelForCausalLM,
12
+ TextIteratorStreamer,
13
+ )
14
+
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM
16
+ from moe_tools import SalamandraClient
17
+
18
+ # ===== Config =====
19
+ MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-instruct")
20
+ DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
+ _tok = None
24
+ _model = None
25
+
26
+ def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
27
+ global _tok, _model
28
+ if _tok is None or _model is None:
29
+ _tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
30
+ _model = AutoModelForCausalLM.from_pretrained(
31
+ MODEL_ID,
32
+ torch_dtype=DTYPE,
33
+ low_cpu_mem_usage=True,
34
+ use_safetensors=True,
35
+ trust_remote_code=True,
36
+ device_map=None,
37
+ ).to(DEVICE)
38
+ return _tok, _model
39
+
40
+ def _build_prompt(prompt: str, system: Optional[str]) -> str:
41
+ """
42
+ Si el tokenizer posee 'chat_template', lo usamos con mensajes [system?, user].
43
+ Si no, hacemos un prompt plano con system arriba.
44
+ """
45
+ tok, _ = _lazy_load()
46
+ messages = []
47
+ if system and system.strip():
48
+ messages.append({"role": "system", "content": system.strip()})
49
+ messages.append({"role": "user", "content": prompt})
50
+
51
+ chat_template = getattr(tok, "chat_template", None)
52
+ if chat_template:
53
+ return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
54
+ # Fallback sin chat template
55
+ sys_part = (f"<<SYS>>\n{system.strip()}\n<</SYS>>\n\n" if system and system.strip() else "")
56
+ return sys_part + f"### Instrucción\n{prompt}\n\n### Respuesta\n"
57
+
58
+ @spaces.GPU # usa GPU si está disponible (ZeroGPU)
59
+ def _generate(
60
+ prompt: str,
61
+ system: str = "",
62
+ max_new_tokens: int = 512,
63
+ temperature: float = 0.7,
64
+ top_p: float = 0.95,
65
+ ) -> str:
66
+ tok, model = _lazy_load()
67
+ text = _build_prompt(prompt, system or "")
68
+ inputs = tok(text, return_tensors="pt").to(DEVICE)
69
+
70
+ with torch.inference_mode():
71
+ out = model.generate(
72
+ **inputs,
73
+ max_new_tokens=int(max_new_tokens),
74
+ temperature=float(temperature),
75
+ top_p=float(top_p),
76
+ do_sample=True if temperature > 0 else False,
77
+ pad_token_id=tok.eos_token_id,
78
+ eos_token_id=tok.eos_token_id,
79
+ )
80
+ return tok.decode(out[0], skip_special_tokens=True).strip()
81
+
82
+ # ------------------- Gradio Endpoints -------------------
83
+ # 1) /predict — lo que espera el ENGINE (solo 'prompt' → string)
84
+ def predict_for_engine(prompt: str) -> str:
85
+ return _generate(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
86
+
87
+ # 2) /generate — más controles (prompt + system + params)
88
+ def generate_advanced(prompt: str, system: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
89
+ return _generate(prompt=prompt, system=system, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
90
+
91
+ # ------------------- HTTP (opcional, clientes puros) -------------------
92
+ # Si quieres, puedes añadir un endpoint HTTP POST /generate (FastAPI),
93
+ # pero con Gradio Client es suficiente para engine/local.
94
+
95
+ # ------------------- UI -------------------
96
+ with gr.Blocks(title="Salamandra 7B Instruct · ZeroGPU") as demo:
97
+ gr.Markdown("## Salamandra-7B-Instruct · ZeroGPU\nTexto → respuesta instruccional.")
98
+ with gr.Row():
99
+ with gr.Column(scale=1):
100
+ in_system = gr.Textbox(label="System (opcional)", value="")
101
+ in_prompt = gr.Textbox(label="Prompt", placeholder="Escribe tu instrucción…", lines=6)
102
+ max_new = gr.Slider(16, 2048, value=512, step=16, label="max_new_tokens")
103
+ temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="temperature")
104
+ top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="top_p")
105
+ btn = gr.Button("Generar", variant="primary")
106
+ with gr.Column(scale=1):
107
+ out = gr.Textbox(label="Respuesta", lines=18)
108
+
109
+ btn.click(generate_advanced, [in_prompt, in_system, max_new, temp, top_p], out, api_name="generate", concurrency_limit=1)
110
+
111
+ # Endpoint minimalista compatible con el ENGINE (/predict: solo prompt)
112
+ in_prompt_engine = gr.Textbox(label="Prompt (ENGINE)", value="Di hola en una frase.")
113
+ out_engine = gr.Textbox(label="Respuesta (ENGINE)")
114
+ gr.Button("Probar /predict").click(predict_for_engine, [in_prompt_engine], out_engine, api_name="predict", concurrency_limit=1)
115
+
116
+ demo.queue(max_size=16).launch()