Spaces:
Sleeping
Sleeping
| import os | |
| import requests | |
| from smolagents import LiteLLMModel, ToolCallingAgent, Tool | |
| from typing import Optional | |
| from google import genai | |
| from google.genai import types | |
| import wikipedia as wiki | |
| from markdownify import markdownify as to_markdown | |
| # --- Tools --- | |
| class VideoWatchingTool(Tool): | |
| name = "watch_video" | |
| description =""" | |
| A tool for watching videos and answering questions about them. | |
| """ | |
| inputs = { | |
| "video_url": { | |
| "type": "string", | |
| "description": "The URL of the video to watch." | |
| }, | |
| "user_query": { | |
| "type": "string", | |
| "description": "The question to answer about the video." | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self, model_name, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.model_name = model_name | |
| def forward(self, video_url: str, user_query: str) -> str: | |
| request_json = { | |
| 'model': f'models/{self.model_name}', | |
| 'contents': [{ | |
| "parts": [ | |
| { | |
| 'fileData': { | |
| 'fileUri': video_url | |
| } | |
| }, | |
| { | |
| 'text': f"Please watch the video and answer the following question: {user_query}" | |
| } | |
| ] | |
| }] | |
| } | |
| api_url = f'https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={os.getenv("GOOGLE_API_KEY")}' | |
| response = requests.post( | |
| api_url, | |
| json=request_json, | |
| headers={ | |
| 'Content-Type': 'application/json', | |
| } | |
| ) | |
| if response.status_code != 200: | |
| return f"Error: {response.status_code} - {response.text}" | |
| response_json = response.json() | |
| result_parts = response_json['candidates'][0]['content']['parts'] | |
| result = "".join([_.get('text', '') for _ in result_parts]) | |
| return result | |
| class GoogleSearchTool(Tool): | |
| name = "google_search" | |
| description = """ | |
| Performs a Google search and returns the results. | |
| """ | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The search query." | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self, client, model_name, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.client = client | |
| self.model_name = model_name | |
| def forward(self, query: str) -> str: | |
| google_search_tool = types.Tool( | |
| google_search=types.GoogleSearch() | |
| ) | |
| response = self.client.models.generate_content( | |
| model=self.model_name, | |
| contents=f"Please search the internet for: {query}", | |
| config=types.GenerateContentConfig( | |
| tools=[google_search_tool], | |
| response_modalities=['TEXT'], | |
| ) | |
| ) | |
| return response.text | |
| class WikipediaTitleSearchTool(Tool): | |
| name = "check_wikipedia_page_titles" | |
| description = """ | |
| Searches for Wikipedia pages related to the query and returns the canonical titles of the related pages. | |
| """ | |
| inputs = { | |
| "query": { | |
| "type": "string", | |
| "description": "The search query." | |
| } | |
| } | |
| output_type = "string" | |
| def forward(self, query: str) -> str: | |
| response = wiki.search(query) | |
| if len(response) > 0: | |
| result = ", ".join(response) | |
| else: | |
| result = "No results found." | |
| return result | |
| class WikipediaPageTool(Tool): | |
| name = "get_wikipedia_page" | |
| description = """ | |
| Gets the content of a Wikipedia page. | |
| """ | |
| inputs = { | |
| "page_title": { | |
| "type": "string", | |
| "description": "The canonical title of the Wikipedia page." | |
| } | |
| } | |
| output_type = "string" | |
| def forward(self, page_title: str) -> str: | |
| # TODO: may need to do caching of the HTML ourselves? | |
| try: | |
| page = wiki.page(page_title) | |
| except wiki.exceptions.PageError: | |
| return f"Page '{page_title}' not found." | |
| md_content = to_markdown(page.html()) | |
| return md_content | |
| class FileAttachmentQueryTool(Tool): | |
| name = "run_query_with_file" | |
| description = """ | |
| Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it. | |
| This assumes the file is 20MB or less. | |
| """ | |
| inputs = { | |
| "task_id": { | |
| "type": "string", | |
| "description": "A unique identifier for the task related to this file, used to download it." | |
| }, | |
| "mime_type": { | |
| "type": "string", | |
| "nullable": True, | |
| "description": "The MIME type of the file, or the best guess if unknown." | |
| }, | |
| "user_query": { | |
| "type": "string", | |
| "description": "The question to answer about the file." | |
| } | |
| } | |
| output_type = "string" | |
| def __init__(self, client, model_name, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.client = client | |
| self.model_name = model_name | |
| def forward(self, task_id: str, mime_type: str | None, user_query: str) -> str: | |
| # Download the file | |
| file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}" | |
| file_response = requests.get(file_url) | |
| if file_response.status_code != 200: | |
| raise Exception(f"Failed to download file: {file_response.status_code} - {file_response.text}") | |
| file_data = file_response.content | |
| mime_type = mime_type or file_response.headers.get('Content-Type', 'application/octet-stream') | |
| response = self.client.models.generate_content( | |
| model=self.model_name, | |
| contents=[ | |
| types.Part.from_bytes( | |
| data=file_data, | |
| mime_type=mime_type, | |
| ), | |
| user_query, | |
| ] | |
| ) | |
| return response.text | |
| # --- Agent Management --- | |
| class GeminiAgentContainer: | |
| """ | |
| A container for the Gemini agent. | |
| """ | |
| # TODO: make it easier to chnge the model | |
| MODEL_NAME = "gemini-2.0-flash" | |
| def __init__(self, api_key: Optional[str] = None): | |
| api_key = api_key or os.getenv("GOOGLE_API_KEY") | |
| self.model = LiteLLMModel(model_id=f"gemini/{self.MODEL_NAME}", api_key=api_key) | |
| self.client = genai.Client(api_key=os.getenv("GOOGLE_API_KEY")) | |
| system_prompt = """ | |
| You are a general AI assistant. I will ask you a question. | |
| YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. | |
| If your answer is a number and you are not explicitly asked for a string, write it in numerals instead of words, and don't use comma to write your number nor use units such as $ or percent sign unless specified otherwise. | |
| If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. | |
| If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. | |
| Answer questions as literally as you can, making as few assumptions as possible. Restrict the answer to the narrowest definition that still satifies the question. | |
| If you are provied with a video, please watch and summarize the entire video before answering the question. The correct answer may be present only in a few frames of the video. | |
| If you have difficulty finding an answer on Wikipedia, you may search the internet using Google Search. | |
| If you are asked to prove something, first state your assumptions and think step by step before giving your final answer. | |
| """ | |
| self.agent = ToolCallingAgent( | |
| model=self.model, | |
| tools = [ | |
| VideoWatchingTool(model_name=self.MODEL_NAME), | |
| GoogleSearchTool(client=self.client, model_name=self.MODEL_NAME), | |
| WikipediaTitleSearchTool(), | |
| WikipediaPageTool(), | |
| FileAttachmentQueryTool(client=self.client, model_name=self.MODEL_NAME), | |
| ], | |
| max_steps=14, | |
| planning_interval=2, | |
| ) | |
| self.system_prompt = system_prompt | |
| def __call__(self, question: str) -> str: | |
| response = self.agent.run(f"{self.system_prompt}\n{question}") | |
| return response | |
| if __name__ == "__main__": | |
| agent_container = GeminiAgentContainer() | |
| agent = agent_container.agent | |
| #my_query = "How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia." | |
| #my_query = "In the video https://www.youtube.com/watch?v=L1vXCYZAYYM, what is the highest number of bird species to be on camera simultaneously?" | |
| my_query= "Examine the video at https://www.youtube.com/watch?v=1htKBjuUWec.\n\nWhat does Teal'c say in response to the question \"Isn't that hot?\"" | |
| response = agent.run(agent_container.system_prompt+my_query, max_steps=5) | |
| print(response) | |
| #print(my_query) | |