schat / app.py
VeuReu's picture
Update app.py
18cf1c5 verified
# app.py — veureu/schat (Salamandra 7B Instruct · ZeroGPU) — compatible with ENGINE
from __future__ import annotations
import os, json
from typing import List, Dict, Any, Optional, Tuple
import gradio as gr
import spaces
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
)
from transformers import AutoTokenizer, AutoModelForCausalLM
from moe_tools import SalamandraClient
# ===== Config =====
MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-instruct")
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
_tok = None
_model = None
_salamandra = None
def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
global _tok, _model
if _tok is None or _model is None:
_tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
_model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
low_cpu_mem_usage=True,
use_safetensors=True,
trust_remote_code=True,
device_map=None,
).to(DEVICE)
return _tok, _model
def _build_prompt(prompt: str, system: Optional[str]) -> str:
"""
If the tokenizer has 'chat_template', use it with messages [system?, user].
Otherwise, create a plain prompt with system at the top.
"""
tok, _ = _lazy_load()
messages = []
if system and system.strip():
messages.append({"role": "system", "content": system.strip()})
messages.append({"role": "user", "content": prompt})
chat_template = getattr(tok, "chat_template", None)
if chat_template:
return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
# Fallback without chat template
sys_part = (f"<<SYS>>\n{system.strip()}\n<</SYS>>\n\n" if system and system.strip() else "")
return sys_part + f"### Instrucció\n{prompt}\n\n### Resposta\n"
#@spaces.GPU # use GPU if available (ZeroGPU)
#def _generate_with_tools(
# messages: List[Dict[str, str]],
# tools: List[Dict[str, Any]],
# max_new_tokens: int = 512,
# temperature: float = 0.7,
# top_p: float = 0.95,
#) -> Dict[str, Any]:
# tok, model = _lazy_load()
# tools_md = _render_tools_md(tools)
# prompt = _compose_chat_prompt(messages, tools_md)
# inputs = tok(prompt, return_tensors="pt").to(DEVICE)
# with torch.inference_mode():
# out = model.generate(
# **inputs,
# max_new_tokens=int(max_new_tokens),
# temperature=float(temperature),
# top_p=float(top_p),
# do_sample=True if temperature > 0 else False,
# pad_token_id=tok.eos_token_id,
# eos_token_id=tok.eos_token_id,
# )
# text = tok.decode(out[0], skip_special_tokens=True).strip()
# # If the model returns a JSON block with 'tool_calls', try to extract it
# tool_calls: List[Dict[str, Any]] = []
# try:
# # Search for the last {...} containing "tool_calls"
# matches = list(re.finditer(r"\{.*?\"tool_calls\".*?\}", text, flags=re.S))
# if matches:
# block = text[matches[-1].start():matches[-1].end()]
# obj = json.loads(block)
# tc = obj.get("tool_calls", [])
# if isinstance(tc, list):
# tool_calls = tc
# except Exception:
# pass
# Execute the extracted tool calls if any
# tool_results = maybe_execute_tool_calls(tool_calls) if tool_calls else []
# return {"text": text, "tool_calls": tool_calls, "tool_results": tool_results}
@spaces.GPU # use GPU if available (ZeroGPU)
def _generate(
prompt: str,
system: str = "",
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.95,
) -> str:
tok, model = _lazy_load()
text = _build_prompt(prompt, system or "")
inputs = tok(text, return_tensors="pt").to(DEVICE)
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=True if temperature > 0 else False,
pad_token_id=tok.eos_token_id,
eos_token_id=tok.eos_token_id,
)
return tok.decode(out[0], skip_special_tokens=True).strip()
# ------------------- Gradio Endpoints -------------------
# 1) /predict — what ENGINE expects (only 'prompt' → string)
def predict_for_engine(prompt: str) -> str:
return _generate(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
# 2) /generate — more controls (prompt + system + params)
def generate_advanced(prompt: str, system: str, max_new_tokens: int, temperature: float, top_p: float) -> str:
return _generate(prompt=prompt, system=system, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p)
def salamandra_chat_endpoint(prompt: str) -> Dict[str, Any]:
global _salamandra
if _salamandra is None:
_salamandra = SalamandraClient() # use your class
try:
text = _salamandra.chat(prompt)
except Exception as e:
text = f"Error running SalamandraClient: {str(e)}"
return {"text": text}
def resume_sentence(sentence, num_words):
"""
Summarizes the given sentence in the specified number of words.
Parameters:
- sentence (str): The sentence to summarize.
- num_words (int): The number of words for the summary.
Returns:
- str: The summarized sentence.
"""
num_words = int(num_words)
# Prompt the model to summarize the sentence
prompt = f"Instrució: Resumeix la següent frase en {num_words} paraules. Input: {sentence}"
result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
# Clean the output if it contains 'assistant' role
if "assistant" in result:
clean_output = result.split("assistant", 1)[1].strip().split("\n")[0]
else:
clean_output = sentence
return clean_output
def identity_manager(sentence, person):
"""
Replaces the subject of the sentence with the indicated person, keeping the rest unchanged.
"""
prompt = f"""Instrucció: Substitueix el subjecte de la frase per la persona indicada, mantenint la resta igual.
Frase: {sentence}
Substitució: {person}
Resposta:"""
# Generate the modified sentence using the advanced generator
result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
# Clean the output if it contains 'assistant' role
if "assistant" in result:
clean_output = result.split("assistant", 1)[1].strip().split("\n")[0]
else:
clean_output = sentence
return clean_output
def free_narration(srt_text):
"""
Converts the given audio description into a short, natural, and coherent free narration.
"""
prompt = f"""Instrucció: Converteix aquesta audiodescripció en una narració lliure breu, natural i coherent.,
input: {srt_text}
output:
"""
# Generate the free narration using the advanced generator
result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95)
# Clean the output if it contains 'assistant' role
if "assistant" in result:
clean_output = result.split("assistant", 1)[1].strip().split("\n")[0]
else:
clean_output = srt_text # fallback to original input
return clean_output
# ------------------- HTTP (opcional, clientes puros) -------------------
# Si quieres, puedes añadir un endpoint HTTP POST /generate (FastAPI),
# pero con Gradio Client es suficiente para engine/local.
# ------------------- UI -------------------
custom_css = """
h2 {
background: #e3e4e6 !important;
padding: 14px 22px !important;
border-radius: 14px !important;
box-shadow: 0 4px 12px rgba(0,0,0,0.08) !important;
display: block !important; /* ocupa tot l'ample */
width: 100% !important; /* assegura 100% */
margin: 20px auto !important;
text-align:center;
}
"""
# App UI built with Gradio. This interface exposes several model utilities.
with gr.Blocks(title="Salamandra 7B Instruct · ZeroGPU", css=custom_css, theme=gr.themes.Soft()) as demo:
# Section: Instruction-based text generation
gr.Markdown("## Salamandra-7B-Instruct · ZeroGPU\nText → resposta instruccional.")
with gr.Row():
with gr.Column(scale=1):
# System prompt (optional internal conditioning)
in_system = gr.Textbox(label="Sistema (opcional)", value="")
# User prompt to instruct the model
in_prompt = gr.Textbox(label="Instrucció", placeholder="Escriu la teva instrucció…", lines=6)
# Maximum number of new tokens to generate
max_new = gr.Slider(16, 2048, value=512, step=16, label="Màxim de tokens nous")
# Diversity parameter for randomness
temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperatura")
# Nucleus sampling threshold
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p")
# Button to trigger text generation
btn = gr.Button("Generar", variant="primary")
with gr.Column(scale=1):
# Output box for generated text
out = gr.Textbox(label="Resposta", lines=18)
# Bind main generation function
btn.click(
generate_advanced,
[in_prompt, in_system, max_new, temp, top_p],
out,
api_name="generate",
concurrency_limit=1
)
# --------------------------------------------------------------
gr.Markdown("---")
# --------------------------------------------------------------
# Minimal endpoint for ENGINE compatibility (/predict)
# Only requires a prompt, returns generated text
in_prompt_engine = gr.Textbox(label="Instrucció (ENGINE)", value="Digues hola en una frase.")
out_engine = gr.Textbox(label="Resposta (ENGINE)")
gr.Button("Provar /predict").click(
predict_for_engine,
[in_prompt_engine],
out_engine,
api_name="predict",
concurrency_limit=1
)
# --------------------------------------------------------------
gr.Markdown("---")
# --------------------------------------------------------------
# Section: Sentence summarization
gr.Markdown('<h2 style="text-align:center">Resumir frases</h2>')
with gr.Row():
with gr.Column(scale=1):
# Text to summarize
sentence = gr.Textbox(label="Frase a resumir", value="", lines=3)
# Desired number of words in the summary
num_words = gr.Textbox(label="Nombre de paraules del resum", value="4")
with gr.Column(scale=1):
# Output summary
out_resume = gr.Textbox(label="Resposta", lines=18)
with gr.Row():
# Button to produce a summary
btn_resume = gr.Button("Resumir", variant="primary")
btn_resume.click(
resume_sentence,
inputs=[sentence, num_words],
outputs=out_resume,
api_name="resume",
concurrency_limit=1
)
# --------------------------------------------------------------
gr.Markdown("---")
# --------------------------------------------------------------
# Section: Inclusion of identities inside text
gr.Markdown('<h2 style="text-align:center">Inclusió d’identitats</h2>')
with gr.Row():
with gr.Column(scale=1):
# Sentence to modify
sentence = gr.Textbox(label="Frase a modificar", value="", lines=3)
# Identity mapping provided by the user
person = gr.Textbox(label="Persones reconegudes", value='"Mireia Martí": 4, "Xavier Busquets": 5')
with gr.Column(scale=1):
out_modificat = gr.Textbox(label="Resposta", lines=18)
with gr.Row():
btn_modify = gr.Button("Modificar frase", variant="primary")
btn_modify.click(
identity_manager,
inputs=[sentence, person],
outputs=out_modificat,
api_name="modificat",
concurrency_limit=1
)
# --------------------------------------------------------------
gr.Markdown("---")
# --------------------------------------------------------------
# Section: Free narration generation from SRT-like audio description
gr.Markdown('<h2 style="text-align:center">Narració lliure</h2>')
with gr.Row():
with gr.Column(scale=1):
# SRT-like structured description
srt = gr.Textbox(
label="Audiodescripció",
value="(AD)\nTOTS CANTANT: avui celebrem la nostra festa major\nAINA: som hi tots a ballar",
lines=3
)
btn_modify = gr.Button("Generar narració lliure", variant="primary")
with gr.Column(scale=1):
narració_lliure = gr.Textbox(label="Narració lliure", lines=18)
btn_modify.click(
free_narration,
inputs=[srt],
outputs=narració_lliure,
api_name="narració",
concurrency_limit=1
)
# --------------------------------------------------------------
gr.Markdown("---")
# --------------------------------------------------------------
# Section: Raw model output from a prompt (JSON)
gr.Markdown('<h2 style="text-align:center">Sortida del model Salamandra a partir d’una petició</h2>')
with gr.Row():
prompt = gr.Textbox(label="Prompt", lines=10)
with gr.Row():
btn2 = gr.Button("Generar", variant="primary")
with gr.Row():
out2 = gr.JSON(label="Sortida")
btn2.click(
salamandra_chat_endpoint,
[prompt],
out2,
api_name="generate_out_from_prompt",
concurrency_limit=1
)
# --------------------------------------------------------------
gr.Markdown("---")
# --------------------------------------------------------------
# Queue to handle multiple requests safely
demo.queue(max_size=16).launch()