| import os |
| import time |
|
|
| import spaces |
| import torch |
| import gradio as gr |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| MODELS = { |
| "global": { |
| "repo_id": "CohereLabs/tiny-aya-global", |
| "title": "Tiny Aya Global", |
| "subtitle": "Broad multilingual coverage (67+ languages).", |
| }, |
| "earth": { |
| "repo_id": "CohereLabs/tiny-aya-earth", |
| "title": "Tiny Aya Earth", |
| "subtitle": "Grounded variant tuned for real-world context.", |
| }, |
| } |
|
|
| HF_TOKEN = os.environ.get("HF_TOKEN") |
| MAX_NEW_TOKENS = 8096 - 1024 |
|
|
| tokenizers: dict[str, AutoTokenizer] = {} |
| models: dict[str, AutoModelForCausalLM] = {} |
|
|
| for key, meta in MODELS.items(): |
| repo_id = meta["repo_id"] |
| tokenizers[key] = AutoTokenizer.from_pretrained(repo_id, token=HF_TOKEN) |
| models[key] = AutoModelForCausalLM.from_pretrained( |
| repo_id, |
| dtype=torch.bfloat16, |
| token=HF_TOKEN, |
| ).to("cuda") |
|
|
| EXAMPLES = [ |
| ["Explain photosynthesis in simple terms."], |
| ["¿Cuál es la capital de Perú y qué la hace especial?"], |
| ["Write a short haiku about rain."], |
| ["Quelles sont les différences entre le français québécois et le français de France?"], |
| ] |
|
|
|
|
| def _build_inputs(tokenizer: AutoTokenizer, prompt: str, device: torch.device): |
| messages = [{"role": "user", "content": prompt}] |
| chat = tokenizer.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| tokenize=False, |
| ) |
| return tokenizer(chat, return_tensors="pt").to(device) |
|
|
|
|
| def _generate_one( |
| key: str, |
| prompt: str, |
| max_new_tokens: int, |
| temperature: float, |
| top_p: float, |
| ) -> tuple[str, float]: |
| tokenizer = tokenizers[key] |
| model = models[key] |
| inputs = _build_inputs(tokenizer, prompt, model.device) |
| input_ids = inputs["input_ids"] |
|
|
| gen_kwargs: dict = { |
| "max_new_tokens": max_new_tokens, |
| "top_p": top_p, |
| "pad_token_id": tokenizer.eos_token_id, |
| } |
| if temperature > 0: |
| gen_kwargs["do_sample"] = True |
| gen_kwargs["temperature"] = temperature |
| else: |
| gen_kwargs["do_sample"] = False |
|
|
| started = time.perf_counter() |
| with torch.inference_mode(): |
| output = model.generate(**inputs, **gen_kwargs) |
| elapsed = time.perf_counter() - started |
|
|
| response = tokenizer.decode( |
| output[0, input_ids.shape[-1] :], |
| skip_special_tokens=True, |
| ).strip() |
| return response, elapsed |
|
|
|
|
| def estimate_duration( |
| prompt: str, |
| max_new_tokens: int, |
| temperature: float, |
| top_p: float, |
| ) -> int: |
| _ = (prompt, temperature, top_p) |
| return int(10 + max_new_tokens * 0.2) |
|
|
|
|
| @spaces.GPU(duration=estimate_duration) |
| def compare( |
| prompt: str, |
| max_new_tokens: int, |
| temperature: float, |
| top_p: float, |
| ) -> tuple[str, str, str, str]: |
| if not prompt.strip(): |
| message = "Enter a prompt to compare both models." |
| return message, message, "", "" |
|
|
| global_text, global_time = _generate_one( |
| "global", prompt, max_new_tokens, temperature, top_p |
| ) |
| earth_text, earth_time = _generate_one( |
| "earth", prompt, max_new_tokens, temperature, top_p |
| ) |
|
|
| global_stats = f"{global_time:.2f}s" |
| earth_stats = f"{earth_time:.2f}s" |
| return global_text, earth_text, global_stats, earth_stats |
|
|
|
|
| with gr.Blocks(title="Tiny Aya Compare") as demo: |
| gr.Markdown( |
| """ |
| # Tiny Aya: Global vs Earth |
| Side-by-side comparison of [tiny-aya-global](https://huggingface.co/CohereLabs/tiny-aya-global) |
| and [tiny-aya-earth](https://huggingface.co/CohereLabs/tiny-aya-earth) on ZeroGPU. |
| |
| Both models are gated. Set an `HF_TOKEN` Space secret with access to CohereLabs models. |
| """ |
| ) |
|
|
| prompt = gr.Textbox( |
| label="Prompt", |
| placeholder="Ask the same question in any supported language…", |
| lines=3, |
| ) |
|
|
| with gr.Row(): |
| max_new_tokens = gr.Slider( |
| 16, MAX_NEW_TOKENS, value=256, step=16, label="Max new tokens" |
| ) |
| temperature = gr.Slider(0, 1.5, value=0.7, step=0.05, label="Temperature") |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") |
|
|
| compare_btn = gr.Button("Compare", variant="primary") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown(f"### {MODELS['global']['title']}") |
| gr.Markdown(MODELS["global"]["subtitle"]) |
| global_out = gr.Textbox(label="Response", lines=12) |
| global_stats = gr.Textbox(label="Latency", interactive=False) |
| with gr.Column(): |
| gr.Markdown(f"### {MODELS['earth']['title']}") |
| gr.Markdown(MODELS["earth"]["subtitle"]) |
| earth_out = gr.Textbox(label="Response", lines=12) |
| earth_stats = gr.Textbox(label="Latency", interactive=False) |
|
|
| gr.Examples( |
| examples=EXAMPLES, |
| inputs=[prompt], |
| cache_examples=False, |
| ) |
|
|
| compare_btn.click( |
| fn=compare, |
| inputs=[prompt, max_new_tokens, temperature, top_p], |
| outputs=[global_out, earth_out, global_stats, earth_stats], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|