abliznyuk's picture
swithc to 4o-mini, add VQA tool
a73f583
raw
history blame
2.75 kB
from smolagents import CodeAgent, OpenAIServerModel, tool, Tool
from smolagents import WikipediaSearchTool, GoogleSearchTool, VisitWebpageTool, PythonInterpreterTool
def get_prompt():
with open("prompt.txt", "r") as f:
return f.read()
@tool
def visual_qa(image_url: str, question: str) -> str:
"""
Provides functionality to perform visual question answering (VQA) by processing an image and a natural language question.
Args:
image_url: str
A URL pointing to the location of the image to be analyzed. The URL
should be accessible and point to a valid image file.
question: str
A natural language string containing the question to be answered
based on the provided image.
Returns:
str
The model-generated answer to the provided question based on the
analysis of the image.
Raises:
Exception
If there is any issue with the API request, such as connection
errors or invalid inputs.
"""
from openai import OpenAI
client = OpenAI()
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{
"role": "user",
"content": [
{"type": "text", "text": question},
{
"type": "image_url",
"image_url": {
"url": image_url,
"detail": "low"
},
},
],
}],
)
return response.choices[0].message.content
class FinalAnswerTool(Tool):
name = "final_answer"
description = "Provides a final answer to the given problem."
inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
output_type = "any"
def forward(self, answer: str) -> str:
if "final answer:" in answer.lower():
return answer.lower().split("final answer:")[1].strip()
return answer
class GAIAAgent:
def __init__(self):
self.agent = CodeAgent(
tools=[
GoogleSearchTool(provider="serper"),
VisitWebpageTool(),
WikipediaSearchTool(),
PythonInterpreterTool(),
FinalAnswerTool(),
visual_qa,
],
model=OpenAIServerModel(model_id='gpt-4o-mini', max_tokens=4096, temperature=0),
add_base_tools=False,
max_steps=10,
)
self.prompt = get_prompt()
def __call__(self, question: str) -> str:
return self.agent.run(self.prompt, additional_args={"question": question})
if __name__ == '__main__':
agent = GAIAAgent()