File size: 3,261 Bytes
111ceaa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import logging
import gradio as gr
from huggingface_hub import InferenceClient

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)

# Environment variables
HF_TOKEN = os.environ.get("HF_TOKEN", "")

logger.info(f"HF_TOKEN configured: {bool(HF_TOKEN)}")

client = InferenceClient(token=HF_TOKEN) if HF_TOKEN else InferenceClient()
logger.info("InferenceClient initialized")

# Models to compare (configurable via env vars)
MODELS = {
    "Qwen 2.5 72B": os.environ.get("MODEL_1", "Qwen/Qwen2.5-72B-Instruct"),
    "Llama 3.2 3B": os.environ.get("MODEL_2", "meta-llama/Llama-3.2-3B-Instruct"),
    "Zephyr 7B": os.environ.get("MODEL_3", "HuggingFaceH4/zephyr-7b-beta"),
}
logger.info(f"Loaded {len(MODELS)} models for comparison")


def query_model(model_id: str, prompt: str, max_tokens: int) -> str:
    """Query a single model."""
    try:
        logger.info(f"Querying {model_id}...")
        response = client.chat.completions.create(
            model=model_id,
            messages=[{"role": "user", "content": prompt}],
            max_tokens=max_tokens,
        )
        return response.choices[0].message.content
    except Exception as e:
        logger.error(f"Error with {model_id}: {e}")
        return f"❌ Error: {e}"


def compare_models(prompt: str, max_tokens: int) -> tuple:
    """Query all models and return responses."""
    logger.info(f"compare_models() called | prompt_len={len(prompt)} | max_tokens={max_tokens}")
    
    if not prompt.strip():
        return ("Enter a prompt!", "Enter a prompt!", "Enter a prompt!")
    
    results = []
    for name, model_id in MODELS.items():
        result = query_model(model_id, prompt, max_tokens)
        results.append(result)
    
    return tuple(results)


logger.info("Building Gradio interface...")

with gr.Blocks(title="Model Arena") as demo:
    gr.Markdown("""# 🏟️ Model Arena
Compare responses from multiple LLMs side-by-side!

*All models run via HuggingFace Inference API - no downloads required.*
""")

    prompt = gr.Textbox(
        label="Your prompt",
        placeholder="Explain quantum computing in simple terms...",
        lines=3,
        autofocus=True,
    )
    
    max_tokens = gr.Slider(
        minimum=50, maximum=500, value=200, step=50,
        label="Max tokens per response"
    )
    
    btn = gr.Button("⚔️ Battle!", variant="primary", size="lg")
    
    with gr.Row(equal_height=True):
        outputs = []
        for name in MODELS.keys():
            outputs.append(gr.Textbox(label=name, lines=10, interactive=False))
    
    btn.click(compare_models, inputs=[prompt, max_tokens], outputs=outputs)
    prompt.submit(compare_models, inputs=[prompt, max_tokens], outputs=outputs)

    gr.Examples(
        examples=[
            ["Explain quantum computing to a 5-year-old", 200],
            ["Write a haiku about machine learning", 100],
            ["What are 3 creative startup ideas?", 300],
            ["Debate: Is AI good or bad for humanity?", 400],
        ],
        inputs=[prompt, max_tokens],
    )

demo.queue()
logger.info("Starting Gradio server...")
demo.launch()