File size: 6,662 Bytes
10e9b7d
 
eccf8e4
3c4371f
f05a776
341e0dc
b6de177
466c18b
10e9b7d
3db6293
e80aab9
341e0dc
6119660
341e0dc
 
6119660
 
 
 
 
341e0dc
 
 
466c18b
6119660
 
 
 
 
341e0dc
 
6119660
 
b6de177
6119660
b6de177
6119660
341e0dc
 
b6de177
341e0dc
6119660
466c18b
b6de177
6119660
b6de177
 
 
341e0dc
466c18b
 
 
 
6119660
b6de177
6119660
b6de177
341e0dc
6119660
31243f4
466c18b
 
 
 
 
 
 
 
 
 
6119660
 
 
b6de177
466c18b
6119660
 
b6de177
 
 
466c18b
 
 
 
 
 
 
 
b6de177
6119660
b6de177
 
 
 
466c18b
 
 
 
 
 
 
 
 
b6de177
44cff3e
b6de177
466c18b
 
 
b6de177
4021bf3
b6de177
17723a8
44cff3e
 
466c18b
3c4371f
6119660
466c18b
 
 
 
 
 
 
6119660
b6de177
466c18b
b6de177
6119660
466c18b
 
 
e80aab9
7d65c66
17723a8
466c18b
b6de177
 
 
466c18b
 
 
 
b6de177
466c18b
 
b6de177
466c18b
f05a776
466c18b
 
 
 
b6de177
6119660
 
b6de177
466c18b
6119660
e80aab9
 
b6de177
 
 
466c18b
 
 
b6de177
466c18b
33d74bd
b6de177
 
 
 
 
466c18b
 
 
 
 
 
b6de177
 
 
466c18b
b6de177
466c18b
b6de177
466c18b
ae0d42c
 
6119660
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import gradio as gr
import requests
import pandas as pd
import time
import io
import re
from smolagents import LiteLLMModel, tool, CodeAgent

DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"

# ====================== TOOLS ======================

@tool
def web_search(query: str) -> str:
    """
    Search the web using DuckDuckGo.
    Args:
        query: The search query string.
    """
    try:
        from duckduckgo_search import DDGS
        with DDGS() as ddgs:
            results = list(ddgs.text(query, max_results=5))
            if not results:
                return "No results found."
            return "\n".join([f"{r.get('title')}: {r.get('body')}" for r in results])
    except Exception as e:
        return f"Search failed: {str(e)}"

@tool
def download_and_read_file(task_id: str) -> str:
    """
    Downloads the file for a task and returns its content.
    Args:
        task_id: The unique ID for the task file.
    """
    url = f"{DEFAULT_API_URL}/files/{task_id}"
    try:
        r = requests.get(url, timeout=30)
        r.raise_for_status()
        content_type = r.headers.get("content-type", "").lower()

        if "csv" in content_type or task_id.lower().endswith(".csv"):
            df = pd.read_csv(io.BytesIO(r.content))
            return f"CSV Content (First 15 rows):\n{df.head(15).to_string()}\n\nColumns: {df.columns.tolist()}"
        elif "text" in content_type or task_id.lower().endswith(".txt"):
            return f"Text Content (Snippet):\n{r.text[:2000]}"
        else:
            return (
                f"File downloaded. Size: {len(r.content)} bytes. "
                f"If this is an image/pdf, use web_search to find related facts about task {task_id}."
            )
    except Exception as e:
        return f"Download failed: {str(e)}"

# ====================== AGENT ======================

class GaiaAgent:
    def __init__(self):
        groq_api_key = os.getenv("GROQ_API_KEY")
        if not groq_api_key:
            raise ValueError("❌ GROQ_API_KEY secret is not set! Add it in HF Spaces β†’ Settings β†’ Secrets.")

        # llama-3.3-70b-versatile is the best free model on Groq for reasoning
        self.model = LiteLLMModel(
            model_id="groq/llama-3.3-70b-versatile",
            api_key=groq_api_key,
        )

        self.agent = CodeAgent(
            tools=[web_search, download_and_read_file],
            model=self.model,
            add_base_tools=True,
            max_steps=12,
        )

    def clean_answer(self, raw_result: str) -> str:
        """Removes conversational filler that fails the GAIA grader."""
        text = str(raw_result).strip()
        # Remove common prefixes like "The answer is:"
        text = re.sub(
            r'^(the answer is|final answer|result is|answer)[:\s]*',
            '', text, flags=re.IGNORECASE
        )
        # Strip trailing punctuation
        text = text.strip(".").strip()
        return text

    def __call__(self, question: str, task_id: str) -> str:
        prompt = f"""Task ID: {task_id}
Question: {question}

INSTRUCTIONS:
- Use your tools to find the exact factual answer.
- If the question mentions a file or attachment, call download_and_read_file("{task_id}") first.
- If you need up-to-date facts, use web_search.
- YOUR FINAL ANSWER MUST BE EXTREMELY BRIEF AND EXACT:
  * Numbers: just the number, e.g. '42' or '4.52'
  * Names: just the name, e.g. 'Marie Curie'
  * Dates: just the date, e.g. '1923' or 'July 4, 1776'
  * Lists: comma-separated, e.g. 'apple, banana, cherry'
- Do NOT write sentences. Do NOT explain. Just the answer.
"""
        try:
            result = self.agent.run(prompt)
            return self.clean_answer(str(result))
        except Exception as e:
            print(f"Agent error on task {task_id}: {e}")
            return "Unknown"

# ====================== MAIN LOGIC ======================

def run_and_submit_all(profile: gr.OAuthProfile | None):
    if not profile:
        return "❌ Please Login with Hugging Face first!", None

    username = profile.username
    print(f"βœ… Logged in as: {username}")

    try:
        agent = GaiaAgent()
    except ValueError as e:
        return str(e), None

    try:
        resp = requests.get(f"{DEFAULT_API_URL}/questions", timeout=30)
        resp.raise_for_status()
        questions = resp.json()
    except Exception as e:
        return f"❌ Failed to fetch questions: {e}", None

    print(f"πŸ“‹ Fetched {len(questions)} questions.")

    answers_payload = []
    results_log = []

    for i, item in enumerate(questions):
        t_id = item.get("task_id")
        q_text = item.get("question")

        print(f"\n--- [{i+1}/{len(questions)}] Task: {t_id} ---")
        print(f"Q: {q_text[:120]}...")

        answer = agent(q_text, t_id)
        print(f"A: {answer}")

        answers_payload.append({"task_id": t_id, "submitted_answer": str(answer)})
        results_log.append({"Task ID": t_id, "Question": q_text[:80], "Answer": str(answer)})

        # Small sleep β€” Groq free tier allows ~30 req/min, no need for 38s waits
        time.sleep(3)

    # ===== SUBMIT =====
    space_id = os.getenv("SPACE_ID", "unknown")
    submission_data = {
        "username": username,
        "agent_code": f"https://huggingface.co/spaces/{space_id}",
        "answers": answers_payload,
    }

    try:
        r = requests.post(f"{DEFAULT_API_URL}/submit", json=submission_data, timeout=300)
        if r.status_code == 200:
            res = r.json()
            score = res.get("score", 0)
            message = res.get("message", "")
            return f"βœ… SCORE: {score}% | {message}", pd.DataFrame(results_log)
        else:
            return f"❌ Submission Error {r.status_code}: {r.text}", pd.DataFrame(results_log)
    except Exception as e:
        return f"❌ Submission Failed: {str(e)}", pd.DataFrame(results_log)

# ====================== UI ======================

with gr.Blocks(theme=gr.themes.Default()) as demo:
    gr.Markdown("# πŸ† GAIA Certificate Agent (Unit 4 Final)")
    gr.Markdown(
        "**Steps:** 1) Login with HF below β†’ 2) Click Start β†’ 3) Wait ~5 mins β†’ 4) Check your score!\n\n"
        "> Make sure `GROQ_API_KEY` is set in your Space **Settings β†’ Secrets**."
    )

    with gr.Row():
        gr.LoginButton()
        run_btn = gr.Button("πŸš€ Start Evaluation", variant="primary")

    status_output = gr.Textbox(label="Final Result", lines=3)
    table_output = gr.DataFrame(label="Answer Log")

    run_btn.click(fn=run_and_submit_all, outputs=[status_output, table_output])

if __name__ == "__main__":
    demo.launch()