Spaces:
Restarting
Restarting
| import os | |
| import requests | |
| import pandas as pd | |
| import pyarrow.parquet as pq | |
| import json | |
| import time | |
| from langchain_core.messages import HumanMessage | |
| from agent import build_graph | |
| from huggingface_hub import hf_hub_download | |
| from dotenv import load_dotenv | |
| load_dotenv(override=True) | |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" | |
| class BasicAgent: | |
| def __init__(self): | |
| print("BasicAgent initialized.") | |
| self.graph = build_graph() | |
| def __call__(self, question: str) -> str: | |
| messages = [HumanMessage(content=question)] | |
| result = self.graph.invoke({"messages": messages}) | |
| answer = result['messages'][-1].content | |
| return answer | |
| def file_extract(local_file_path, task_id): | |
| if not local_file_path: | |
| return None | |
| token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN") | |
| prefixes = ["2023/validation/", "2023/test/", "2023/train/", ""] | |
| for prefix in prefixes: | |
| try: | |
| resolved_path = hf_hub_download( | |
| repo_id="gaia-benchmark/GAIA", | |
| filename=f"{prefix}{local_file_path}", | |
| repo_type="dataset", | |
| token=token | |
| ) | |
| return resolved_path | |
| except Exception: | |
| continue | |
| return None | |
| def main(): | |
| # 1. Fetch questions | |
| print("Fetching questions...") | |
| questions_url = f"{DEFAULT_API_URL}/questions" | |
| response = requests.get(questions_url, timeout=15) | |
| questions_data = response.json() | |
| print(f"Fetched {len(questions_data)} questions") | |
| # 2. Load ground truth | |
| print("Loading ground truth...") | |
| token = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
| path = hf_hub_download(repo_id='gaia-benchmark/GAIA', filename='2023/validation/metadata.parquet', repo_type='dataset', token=token) | |
| df = pq.read_table(path).to_pandas() | |
| answer_map = dict(zip(df['task_id'], df['Final answer'])) | |
| # 3. Initialize agent | |
| agent = BasicAgent() | |
| # 4. Run on all questions (can slice for testing) | |
| results = [] | |
| # Run ALL questions | |
| for i, item in enumerate(questions_data): | |
| task_id = item.get("task_id") | |
| question_text = item.get("question") | |
| file_name = item.get("file_name") | |
| if not task_id or question_text is None: | |
| continue | |
| if file_name: | |
| resolved_path = file_extract(file_name, task_id) | |
| if resolved_path: | |
| question_text += f"\n\n[Attached File Local Path: {resolved_path}]" | |
| print(f"\n[{i+1}/{len(questions_data)}] Task: {task_id[:20]}...") | |
| try: | |
| answer = agent(question_text) | |
| except Exception as e: | |
| answer = f"ERROR: {e}" | |
| ground_truth = answer_map.get(task_id, "NOT FOUND") | |
| is_correct = str(answer).strip().lower() == str(ground_truth).strip().lower() | |
| results.append({ | |
| "task_id": task_id, | |
| "question": item.get("question"), | |
| "submitted_answer": answer, | |
| "ground_truth": ground_truth, | |
| "correct": is_correct | |
| }) | |
| status = "✅" if is_correct else "❌" | |
| print(f" {status} Submitted: {str(answer)[:40]}") | |
| print(f" Ground: {str(ground_truth)[:40]}") | |
| time.sleep(1.5) | |
| # 5. Calculate score | |
| correct_count = sum(1 for r in results if r["correct"]) | |
| total = len(results) | |
| score_pct = correct_count / total * 100 if total > 0 else 0 | |
| print("\n" + "="*60) | |
| print(f"FINAL SCORE: {correct_count}/{total} = {score_pct:.0f}%") | |
| print("="*60) | |
| # 6. Save results | |
| output = {"score": score_pct, "correct": correct_count, "total": total, "results": results} | |
| with open("gaia_results.json", "w") as f: | |
| json.dump(output, f, indent=2) | |
| pd.DataFrame(results).to_csv("gaia_results.csv", index=False) | |
| print("Results saved!") | |
| if __name__ == "__main__": | |
| main() | |