import os import time import spaces import torch import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer MODELS = { "global": { "repo_id": "CohereLabs/tiny-aya-global", "title": "Tiny Aya Global", "subtitle": "Broad multilingual coverage (67+ languages).", }, "earth": { "repo_id": "CohereLabs/tiny-aya-earth", "title": "Tiny Aya Earth", "subtitle": "Grounded variant tuned for real-world context.", }, } HF_TOKEN = os.environ.get("HF_TOKEN") MAX_NEW_TOKENS = 8096 - 1024 # leave room for the prompt in ~8k context tokenizers: dict[str, AutoTokenizer] = {} models: dict[str, AutoModelForCausalLM] = {} for key, meta in MODELS.items(): repo_id = meta["repo_id"] tokenizers[key] = AutoTokenizer.from_pretrained(repo_id, token=HF_TOKEN) models[key] = AutoModelForCausalLM.from_pretrained( repo_id, dtype=torch.bfloat16, token=HF_TOKEN, ).to("cuda") EXAMPLES = [ ["Explain photosynthesis in simple terms."], ["¿Cuál es la capital de Perú y qué la hace especial?"], ["Write a short haiku about rain."], ["Quelles sont les différences entre le français québécois et le français de France?"], ] def _build_inputs(tokenizer: AutoTokenizer, prompt: str, device: torch.device): messages = [{"role": "user", "content": prompt}] chat = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=False, ) return tokenizer(chat, return_tensors="pt").to(device) def _generate_one( key: str, prompt: str, max_new_tokens: int, temperature: float, top_p: float, ) -> tuple[str, float]: tokenizer = tokenizers[key] model = models[key] inputs = _build_inputs(tokenizer, prompt, model.device) input_ids = inputs["input_ids"] gen_kwargs: dict = { "max_new_tokens": max_new_tokens, "top_p": top_p, "pad_token_id": tokenizer.eos_token_id, } if temperature > 0: gen_kwargs["do_sample"] = True gen_kwargs["temperature"] = temperature else: gen_kwargs["do_sample"] = False started = time.perf_counter() with torch.inference_mode(): output = model.generate(**inputs, **gen_kwargs) elapsed = time.perf_counter() - started response = tokenizer.decode( output[0, input_ids.shape[-1] :], skip_special_tokens=True, ).strip() return response, elapsed def estimate_duration( prompt: str, max_new_tokens: int, temperature: float, top_p: float, ) -> int: _ = (prompt, temperature, top_p) return int(10 + max_new_tokens * 0.2) @spaces.GPU(duration=estimate_duration) def compare( prompt: str, max_new_tokens: int, temperature: float, top_p: float, ) -> tuple[str, str, str, str]: if not prompt.strip(): message = "Enter a prompt to compare both models." return message, message, "", "" global_text, global_time = _generate_one( "global", prompt, max_new_tokens, temperature, top_p ) earth_text, earth_time = _generate_one( "earth", prompt, max_new_tokens, temperature, top_p ) global_stats = f"{global_time:.2f}s" earth_stats = f"{earth_time:.2f}s" return global_text, earth_text, global_stats, earth_stats with gr.Blocks(title="Tiny Aya Compare") as demo: gr.Markdown( """ # Tiny Aya: Global vs Earth Side-by-side comparison of [tiny-aya-global](https://huggingface.co/CohereLabs/tiny-aya-global) and [tiny-aya-earth](https://huggingface.co/CohereLabs/tiny-aya-earth) on ZeroGPU. Both models are gated. Set an `HF_TOKEN` Space secret with access to CohereLabs models. """ ) prompt = gr.Textbox( label="Prompt", placeholder="Ask the same question in any supported language…", lines=3, ) with gr.Row(): max_new_tokens = gr.Slider( 16, MAX_NEW_TOKENS, value=256, step=16, label="Max new tokens" ) temperature = gr.Slider(0, 1.5, value=0.7, step=0.05, label="Temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") compare_btn = gr.Button("Compare", variant="primary") with gr.Row(): with gr.Column(): gr.Markdown(f"### {MODELS['global']['title']}") gr.Markdown(MODELS["global"]["subtitle"]) global_out = gr.Textbox(label="Response", lines=12) global_stats = gr.Textbox(label="Latency", interactive=False) with gr.Column(): gr.Markdown(f"### {MODELS['earth']['title']}") gr.Markdown(MODELS["earth"]["subtitle"]) earth_out = gr.Textbox(label="Response", lines=12) earth_stats = gr.Textbox(label="Latency", interactive=False) gr.Examples( examples=EXAMPLES, inputs=[prompt], cache_examples=False, ) compare_btn.click( fn=compare, inputs=[prompt, max_new_tokens, temperature, top_p], outputs=[global_out, earth_out, global_stats, earth_stats], ) if __name__ == "__main__": demo.launch()