Spaces:
Runtime error
Runtime error
| from llama_index.core.query_engine import CustomQueryEngine | |
| from llama_index.core.retrievers import BaseRetriever | |
| from llama_index.multi_modal_llms.openai import OpenAIMultiModal | |
| from llama_index.core.schema import ImageNode, NodeWithScore, MetadataMode | |
| from llama_index.core.prompts import PromptTemplate | |
| from llama_index.core.base.response.schema import Response | |
| from typing import Optional | |
| from core.prompt import MULTOMODAL_QUERY_TEMPLATE | |
| gpt_4o = OpenAIMultiModal(model="gpt-4o-mini", max_new_tokens=4096) | |
| QA_PROMPT = PromptTemplate(MULTOMODAL_QUERY_TEMPLATE) | |
| class MultimodalQueryEngine(CustomQueryEngine): | |
| """Custom multimodal Query Engine. | |
| Takes in a retriever to retrieve a set of document nodes. | |
| Also takes in a prompt template and multimodal model. | |
| """ | |
| qa_prompt: PromptTemplate | |
| retriever: BaseRetriever | |
| multi_modal_llm: OpenAIMultiModal | |
| def __init__(self, qa_prompt: Optional[PromptTemplate] = None, **kwargs) -> None: | |
| """Initialize.""" | |
| super().__init__(qa_prompt=qa_prompt or QA_PROMPT, **kwargs) | |
| def custom_query(self, query_str: str): | |
| # retrieve text nodes | |
| nodes = self.retriever.retrieve(query_str) | |
| # create ImageNode items from text nodes | |
| image_nodes = [ | |
| NodeWithScore(node=ImageNode(image_url=link)) | |
| for n in nodes | |
| if "image_link" in n.metadata | |
| and n.metadata["image_link"] not in ["", []] | |
| for link in (n.metadata["image_link"] if isinstance(n.metadata["image_link"], list) else [n.metadata["image_link"]]) | |
| if link not in ["", []] | |
| ] | |
| print("image_nodes: {}".format(image_nodes)) | |
| # create context string from text nodes, dump into the prompt | |
| context_str = "\n\n".join( | |
| [r.get_content(metadata_mode=MetadataMode.LLM) for r in nodes] | |
| ) | |
| fmt_prompt = self.qa_prompt.format(context_str=context_str, query_str=query_str) | |
| # synthesize an answer from formatted text and images | |
| llm_response = self.multi_modal_llm.complete( | |
| prompt=fmt_prompt, | |
| image_documents=[image_node.node for image_node in image_nodes], | |
| ) | |
| return Response( | |
| response=str(llm_response), | |
| source_nodes=nodes, | |
| metadata={"text_nodes": nodes, "image_nodes": image_nodes}, | |
| ) | |