VisionFlowModule / VisionAtomicFlow.py
nbaldwin's picture
vision module first version
c296fdd
raw
history blame
3.9 kB
from typing import Dict, Any
from flow_modules.aiflows.OpenAIChatFlowModule import OpenAIChatAtomicFlow
from flows.utils.general_helpers import encode_image,encode_from_buffer
import cv2
class VisionAtomicFlow(OpenAIChatAtomicFlow):
@staticmethod
def get_image(image):
extension_dict = {
"jpg": "jpeg",
"jpeg": "jpeg",
"png": "png",
"webp": "webp",
"gif": "gif"
}
supported_image_types = ["local_path","url"]
assert image.get("type",None) in supported_image_types, f"Must define a valid image type for every image \n your type: {image.get('type',None)} \n supported types{supported_image_types} "
processed_image = None
url = None
if image["type"] == "local_path":
processed_image = encode_image(image.get("image"))
image_extension_type = image.get("image").split(".")[-1]
url = f"data:image/{extension_dict[image_extension_type]};base64, {processed_image}"
elif image["type"] == "url":
processed_image = image
url = image.get("image")
return {"type": "image_url", "image_url": {"url": url}}
@staticmethod
def get_video(video):
video_path = video["video_path"]
resize = video.get("resize",768)
frame_step_size = video.get("frame_step_size",10)
start_frame = video.get("start_frame",0)
end_frame = video.get("end_frame",None)
base64Frames = []
video = cv2.VideoCapture(video_path)
while video.isOpened():
success,frame = video.read()
if not success:
break
_,buffer = cv2.imencode(".jpg",frame)
base64Frames.append(encode_from_buffer(buffer))
video.release()
return map(lambda x: {"image": x, "resize": resize},base64Frames[start_frame:end_frame:frame_step_size])
@staticmethod
def get_user_message(prompt_template, input_data: Dict[str, Any]):
content = VisionAtomicFlow._get_message(prompt_template=prompt_template,input_data=input_data)
media_data = input_data["data"]
if "video" in media_data:
content = [ content[0], *VisionAtomicFlow.get_video(media_data["video"])]
if "images" in media_data:
images = [VisionAtomicFlow.get_image(image) for image in media_data["images"]]
content.extend(images)
return content
@staticmethod
def _get_message(prompt_template, input_data: Dict[str, Any]):
template_kwargs = {}
for input_variable in prompt_template.input_variables:
template_kwargs[input_variable] = input_data[input_variable]
msg_content = prompt_template.format(**template_kwargs)
return [{"type": "text", "text": msg_content}]
def _process_input(self, input_data: Dict[str, Any]):
if self._is_conversation_initialized():
# Construct the message using the human message prompt template
user_message_content = self.get_user_message(self.human_message_prompt_template, input_data)
else:
# Initialize the conversation (add the system message, and potentially the demonstrations)
self._initialize_conversation(input_data)
if getattr(self, "init_human_message_prompt_template", None) is not None:
# Construct the message using the query message prompt template
user_message_content = self.get_user_message(self.init_human_message_prompt_template, input_data)
else:
user_message_content = self.get_user_message(self.human_message_prompt_template, input_data)
self._state_update_add_chat_message(role=self.flow_config["user_name"],
content=user_message_content)