File size: 5,999 Bytes
b8c9254
 
 
 
 
 
 
 
 
 
 
 
 
2554896
 
 
b8c9254
 
2554896
b8c9254
 
 
 
 
2554896
 
 
 
 
 
 
 
 
 
 
 
b8c9254
2554896
b8c9254
2554896
b8c9254
2554896
b8c9254
 
 
2554896
b8c9254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8756846
 
 
b8c9254
29c51b6
5f6f019
 
 
 
 
8756846
5f6f019
29c51b6
8756846
29c51b6
16e7684
29c51b6
b8c9254
16e7684
b8c9254
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69d5204
 
 
b8c9254
16e7684
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from typing import Dict, List, Any
import torch
from transformers import LlavaNextVideoForConditionalGeneration, LlavaNextVideoProcessor
from peft import PeftModel
import base64
import numpy as np

def base64_to_numpy(base64_str, shape):
    arr_bytes = base64.b64decode(base64_str)
    arr = np.frombuffer(arr_bytes, dtype=np.uint8)
    return arr.reshape(shape)

class EndpointHandler:
    def __init__(self, model_dir: str = None):
        self.base_model_name = "llava-hf/LLaVA-NeXT-Video-7B-hf"
        self.adapter_model_name = "EnariGmbH/surftown-1.0"

        # Load the base model
        print("Loading base model:", self.base_model_name)
        self.model = LlavaNextVideoForConditionalGeneration.from_pretrained(
            self.base_model_name,
            torch_dtype=torch.float16,
            device_map="auto"
        )
        print("Base model successfully loaded.")

        # Load the adapter model into the base model
        print("Loading adapter model:", self.adapter_model_name)
        try:
            self.model = PeftModel.from_pretrained(self.model, self.adapter_model_name)
            print("Adapter model successfully loaded.")
        except Exception as e:
            print(f"Failed to load adapter model: {e}")
            raise e

        # Merge the adapter weights into the base model
        self.model = self.model.merge_and_unload()
        print("Adapter model merged and unloaded.")

        # Load processor
        self.processor = LlavaNextVideoProcessor.from_pretrained(self.adapter_model_name)
        print("Processor loaded.")

        self.model.eval()


    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Args:
            data (Dict): Contains the input data including "clip"
        
        Returns:
            List[Dict[str, Any]]: The generated text from the model.
        """
        # Extract inputs from the data dictionary
        clip_base64 = data.get("clip")
        clip_shape = data.get("clip_shape")  # Expect the shape to be passed
        
        if clip_base64 is None or clip_shape is None:
            return [{"error": "Missing 'clip' or 'clip_shape' in input data"}]

        # Decode the base64 back to numpy array and reshape
        clip = base64_to_numpy(clip_base64, tuple(clip_shape))

        prompt = """
        You are a surfing coach specialized on perfecting surfer's pop-up move. Please analyze the surfer's pop-up move in detail from the video. 
                    In your detailed analysis you should always mention: Wave Position and paddling, Pushing Phase, Transition, Reaching Phase and finnaly Balance and Control.
                    At the end of your answer you must provide suggestions on how the surfer can improve in the next pop-up.
                    Your answers should ALWAYS follow this structure:
                    
                        Description:
                            Wave Position and paddling: text
                            Pushing Phase: text
                            Transition: text
                            Reaching Phase: text
                            Balance and Control: text 
                        Summary:
                            Suggestions for improvement: text
                    NEVER MENTION ANY INFORMATION THAT IS NOT RELEVANT FOR THE SURFER.
                    KEEP YOUR ANSWERS SHORT AND DIRECT AND DO NOT MENTION ANY INFORMATION OUTSIDE OF THE BEFORE MENTIONED STRUCTURE.
                    IMPORTANT: In the Balance and Control section you should also explain how the surfer performs their twists and turns after the pop-up is done.
                """


        # Define a conversation history for surfing pop-up move analysis
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "video"},
                ],
            },
        ]

        # Apply the chat template to create the prompt for the model
        prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)

        if clip is None or prompt is None:
            return [{"error": "Missing 'clip' or 'prompt' in input data"}]

        # Ensure clip_bytes is converted properly to the expected format by the model
        inputs_video = ml.processor(text=prompt, videos=clip, padding=True, return_tensors="pt").to(ml.model.device)
        
        # Debug: Print the entire inputs_video structure
        print(f"Keys in inputs_video: {inputs_video.keys()}")

        # Rename pixel_values_videos to pixel_values if it exists
        if 'pixel_values_videos' in inputs_video:
            inputs_video['pixel_values'] = inputs_video.pop('pixel_values_videos')
            print(f"Renamed pixel_values_videos to pixel_values. New shape: {inputs_video['pixel_values'].shape}")
        else:
            print("pixel_values_videos not found in inputs_video")

        if 'input_ids' in inputs_video:
            print(f"input_ids shape: {inputs_video['input_ids'].shape}")
        else:
            print("input_ids not found in inputs_video")

        if 'attention_mask' in inputs_video:
            print(f"attention_mask shape: {inputs_video['attention_mask'].shape}")
        else:
            print("attention_mask not found in inputs_video")


        # Generate output from the model
        generate_kwargs = {"max_new_tokens": 512, "do_sample": True, "top_p": 0.9}
        output = self.model.generate(**inputs_video, **generate_kwargs)
        generated_text = self.processor.batch_decode(output, skip_special_tokens=True)

        # Extract the relevant part of the assistant's answer
        assistant_answer_start = generated_text[0].find("ASSISTANT:") + len("ASSISTANT:")
        assistant_answer = generated_text[0][assistant_answer_start:].strip() 

        print("model answer", assistant_answer)

        return [{"generated_text": assistant_answer}]