jaronfei commited on
Commit
53cd606
·
1 Parent(s): c9a55ab

clean unnecessary files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/example.mp4 filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -14,62 +14,23 @@ torch==2.1.0
14
  torchvision==0.16.0
15
  transformers==4.40.2
16
  peft==0.10.0
17
- pyarrow==13.0.0 # load parquet
18
- decord==0.6.0 # load video
19
- pysubs2==1.7.2 # load subtitle
20
  ```
21
 
22
- ### Sample Inference Code
23
 
24
- ```
25
- import torch
26
-
27
- from eval import load_video
28
- from videoccam import VideoCCAM
29
-
30
- video_path = 'assets/example.mp4'
31
- question = 'Can you please describe what happens in the video in detail?'
32
 
33
- sample_config = dict(
34
- sample_type='uniform',
35
- num_frames=32
36
- )
37
 
38
- mllm = VideoCCAM(
39
- model_path='.',
40
- chat_template='<|user|>\n{input}<|end|>\n<|assistant|>\n',
41
- generation_args=dict(
42
- stop_tokens=['<|end|>', '<|endoftext|>'],
43
- max_new_tokens=512,
44
- do_sample=False,
45
- num_beams=5,
46
- ),
47
- llm_name_or_path='microsoft/Phi-3-mini-4k-instruct', # you can replace this with local directory if the model has been downloaded before
48
- visual_encoder_name_or_path='google/siglip-so400m-patch14-384', # you can replace this with local directory if the model has been downloaded before
49
- special_tokens=['<time>', '</time>'],
50
- visual_select_layer=-2,
51
- torch_dtype=torch.bfloat16,
52
- device_map='cuda:0'
53
- )
54
-
55
- frames, = load_video(video_path, **sample_config)
56
- response = mllm.generate(texts=[question], videos=[frames])[0]
57
-
58
- print(response)
59
- ```
60
-
61
- ### Video-MME Evaluation
62
-
63
- You are expected to reproduce the results of 48.2 (without subtitle) and 51.7 (with subtitle) by running the following command. By default, the results are saved as `output_w_sub.json` and `output_wo_sub.json` in local directory. We provide our results in `ref_results` directory.
64
-
65
- ```
66
- python eval.py
67
- ```
68
 
69
  ## Acknowledgement
70
 
71
- * [xtuner](https://github.com/InternLM/xtuner): Video-CCAM-4B is trained using the xtuner framework. Thanks for their excellent works!
72
- * [Phi-3-Mini-4K-Instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct): Great small language models developed by Microsoft.
73
  * [SigLIP SO400M](https://huggingface.co/google/siglip-so400m-patch14-384): Outstanding vision encoder developed by Google.
74
 
75
  ## License
 
14
  torchvision==0.16.0
15
  transformers==4.40.2
16
  peft==0.10.0
 
 
 
17
  ```
18
 
19
+ ## Inference
20
 
21
+ Please refer to [Video-CCAM](https://github.com/QQ-MM/Video-CCAM) on inference and evaluation.
 
 
 
 
 
 
 
22
 
23
+ ### Video-MME
 
 
 
24
 
25
+ |#Frames.|32|96|
26
+ |:-:|:-:|:-:|
27
+ |w/o subs|48.2|49.6|
28
+ |w subs|51.7|53.0|
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  ## Acknowledgement
31
 
32
+ * [xtuner](https://github.com/InternLM/xtuner): Video-CCAM-9B is trained using the xtuner framework. Thanks for their excellent works!
33
+ * [Phi-3-Mini-4K-Instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct): Powerful language models developed by Microsoft.
34
  * [SigLIP SO400M](https://huggingface.co/google/siglip-so400m-patch14-384): Outstanding vision encoder developed by Google.
35
 
36
  ## License
assets/example.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2c9ce295c4c154bdbc266c2333b18710796ff1d151623447664730aae25a461c
3
- size 3283880
 
 
 
 
eval.py DELETED
@@ -1,257 +0,0 @@
1
- #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
- """
4
- ================================================
5
- @author: Jaron
6
- @time: 2024/06/23 12:59:38
7
- @email: fjjth98@163.com
8
- @description: Evaluate MLLM on Video-MME Benchmark
9
- ================================================
10
- """
11
-
12
- import json
13
- import torch
14
- import pysubs2
15
- import os.path as osp
16
-
17
- from PIL import Image
18
- from tqdm import tqdm
19
- from typing import Any
20
- from copy import deepcopy
21
- from pandas import read_parquet
22
- from decord import VideoReader, cpu
23
- from torch.utils.data import Dataset, DataLoader, default_collate
24
-
25
-
26
- def video_collate_fn(batch: Any) -> Any:
27
- """this collate function address dict video inputs, support to process variable number of frames for different inputs
28
-
29
- Args:
30
- batch (_type_): _description_
31
-
32
- Returns:
33
- _type_: _description_
34
- """
35
- if isinstance(batch[0], dict) and 'video' in batch[0]:
36
- video = [b.pop('video') for b in batch]
37
- batch = default_collate(batch)
38
- batch['video'] = video
39
- else:
40
- batch = default_collate(batch)
41
- return batch
42
-
43
-
44
- def uniform_indices(num_frames: int, total_frames: int) -> list[int]:
45
- """Get uniform indices
46
-
47
- Args:
48
- num_frames (int): number of frames
49
- total_frames (int): total number of frames
50
-
51
- Returns:
52
- list[int]: Output frame indices
53
- """
54
- if num_frames < total_frames:
55
- splits = torch.linspace(0, total_frames, num_frames+1, dtype=int)
56
- indices = ((splits[:-1] + splits[1:]) // 2).tolist()
57
- else:
58
- indices = list(range(total_frames))
59
-
60
- return indices
61
-
62
-
63
- def fps_indices(input_fps: float, total_frames: int, output_fps: float = None, max_num_frames: int = -1) -> list[int]:
64
- """Get indices according to the output_fps
65
-
66
- Args:
67
- input_fps (float): input fps
68
- total_frames (int): total number of frames
69
- output_fps (float, optional): output fps. Defaults to None, means output_fps==input_fps.
70
- max_num_frames (int, optional): max number of frames. Defaults to -1, means no limitation.
71
-
72
- Returns:
73
- list[int]: Output frame indices
74
- """
75
- delta = 1 if output_fps is None else input_fps / output_fps
76
- indices = torch.arange(0, total_frames, delta).round().to(int)
77
- indices = [e for e in indices if e < total_frames]
78
- if 0 < max_num_frames < len(indices):
79
- indices = indices[:max_num_frames]
80
-
81
- return indices
82
-
83
-
84
- def load_video(src_path: str, sample_type: str, sub_path: str = None, **kwargs) -> list[Image.Image]:# | tuple[list[Image.Image], str]:
85
- """Load video using decord, optionally load subtitles
86
-
87
- Args:
88
- src_path (str): video path
89
- sample_type (str): 'uniform' or 'fps'
90
- sub_path (str): subtitle path, .srt
91
- kwargs: for 'uniform', require 'num_frames'; for 'fps', optionally require 'output_fps' and 'max_num_frames'
92
-
93
- Returns:
94
- list[Image.Image]: frame list
95
- """
96
- vr = VideoReader(src_path, ctx=cpu(0), num_threads=1)
97
- total_frames = len(vr)
98
- if sample_type == 'uniform':
99
- num_frames = kwargs.pop('num_frames')
100
- indices = uniform_indices(num_frames, total_frames)
101
- elif sample_type == 'fps':
102
- input_fps = float(vr.get_avg_fps())
103
- output_fps = kwargs.pop('output_fps', None)
104
- max_num_frames = kwargs.pop('max_num_frames', -1)
105
- indices = fps_indices(input_fps, total_frames, output_fps, max_num_frames)
106
- else:
107
- raise ValueError(f'Do not support {sample_type} sample type')
108
- frames = vr.get_batch(indices).asnumpy() # (T, H, W, C), np.uint8
109
- frames = [Image.fromarray(frame) for frame in frames]
110
-
111
- if sub_path is None:
112
- return frames
113
- elif osp.exists(sub_path):
114
- subs = pysubs2.load(sub_path, encoding='utf-8')
115
- subtitles = []
116
- for idx in indices:
117
- sub_text = []
118
- cur_time = pysubs2.make_time(fps=float(vr.get_avg_fps()), frames=idx)
119
- for sub in subs:
120
- if sub.end < cur_time:
121
- continue
122
- elif sub.start < cur_time:
123
- sub_text.append(sub.text.replace('\\N', ' '))
124
- break # in accordance to the official benchmark
125
- else:
126
- break
127
- sub_text = ' '.join(sub_text)
128
- if sub_text.strip():
129
- subtitles.append(sub_text)
130
- subtitles = '\n'.join(subtitles)
131
- return frames, subtitles
132
- else:
133
- return frames, ''
134
-
135
-
136
- class VideoMMEDataset(Dataset):
137
-
138
- def __init__(self, dataset_path: str, sample_config: dict, use_subtitle: bool = False):
139
- super().__init__()
140
- self.dataset_path = dataset_path
141
- self.sample_config = sample_config
142
- self.use_subtitle = use_subtitle
143
-
144
- data_dict = {}
145
- index_keys = ['video_id', 'duration', 'domain', 'sub_category', 'videoID']
146
- value_keys = ['question_id', 'task_type', 'question', 'options', 'answer']
147
- df = read_parquet(osp.join(dataset_path, 'videomme', 'test-00000-of-00001.parquet'))
148
- df['options'] = df['options'].apply(list)
149
- for _, data in df.iterrows():
150
- key = tuple(data[k] for k in index_keys)
151
- value = data[value_keys].to_dict()
152
- if key in data_dict:
153
- data_dict[key].append(value)
154
- else:
155
- data_dict[key] = [value]
156
- self.data_list = [dict(zip(index_keys + ['questions'], list(k) + [v])) for k, v in data_dict.items()]
157
-
158
- def __len__(self):
159
- return len(self.data_list)
160
-
161
- def __getitem__(self, idx) -> dict:
162
- if self.use_subtitle:
163
- frames, subtitles = load_video(
164
- src_path=osp.join(self.dataset_path, 'video', self.data_list[idx]['videoID'] + '.mp4'),
165
- sub_path=osp.join(self.dataset_path, 'subtitle', self.data_list[idx]['videoID'] + '.srt'),
166
- **self.sample_config
167
- )
168
- text = ['\n'.join([
169
- "This video's subtitles are listed below:",
170
- subtitles,
171
- 'Select the best answer to the following multiple-choice question based on the video. Respond with only the letter (A, B, C, or D) of the correct option.',
172
- i['question']
173
- ] + i['options']) for i in self.data_list[idx]['questions']]
174
- else:
175
- frames = load_video(
176
- src_path=osp.join(self.dataset_path, 'video', self.data_list[idx]['videoID'] + '.mp4'),
177
- **self.sample_config
178
- )
179
- text = ['\n'.join([
180
- 'Select the best answer to the following multiple-choice question based on the video. Respond with only the letter (A, B, C, or D) of the correct option.',
181
- i['question']
182
- ] + i['options']) for i in self.data_list[idx]['questions']]
183
- subtitles = ''
184
-
185
- return dict(
186
- video=frames,
187
- text=text
188
- )
189
-
190
-
191
- if __name__ == '__main__':
192
-
193
- from videoccam import VideoCCAM, DEFAULT_VIDEO_TOKEN
194
-
195
- mllm = VideoCCAM(
196
- model_path='.',
197
- chat_template='<|user|>\n{input}<|end|>\n<|assistant|>\n',
198
- generation_args=dict(
199
- stop_tokens=['<|end|>', '<|endoftext|>'],
200
- max_new_tokens=512,
201
- do_sample=False
202
- ),
203
- llm_name_or_path='microsoft/Phi-3-mini-4k-instruct',
204
- visual_encoder_name_or_path='google/siglip-so400m-patch14-384',
205
- special_tokens=['<time>', '</time>'],
206
- visual_select_layer=-2,
207
- torch_dtype=torch.bfloat16,
208
- device_map='cuda:0'
209
- )
210
- mllm.eval()
211
-
212
- dataset = VideoMMEDataset(
213
- dataset_path='',
214
- sample_config=dict(
215
- sample_type='uniform',
216
- num_frames=32
217
- )
218
- )
219
-
220
- with torch.inference_mode():
221
- for use_subtitle in (True,):
222
- dataset.use_subtitle = use_subtitle
223
- dataloader = DataLoader(
224
- dataset,
225
- batch_size=4,
226
- num_workers=8,
227
- shuffle=False,
228
- pin_memory=True,
229
- collate_fn=video_collate_fn
230
- )
231
- results = []
232
- for data in tqdm(dataloader):
233
- response, pixel_values = mllm.generate(
234
- texts=['\n'.join([DEFAULT_VIDEO_TOKEN, t]) for t in data['text'][0]],
235
- videos=data['video'],
236
- return_pixel_values=True
237
- )
238
- response = [response]
239
- for i in range(1, len(data['text'])):
240
- response.append(mllm.generate(
241
- texts=['\n'.join([DEFAULT_VIDEO_TOKEN, t]) for t in data['text'][i]],
242
- pixel_values=pixel_values
243
- ))
244
- response = [[response[i][j] for i in range(len(response))] for j in range(len(response[0]))]
245
- results.extend(response)
246
-
247
- outputs = []
248
- for data, responses in zip(dataset.data_list, results):
249
- data = deepcopy(data)
250
- data.pop('videoID')
251
- for question, response in zip(data['questions'], responses):
252
- question['response'] = response
253
- outputs.append(data)
254
-
255
- suffix = 'w_sub' if use_subtitle else 'wo_sub'
256
- with open(f'output_{suffix}.json', 'w') as f:
257
- json.dump(outputs, f, indent=4, ensure_ascii=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ref_results/output_w_sub.json DELETED
The diff for this file is too large to render. See raw diff
 
ref_results/output_wo_sub.json DELETED
The diff for this file is too large to render. See raw diff
 
videoccam.py DELETED
@@ -1,312 +0,0 @@
1
- #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
- """
4
- ================================================
5
- @author: Jaron
6
- @time: 2024/06/23 09:52:24
7
- @email: fjjth98@163.com
8
- @description:
9
- ================================================
10
- """
11
-
12
- import torch
13
- import os.path as osp
14
- import torch.nn as nn
15
-
16
- from PIL import Image
17
- from peft import PeftModel
18
- from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, SiglipVisionModel, SiglipImageProcessor
19
-
20
-
21
- IGNORE_INDEX = -100
22
- IMAGE_TOKEN_INDEX = -200
23
- DEFAULT_IMAGE_TOKEN = '<image>'
24
- DEFAULT_VIDEO_TOKEN = '<video>'
25
-
26
-
27
- class VideoCCAM(nn.Module):
28
-
29
- def __init__(
30
- self,
31
- model_path: str,
32
- chat_template: str,
33
- generation_args: dict,
34
- llm_name_or_path: str = None,
35
- visual_encoder_name_or_path: str = None,
36
- special_tokens: list[str] = None,
37
- visual_select_layer: int = -2,
38
- torch_dtype: torch.dtype = torch.float16,
39
- device_map: str = 'cuda:0'
40
- ):
41
- super().__init__()
42
- self.chat_template = chat_template
43
- self.generation_args = generation_args
44
- self.visual_select_layer = visual_select_layer
45
- self.torch_dtype = torch_dtype
46
- self.device_map = device_map
47
-
48
- if llm_name_or_path is None:
49
- llm_name_or_path = model_path
50
- if visual_encoder_name_or_path is None:
51
- visual_encoder_name_or_path = osp.join(model_path, 'visual_encoder')
52
- assert osp.exists(visual_encoder_name_or_path), f'{visual_encoder_name_or_path} does not exist, you have to specify `visual_encoder_name_or_path`'
53
- projector_path = osp.join(model_path, 'projector')
54
- assert osp.exists(projector_path), f'{projector_path} does not exist, you have to change `model_path`'
55
-
56
- self.llm = AutoModelForCausalLM.from_pretrained(
57
- llm_name_or_path,
58
- trust_remote_code=True,
59
- torch_dtype=torch_dtype,
60
- device_map=device_map
61
- )
62
- self.tokenizer = AutoTokenizer.from_pretrained(
63
- llm_name_or_path,
64
- trust_remote_code=True
65
- )
66
- print(f'Load LLM from {llm_name_or_path}')
67
- if special_tokens is not None:
68
- self.llm.resize_token_embeddings(self.llm.get_input_embeddings().weight.size(0) + len(special_tokens))
69
- self.llm.requires_grad_(False)
70
- self.llm.get_input_embeddings().weight[-len(special_tokens):].zero_()
71
- self.tokenizer.add_tokens(special_tokens, special_tokens=True)
72
- print(f'Add special_tokens {special_tokens} to LLM and tokenizer')
73
- if osp.exists(adapter_path := osp.join(model_path, 'llm_adapter')):
74
- self.llm = PeftModel.from_pretrained(self.llm, adapter_path)
75
- print(f'Load LLM adapter from {adapter_path}')
76
- self.generation_args['eos_token_id'] = self.tokenizer.convert_tokens_to_ids(self.generation_args.pop('stop_tokens'))
77
-
78
- self.visual_encoder = SiglipVisionModel.from_pretrained(
79
- visual_encoder_name_or_path,
80
- torch_dtype=torch_dtype,
81
- device_map=device_map
82
- )
83
- self.image_processor = SiglipImageProcessor.from_pretrained(visual_encoder_name_or_path)
84
- print(f'Load SigLIP visual encoder from {visual_encoder_name_or_path}')
85
- if osp.exists(adapter_path := osp.join(model_path, 'visual_encoder_adapter')):
86
- self.visual_encoder = PeftModel.from_pretrained(self.visual_encoder, adapter_path)
87
- print(f'Load visual_encoder adapter from {adapter_path}')
88
-
89
- self.projector = AutoModel.from_pretrained(
90
- projector_path,
91
- torch_dtype=torch_dtype,
92
- device_map=device_map,
93
- trust_remote_code=True
94
- )
95
- print(f'Load projector from {projector_path}')
96
-
97
- # Modified from https://github.com/InternLM/xtuner/blob/main/xtuner/model/utils.py#L138
98
- def prepare_inputs_labels_for_multimodal(
99
- self,
100
- input_ids: torch.LongTensor = None,
101
- position_ids: torch.LongTensor = None,
102
- attention_mask: torch.Tensor = None,
103
- past_key_values: list[torch.FloatTensor] = None,
104
- labels: torch.LongTensor = None,
105
- pixel_values: torch.FloatTensor = None
106
- ):
107
- if pixel_values is None:
108
- return {
109
- 'input_ids': input_ids,
110
- 'position_ids': position_ids,
111
- 'attention_mask': attention_mask,
112
- 'past_key_values': past_key_values,
113
- 'inputs_embeds': None,
114
- 'labels': labels
115
- }
116
-
117
- _labels = labels
118
- _position_ids = position_ids
119
- _attention_mask = attention_mask
120
- if attention_mask is None:
121
- if isinstance(input_ids, torch.Tensor):
122
- attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
123
- elif isinstance(input_ids, list):
124
- attention_mask = [torch.ones_like(i, dtype=torch.bool) for i in input_ids]
125
- _attention_mask = attention_mask
126
- else:
127
- raise ValueError(f'Do not support {type(input_ids)} type as input_ids')
128
- else:
129
- attention_mask = attention_mask.bool()
130
- if position_ids is None:
131
- position_ids = torch.arange(
132
- 0, input_ids[0].shape[0], dtype=torch.long, device=input_ids[0].device)
133
- if labels is None:
134
- if isinstance(input_ids, torch.Tensor):
135
- labels = torch.full_like(input_ids, IGNORE_INDEX)
136
- elif isinstance(input_ids, list):
137
- labels = [torch.full_like(i, IGNORE_INDEX) for i in input_ids]
138
- else:
139
- raise ValueError(f'Do not support {type(input_ids)} type as input_ids')
140
-
141
- # remove the padding using attention_mask -- TODO: double check
142
- input_ids = [
143
- cur_input_ids[cur_attention_mask]
144
- for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
145
- ]
146
- labels = [
147
- cur_labels[cur_attention_mask]
148
- for cur_labels, cur_attention_mask in zip(labels, attention_mask)
149
- ]
150
-
151
- new_inputs_embeds = []
152
- new_labels = []
153
- cur_image_idx = 0
154
- for batch_idx, cur_input_ids in enumerate(input_ids):
155
- num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
156
- if num_images == 0:
157
- cur_pixel_values = pixel_values[cur_image_idx]
158
- cur_inputs_embeds_1 = self.llm.get_input_embeddings()(cur_input_ids)
159
- cur_inputs_embeds = torch.cat(
160
- [cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0)
161
- new_inputs_embeds.append(cur_inputs_embeds)
162
- new_labels.append(labels[batch_idx])
163
- cur_image_idx += 1
164
- continue
165
-
166
- image_token_indices = [-1] + torch.where(
167
- cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [
168
- cur_input_ids.shape[0]
169
- ]
170
- cur_input_ids_noim = []
171
- cur_labels = labels[batch_idx]
172
- cur_labels_noim = []
173
- for i in range(len(image_token_indices) - 1):
174
- cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] +
175
- 1:image_token_indices[i +
176
- 1]])
177
- cur_labels_noim.append(cur_labels[image_token_indices[i] +
178
- 1:image_token_indices[i + 1]])
179
- split_sizes = [x.shape[0] for x in cur_labels_noim]
180
- cur_inputs_embeds = self.llm.get_input_embeddings()(
181
- torch.cat(cur_input_ids_noim))
182
- cur_inputs_embeds_no_im = torch.split(
183
- cur_inputs_embeds, split_sizes, dim=0)
184
- cur_new_inputs_embeds = []
185
- cur_new_labels = []
186
-
187
- for i in range(num_images + 1):
188
- cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i])
189
- cur_new_labels.append(cur_labels_noim[i])
190
- if i < num_images:
191
- cur_pixel_values = pixel_values[cur_image_idx]
192
- cur_image_idx += 1
193
- cur_new_inputs_embeds.append(cur_pixel_values)
194
- cur_new_labels.append(
195
- torch.full((cur_pixel_values.shape[0], ),
196
- IGNORE_INDEX,
197
- device=cur_labels.device,
198
- dtype=cur_labels.dtype))
199
-
200
- cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds)
201
- cur_new_labels = torch.cat(cur_new_labels)
202
-
203
- new_inputs_embeds.append(cur_new_inputs_embeds)
204
- new_labels.append(cur_new_labels)
205
-
206
- # Combine them
207
- max_len = max(x.shape[0] for x in new_inputs_embeds)
208
- batch_size = len(new_inputs_embeds)
209
-
210
- new_inputs_embeds_padded = []
211
- new_labels_padded = torch.full((batch_size, max_len),
212
- IGNORE_INDEX,
213
- dtype=new_labels[0].dtype,
214
- device=new_labels[0].device)
215
- attention_mask = torch.zeros((batch_size, max_len),
216
- dtype=attention_mask[0].dtype,
217
- device=attention_mask[0].device)
218
- position_ids = torch.zeros((batch_size, max_len),
219
- dtype=position_ids.dtype,
220
- device=position_ids.device)
221
-
222
- for i, (cur_new_embed,
223
- cur_new_labels) in enumerate(zip(new_inputs_embeds, new_labels)):
224
- cur_len = cur_new_embed.shape[0]
225
- new_inputs_embeds_padded.append(
226
- torch.cat((cur_new_embed,
227
- torch.zeros((max_len - cur_len, cur_new_embed.shape[1]),
228
- dtype=cur_new_embed.dtype,
229
- device=cur_new_embed.device)),
230
- dim=0))
231
- if cur_len > 0:
232
- new_labels_padded[i, :cur_len] = cur_new_labels
233
- attention_mask[i, :cur_len] = True
234
- position_ids[i, :cur_len] = torch.arange(
235
- 0,
236
- cur_len,
237
- dtype=position_ids.dtype,
238
- device=position_ids.device)
239
-
240
- new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0)
241
-
242
- if _labels is None:
243
- new_labels = None
244
- else:
245
- new_labels = new_labels_padded
246
-
247
- if _attention_mask is None:
248
- attention_mask = None
249
- elif isinstance(_attention_mask, list):
250
- attention_mask = attention_mask.to(dtype=_attention_mask[0].dtype)
251
- else:
252
- attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
253
-
254
- if _position_ids is None:
255
- position_ids = None
256
-
257
- return {
258
- 'input_ids': None,
259
- 'position_ids': position_ids,
260
- 'attention_mask': attention_mask,
261
- 'past_key_values': past_key_values,
262
- 'inputs_embeds': new_inputs_embeds,
263
- 'labels': new_labels
264
- }
265
-
266
- def generate(
267
- self,
268
- texts: list[str],
269
- videos: list[list[Image.Image]] = None,
270
- pixel_values: torch.Tensor = None,
271
- return_pixel_values: bool = False
272
- ) -> list[str] | tuple[list[str], torch.Tensor]:
273
- """Genrate respoonse for video and text inputs.
274
-
275
- Args:
276
- text (list[str]): list of text inputs
277
- video (list[list[Image.Image]], optional): list of frame list. Defaults to None.
278
- pixel_values (torch.Tensor, optional): precomputed pixel_values. Defaults to None.
279
- return_pixel_values (bool, optional): whether return pixel values or not. Defaults to False.
280
-
281
- Returns:
282
- list[str]: _description_
283
- """
284
- prediction = []
285
- # Get visual embeddings
286
- if pixel_values is None:
287
- frames, split_sizes = [], []
288
- for i in videos:
289
- frames += i
290
- split_sizes.append(len(i))
291
- pixel_values = self.image_processor(frames, return_tensors='pt')['pixel_values'].to(self.torch_dtype).to(self.device_map)
292
- pixel_values = self.visual_encoder(pixel_values, output_hidden_states=True).hidden_states[self.visual_select_layer]
293
- pixel_values = self.projector(pixel_values, split_sizes)
294
-
295
- for i, t in enumerate(texts):
296
- et = self.chat_template.format(input=t).replace(DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN).split(DEFAULT_IMAGE_TOKEN)
297
- assert len(et) == 2, f'Wrong input formats for {t}'
298
- input_ids = [torch.tensor(self.tokenizer.encode(et[0]) + [IMAGE_TOKEN_INDEX] + self.tokenizer.encode(et[1], add_special_tokens=False), device=self.device_map)]
299
- mm_inputs = self.prepare_inputs_labels_for_multimodal(
300
- input_ids=input_ids,
301
- pixel_values=pixel_values[i:i+1]
302
- )
303
- generate_output = self.llm.generate(
304
- **mm_inputs,
305
- **self.generation_args
306
- )[0]
307
- prediction.append(self.tokenizer.decode(generate_output, skip_special_tokens=True))
308
-
309
- if return_pixel_values:
310
- return prediction, pixel_values
311
- else:
312
- return prediction