Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -161,17 +161,74 @@ QA_PROMPT = PromptTemplate(QA_PROMPT_TMPL)
|
|
| 161 |
gpt_4o_mm = OpenAIMultiModal(model="gpt-4o-mini-2024-07-18")
|
| 162 |
|
| 163 |
class MultimodalQueryEngine(CustomQueryEngine):
|
| 164 |
-
def __init__(self, qa_prompt, retriever, multi_modal_llm, node_postprocessors=[]):
|
| 165 |
-
|
| 166 |
|
| 167 |
-
def custom_query(self, query_str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
nodes = self.retriever.retrieve(query_str)
|
| 169 |
-
image_nodes = [NodeWithScore(node=ImageNode(image_path=n.node.metadata["image_path"])) for n in nodes]
|
| 170 |
-
ctx_str = "\n\n".join([r.node.get_content().strip() for r in nodes])
|
| 171 |
-
fmt_prompt = self.qa_prompt.format(context_str=ctx_str, query_str=query_str, encoded_image_url=encoded_image_url)
|
| 172 |
-
llm_response = self.multi_modal_llm.complete(prompt=fmt_prompt, image_documents=[image_node.node for image_node in image_nodes])
|
| 173 |
-
return Response(response=str(llm_response), source_nodes=nodes, metadata={"text_nodes": text_nodes, "image_nodes": image_nodes})
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
query_engine = MultimodalQueryEngine(QA_PROMPT, retriever, gpt_4o_mm)
|
| 176 |
|
| 177 |
# Handle query
|
|
|
|
| 161 |
gpt_4o_mm = OpenAIMultiModal(model="gpt-4o-mini-2024-07-18")
|
| 162 |
|
| 163 |
class MultimodalQueryEngine(CustomQueryEngine):
|
| 164 |
+
# def __init__(self, qa_prompt, retriever, multi_modal_llm, node_postprocessors=[]):
|
| 165 |
+
# super().__init__(qa_prompt=qa_prompt, retriever=retriever, multi_modal_llm=multi_modal_llm, node_postprocessors=node_postprocessors)
|
| 166 |
|
| 167 |
+
# def custom_query(self, query_str):
|
| 168 |
+
# nodes = self.retriever.retrieve(query_str)
|
| 169 |
+
# image_nodes = [NodeWithScore(node=ImageNode(image_path=n.node.metadata["image_path"])) for n in nodes]
|
| 170 |
+
# ctx_str = "\n\n".join([r.node.get_content().strip() for r in nodes])
|
| 171 |
+
# fmt_prompt = self.qa_prompt.format(context_str=ctx_str, query_str=query_str, encoded_image_url=encoded_image_url)
|
| 172 |
+
# llm_response = self.multi_modal_llm.complete(prompt=fmt_prompt, image_documents=[image_node.node for image_node in image_nodes])
|
| 173 |
+
# return Response(response=str(llm_response), source_nodes=nodes, metadata={"text_nodes": text_nodes, "image_nodes": image_nodes})
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class MultimodalQueryEngine(CustomQueryEngine):
|
| 178 |
+
qa_prompt: PromptTemplate
|
| 179 |
+
retriever: BaseRetriever
|
| 180 |
+
multi_modal_llm: OpenAIMultiModal
|
| 181 |
+
node_postprocessors: Optional[List[BaseNodePostprocessor]]
|
| 182 |
+
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
qa_prompt: PromptTemplate,
|
| 186 |
+
retriever: BaseRetriever,
|
| 187 |
+
multi_modal_llm: OpenAIMultiModal,
|
| 188 |
+
node_postprocessors: Optional[List[BaseNodePostprocessor]] = [],
|
| 189 |
+
):
|
| 190 |
+
super().__init__(
|
| 191 |
+
qa_prompt=qa_prompt,
|
| 192 |
+
retriever=retriever,
|
| 193 |
+
multi_modal_llm=multi_modal_llm,
|
| 194 |
+
node_postprocessors=node_postprocessors
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def custom_query(self, query_str: str):
|
| 198 |
+
# retrieve most relevant nodes
|
| 199 |
nodes = self.retriever.retrieve(query_str)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
+
for postprocessor in self.node_postprocessors:
|
| 202 |
+
nodes = postprocessor.postprocess_nodes(
|
| 203 |
+
nodes, query_bundle=QueryBundle(query_str)
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# create image nodes from the image associated with those nodes
|
| 208 |
+
image_nodes = [
|
| 209 |
+
NodeWithScore(node=ImageNode(image_path=n.node.metadata["image_path"]))
|
| 210 |
+
for n in nodes
|
| 211 |
+
]
|
| 212 |
+
|
| 213 |
+
# create context string from parsed markdown text
|
| 214 |
+
ctx_str = "\n\n".join(
|
| 215 |
+
[r.node.get_content(metadata_mode=MetadataMode.LLM).strip() for r in nodes]
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# prompt for the LLM
|
| 219 |
+
fmt_prompt = self.qa_prompt.format(context_str=ctx_str, query_str=query_str,encoded_image_url=encoded_image_url)
|
| 220 |
+
|
| 221 |
+
# use the multimodal LLM to interpret images and generate a response to the prompt
|
| 222 |
+
llm_repsonse = self.multi_modal_llm.complete(
|
| 223 |
+
prompt=fmt_prompt,
|
| 224 |
+
image_documents=[image_node.node for image_node in image_nodes],
|
| 225 |
+
)
|
| 226 |
+
return Response(
|
| 227 |
+
response=str(llm_repsonse),
|
| 228 |
+
source_nodes=nodes,
|
| 229 |
+
metadata={"text_nodes": text_nodes, "image_nodes": image_nodes},
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
query_engine = MultimodalQueryEngine(QA_PROMPT, retriever, gpt_4o_mm)
|
| 233 |
|
| 234 |
# Handle query
|