daiep-workshop / app.py
elledilara's picture
Update app.py
e9c0d93 verified
import os
import requests
import gradio as gr
# Replace with your actual endpoint URLs
ENDPOINTS = {
"Base LLaMA 3.1 8B": "https://u3fjx3mvvnn0jlne.us-east-1.aws.endpoints.huggingface.cloud", #base
"Fine-tuned LLaMA 3.1 8B": "https://vcugblrdxbk79vsd.us-east-1.aws.endpoints.huggingface.cloud" #fine-tuned
}
# Hugging Face token secret from Space settings
HF_TOKEN = os.environ["HF_TOKEN"]
headers = {
"Authorization": f"Bearer {HF_TOKEN}",
"Content-Type": "application/json"
}
def query_models(prompt):
"""
Query both endpoints and return their outputs as a tuple.
Includes robust error handling.
"""
outputs = []
for model_name, api_url in ENDPOINTS.items():
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 80,
"temperature": 0.5
}
}
try:
response = requests.post(api_url, headers=headers, json=payload, timeout=60)
response.raise_for_status() # catches 4xx/5xx HTTP errors
result_json = response.json()
# sometimes the endpoint returns a dict or a list
if isinstance(result_json, list) and "generated_text" in result_json[0]:
outputs.append(result_json[0]["generated_text"])
elif isinstance(result_json, dict) and "generated_text" in result_json:
outputs.append(result_json["generated_text"])
else:
outputs.append(f"Unexpected response format: {result_json}")
except requests.exceptions.Timeout:
outputs.append("Error: request timed out")
except requests.exceptions.HTTPError as e:
outputs.append(f"HTTP error: {e.response.status_code}")
except Exception as e:
outputs.append(f"Other error: {e}")
return tuple(outputs) # must return a tuple for Gradio multiple outputs
# Build Gradio interface with multiple outputs
demo = gr.Interface(
fn=query_models,
inputs=gr.Textbox(lines=5, label="Enter your prompt"),
outputs=[
gr.Textbox(lines=10, label="Model 2"),
gr.Textbox(lines=10, label="Model 1")
],
title="LLM Comparison",
description="Enter a prompt and see how both models respond side by side."
)
demo.launch()