gaia / app.py
bstraehle's picture
Update app.py
05f6f94 verified
# References:
# https://www.gradio.app/guides/quickstart
import gradio.utils, os, sys
import gradio as gr
import threading
import queue
import time
from agents.crew import run_crew
from huggingface_hub import whoami
from utils.utils import (
DATASET_TYPE_GAIA,
DATASET_TYPE_HLE,
get_dataset,
validate_input
)
# Configuration
SPACE_ID = os.environ.get("SPACE_ID")
BASE_URL = f"https://huggingface.co/spaces/{SPACE_ID}/blob/main"
# Streaming console output capture
class StreamingCapture:
def __init__(self):
self.queue = queue.Queue()
self.old_stdout = None
self.old_stderr = None
self.stopped = False
def start(self):
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
parent = self
class QueueWriter:
def __init__(self, original, q):
self.original = original
self.q = q
def write(self, data):
if data:
self.q.put(data)
self.original.write(data)
def flush(self):
self.original.flush()
sys.stdout = QueueWriter(self.old_stdout, self.queue)
sys.stderr = QueueWriter(self.old_stderr, self.queue)
def stop(self):
self.stopped = True
if self.old_stdout:
sys.stdout = self.old_stdout
if self.old_stderr:
sys.stderr = self.old_stderr
def get_new_output(self):
output = []
while not self.queue.empty():
try:
output.append(self.queue.get_nowait())
except queue.Empty:
break
return "".join(output)
# MCP server functions
def ask(oauth_token: gr.OAuthToken, question: str, openai_api_key: str, gemini_api_key: str, anthropic_api_key: str, file_name: str = ""):
"""
Ask General AI Assistant a question to answer.
Args:
question (str): The question to answer
openai_api_key (str): OpenAI API key (always used)
gemini_api_key (str): Gemini API key (always used)
anthropic_api_key (str): Anthropic API key (only used by Stagehand tool)
file_name (str): Optional file name
Yields:
tuple: (answer, console_logs) - The answer to the question and captured console output (streamed)
"""
if not question:
msg = "Question is required."
gr.Warning(msg)
yield None, msg
return
if not openai_api_key:
msg = "OpenAI API Key is required."
gr.Warning(msg)
yield None, msg
return
if not gemini_api_key:
msg = "Gemini API Key is required."
gr.Warning(msg)
yield None, msg
return
if not anthropic_api_key:
msg = "Anthropic API Key is required."
gr.Warning(msg)
yield None, msg
return
is_valid = validate_input(question, openai_api_key, gemini_api_key, anthropic_api_key)
if not is_valid:
msg = "Invalid input"
gr.Warning(msg)
yield None, msg
return
task_file_name = file_name
if task_file_name:
task_file_name = f"files/{task_file_name}"
try:
# API keys are provided thanks to sponsor credit.
openai_key = openai_api_key if openai_api_key and openai_api_key != "*" else os.environ.get("OPENAI_API_KEY")
gemini_key = gemini_api_key if gemini_api_key and gemini_api_key != "*" else os.environ.get("GEMINI_API_KEY")
anthropic_key = anthropic_api_key if anthropic_api_key and anthropic_api_key != "*" else os.environ.get("ANTHROPIC_API_KEY")
os.environ["OPENAI_API_KEY"] = openai_key
os.environ["GEMINI_API_KEY"] = gemini_key
os.environ["ANTHROPIC_API_KEY"] = anthropic_key
# Streaming capture
capture = StreamingCapture()
capture.start()
result = {"answer": None, "error": None}
all_logs = ""
def run_task():
try:
result["answer"] = run_crew(question, task_file_name)
except Exception as e:
result["error"] = str(e)
# Run in background thread
thread = threading.Thread(target=run_task)
thread.start()
# Stream logs while waiting
while thread.is_alive():
new_output = capture.get_new_output()
if new_output:
all_logs += new_output
yield None, all_logs # Update console logs, answer still None
time.sleep(0.1) # Poll interval
thread.join()
capture.stop()
# Get any remaining output
new_output = capture.get_new_output()
if new_output:
all_logs += new_output
if result["error"]:
gr.Warning(result["error"])
yield None, all_logs + f"\n\nError: {result['error']}"
else:
yield result["answer"], all_logs
except Exception as e:
msg = str(e)
gr.Warning(msg)
yield None, f"Error: {msg}"
# Helper functions
def update_file_link(file_name):
if file_name:
return f"<a href='{BASE_URL}/files/{file_name}' target='_blank'>Open File</a>"
return ""
def watchfn(*args, **kwargs):
pass
gradio.utils.watchfn_spaces = watchfn
# Graphical user interface
DESCRIPTION = (
f"Prototype <strong>multi-agent AI platform</strong> with high autonomy, "
f"including code generation & execution, browser automation, and multi-modal reasoning. "
f"The system can solve multiple <a href='https://arxiv.org/pdf/2311.12983'>GAIA Benchmark</a> "
f"Level 1, 2, 3 and even <a href='https://arxiv.org/pdf/2501.14249'>Humanity's Last Exam</a> "
f"problems. To get started, select from the examples below. "
f"Processing can take minutes depending on question complexity. "
f"Console logs are provided below for transparency. "
f"API keys are provided temporarily thanks to sponsor credit. "
f"<a href='{BASE_URL}/README.md'>Documentation</a></p>"
)
DEFAULT_QUESTION = "How many public GitHub repos does the person who submitted the 'General AI Assistant' solution in MCP's 1st Birthday Hackathon have?"
DEFAULT_GROUND_TRUTH = "18"
CSS_FULL_WIDTH = """
<style>
html,
body,
main,
.gradio-app {
width: 100% !important;
max-width: 100% !important;
margin: 0 !important;
padding: 0 !important;
overflow-x: hidden !important;
}
.full-width-app {
width: 100% !important;
max-width: 100% !important;
margin: 0 !important;
padding: 0 !important;
}
.content-padding {
padding: 0 1.5rem 0;
}
</style>
"""
with gr.Blocks(elem_classes=["full-width-app"]) as gaia:
gr.HTML(CSS_FULL_WIDTH)
with gr.Column(elem_classes=["content-padding"]):
gr.Markdown("## General AI Assistant")
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
question = gr.Textbox(
label="Question *",
value=DEFAULT_QUESTION,
interactive=True,
max_length = 500,
lines=1,
max_lines=5
)
with gr.Row():
ground_truth = gr.Textbox(
label="Ground Truth",
value=DEFAULT_GROUND_TRUTH,
interactive=False,
lines=1,
max_lines=2
)
file_name = gr.Textbox(
label="File Name",
interactive=False,
lines=1,
max_lines=2,
scale=2
)
file_link = gr.HTML(
label="File Link",
value=""
)
with gr.Row():
openai_api_key = gr.Textbox(
label="OpenAI API Key *",
type="password",
placeholder="sk‑...",
value="*", # API keys are provided thanks to sponsor credit
interactive=True,
max_length = 150
)
gemini_api_key = gr.Textbox(
label="Gemini API Key *",
type="password",
value="*", # API keys are provided thanks to sponsor credit
interactive=True,
max_length = 150
)
anthropic_api_key = gr.Textbox(
label="Anthropic API Key *",
type="password",
placeholder="sk-ant-...",
value="*", # API keys are provided thanks to sponsor credit
interactive=True,
max_length = 150
)
with gr.Row():
gr.LoginButton()
submit_btn = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
answer = gr.Textbox(
label="Answer",
interactive=False,
lines=1,
max_lines=5
)
with gr.Accordion("Console Logs", open=True):
console_logs = gr.Textbox(
label="Output",
interactive=False,
lines=25,
max_lines=25,
autoscroll=True
)
with gr.Row():
clear_btn = gr.ClearButton(
components=[question, ground_truth, file_name, file_link, answer, console_logs]
)
submit_btn.click(
fn=ask,
inputs=[question, openai_api_key, gemini_api_key, anthropic_api_key, file_name],
outputs=[answer, console_logs]
)
file_name.change(
fn=update_file_link,
inputs=[file_name],
outputs=[file_link]
)
with gr.Tabs():
with gr.TabItem("GAIA Benchmark Level 1"):
gr.Examples(
examples=get_dataset(DATASET_TYPE_GAIA, 1),
inputs=[question, ground_truth, file_name, "", "", ""],
examples_per_page=5,
cache_examples=False
)
with gr.TabItem("GAIA Benchmark Level 2"):
gr.Examples(
examples=get_dataset(DATASET_TYPE_GAIA, 2),
inputs=[question, ground_truth, file_name, "", "", ""],
examples_per_page=5,
cache_examples=False
)
with gr.TabItem("GAIA Benchmark Level 3"):
gr.Examples(
examples=get_dataset(DATASET_TYPE_GAIA, 3),
inputs=[question, ground_truth, file_name, "", "", ""],
examples_per_page=5,
cache_examples=False
)
with gr.TabItem("Humanity's Last Exam"):
gr.Examples(
examples=get_dataset(DATASET_TYPE_HLE, 0),
inputs=[question, ground_truth, file_name, "", "", ""],
examples_per_page=5,
cache_examples=False
)
gaia.launch(mcp_server=True, ssr_mode=False)