Annessha18's picture
Update app.py
01ce061 verified
raw
history blame
1.65 kB
import requests
from datasets import load_dataset
from transformers import pipeline
# ---------------------------
# CONFIG
# ---------------------------
SCORING_API = "https://agents-course-unit4-scoring.hf.space"
MODEL_NAME = "google/flan-t5-base"
# ---------------------------
# Load model
# ---------------------------
print("Loading model...")
qa = pipeline("text2text-generation", model=MODEL_NAME, max_new_tokens=64)
# ---------------------------
# Fetch the 20 questions
# ---------------------------
print("Fetching GAIA questions...")
questions = requests.get(f"{SCORING_API}/questions").json()
task_ids = [q["task_id"] for q in questions]
# ---------------------------
# Load GAIA validation dataset
# ---------------------------
print("Loading GAIA validation set...")
dataset = load_dataset(
"gaia-benchmark/GAIA",
"2023_level1",
split="validation"
)
# Map task_id β†’ correct answer
ground_truth = {
item["task_id"]: item["Final answer"]
for item in dataset
if item["task_id"] in task_ids
}
# ---------------------------
# Evaluate
# ---------------------------
correct = 0
for q in questions:
task_id = q["task_id"]
question = q["question"]
true_answer = ground_truth.get(task_id, "").strip().lower()
model_output = qa(question)[0]["generated_text"].strip().lower()
match = model_output == true_answer
correct += int(match)
print("\n" + "="*80)
print(f"QUESTION:\n{question}")
print(f"\nEXPECTED:\n{true_answer}")
print(f"\nMODEL:\n{model_output}")
print(f"\nMATCH: {'βœ…' if match else '❌'}")
print("\n" + "="*80)
print(f"FINAL SCORE: {correct}/20")