model-arena / app.py
jonathanagustin's picture
Sync from deploy tool: tutorials/07-model-arena
111ceaa verified
import os
import logging
import gradio as gr
from huggingface_hub import InferenceClient
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
# Environment variables
HF_TOKEN = os.environ.get("HF_TOKEN", "")
logger.info(f"HF_TOKEN configured: {bool(HF_TOKEN)}")
client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else InferenceClient()
logger.info("InferenceClient initialized")
# Models to compare (configurable via env vars)
MODELS = {
"Qwen 2.5 72B": os.environ.get("MODEL_1", "Qwen/Qwen2.5-72B-Instruct"),
"Llama 3.2 3B": os.environ.get("MODEL_2", "meta-llama/Llama-3.2-3B-Instruct"),
"Zephyr 7B": os.environ.get("MODEL_3", "HuggingFaceH4/zephyr-7b-beta"),
}
logger.info(f"Loaded {len(MODELS)} models for comparison")
def query_model(model_id: str, prompt: str, max_tokens: int) -> str:
"""Query a single model."""
try:
logger.info(f"Querying {model_id}...")
response = client.chat.completions.create(
model=model_id,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"Error with {model_id}: {e}")
return f"❌ Error: {e}"
def compare_models(prompt: str, max_tokens: int) -> tuple:
"""Query all models and return responses."""
logger.info(f"compare_models() called | prompt_len={len(prompt)} | max_tokens={max_tokens}")
if not prompt.strip():
return ("Enter a prompt!", "Enter a prompt!", "Enter a prompt!")
results = []
for name, model_id in MODELS.items():
result = query_model(model_id, prompt, max_tokens)
results.append(result)
return tuple(results)
logger.info("Building Gradio interface...")
with gr.Blocks(title="Model Arena") as demo:
gr.Markdown("""# 🏟️ Model Arena
Compare responses from multiple LLMs side-by-side!
*All models run via HuggingFace Inference API - no downloads required.*
""")
prompt = gr.Textbox(
label="Your prompt",
placeholder="Explain quantum computing in simple terms...",
lines=3,
autofocus=True,
)
max_tokens = gr.Slider(
minimum=50, maximum=500, value=200, step=50,
label="Max tokens per response"
)
btn = gr.Button("⚔️ Battle!", variant="primary", size="lg")
with gr.Row(equal_height=True):
outputs = []
for name in MODELS.keys():
outputs.append(gr.Textbox(label=name, lines=10, interactive=False))
btn.click(compare_models, inputs=[prompt, max_tokens], outputs=outputs)
prompt.submit(compare_models, inputs=[prompt, max_tokens], outputs=outputs)
gr.Examples(
examples=[
["Explain quantum computing to a 5-year-old", 200],
["Write a haiku about machine learning", 100],
["What are 3 creative startup ideas?", 300],
["Debate: Is AI good or bad for humanity?", 400],
],
inputs=[prompt, max_tokens],
)
demo.queue()
logger.info("Starting Gradio server...")
demo.launch()