VeuReu commited on
Commit
73cad8e
·
verified ·
1 Parent(s): cb4cd52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -74
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py — veureu/schat (Salamandra 7B Instruct · ZeroGPU) — compatible con ENGINE
2
  from __future__ import annotations
3
  import os, json
4
  from typing import List, Dict, Any, Optional, Tuple
@@ -40,8 +40,8 @@ def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
40
 
41
  def _build_prompt(prompt: str, system: Optional[str]) -> str:
42
  """
43
- Si el tokenizer posee 'chat_template', lo usamos con mensajes [system?, user].
44
- Si no, hacemos un prompt plano con system arriba.
45
  """
46
  tok, _ = _lazy_load()
47
  messages = []
@@ -52,54 +52,56 @@ def _build_prompt(prompt: str, system: Optional[str]) -> str:
52
  chat_template = getattr(tok, "chat_template", None)
53
  if chat_template:
54
  return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
55
- # Fallback sin chat template
 
56
  sys_part = (f"<<SYS>>\n{system.strip()}\n<</SYS>>\n\n" if system and system.strip() else "")
57
- return sys_part + f"### Instrucción\n{prompt}\n\n### Respuesta\n"
58
 
59
- @spaces.GPU # usa GPU si está disponible (ZeroGPU)
60
- def _generate_with_tools(
61
- messages: List[Dict[str, str]],
62
- tools: List[Dict[str, Any]],
63
- max_new_tokens: int = 512,
64
- temperature: float = 0.7,
65
- top_p: float = 0.95,
66
- ) -> Dict[str, Any]:
67
- tok, model = _lazy_load()
68
- tools_md = _render_tools_md(tools)
69
- prompt = _compose_chat_prompt(messages, tools_md)
70
 
71
- inputs = tok(prompt, return_tensors="pt").to(DEVICE)
72
- with torch.inference_mode():
73
- out = model.generate(
74
- **inputs,
75
- max_new_tokens=int(max_new_tokens),
76
- temperature=float(temperature),
77
- top_p=float(top_p),
78
- do_sample=True if temperature > 0 else False,
79
- pad_token_id=tok.eos_token_id,
80
- eos_token_id=tok.eos_token_id,
81
- )
82
- text = tok.decode(out[0], skip_special_tokens=True).strip()
83
 
84
- # Si el modelo devuelve un bloque JSON con 'tool_calls', lo intentamos extraer.
85
- tool_calls: List[Dict[str, Any]] = []
86
- try:
87
- # busca el último {...} que contenga "tool_calls"
88
- matches = list(re.finditer(r"\{.*?\"tool_calls\".*?\}", text, flags=re.S))
89
- if matches:
90
- block = text[matches[-1].start():matches[-1].end()]
91
- obj = json.loads(block)
92
- tc = obj.get("tool_calls", [])
93
- if isinstance(tc, list):
94
- tool_calls = tc
95
- except Exception:
96
- pass
97
-
98
- tool_results = maybe_execute_tool_calls(tool_calls) if tool_calls else []
99
-
100
- return {"text": text, "tool_calls": tool_calls, "tool_results": tool_results}
101
-
102
- @spaces.GPU # usa GPU si está disponible (ZeroGPU)
 
103
  def _generate(
104
  prompt: str,
105
  system: str = "",
@@ -124,58 +126,89 @@ def _generate(
124
  return tok.decode(out[0], skip_special_tokens=True).strip()
125
 
126
  # ------------------- Gradio Endpoints -------------------
127
- # 1) /predict — lo que espera el ENGINE (solo 'prompt' → string)
128
  def predict_for_engine(prompt: str) -> str:
129
  return _generate(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
130
 
131
- # 2) /generate — más controles (prompt + system + params)
132
  def generate_advanced(prompt: str, system: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
133
  return _generate(prompt=prompt, system=system, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
134
 
135
  def salamandra_chat_endpoint(prompt: str) -> Dict[str, Any]:
136
  global _salamandra
137
  if _salamandra is None:
138
- _salamandra = SalamandraClient() # usa tu clase
139
 
140
  try:
141
  text = _salamandra.chat(prompt)
142
  except Exception as e:
143
- text = f"Error ejecutando SalamandraClient: {str(e)}"
144
 
145
  return {"text": text}
146
 
147
- def resumir_frases(frase, num_palabras):
148
- num_palabras = int(num_palabras)
149
- prompt = f"Instrució: Resumeix la següent frase en {num_palabras} paraules. Input: {frase}"
 
 
 
 
 
 
 
 
 
 
 
 
150
  result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
 
 
151
  if "assistant" in result:
152
  clean_output = result.split("assistant", 1)[1].strip().split("\n")[0]
153
  else:
154
- clean_output = frase
 
155
  return clean_output
156
 
157
- def identity_manager (frase, persona):
 
 
 
158
  prompt = f"""Instrucció: Substitueix el subjecte de la frase per la persona indicada, mantenint la resta igual.
159
- Frase: {frase}
160
- Substitució: {persona}
161
  Resposta:"""
 
 
162
  result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
 
 
163
  if "assistant" in result:
164
  clean_output = result.split("assistant", 1)[1].strip().split("\n")[0]
165
  else:
166
- clean_output = frase
 
167
  return clean_output
168
 
169
- def free_narration (srt_final):
 
 
 
170
  prompt = f"""Instrucció: Converteix aquesta audiodescripció en una narració lliure breu, natural i coherent.,
171
  input: {srt_final}
172
  output:
173
  """
 
 
174
  result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
 
 
175
  if "assistant" in result:
176
  clean_output = result.split("assistant", 1)[1].strip().split("\n")[0]
177
  else:
178
- clean_output = frase
 
179
  return clean_output
180
 
181
  # ------------------- HTTP (opcional, clientes puros) -------------------
@@ -217,17 +250,6 @@ with gr.Blocks(title="Salamandra 7B Instruct · ZeroGPU",css=custom_css,theme=gr
217
  gr.Button("Probar /predict").click(predict_for_engine, [in_prompt_engine], out_engine, api_name="predict", concurrency_limit=1)
218
  gr.Markdown("---")
219
 
220
- gr.Markdown('<h2 style="text-align:center">Sortida del model Salamandra a partir d’una petició</h2>')
221
- with gr.Row():
222
- prompt = gr.Textbox(label="prompt", lines=10)
223
- with gr.Row():
224
- btn2 = gr.Button("Generar", variant="primary")
225
- with gr.Row():
226
- out2 = gr.JSON(label="Salida")
227
-
228
- btn2.click(salamandra_chat_endpoint, [prompt], out2, api_name="generate_out_from_prompt", concurrency_limit=1)
229
- gr.Markdown("---")
230
-
231
  gr.Markdown('<h2 style="text-align:center">Resumir frases</h2>')
232
  with gr.Row():
233
  with gr.Column(scale=1):
@@ -239,7 +261,7 @@ with gr.Blocks(title="Salamandra 7B Instruct · ZeroGPU",css=custom_css,theme=gr
239
  btn_resumir = gr.Button("Resumir", variant="primary")
240
 
241
  btn_resumir.click(
242
- resumir_frases,
243
  inputs=[frase, num_paraules],
244
  outputs=out_resumir,
245
  api_name="resumir",
@@ -269,7 +291,7 @@ with gr.Blocks(title="Salamandra 7B Instruct · ZeroGPU",css=custom_css,theme=gr
269
  with gr.Row():
270
  with gr.Column(scale=1):
271
  srt = gr.Textbox(label="Audiodescripció", value="(AD)\nTOTS CANTANT: avui celebrem la nostra festa major\nAINA: som hi tots a ballar", lines=3)
272
- btn_modificar = gr.Button("Generar audiodescripció", variant="primary")
273
  with gr.Column(scale=1):
274
  narració_lliure = gr.Textbox(label="Narració lliure", lines=18)
275
 
@@ -281,4 +303,15 @@ with gr.Blocks(title="Salamandra 7B Instruct · ZeroGPU",css=custom_css,theme=gr
281
  concurrency_limit=1
282
  )
283
 
 
 
 
 
 
 
 
 
 
 
 
284
  demo.queue(max_size=16).launch()
 
1
+ # app.py — veureu/schat (Salamandra 7B Instruct · ZeroGPU) — compatible with ENGINE
2
  from __future__ import annotations
3
  import os, json
4
  from typing import List, Dict, Any, Optional, Tuple
 
40
 
41
  def _build_prompt(prompt: str, system: Optional[str]) -> str:
42
  """
43
+ If the tokenizer has 'chat_template', use it with messages [system?, user].
44
+ Otherwise, create a plain prompt with system at the top.
45
  """
46
  tok, _ = _lazy_load()
47
  messages = []
 
52
  chat_template = getattr(tok, "chat_template", None)
53
  if chat_template:
54
  return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
55
+
56
+ # Fallback without chat template
57
  sys_part = (f"<<SYS>>\n{system.strip()}\n<</SYS>>\n\n" if system and system.strip() else "")
58
+ return sys_part + f"### Instrucció\n{prompt}\n\n### Resposta\n"
59
 
60
+ #@spaces.GPU # use GPU if available (ZeroGPU)
61
+ #def _generate_with_tools(
62
+ # messages: List[Dict[str, str]],
63
+ # tools: List[Dict[str, Any]],
64
+ # max_new_tokens: int = 512,
65
+ # temperature: float = 0.7,
66
+ # top_p: float = 0.95,
67
+ #) -> Dict[str, Any]:
68
+ # tok, model = _lazy_load()
69
+ # tools_md = _render_tools_md(tools)
70
+ # prompt = _compose_chat_prompt(messages, tools_md)
71
 
72
+ # inputs = tok(prompt, return_tensors="pt").to(DEVICE)
73
+ # with torch.inference_mode():
74
+ # out = model.generate(
75
+ # **inputs,
76
+ # max_new_tokens=int(max_new_tokens),
77
+ # temperature=float(temperature),
78
+ # top_p=float(top_p),
79
+ # do_sample=True if temperature > 0 else False,
80
+ # pad_token_id=tok.eos_token_id,
81
+ # eos_token_id=tok.eos_token_id,
82
+ # )
83
+ # text = tok.decode(out[0], skip_special_tokens=True).strip()
84
 
85
+ # # If the model returns a JSON block with 'tool_calls', try to extract it
86
+ # tool_calls: List[Dict[str, Any]] = []
87
+ # try:
88
+ # # Search for the last {...} containing "tool_calls"
89
+ # matches = list(re.finditer(r"\{.*?\"tool_calls\".*?\}", text, flags=re.S))
90
+ # if matches:
91
+ # block = text[matches[-1].start():matches[-1].end()]
92
+ # obj = json.loads(block)
93
+ # tc = obj.get("tool_calls", [])
94
+ # if isinstance(tc, list):
95
+ # tool_calls = tc
96
+ # except Exception:
97
+ # pass
98
+
99
+ # Execute the extracted tool calls if any
100
+ # tool_results = maybe_execute_tool_calls(tool_calls) if tool_calls else []
101
+
102
+ # return {"text": text, "tool_calls": tool_calls, "tool_results": tool_results}
103
+
104
+ @spaces.GPU # use GPU if available (ZeroGPU)
105
  def _generate(
106
  prompt: str,
107
  system: str = "",
 
126
  return tok.decode(out[0], skip_special_tokens=True).strip()
127
 
128
  # ------------------- Gradio Endpoints -------------------
129
+ # 1) /predict — what ENGINE expects (only 'prompt' → string)
130
  def predict_for_engine(prompt: str) -> str:
131
  return _generate(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
132
 
133
+ # 2) /generate — more controls (prompt + system + params)
134
  def generate_advanced(prompt: str, system: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
135
  return _generate(prompt=prompt, system=system, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
136
 
137
  def salamandra_chat_endpoint(prompt: str) -> Dict[str, Any]:
138
  global _salamandra
139
  if _salamandra is None:
140
+ _salamandra = SalamandraClient() # use your class
141
 
142
  try:
143
  text = _salamandra.chat(prompt)
144
  except Exception as e:
145
+ text = f"Error running SalamandraClient: {str(e)}"
146
 
147
  return {"text": text}
148
 
149
+ def resume_sentence(sentence, num_words):
150
+ """
151
+ Summarizes the given sentence in the specified number of words.
152
+
153
+ Parameters:
154
+ - sentence (str): The sentence to summarize.
155
+ - num_words (int): The number of words for the summary.
156
+
157
+ Returns:
158
+ - str: The summarized sentence.
159
+ """
160
+ num_words = int(num_words)
161
+
162
+ # Prompt the model to summarize the sentence
163
+ prompt = f"Instrució: Resumeix la següent frase en {num_words} paraules. Input: {sentence}"
164
  result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
165
+
166
+ # Clean the output if it contains 'assistant' role
167
  if "assistant" in result:
168
  clean_output = result.split("assistant", 1)[1].strip().split("\n")[0]
169
  else:
170
+ clean_output = sentence
171
+
172
  return clean_output
173
 
174
+ def identity_manager(sentence, person):
175
+ """
176
+ Replaces the subject of the sentence with the indicated person, keeping the rest unchanged.
177
+ """
178
  prompt = f"""Instrucció: Substitueix el subjecte de la frase per la persona indicada, mantenint la resta igual.
179
+ Frase: {sentence}
180
+ Substitució: {person}
181
  Resposta:"""
182
+
183
+ # Generate the modified sentence using the advanced generator
184
  result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
185
+
186
+ # Clean the output if it contains 'assistant' role
187
  if "assistant" in result:
188
  clean_output = result.split("assistant", 1)[1].strip().split("\n")[0]
189
  else:
190
+ clean_output = sentence
191
+
192
  return clean_output
193
 
194
+ def free_narration(srt_text):
195
+ """
196
+ Converts the given audio description into a short, natural, and coherent free narration.
197
+ """
198
  prompt = f"""Instrucció: Converteix aquesta audiodescripció en una narració lliure breu, natural i coherent.,
199
  input: {srt_final}
200
  output:
201
  """
202
+
203
+ # Generate the free narration using the advanced generator
204
  result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
205
+
206
+ # Clean the output if it contains 'assistant' role
207
  if "assistant" in result:
208
  clean_output = result.split("assistant", 1)[1].strip().split("\n")[0]
209
  else:
210
+ clean_output = srt_text # fallback to original input
211
+
212
  return clean_output
213
 
214
  # ------------------- HTTP (opcional, clientes puros) -------------------
 
250
  gr.Button("Probar /predict").click(predict_for_engine, [in_prompt_engine], out_engine, api_name="predict", concurrency_limit=1)
251
  gr.Markdown("---")
252
 
 
 
 
 
 
 
 
 
 
 
 
253
  gr.Markdown('<h2 style="text-align:center">Resumir frases</h2>')
254
  with gr.Row():
255
  with gr.Column(scale=1):
 
261
  btn_resumir = gr.Button("Resumir", variant="primary")
262
 
263
  btn_resumir.click(
264
+ resume_sentence,
265
  inputs=[frase, num_paraules],
266
  outputs=out_resumir,
267
  api_name="resumir",
 
291
  with gr.Row():
292
  with gr.Column(scale=1):
293
  srt = gr.Textbox(label="Audiodescripció", value="(AD)\nTOTS CANTANT: avui celebrem la nostra festa major\nAINA: som hi tots a ballar", lines=3)
294
+ btn_modificar = gr.Button("Generar narració lliure", variant="primary")
295
  with gr.Column(scale=1):
296
  narració_lliure = gr.Textbox(label="Narració lliure", lines=18)
297
 
 
303
  concurrency_limit=1
304
  )
305
 
306
+ gr.Markdown('<h2 style="text-align:center">Sortida del model Salamandra a partir d’una petició</h2>')
307
+ with gr.Row():
308
+ prompt = gr.Textbox(label="prompt", lines=10)
309
+ with gr.Row():
310
+ btn2 = gr.Button("Generar", variant="primary")
311
+ with gr.Row():
312
+ out2 = gr.JSON(label="Salida")
313
+
314
+ btn2.click(salamandra_chat_endpoint, [prompt], out2, api_name="generate_out_from_prompt", concurrency_limit=1)
315
+ gr.Markdown("---")
316
+
317
  demo.queue(max_size=16).launch()