|
|
|
|
|
from __future__ import annotations |
|
|
import os, json |
|
|
from typing import List, Dict, Any, Optional, Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import spaces |
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
TextIteratorStreamer, |
|
|
) |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from moe_tools import SalamandraClient |
|
|
|
|
|
|
|
|
MODEL_ID = os.environ.get("MODEL_ID", "BSC-LT/salamandra-7b-instruct") |
|
|
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
_tok = None |
|
|
_model = None |
|
|
_salamandra = None |
|
|
|
|
|
def _lazy_load() -> Tuple[AutoTokenizer, AutoModelForCausalLM]: |
|
|
global _tok, _model |
|
|
if _tok is None or _model is None: |
|
|
_tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True) |
|
|
_model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=DTYPE, |
|
|
low_cpu_mem_usage=True, |
|
|
use_safetensors=True, |
|
|
trust_remote_code=True, |
|
|
device_map=None, |
|
|
).to(DEVICE) |
|
|
return _tok, _model |
|
|
|
|
|
def _build_prompt(prompt: str, system: Optional[str]) -> str: |
|
|
""" |
|
|
If the tokenizer has 'chat_template', use it with messages [system?, user]. |
|
|
Otherwise, create a plain prompt with system at the top. |
|
|
""" |
|
|
tok, _ = _lazy_load() |
|
|
messages = [] |
|
|
if system and system.strip(): |
|
|
messages.append({"role": "system", "content": system.strip()}) |
|
|
messages.append({"role": "user", "content": prompt}) |
|
|
|
|
|
chat_template = getattr(tok, "chat_template", None) |
|
|
if chat_template: |
|
|
return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
|
|
|
sys_part = (f"<<SYS>>\n{system.strip()}\n<</SYS>>\n\n" if system and system.strip() else "") |
|
|
return sys_part + f"### Instrucció\n{prompt}\n\n### Resposta\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU |
|
|
def _generate( |
|
|
prompt: str, |
|
|
system: str = "", |
|
|
max_new_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
top_p: float = 0.95, |
|
|
) -> str: |
|
|
tok, model = _lazy_load() |
|
|
text = _build_prompt(prompt, system or "") |
|
|
inputs = tok(text, return_tensors="pt").to(DEVICE) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
out = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=int(max_new_tokens), |
|
|
temperature=float(temperature), |
|
|
top_p=float(top_p), |
|
|
do_sample=True if temperature > 0 else False, |
|
|
pad_token_id=tok.eos_token_id, |
|
|
eos_token_id=tok.eos_token_id, |
|
|
) |
|
|
return tok.decode(out[0], skip_special_tokens=True).strip() |
|
|
|
|
|
|
|
|
|
|
|
def predict_for_engine(prompt: str) -> str: |
|
|
return _generate(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95) |
|
|
|
|
|
|
|
|
def generate_advanced(prompt: str, system: str, max_new_tokens: int, temperature: float, top_p: float) -> str: |
|
|
return _generate(prompt=prompt, system=system, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p) |
|
|
|
|
|
def salamandra_chat_endpoint(prompt: str) -> Dict[str, Any]: |
|
|
global _salamandra |
|
|
if _salamandra is None: |
|
|
_salamandra = SalamandraClient() |
|
|
|
|
|
try: |
|
|
text = _salamandra.chat(prompt) |
|
|
except Exception as e: |
|
|
text = f"Error running SalamandraClient: {str(e)}" |
|
|
|
|
|
return {"text": text} |
|
|
|
|
|
def resume_sentence(sentence, num_words): |
|
|
""" |
|
|
Summarizes the given sentence in the specified number of words. |
|
|
|
|
|
Parameters: |
|
|
- sentence (str): The sentence to summarize. |
|
|
- num_words (int): The number of words for the summary. |
|
|
|
|
|
Returns: |
|
|
- str: The summarized sentence. |
|
|
""" |
|
|
num_words = int(num_words) |
|
|
|
|
|
|
|
|
prompt = f"Instrució: Resumeix la següent frase en {num_words} paraules. Input: {sentence}" |
|
|
result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95) |
|
|
|
|
|
|
|
|
if "assistant" in result: |
|
|
clean_output = result.split("assistant", 1)[1].strip().split("\n")[0] |
|
|
else: |
|
|
clean_output = sentence |
|
|
|
|
|
return clean_output |
|
|
|
|
|
def identity_manager(sentence, person): |
|
|
""" |
|
|
Replaces the subject of the sentence with the indicated person, keeping the rest unchanged. |
|
|
""" |
|
|
prompt = f"""Instrucció: Substitueix el subjecte de la frase per la persona indicada, mantenint la resta igual. |
|
|
Frase: {sentence} |
|
|
Substitució: {person} |
|
|
Resposta:""" |
|
|
|
|
|
|
|
|
result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95) |
|
|
|
|
|
|
|
|
if "assistant" in result: |
|
|
clean_output = result.split("assistant", 1)[1].strip().split("\n")[0] |
|
|
else: |
|
|
clean_output = sentence |
|
|
|
|
|
return clean_output |
|
|
|
|
|
def free_narration(srt_text): |
|
|
""" |
|
|
Converts the given audio description into a short, natural, and coherent free narration. |
|
|
""" |
|
|
prompt = f"""Instrucció: Converteix aquesta audiodescripció en una narració lliure breu, natural i coherent., |
|
|
input: {srt_text} |
|
|
output: |
|
|
""" |
|
|
|
|
|
|
|
|
result = generate_advanced(prompt=prompt, system="", max_new_tokens=512, temperature=0.7, top_p=0.95) |
|
|
|
|
|
|
|
|
if "assistant" in result: |
|
|
clean_output = result.split("assistant", 1)[1].strip().split("\n")[0] |
|
|
else: |
|
|
clean_output = srt_text |
|
|
|
|
|
return clean_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_css = """ |
|
|
h2 { |
|
|
background: #e3e4e6 !important; |
|
|
padding: 14px 22px !important; |
|
|
border-radius: 14px !important; |
|
|
box-shadow: 0 4px 12px rgba(0,0,0,0.08) !important; |
|
|
display: block !important; /* ocupa tot l'ample */ |
|
|
width: 100% !important; /* assegura 100% */ |
|
|
margin: 20px auto !important; |
|
|
text-align:center; |
|
|
} |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Salamandra 7B Instruct · ZeroGPU", css=custom_css, theme=gr.themes.Soft()) as demo: |
|
|
|
|
|
|
|
|
gr.Markdown("## Salamandra-7B-Instruct · ZeroGPU\nText → resposta instruccional.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
in_system = gr.Textbox(label="Sistema (opcional)", value="") |
|
|
|
|
|
|
|
|
in_prompt = gr.Textbox(label="Instrucció", placeholder="Escriu la teva instrucció…", lines=6) |
|
|
|
|
|
|
|
|
max_new = gr.Slider(16, 2048, value=512, step=16, label="Màxim de tokens nous") |
|
|
|
|
|
|
|
|
temp = gr.Slider(0.0, 1.5, value=0.7, step=0.05, label="Temperatura") |
|
|
|
|
|
|
|
|
top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.01, label="Top-p") |
|
|
|
|
|
|
|
|
btn = gr.Button("Generar", variant="primary") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
out = gr.Textbox(label="Resposta", lines=18) |
|
|
|
|
|
|
|
|
btn.click( |
|
|
generate_advanced, |
|
|
[in_prompt, in_system, max_new, temp, top_p], |
|
|
out, |
|
|
api_name="generate", |
|
|
concurrency_limit=1 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
in_prompt_engine = gr.Textbox(label="Instrucció (ENGINE)", value="Digues hola en una frase.") |
|
|
out_engine = gr.Textbox(label="Resposta (ENGINE)") |
|
|
|
|
|
gr.Button("Provar /predict").click( |
|
|
predict_for_engine, |
|
|
[in_prompt_engine], |
|
|
out_engine, |
|
|
api_name="predict", |
|
|
concurrency_limit=1 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown('<h2 style="text-align:center">Resumir frases</h2>') |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
sentence = gr.Textbox(label="Frase a resumir", value="", lines=3) |
|
|
|
|
|
|
|
|
num_words = gr.Textbox(label="Nombre de paraules del resum", value="4") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
|
|
|
out_resume = gr.Textbox(label="Resposta", lines=18) |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
btn_resume = gr.Button("Resumir", variant="primary") |
|
|
|
|
|
btn_resume.click( |
|
|
resume_sentence, |
|
|
inputs=[sentence, num_words], |
|
|
outputs=out_resume, |
|
|
api_name="resume", |
|
|
concurrency_limit=1 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown('<h2 style="text-align:center">Inclusió d’identitats</h2>') |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
sentence = gr.Textbox(label="Frase a modificar", value="", lines=3) |
|
|
|
|
|
|
|
|
person = gr.Textbox(label="Persones reconegudes", value='"Mireia Martí": 4, "Xavier Busquets": 5') |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
out_modificat = gr.Textbox(label="Resposta", lines=18) |
|
|
|
|
|
with gr.Row(): |
|
|
btn_modify = gr.Button("Modificar frase", variant="primary") |
|
|
|
|
|
btn_modify.click( |
|
|
identity_manager, |
|
|
inputs=[sentence, person], |
|
|
outputs=out_modificat, |
|
|
api_name="modificat", |
|
|
concurrency_limit=1 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown('<h2 style="text-align:center">Narració lliure</h2>') |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
|
|
|
srt = gr.Textbox( |
|
|
label="Audiodescripció", |
|
|
value="(AD)\nTOTS CANTANT: avui celebrem la nostra festa major\nAINA: som hi tots a ballar", |
|
|
lines=3 |
|
|
) |
|
|
|
|
|
btn_modify = gr.Button("Generar narració lliure", variant="primary") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
narració_lliure = gr.Textbox(label="Narració lliure", lines=18) |
|
|
|
|
|
btn_modify.click( |
|
|
free_narration, |
|
|
inputs=[srt], |
|
|
outputs=narració_lliure, |
|
|
api_name="narració", |
|
|
concurrency_limit=1 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
|
|
|
|
|
|
gr.Markdown('<h2 style="text-align:center">Sortida del model Salamandra a partir d’una petició</h2>') |
|
|
|
|
|
with gr.Row(): |
|
|
prompt = gr.Textbox(label="Prompt", lines=10) |
|
|
|
|
|
with gr.Row(): |
|
|
btn2 = gr.Button("Generar", variant="primary") |
|
|
|
|
|
with gr.Row(): |
|
|
out2 = gr.JSON(label="Sortida") |
|
|
|
|
|
btn2.click( |
|
|
salamandra_chat_endpoint, |
|
|
[prompt], |
|
|
out2, |
|
|
api_name="generate_out_from_prompt", |
|
|
concurrency_limit=1 |
|
|
) |
|
|
|
|
|
|
|
|
gr.Markdown("---") |
|
|
|
|
|
|
|
|
|
|
|
demo.queue(max_size=16).launch() |
|
|
|