sasio06 / gaia_integration.py
Sasa06's picture
Upload 7 files
1fa212d verified
raw
history blame
9.87 kB
"""
Интеграция агента с GAIA benchmark и Gradio интерфейсом.
Этот файл содержит код для запуска бенчмарка и визуализации результатов.
"""
import os
import json
import logging
import time
import gradio as gr
from typing import Dict, List, Any, Optional, Tuple
from agent_implementation import AgentController, Task
# Настройка логирования
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("gaia_benchmark.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("gaia_benchmark")
class GAIABenchmark:
"""Класс для запуска и оценки GAIA benchmark."""
def __init__(self, agent: Optional[AgentController] = None, hf_token: Optional[str] = None):
"""
Инициализация бенчмарка GAIA.
Args:
agent: Экземпляр агента для тестирования
hf_token: Токен Hugging Face для доступа к API
"""
self.agent = agent or self._create_default_agent()
self.hf_token = hf_token or os.environ.get("HF_TOKEN")
self.results = []
self.stats = {}
# Проверка наличия токена
if not self.hf_token:
logger.warning("HF_TOKEN not found in environment variables")
print("WARNING: HF_TOKEN not found in environment variables")
else:
logger.info("HF_TOKEN found in environment variables")
logger.info("Initialized GAIA benchmark")
def _create_default_agent(self) -> AgentController:
"""
Создание агента по умолчанию.
Returns:
Экземпляр агента
"""
return AgentController(
model_name="gpt-3.5-turbo",
username="default_user"
)
def run_benchmark(self, num_questions: int = 20, level: Optional[int] = None) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
"""
Запуск бенчмарка.
Args:
num_questions: Количество вопросов
level: Уровень сложности (None для всех уровней)
Returns:
Кортеж из списка результатов и статистики
"""
logger.info(f"Running benchmark with {num_questions} questions, level: {level}")
try:
# Запуск бенчмарка через агента
benchmark_result = self.agent.run_benchmark(
num_questions=num_questions,
level=level,
code_link="https://huggingface.co/spaces/user/gaia-agent"
)
# Сохранение результатов
self.results = benchmark_result.get("results", [])
# Расчет статистики
total_questions = benchmark_result.get("total_questions", 0)
correct_answers = benchmark_result.get("correct_answers", 0)
# Расчет точности по уровням
level_accuracy = {}
if level is not None:
level_accuracy[level] = {
"total": total_questions,
"correct": correct_answers,
"accuracy": (correct_answers / total_questions * 100) if total_questions > 0 else 0
}
# Формирование статистики
self.stats = {
"overall_accuracy": benchmark_result.get("score", 0) * 100,
"total_questions": total_questions,
"correct_answers": correct_answers,
"level_accuracy": level_accuracy
}
logger.info(f"Benchmark completed, overall accuracy: {self.stats['overall_accuracy']:.2f}%")
return self.results, self.stats
except Exception as e:
logger.error(f"Error running benchmark: {e}")
# Возврат пустых результатов в случае ошибки
self.results = []
self.stats = {
"overall_accuracy": 0.0,
"total_questions": 0,
"correct_answers": 0,
"level_accuracy": {}
}
return self.results, self.stats
def save_results(self, filename: str) -> bool:
"""
Сохранение результатов в файл.
Args:
filename: Имя файла для сохранения
Returns:
True в случае успеха, False в случае ошибки
"""
logger.info(f"Saving results to {filename}")
try:
# Подготовка данных для сохранения
data = {
"results": self.results,
"stats": self.stats,
"timestamp": time.time()
}
# Сохранение в файл
with open(filename, "w") as f:
json.dump(data, f, indent=2)
logger.info(f"Results saved to {filename}")
return True
except Exception as e:
logger.error(f"Error saving results: {e}")
return False
def create_gaia_gradio_interface(agent: Optional[AgentController] = None, hf_token: Optional[str] = None):
"""
Создание Gradio интерфейса для GAIA benchmark.
Args:
agent: Экземпляр агента для тестирования
hf_token: Токен Hugging Face для доступа к API
Returns:
Gradio интерфейс
"""
# Инициализация GAIA benchmark
gaia = GAIABenchmark(agent, hf_token)
# Функция для запуска бенчмарка
def run_benchmark(username: str, num_questions: int, level: str) -> Tuple[str, Optional[str]]:
"""
Запуск бенчмарка через Gradio интерфейс.
Args:
username: Имя пользователя Hugging Face
num_questions: Количество вопросов
level: Уровень сложности
Returns:
Кортеж из текста результатов и пути к файлу результатов
"""
# Обновление имени пользователя в агенте
if gaia.agent:
gaia.agent.gaia_client.username = username
# Преобразование уровня
if level == "All":
level_int = None
else:
level_int = int(level)
try:
# Запуск бенчмарка
results, stats = gaia.run_benchmark(num_questions=int(num_questions), level=level_int)
# Форматирование результатов
results_text = f"### GAIA Benchmark Results\n\n"
results_text += f"**Overall Accuracy:** {stats['overall_accuracy']:.2f}%\n"
results_text += f"**Total Questions:** {stats['total_questions']}\n"
results_text += f"**Correct Answers:** {stats['correct_answers']}\n\n"
# Точность по уровням
results_text += "**Accuracy by Level:**\n"
for level, level_stats in stats['level_accuracy'].items():
results_text += f"- Level {level}: {level_stats['accuracy']:.2f}% ({level_stats['correct']}/{level_stats['total']})\n"
# Сохранение результатов
filename = f"gaia_results_{username}.json"
gaia.save_results(filename)
return results_text, filename
except Exception as e:
error_message = f"Error running benchmark: {str(e)}"
logger.error(error_message)
return error_message, None
# Создание интерфейса Gradio
with gr.Blocks() as demo:
gr.Markdown("# Hugging Face Agent with Tools")
with gr.Tab("GAIA Benchmark"):
with gr.Row():
with gr.Column():
username_input = gr.Textbox(label="Hugging Face Username", placeholder="Enter your Hugging Face username")
num_questions_input = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="Number of Questions")
level_input = gr.Dropdown(choices=["All", "1", "2", "3"], value="1", label="Level")
run_button = gr.Button("Run Benchmark")
with gr.Column():
results_output = gr.Markdown(label="Results")
file_output = gr.File(label="Results File")
run_button.click(fn=run_benchmark, inputs=[username_input, num_questions_input, level_input], outputs=[results_output, file_output])
return demo
# Пример использования
if __name__ == "__main__":
# Создание агента
agent = AgentController(
model_name="gpt-3.5-turbo",
username="your_username"
)
# Создание интерфейса
demo = create_gaia_gradio_interface(agent)
# Запуск интерфейса
demo.launch()