|
|
|
|
|
from typing import Dict, Any |
|
|
from flow_modules.aiflows.OpenAIChatFlowModule import OpenAIChatAtomicFlow |
|
|
from flows.utils.general_helpers import encode_image,encode_from_buffer |
|
|
import cv2 |
|
|
|
|
|
|
|
|
class VisionAtomicFlow(OpenAIChatAtomicFlow): |
|
|
|
|
|
@staticmethod |
|
|
def get_image(image): |
|
|
extension_dict = { |
|
|
"jpg": "jpeg", |
|
|
"jpeg": "jpeg", |
|
|
"png": "png", |
|
|
"webp": "webp", |
|
|
"gif": "gif" |
|
|
} |
|
|
supported_image_types = ["local_path","url"] |
|
|
assert image.get("type",None) in supported_image_types, f"Must define a valid image type for every image \n your type: {image.get('type',None)} \n supported types{supported_image_types} " |
|
|
|
|
|
processed_image = None |
|
|
url = None |
|
|
if image["type"] == "local_path": |
|
|
processed_image = encode_image(image.get("image")) |
|
|
image_extension_type = image.get("image").split(".")[-1] |
|
|
url = f"data:image/{extension_dict[image_extension_type]};base64, {processed_image}" |
|
|
|
|
|
elif image["type"] == "url": |
|
|
processed_image = image |
|
|
url = image.get("image") |
|
|
|
|
|
return {"type": "image_url", "image_url": {"url": url}} |
|
|
|
|
|
@staticmethod |
|
|
def get_video(video): |
|
|
video_path = video["video_path"] |
|
|
resize = video.get("resize",768) |
|
|
frame_step_size = video.get("frame_step_size",10) |
|
|
start_frame = video.get("start_frame",0) |
|
|
end_frame = video.get("end_frame",None) |
|
|
base64Frames = [] |
|
|
video = cv2.VideoCapture(video_path) |
|
|
while video.isOpened(): |
|
|
success,frame = video.read() |
|
|
if not success: |
|
|
break |
|
|
_,buffer = cv2.imencode(".jpg",frame) |
|
|
base64Frames.append(encode_from_buffer(buffer)) |
|
|
video.release() |
|
|
return map(lambda x: {"image": x, "resize": resize},base64Frames[start_frame:end_frame:frame_step_size]) |
|
|
|
|
|
@staticmethod |
|
|
def get_user_message(prompt_template, input_data: Dict[str, Any]): |
|
|
content = VisionAtomicFlow._get_message(prompt_template=prompt_template,input_data=input_data) |
|
|
media_data = input_data["data"] |
|
|
if "video" in media_data: |
|
|
content = [ content[0], *VisionAtomicFlow.get_video(media_data["video"])] |
|
|
if "images" in media_data: |
|
|
images = [VisionAtomicFlow.get_image(image) for image in media_data["images"]] |
|
|
content.extend(images) |
|
|
return content |
|
|
|
|
|
@staticmethod |
|
|
def _get_message(prompt_template, input_data: Dict[str, Any]): |
|
|
template_kwargs = {} |
|
|
for input_variable in prompt_template.input_variables: |
|
|
template_kwargs[input_variable] = input_data[input_variable] |
|
|
msg_content = prompt_template.format(**template_kwargs) |
|
|
return [{"type": "text", "text": msg_content}] |
|
|
|
|
|
def _process_input(self, input_data: Dict[str, Any]): |
|
|
if self._is_conversation_initialized(): |
|
|
|
|
|
user_message_content = self.get_user_message(self.human_message_prompt_template, input_data) |
|
|
|
|
|
else: |
|
|
|
|
|
self._initialize_conversation(input_data) |
|
|
if getattr(self, "init_human_message_prompt_template", None) is not None: |
|
|
|
|
|
user_message_content = self.get_user_message(self.init_human_message_prompt_template, input_data) |
|
|
else: |
|
|
user_message_content = self.get_user_message(self.human_message_prompt_template, input_data) |
|
|
|
|
|
self._state_update_add_chat_message(role=self.flow_config["user_name"], |
|
|
content=user_message_content) |