File size: 3,900 Bytes
c296fdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

from typing import Dict, Any
from flow_modules.aiflows.OpenAIChatFlowModule import OpenAIChatAtomicFlow
from flows.utils.general_helpers import encode_image,encode_from_buffer
import cv2


class VisionAtomicFlow(OpenAIChatAtomicFlow):
    
    @staticmethod      
    def get_image(image):
        extension_dict = {
            "jpg": "jpeg",
            "jpeg": "jpeg",
            "png": "png",
            "webp": "webp",
            "gif": "gif"
        }
        supported_image_types = ["local_path","url"]
        assert image.get("type",None) in supported_image_types, f"Must define a valid image type for every image \n your type: {image.get('type',None)} \n supported types{supported_image_types} "
        
        processed_image = None
        url = None
        if image["type"] == "local_path":
            processed_image = encode_image(image.get("image"))
            image_extension_type = image.get("image").split(".")[-1]
            url = f"data:image/{extension_dict[image_extension_type]};base64, {processed_image}"
            
        elif image["type"] == "url":
            processed_image = image
            url = image.get("image")
            
        return {"type": "image_url", "image_url": {"url": url}}
    
    @staticmethod
    def get_video(video):
        video_path = video["video_path"]
        resize = video.get("resize",768)
        frame_step_size = video.get("frame_step_size",10)
        start_frame = video.get("start_frame",0)
        end_frame = video.get("end_frame",None)
        base64Frames = []
        video = cv2.VideoCapture(video_path)
        while video.isOpened():
            success,frame = video.read()
            if not success:
                break
            _,buffer = cv2.imencode(".jpg",frame)
            base64Frames.append(encode_from_buffer(buffer))
        video.release()
        return map(lambda x: {"image": x, "resize": resize},base64Frames[start_frame:end_frame:frame_step_size])

    @staticmethod
    def get_user_message(prompt_template, input_data: Dict[str, Any]):
        content = VisionAtomicFlow._get_message(prompt_template=prompt_template,input_data=input_data)
        media_data = input_data["data"]
        if "video" in media_data:
            content = [ content[0], *VisionAtomicFlow.get_video(media_data["video"])]
        if "images" in media_data:
            images = [VisionAtomicFlow.get_image(image) for image in media_data["images"]]
            content.extend(images)
        return content
    
    @staticmethod
    def _get_message(prompt_template, input_data: Dict[str, Any]):
        template_kwargs = {}
        for input_variable in prompt_template.input_variables:
            template_kwargs[input_variable] = input_data[input_variable]
        msg_content = prompt_template.format(**template_kwargs)
        return [{"type": "text", "text": msg_content}]

    def _process_input(self, input_data: Dict[str, Any]):
        if self._is_conversation_initialized():
            # Construct the message using the human message prompt template
            user_message_content = self.get_user_message(self.human_message_prompt_template, input_data)

        else:
            # Initialize the conversation (add the system message, and potentially the demonstrations)
            self._initialize_conversation(input_data)
            if getattr(self, "init_human_message_prompt_template", None) is not None:
                # Construct the message using the query message prompt template
                user_message_content = self.get_user_message(self.init_human_message_prompt_template, input_data)
            else:
                user_message_content = self.get_user_message(self.human_message_prompt_template, input_data)

        self._state_update_add_chat_message(role=self.flow_config["user_name"],
                                            content=user_message_content)