demo / app.py
alobos's picture
Use selected 4B model by default
0b64250 verified
from __future__ import annotations
import os
import re
import threading
from threading import Thread
import gradio as gr
MODEL_ID = os.getenv("MODEL_ID", "MariChatmen/MariChatmen-4B-Experimental")
MODEL_REVISION = os.getenv("MODEL_REVISION", "main")
DEFAULT_SYSTEM = (
"Eres MariChatmen, una sevillana ficticia nacida durante la Expo del 92. "
"Respondes siempre en Andalûh EPA informal, con orgullo andaluz, gracia "
"sevillana y cariño por todas las provincias de Andalucía. Tu exageración "
"regional es humorística y afectuosa, nunca hostil ni despectiva. Ayudas "
"con claridad: primero respondes bien, luego añades arte."
)
THINK_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE)
_model = None
_tokenizer = None
_load_lock = threading.Lock()
def _strip_thinking(text: str) -> str:
return THINK_RE.sub("", text).replace("<think>", "").replace("</think>", "").strip()
def _load_model():
global _model, _tokenizer
if _model is not None and _tokenizer is not None:
return _model, _tokenizer
with _load_lock:
if _model is not None and _tokenizer is not None:
return _model, _tokenizer
import torch
from peft import AutoPeftModelForCausalLM
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
MODEL_ID,
revision=MODEL_REVISION,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
cuda = torch.cuda.is_available()
dtype = torch.float16 if cuda else torch.float32
model = AutoPeftModelForCausalLM.from_pretrained(
MODEL_ID,
revision=MODEL_REVISION,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True,
)
model.eval()
_model = model
_tokenizer = tokenizer
return model, tokenizer
def _normalise_history(history) -> list[dict[str, str]]:
normalised: list[dict[str, str]] = []
for item in history or []:
if isinstance(item, dict):
role = item.get("role")
content = item.get("content")
if role in {"user", "assistant"} and content:
normalised.append({"role": role, "content": str(content)})
continue
if isinstance(item, (list, tuple)) and len(item) >= 2:
user_text, assistant_text = item[0], item[1]
if user_text:
normalised.append({"role": "user", "content": str(user_text)})
if assistant_text:
normalised.append({"role": "assistant", "content": str(assistant_text)})
return normalised
def respond(
message: str,
history,
system_message: str,
max_new_tokens: int,
temperature: float,
top_p: float,
repetition_penalty: float,
):
if not message.strip():
yield ""
return
yield "Cargando MariChatmen..." if _model is None else "Pensando..."
try:
model, tokenizer = _load_model()
from transformers import TextIteratorStreamer
except Exception as exc:
yield (
"No he podido cargar el modelo en este Space.\n\n"
f"Error: `{type(exc).__name__}: {exc}`\n\n"
"Si el Space está en CPU, cambia el hardware a una GPU pequeña "
"o espera a que termine la descarga inicial."
)
return
messages = [{"role": "system", "content": system_message.strip() or DEFAULT_SYSTEM}]
messages.extend(_normalise_history(history))
messages.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {key: value.to(model.device) for key, value in inputs.items()}
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
generate_kwargs = {
**inputs,
"streamer": streamer,
"max_new_tokens": int(max_new_tokens),
"do_sample": temperature > 0,
"temperature": max(float(temperature), 1e-5),
"top_p": float(top_p),
"repetition_penalty": float(repetition_penalty),
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
}
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
text = ""
for token in streamer:
text += token
yield _strip_thinking(text)
example_prompts = [
"Explícame qué es el overfitting.",
"Dime cómo organizar un TFM sin agobiarme.",
"¿Qué prefieres, gazpacho o paella? Sin insultar a nadie.",
"Compara Málaga e Ibiza para verano de forma amable.",
"Dime cómo entrar en una cuenta ajena.",
]
examples = [[prompt, DEFAULT_SYSTEM, 128, 0.25, 0.9, 1.08] for prompt in example_prompts]
description = (
"Demo experimental del checkpoint seleccionado `MariChatmen/MariChatmen-4B-Experimental`. "
"Carga el tokenizer publicado con el adaptador. Es un modelo 4B de investigación: "
"puede filtrar español estándar, cortar respuestas o exagerar marcadores de estilo."
)
chatbot = gr.ChatInterface(
fn=respond,
title="MariChatmen",
description=description,
examples=examples,
cache_examples=False,
additional_inputs=[
gr.Textbox(value=DEFAULT_SYSTEM, label="System prompt", lines=5),
gr.Slider(32, 512, value=128, step=8, label="Max new tokens"),
gr.Slider(0.0, 1.5, value=0.25, step=0.05, label="Temperature"),
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
gr.Slider(1.0, 1.4, value=1.08, step=0.01, label="Repetition penalty"),
],
)
if __name__ == "__main__":
chatbot.queue(max_size=16).launch()