Commit ·
0ce2853
1
Parent(s): fd3fe1e
Change provider to OpenAI.
Browse files- app.py +15 -8
- requirements.txt +1 -1
app.py
CHANGED
|
@@ -6,7 +6,8 @@ import mimetypes
|
|
| 6 |
import requests
|
| 7 |
import pandas as pd
|
| 8 |
from llama_index.core.llms import ChatMessage, TextBlock, ImageBlock, AudioBlock
|
| 9 |
-
from llama_index.llms.google_genai import GoogleGenAI
|
|
|
|
| 10 |
from llama_index.core.agent.workflow import ReActAgent, AgentOutput
|
| 11 |
from llama_index.core.tools import FunctionTool
|
| 12 |
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
|
|
@@ -27,7 +28,7 @@ load_dotenv()
|
|
| 27 |
# --- Constants ---
|
| 28 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 29 |
SYSTEM_PROMPT = (Path(__file__).parent / 'system_prompt.txt').read_text()
|
| 30 |
-
GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY']
|
| 31 |
|
| 32 |
# --- Basic Agent Definition ---
|
| 33 |
class BasicAgent:
|
|
@@ -37,7 +38,8 @@ class BasicAgent:
|
|
| 37 |
wikipedia_search_tool = FunctionTool.from_defaults(WikipediaToolSpec().search_data)
|
| 38 |
self._tools = [search_tool, wikipedia_load_tool, wikipedia_search_tool]
|
| 39 |
|
| 40 |
-
self._llm = GoogleGenAI(api_key=GOOGLE_API_KEY, model="gemini-2.0-flash", max_tokens=1600)
|
|
|
|
| 41 |
self._agent = ReActAgent(tools=self._tools, llm=self._llm)
|
| 42 |
# Modify the react prompt.
|
| 43 |
self._agent.update_prompts({"react_header": SYSTEM_PROMPT})
|
|
@@ -72,23 +74,28 @@ def fetch_questions(api_url: str = DEFAULT_API_URL):
|
|
| 72 |
|
| 73 |
|
| 74 |
def get_media_type(filename: str):
|
| 75 |
-
|
| 76 |
-
if
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
def get_media_content(item):
|
| 81 |
if item.get('file_name'):
|
| 82 |
file_response = requests.get(f"{DEFAULT_API_URL}/files/{item.get('task_id')}")
|
| 83 |
if file_response:
|
| 84 |
-
media_type = get_media_type(item.get('file_name'))
|
| 85 |
if media_type == 'image':
|
| 86 |
return ImageBlock(image=file_response.content)
|
| 87 |
elif media_type == 'text':
|
| 88 |
return TextBlock(text=file_response.content)
|
| 89 |
# Audio currently not supported?
|
| 90 |
elif media_type == 'audio':
|
| 91 |
-
return AudioBlock(audio=file_response.content)
|
| 92 |
|
| 93 |
|
| 94 |
def create_question_message(item):
|
|
|
|
| 6 |
import requests
|
| 7 |
import pandas as pd
|
| 8 |
from llama_index.core.llms import ChatMessage, TextBlock, ImageBlock, AudioBlock
|
| 9 |
+
# from llama_index.llms.google_genai import GoogleGenAI
|
| 10 |
+
from llama_index.llms.openai import OpenAI
|
| 11 |
from llama_index.core.agent.workflow import ReActAgent, AgentOutput
|
| 12 |
from llama_index.core.tools import FunctionTool
|
| 13 |
from llama_index.tools.duckduckgo import DuckDuckGoSearchToolSpec
|
|
|
|
| 28 |
# --- Constants ---
|
| 29 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 30 |
SYSTEM_PROMPT = (Path(__file__).parent / 'system_prompt.txt').read_text()
|
| 31 |
+
# GOOGLE_API_KEY = os.environ['GOOGLE_API_KEY']
|
| 32 |
|
| 33 |
# --- Basic Agent Definition ---
|
| 34 |
class BasicAgent:
|
|
|
|
| 38 |
wikipedia_search_tool = FunctionTool.from_defaults(WikipediaToolSpec().search_data)
|
| 39 |
self._tools = [search_tool, wikipedia_load_tool, wikipedia_search_tool]
|
| 40 |
|
| 41 |
+
#self._llm = GoogleGenAI(api_key=GOOGLE_API_KEY, model="gemini-2.0-flash", max_tokens=1600)
|
| 42 |
+
self._llm = OpenAI(model="gpt-4o-mini")
|
| 43 |
self._agent = ReActAgent(tools=self._tools, llm=self._llm)
|
| 44 |
# Modify the react prompt.
|
| 45 |
self._agent.update_prompts({"react_header": SYSTEM_PROMPT})
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
def get_media_type(filename: str):
|
| 77 |
+
media_type_and_format = mimetypes.guess_type(filename)[0]
|
| 78 |
+
if media_type_and_format is not None:
|
| 79 |
+
media_type, media_format = media_type_and_format.split('/')
|
| 80 |
+
if media_type == "audio" and media_format == "mpeg":
|
| 81 |
+
media_format = "mp3"
|
| 82 |
+
return media_type, media_format
|
| 83 |
+
else:
|
| 84 |
+
return None, None
|
| 85 |
|
| 86 |
|
| 87 |
def get_media_content(item):
|
| 88 |
if item.get('file_name'):
|
| 89 |
file_response = requests.get(f"{DEFAULT_API_URL}/files/{item.get('task_id')}")
|
| 90 |
if file_response:
|
| 91 |
+
media_type, media_format = get_media_type(item.get('file_name'))
|
| 92 |
if media_type == 'image':
|
| 93 |
return ImageBlock(image=file_response.content)
|
| 94 |
elif media_type == 'text':
|
| 95 |
return TextBlock(text=file_response.content)
|
| 96 |
# Audio currently not supported?
|
| 97 |
elif media_type == 'audio':
|
| 98 |
+
return AudioBlock(audio=file_response.content, format=media_format)
|
| 99 |
|
| 100 |
|
| 101 |
def create_question_message(item):
|
requirements.txt
CHANGED
|
@@ -3,6 +3,6 @@ gradio
|
|
| 3 |
requests
|
| 4 |
pandas
|
| 5 |
llama-index
|
| 6 |
-
llama-index-llms-
|
| 7 |
llama-index-tools-duckduckgo
|
| 8 |
llama-index-tools-wikipedia
|
|
|
|
| 3 |
requests
|
| 4 |
pandas
|
| 5 |
llama-index
|
| 6 |
+
llama-index-llms-openai
|
| 7 |
llama-index-tools-duckduckgo
|
| 8 |
llama-index-tools-wikipedia
|