Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| from dotenv import load_dotenv | |
| from duckduckgo_search import DDGS | |
| from langchain_core.messages.tool import BaseMessage, ToolMessage | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_core.tools import tool | |
| from langgraph.graph import END, MessageGraph | |
| from langgraph.prebuilt import ToolNode | |
| from typing import TypedDict | |
| from llm import get_text_llm | |
| from log_util import logger | |
| from time_it import time_it | |
| from util import load_prompt | |
| load_dotenv() | |
| MAX_IMAGE_SEARCH_RESULTS = int(os.getenv('MAX_IMAGE_SEARCH_RESULTS', '3')) | |
| class ImageSearchResult(TypedDict): | |
| title: str | |
| url: str | |
| def search_meal_image(meal: str) -> str: | |
| prompt = load_prompt('validate_is_meal.prompt.txt') | |
| llm = get_text_llm() | |
| tools = [search_meal_images] | |
| def is_meal_router(messages: list[BaseMessage]) -> str: | |
| if messages[-1].content.lower() == 'yes': | |
| return 'is_meal' | |
| return END | |
| graph = MessageGraph() | |
| graph.add_node('validate_is_meal', llm) | |
| graph.add_conditional_edges('validate_is_meal', is_meal_router) | |
| graph.add_node('is_meal', llm.bind_tools(tools)) | |
| graph.add_edge('is_meal', 'call_tools') | |
| graph.add_node('call_tools', ToolNode(tools)) | |
| graph.add_edge('call_tools', END) | |
| graph.set_entry_point('validate_is_meal') | |
| prompt_template = PromptTemplate.from_template(prompt) | |
| prompt = prompt_template.format(phrase=meal) | |
| workflow = graph.compile() | |
| messages: list = workflow.invoke(prompt) | |
| tool_messages = [message for message in messages if isinstance(message, ToolMessage)] | |
| if tool_messages and tool_messages[0].content: | |
| meal_images: list[ImageSearchResult] = json.loads(tool_messages[0].content) | |
| if meal_images: | |
| meal_image_url = meal_images[0]['url'] | |
| logger.info(f'{meal_image_url=}') | |
| return meal_image_url | |
| return None | |
| def search_meal_images(meal: str) -> list[ImageSearchResult]: | |
| '''Searches for images of the given meal.''' | |
| return search_images(meal) | |
| def search_images(keywords: str, max_results: int | None=MAX_IMAGE_SEARCH_RESULTS) -> list[ImageSearchResult]: | |
| results = DDGS().images( | |
| keywords=keywords, | |
| region='wt-wt', | |
| safesearch='on', | |
| size=None, | |
| color='color', | |
| type_image='photo', | |
| layout=None, | |
| license_image=None, | |
| max_results=max_results, | |
| ) | |
| logger.info(f'{keywords=}: {results=}') | |
| results = [ImageSearchResult(title=result['title'], url=result['image']) for result in results] | |
| return results | |