mlandia's picture
max
f7ea7c6
Raw
History Blame Contribute Delete
5.19 kB
import os
import time
import spaces
import torch
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
MODELS = {
"global": {
"repo_id": "CohereLabs/tiny-aya-global",
"title": "Tiny Aya Global",
"subtitle": "Broad multilingual coverage (67+ languages).",
},
"earth": {
"repo_id": "CohereLabs/tiny-aya-earth",
"title": "Tiny Aya Earth",
"subtitle": "Grounded variant tuned for real-world context.",
},
}
HF_TOKEN = os.environ.get("HF_TOKEN")
MAX_NEW_TOKENS = 8096 - 1024 # leave room for the prompt in ~8k context
tokenizers: dict[str, AutoTokenizer] = {}
models: dict[str, AutoModelForCausalLM] = {}
for key, meta in MODELS.items():
repo_id = meta["repo_id"]
tokenizers[key] = AutoTokenizer.from_pretrained(repo_id, token=HF_TOKEN)
models[key] = AutoModelForCausalLM.from_pretrained(
repo_id,
dtype=torch.bfloat16,
token=HF_TOKEN,
).to("cuda")
EXAMPLES = [
["Explain photosynthesis in simple terms."],
["¿Cuál es la capital de Perú y qué la hace especial?"],
["Write a short haiku about rain."],
["Quelles sont les différences entre le français québécois et le français de France?"],
]
def _build_inputs(tokenizer: AutoTokenizer, prompt: str, device: torch.device):
messages = [{"role": "user", "content": prompt}]
chat = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=False,
)
return tokenizer(chat, return_tensors="pt").to(device)
def _generate_one(
key: str,
prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
) -> tuple[str, float]:
tokenizer = tokenizers[key]
model = models[key]
inputs = _build_inputs(tokenizer, prompt, model.device)
input_ids = inputs["input_ids"]
gen_kwargs: dict = {
"max_new_tokens": max_new_tokens,
"top_p": top_p,
"pad_token_id": tokenizer.eos_token_id,
}
if temperature > 0:
gen_kwargs["do_sample"] = True
gen_kwargs["temperature"] = temperature
else:
gen_kwargs["do_sample"] = False
started = time.perf_counter()
with torch.inference_mode():
output = model.generate(**inputs, **gen_kwargs)
elapsed = time.perf_counter() - started
response = tokenizer.decode(
output[0, input_ids.shape[-1] :],
skip_special_tokens=True,
).strip()
return response, elapsed
def estimate_duration(
prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
) -> int:
_ = (prompt, temperature, top_p)
return int(10 + max_new_tokens * 0.2)
@spaces.GPU(duration=estimate_duration)
def compare(
prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
) -> tuple[str, str, str, str]:
if not prompt.strip():
message = "Enter a prompt to compare both models."
return message, message, "", ""
global_text, global_time = _generate_one(
"global", prompt, max_new_tokens, temperature, top_p
)
earth_text, earth_time = _generate_one(
"earth", prompt, max_new_tokens, temperature, top_p
)
global_stats = f"{global_time:.2f}s"
earth_stats = f"{earth_time:.2f}s"
return global_text, earth_text, global_stats, earth_stats
with gr.Blocks(title="Tiny Aya Compare") as demo:
gr.Markdown(
"""
# Tiny Aya: Global vs Earth
Side-by-side comparison of [tiny-aya-global](https://huggingface.co/CohereLabs/tiny-aya-global)
and [tiny-aya-earth](https://huggingface.co/CohereLabs/tiny-aya-earth) on ZeroGPU.
Both models are gated. Set an `HF_TOKEN` Space secret with access to CohereLabs models.
"""
)
prompt = gr.Textbox(
label="Prompt",
placeholder="Ask the same question in any supported language…",
lines=3,
)
with gr.Row():
max_new_tokens = gr.Slider(
16, MAX_NEW_TOKENS, value=256, step=16, label="Max new tokens"
)
temperature = gr.Slider(0, 1.5, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p")
compare_btn = gr.Button("Compare", variant="primary")
with gr.Row():
with gr.Column():
gr.Markdown(f"### {MODELS['global']['title']}")
gr.Markdown(MODELS["global"]["subtitle"])
global_out = gr.Textbox(label="Response", lines=12)
global_stats = gr.Textbox(label="Latency", interactive=False)
with gr.Column():
gr.Markdown(f"### {MODELS['earth']['title']}")
gr.Markdown(MODELS["earth"]["subtitle"])
earth_out = gr.Textbox(label="Response", lines=12)
earth_stats = gr.Textbox(label="Latency", interactive=False)
gr.Examples(
examples=EXAMPLES,
inputs=[prompt],
cache_examples=False,
)
compare_btn.click(
fn=compare,
inputs=[prompt, max_new_tokens, temperature, top_p],
outputs=[global_out, earth_out, global_stats, earth_stats],
)
if __name__ == "__main__":
demo.launch()