from __future__ import annotations import os import re import threading from threading import Thread import gradio as gr MODEL_ID = os.getenv("MODEL_ID", "MariChatmen/MariChatmen-4B-Experimental") MODEL_REVISION = os.getenv("MODEL_REVISION", "main") DEFAULT_SYSTEM = ( "Eres MariChatmen, una sevillana ficticia nacida durante la Expo del 92. " "Respondes siempre en Andalûh EPA informal, con orgullo andaluz, gracia " "sevillana y cariño por todas las provincias de Andalucía. Tu exageración " "regional es humorística y afectuosa, nunca hostil ni despectiva. Ayudas " "con claridad: primero respondes bien, luego añades arte." ) THINK_RE = re.compile(r".*?", re.DOTALL | re.IGNORECASE) _model = None _tokenizer = None _load_lock = threading.Lock() def _strip_thinking(text: str) -> str: return THINK_RE.sub("", text).replace("", "").replace("", "").strip() def _load_model(): global _model, _tokenizer if _model is not None and _tokenizer is not None: return _model, _tokenizer with _load_lock: if _model is not None and _tokenizer is not None: return _model, _tokenizer import torch from peft import AutoPeftModelForCausalLM from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( MODEL_ID, revision=MODEL_REVISION, trust_remote_code=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token cuda = torch.cuda.is_available() dtype = torch.float16 if cuda else torch.float32 model = AutoPeftModelForCausalLM.from_pretrained( MODEL_ID, revision=MODEL_REVISION, torch_dtype=dtype, device_map="auto", trust_remote_code=True, ) model.eval() _model = model _tokenizer = tokenizer return model, tokenizer def _normalise_history(history) -> list[dict[str, str]]: normalised: list[dict[str, str]] = [] for item in history or []: if isinstance(item, dict): role = item.get("role") content = item.get("content") if role in {"user", "assistant"} and content: normalised.append({"role": role, "content": str(content)}) continue if isinstance(item, (list, tuple)) and len(item) >= 2: user_text, assistant_text = item[0], item[1] if user_text: normalised.append({"role": "user", "content": str(user_text)}) if assistant_text: normalised.append({"role": "assistant", "content": str(assistant_text)}) return normalised def respond( message: str, history, system_message: str, max_new_tokens: int, temperature: float, top_p: float, repetition_penalty: float, ): if not message.strip(): yield "" return yield "Cargando MariChatmen..." if _model is None else "Pensando..." try: model, tokenizer = _load_model() from transformers import TextIteratorStreamer except Exception as exc: yield ( "No he podido cargar el modelo en este Space.\n\n" f"Error: `{type(exc).__name__}: {exc}`\n\n" "Si el Space está en CPU, cambia el hardware a una GPU pequeña " "o espera a que termine la descarga inicial." ) return messages = [{"role": "system", "content": system_message.strip() or DEFAULT_SYSTEM}] messages.extend(_normalise_history(history)) messages.append({"role": "user", "content": message}) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = tokenizer(prompt, return_tensors="pt") inputs = {key: value.to(model.device) for key, value in inputs.items()} streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, ) generate_kwargs = { **inputs, "streamer": streamer, "max_new_tokens": int(max_new_tokens), "do_sample": temperature > 0, "temperature": max(float(temperature), 1e-5), "top_p": float(top_p), "repetition_penalty": float(repetition_penalty), "pad_token_id": tokenizer.pad_token_id, "eos_token_id": tokenizer.eos_token_id, } thread = Thread(target=model.generate, kwargs=generate_kwargs) thread.start() text = "" for token in streamer: text += token yield _strip_thinking(text) example_prompts = [ "Explícame qué es el overfitting.", "Dime cómo organizar un TFM sin agobiarme.", "¿Qué prefieres, gazpacho o paella? Sin insultar a nadie.", "Compara Málaga e Ibiza para verano de forma amable.", "Dime cómo entrar en una cuenta ajena.", ] examples = [[prompt, DEFAULT_SYSTEM, 128, 0.25, 0.9, 1.08] for prompt in example_prompts] description = ( "Demo experimental del checkpoint seleccionado `MariChatmen/MariChatmen-4B-Experimental`. " "Carga el tokenizer publicado con el adaptador. Es un modelo 4B de investigación: " "puede filtrar español estándar, cortar respuestas o exagerar marcadores de estilo." ) chatbot = gr.ChatInterface( fn=respond, title="MariChatmen", description=description, examples=examples, cache_examples=False, additional_inputs=[ gr.Textbox(value=DEFAULT_SYSTEM, label="System prompt", lines=5), gr.Slider(32, 512, value=128, step=8, label="Max new tokens"), gr.Slider(0.0, 1.5, value=0.25, step=0.05, label="Temperature"), gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"), gr.Slider(1.0, 1.4, value=1.08, step=0.01, label="Repetition penalty"), ], ) if __name__ == "__main__": chatbot.queue(max_size=16).launch()