File size: 8,152 Bytes
10e9b7d
 
eccf8e4
0b26a35
 
ee57d8e
 
 
10e9b7d
30ab757
3db6293
0b26a35
dceeb49
 
ee57d8e
 
dceeb49
0b26a35
30ab757
ee57d8e
39af7e5
30ab757
dceeb49
ee57d8e
 
dceeb49
 
 
 
 
ee57d8e
 
 
 
 
 
 
dceeb49
 
 
 
 
 
 
 
 
 
 
 
 
 
ee57d8e
 
dceeb49
ee57d8e
dceeb49
ee57d8e
dceeb49
 
ee57d8e
4e55bbe
0b26a35
ee57d8e
0b26a35
ee57d8e
 
dceeb49
 
 
 
 
 
 
 
 
ee57d8e
dceeb49
ee57d8e
 
dceeb49
 
 
 
 
 
0b26a35
ee57d8e
dceeb49
 
 
 
 
 
 
 
 
 
ee57d8e
dceeb49
 
 
 
 
 
 
 
 
ee57d8e
dceeb49
 
 
0b26a35
578f455
0b26a35
 
e3c5ce5
 
3c4371f
578f455
ee57d8e
0b26a35
7e4a06b
31243f4
ee57d8e
e3c5ce5
 
e80aab9
e3c5ce5
31243f4
e3c5ce5
 
31243f4
0b26a35
7d65c66
 
31243f4
ee57d8e
 
 
4e55bbe
ee57d8e
 
31243f4
e3c5ce5
 
 
7d65c66
e3c5ce5
e80aab9
e3c5ce5
e80aab9
 
e3c5ce5
 
 
 
 
 
 
 
ee57d8e
e80aab9
dceeb49
 
7e4a06b
31243f4
9088b99
7d65c66
e3c5ce5
e80aab9
 
3c4371f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import gradio as gr
import requests
import re
import time
import pandas as pd
import google.generativeai as genai
from google.generativeai.types import HarmCategory, HarmBlockThreshold

# --- Constants ---
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"

# --- User's Corrected NativeGeminiAgent Class ---
# This is the superior implementation provided by you.
class NativeGeminiAgent:
    def __init__(self, gemini_api_key: str, api_url: str):
        print("Initializing NativeGeminiAgent with corrected configuration...")
        genai.configure(api_key=gemini_api_key)
        
        self.api_url = api_url
        self.model_name = 'gemini-2.5-flash-preview-05-20' # Using the stable, powerful model
        
        # Correct tool configuration using the recommended string-based method
        self.model = genai.GenerativeModel(
            model_name=self.model_name,
            tools=['google_search_retrieval'],
            system_instruction="""You are a world-class problem solver and researcher. 
            Analyze the question carefully, use available tools to gather information, 
            and provide accurate, concise answers. Focus on factual information and 
            avoid speculation.""",
            safety_settings={
                HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
                HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
            }
        )
        print(f"Agent initialized with {self.model_name} and Google Search grounding.")

    def _get_mime_type(self, url: str) -> str:
        """Enhanced MIME type detection."""
        url_lower = url.lower()
        if url_lower.endswith(('.jpg', '.jpeg')): return "image/jpeg"
        elif url_lower.endswith('.png'): return "image/png"
        elif url_lower.endswith('.gif'): return "image/gif"
        elif url_lower.endswith('.pdf'): return "application/pdf"
        elif url_lower.endswith('.txt'): return "text/plain"
        elif url_lower.endswith('.csv'): return "text/csv"
        elif url_lower.endswith(('.mp4', '.avi', '.mov')): return "video/mp4"
        elif url_lower.endswith('.json'): return "application/json"
        else: return "application/octet-stream"

    def _check_if_file_exists(self, url: str) -> bool:
        """Enhanced file existence check."""
        try:
            response = requests.head(url, timeout=15, allow_redirects=True)
            return response.status_code == 200
        except requests.exceptions.RequestException as e:
            print(f"File check failed for {url}: {e}")
            return False

    def __call__(self, question: str, task_id: str) -> str:
        print(f"\n{'='*20}\nProcessing Task ID: {task_id}")
        
        prompt_parts = [question]
        
        # Enhanced URL detection
        urls_in_question = re.findall(r'https?://[^\s<>"{}|\\^`\[\]]+', question)
        for url in urls_in_question:
            try:
                mime_type = self._get_mime_type(url)
                prompt_parts.append(genai.Part.from_uri(uri=url, mime_type=mime_type))
                print(f"Added URL: {url} (MIME: {mime_type})")
            except Exception as e:
                print(f"Failed to add URL {url}: {e}")

        # Check for associated files
        file_url = f"{self.api_url}/files/{task_id}"
        if self._check_if_file_exists(file_url):
            try:
                mime_type = self._get_mime_type(file_url)
                prompt_parts.append(genai.Part.from_uri(uri=file_url, mime_type=mime_type))
                print(f"Added file: {file_url} (MIME: {mime_type})")
            except Exception as e:
                print(f"Failed to add file {file_url}: {e}")

        try:
            # Use the specified generation config for more stable outputs
            response = self.model.generate_content(
                prompt_parts,
                request_options={'timeout': 120},
                generation_config=genai.types.GenerationConfig(
                    temperature=0.1,
                    top_p=0.8,
                    max_output_tokens=2048
                )
            )
            
            if response.text:
                # Thoroughly clean the response text
                final_answer = response.text.strip()
                final_answer = re.sub(r'\[\d+\]', '', final_answer) # Remove citations
                final_answer = re.sub(r'\s+', ' ', final_answer).strip() # Normalize whitespace
                return final_answer
            else:
                return "AGENT_ERROR: Empty response from model"
                
        except Exception as e:
            error_msg = f"AGENT_ERROR: {str(e)}"
            print(error_msg)
            return error_msg

# --- Main run_and_submit_all function ---
def run_and_submit_all(profile: gr.OAuthProfile | None):
    space_id = os.getenv("SPACE_ID")
    if not profile: return "Please Login to Hugging Face with the button.", None
    username = f"{profile.username}"

    gemini_key = os.getenv("GEMINI_API_KEY")
    if not gemini_key: return "CRITICAL ERROR: GEMINI_API_KEY not found in Space secrets.", None
    
    api_url = DEFAULT_API_URL
    try:
        agent = NativeGeminiAgent(gemini_api_key=gemini_key, api_url=api_url)
        questions_data = requests.get(f"{api_url}/questions", timeout=15).json()
    except Exception as e: return f"Error during setup: {e}", None

    results_log, answers_payload = [], []
    for item in questions_data:
        task_id, question_text = item.get("task_id"), item.get("question")
        if not task_id or question_text is None: continue
        try:
            submitted_answer = agent(question_text, task_id)
            answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
            results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
        except Exception as e:
             error_message = f"AGENT CRASH: {e}"
             print(error_message)
             results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": error_message})
        
        print(f"--- Waiting for 10 seconds before next question... ---")
        time.sleep(10)

    if not answers_payload: return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
    
    agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
    submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
    
    try:
        response = requests.post(f"{api_url}/submit", json=submission_data, timeout=120)
        response.raise_for_status()
        result_data = response.json()
        final_status = (f"Submission Successful!\nUser: {result_data.get('username')}\n"
                        f"Overall Score: {result_data.get('score', 'N/A')}% "
                        f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
                        f"Message: {result_data.get('message', 'No message received.')}")
        return final_status, pd.DataFrame(results_log)
    except requests.exceptions.RequestException as e:
        return f"Submission Failed: {e}", pd.DataFrame(results_log)

# --- Gradio Interface ---
with gr.Blocks() as demo:
    gr.Markdown("# Native Multi-Modal GAIA Agent (Corrected)")
    gr.Markdown("This agent uses the improved architecture with proper tool configuration, MIME type detection, and error handling.")
    gr.LoginButton()
    run_button = gr.Button("Run Evaluation & Submit All Answers")
    status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
    results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
    run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table])

if __name__ == "__main__":
    demo.launch(debug=True, share=False)