Agent_Final_Assignment / my_agent.py
derek
use google gemini
e485756
import os
import requests
from smolagents import LiteLLMModel, ToolCallingAgent, Tool
from typing import Optional
from google import genai
from google.genai import types
import wikipedia as wiki
from markdownify import markdownify as to_markdown
# --- Tools ---
class VideoWatchingTool(Tool):
name = "watch_video"
description ="""
A tool for watching videos and answering questions about them.
"""
inputs = {
"video_url": {
"type": "string",
"description": "The URL of the video to watch."
},
"user_query": {
"type": "string",
"description": "The question to answer about the video."
}
}
output_type = "string"
def __init__(self, model_name, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_name = model_name
def forward(self, video_url: str, user_query: str) -> str:
request_json = {
'model': f'models/{self.model_name}',
'contents': [{
"parts": [
{
'fileData': {
'fileUri': video_url
}
},
{
'text': f"Please watch the video and answer the following question: {user_query}"
}
]
}]
}
api_url = f'https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={os.getenv("GOOGLE_API_KEY")}'
response = requests.post(
api_url,
json=request_json,
headers={
'Content-Type': 'application/json',
}
)
if response.status_code != 200:
return f"Error: {response.status_code} - {response.text}"
response_json = response.json()
result_parts = response_json['candidates'][0]['content']['parts']
result = "".join([_.get('text', '') for _ in result_parts])
return result
class GoogleSearchTool(Tool):
name = "google_search"
description = """
Performs a Google search and returns the results.
"""
inputs = {
"query": {
"type": "string",
"description": "The search query."
}
}
output_type = "string"
def __init__(self, client, model_name, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = client
self.model_name = model_name
def forward(self, query: str) -> str:
google_search_tool = types.Tool(
google_search=types.GoogleSearch()
)
response = self.client.models.generate_content(
model=self.model_name,
contents=f"Please search the internet for: {query}",
config=types.GenerateContentConfig(
tools=[google_search_tool],
response_modalities=['TEXT'],
)
)
return response.text
class WikipediaTitleSearchTool(Tool):
name = "check_wikipedia_page_titles"
description = """
Searches for Wikipedia pages related to the query and returns the canonical titles of the related pages.
"""
inputs = {
"query": {
"type": "string",
"description": "The search query."
}
}
output_type = "string"
def forward(self, query: str) -> str:
response = wiki.search(query)
if len(response) > 0:
result = ", ".join(response)
else:
result = "No results found."
return result
class WikipediaPageTool(Tool):
name = "get_wikipedia_page"
description = """
Gets the content of a Wikipedia page.
"""
inputs = {
"page_title": {
"type": "string",
"description": "The canonical title of the Wikipedia page."
}
}
output_type = "string"
def forward(self, page_title: str) -> str:
# TODO: may need to do caching of the HTML ourselves?
try:
page = wiki.page(page_title)
except wiki.exceptions.PageError:
return f"Page '{page_title}' not found."
md_content = to_markdown(page.html())
return md_content
class FileAttachmentQueryTool(Tool):
name = "run_query_with_file"
description = """
Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it.
This assumes the file is 20MB or less.
"""
inputs = {
"task_id": {
"type": "string",
"description": "A unique identifier for the task related to this file, used to download it."
},
"mime_type": {
"type": "string",
"nullable": True,
"description": "The MIME type of the file, or the best guess if unknown."
},
"user_query": {
"type": "string",
"description": "The question to answer about the file."
}
}
output_type = "string"
def __init__(self, client, model_name, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = client
self.model_name = model_name
def forward(self, task_id: str, mime_type: str | None, user_query: str) -> str:
# Download the file
file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
file_response = requests.get(file_url)
if file_response.status_code != 200:
raise Exception(f"Failed to download file: {file_response.status_code} - {file_response.text}")
file_data = file_response.content
mime_type = mime_type or file_response.headers.get('Content-Type', 'application/octet-stream')
response = self.client.models.generate_content(
model=self.model_name,
contents=[
types.Part.from_bytes(
data=file_data,
mime_type=mime_type,
),
user_query,
]
)
return response.text
# --- Agent Management ---
class GeminiAgentContainer:
"""
A container for the Gemini agent.
"""
# TODO: make it easier to chnge the model
MODEL_NAME = "gemini-2.0-flash"
def __init__(self, api_key: Optional[str] = None):
api_key = api_key or os.getenv("GOOGLE_API_KEY")
self.model = LiteLLMModel(model_id=f"gemini/{self.MODEL_NAME}", api_key=api_key)
self.client = genai.Client(api_key=os.getenv("GOOGLE_API_KEY"))
system_prompt = """
You are a general AI assistant. I will ask you a question.
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
If your answer is a number and you are not explicitly asked for a string, write it in numerals instead of words, and don't use comma to write your number nor use units such as $ or percent sign unless specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
Answer questions as literally as you can, making as few assumptions as possible. Restrict the answer to the narrowest definition that still satifies the question.
If you are provied with a video, please watch and summarize the entire video before answering the question. The correct answer may be present only in a few frames of the video.
If you have difficulty finding an answer on Wikipedia, you may search the internet using Google Search.
If you are asked to prove something, first state your assumptions and think step by step before giving your final answer.
"""
self.agent = ToolCallingAgent(
model=self.model,
tools = [
VideoWatchingTool(model_name=self.MODEL_NAME),
GoogleSearchTool(client=self.client, model_name=self.MODEL_NAME),
WikipediaTitleSearchTool(),
WikipediaPageTool(),
FileAttachmentQueryTool(client=self.client, model_name=self.MODEL_NAME),
],
max_steps=14,
planning_interval=2,
)
self.system_prompt = system_prompt
def __call__(self, question: str) -> str:
response = self.agent.run(f"{self.system_prompt}\n{question}")
return response
if __name__ == "__main__":
agent_container = GeminiAgentContainer()
agent = agent_container.agent
#my_query = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia."
#my_query = "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species to be on camera simultaneously?"
my_query= "Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\""
response = agent.run(agent_container.system_prompt+my_query, max_steps=5)
print(response)
#print(my_query)