ccclemenfff commited on
Commit
cf932d8
·
1 Parent(s): 9e99e54

Add model code supports

Browse files
Files changed (38) hide show
  1. .idea/.gitignore +8 -0
  2. .idea/embodied_explainer.iml +12 -0
  3. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  4. .idea/modules.xml +8 -0
  5. .idea/vcs.xml +6 -0
  6. handler.py +52 -0
  7. inference.py +415 -0
  8. requirements.txt +13 -0
  9. robohusky/.DS_Store +0 -0
  10. robohusky/base_dataset.py +226 -0
  11. robohusky/base_dataset_uni.py +434 -0
  12. robohusky/compression.py +230 -0
  13. robohusky/configuration_husky.py +326 -0
  14. robohusky/constants.py +47 -0
  15. robohusky/conversation.py +511 -0
  16. robohusky/convert_fp16.py +27 -0
  17. robohusky/convert_husky_fp16.py +28 -0
  18. robohusky/convert_reward_fp16.py +27 -0
  19. robohusky/dist_utils.py +100 -0
  20. robohusky/llama2_flash_attn_monkey_patch.py +232 -0
  21. robohusky/model/__init__.py +70 -0
  22. robohusky/model/__pycache__/__init__.cpython-38.pyc +0 -0
  23. robohusky/model/__pycache__/configuration_husky.cpython-38.pyc +0 -0
  24. robohusky/model/__pycache__/modeling_husky_embody2.cpython-38.pyc +0 -0
  25. robohusky/model/compression.py +0 -0
  26. robohusky/model/configuration_husky.py +331 -0
  27. robohusky/model/configuration_husky_ori.py +327 -0
  28. robohusky/model/modeling_husky.py +1820 -0
  29. robohusky/model/modeling_husky_embody2.py +1962 -0
  30. robohusky/model/modeling_husky_embody2_ori.py +1821 -0
  31. robohusky/model/processing_husky.py +178 -0
  32. robohusky/train/.DS_Store +0 -0
  33. robohusky/train/llama_flash_attn_monkey_patch.py +232 -0
  34. robohusky/train/llama_rmsnorm_monkey_patch.py +15 -0
  35. robohusky/train/train.py +597 -0
  36. robohusky/train/train_uni.py +603 -0
  37. robohusky/utils.py +238 -0
  38. robohusky/video_transformers.py +406 -0
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # 默认忽略的文件
2
+ /shelf/
3
+ /workspace.xml
4
+ # 基于编辑器的 HTTP 客户端请求
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/embodied_explainer.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="GOOGLE" />
10
+ <option name="myDocStringFormat" value="Google" />
11
+ </component>
12
+ </module>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/embodied_explainer.iml" filepath="$PROJECT_DIR$/.idea/embodied_explainer.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
handler.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, Any
3
+ from PIL import Image
4
+ from io import BytesIO
5
+
6
+ from inference import Chat # 直接import你放的inference.py里Chat类
7
+ from robohusky.conversation import get_conv_template
8
+
9
+ class EndpointHandler:
10
+ def __init__(self, path: str = "."):
11
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ self.chat = Chat(
13
+ model_path=path,
14
+ device=self.device,
15
+ num_gpus=1,
16
+ max_new_tokens=1024,
17
+ load_8bit=False
18
+ )
19
+ self.vision_feature = None
20
+ self.modal_type = "text"
21
+ self.conv = get_conv_template("husky").copy()
22
+
23
+ def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
24
+ query = inputs.get("inputs", "")
25
+ self.conv = get_conv_template("husky").copy()
26
+ self.vision_feature = None
27
+ self.modal_type = "text"
28
+
29
+ if "image" in inputs:
30
+ image_bytes = inputs["image"]
31
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
32
+ image.save("temp.jpg")
33
+ self.vision_feature = self.chat.get_image_embedding("temp.jpg")
34
+ self.modal_type = "image"
35
+
36
+ elif "video" in inputs:
37
+ video_bytes = inputs["video"]
38
+ with open("temp.mp4", "wb") as f:
39
+ f.write(video_bytes)
40
+ self.vision_feature = self.chat.get_video_embedding("temp.mp4")
41
+ self.modal_type = "video"
42
+
43
+ return {"query": query}
44
+
45
+ def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
46
+ processed = self.preprocess(inputs)
47
+ query = processed["query"]
48
+
49
+ conversations = self.chat.ask(text=query, conv=self.conv, modal_type=self.modal_type)
50
+ outputs = self.chat.answer(conversations, self.vision_feature, modal_type=self.modal_type)
51
+ self.conv.messages[-1][1] = outputs.strip()
52
+ return {"output": outputs.strip()}
inference.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ srun -p INTERN2 --job-name='husky_multi_test' --gres=gpu:1 --cpus-per-task=8 --quotatype="auto" python -u demo/inference_new.py
3
+ """
4
+
5
+ import abc
6
+ from typing import Optional
7
+
8
+ import os
9
+ import requests
10
+ from PIL import Image
11
+ from io import BytesIO
12
+
13
+ import torch
14
+ import torchvision.transforms as T
15
+ from peft import PeftModel
16
+ from torchvision.transforms.functional import InterpolationMode
17
+
18
+ from transformers import (
19
+ LlamaTokenizer,
20
+ GenerationConfig,
21
+ StoppingCriteria,
22
+ StoppingCriteriaList,
23
+ )
24
+
25
+ from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
26
+
27
+ from robohusky.conversation import (
28
+ conv_templates,
29
+ get_conv_template,
30
+ )
31
+
32
+ from robohusky.video_transformers import (
33
+ GroupNormalize,
34
+ GroupScale,
35
+ GroupCenterCrop,
36
+ Stack,
37
+ ToTorchFormatTensor,
38
+ get_index,
39
+ )
40
+
41
+ from robohusky.compression import compress_module
42
+ from decord import VideoReader, cpu
43
+
44
+ # import deepspeed
45
+
46
+ IGNORE_INDEX = -100
47
+ DEFAULT_UNK_TOKEN = "<unk>"
48
+ DEFAULT_IMG_START_TOKEN = "<img>"
49
+ DEFAULT_IMG_END_TOKEN = "</img>"
50
+
51
+ DEFAULT_VIDEO_START_TOKEN = "<vid>"
52
+ DEFAULT_VIDEO_END_TOKEN = "</vid>"
53
+
54
+ def get_gpu_memory(max_gpus=None):
55
+ gpu_memory = []
56
+ num_gpus = (
57
+ torch.cuda.device_count()
58
+ if max_gpus is None
59
+ else min(max_gpus, torch.cuda.device_count())
60
+ )
61
+
62
+ for gpu_id in range(num_gpus):
63
+ with torch.cuda.device(gpu_id):
64
+ device = torch.cuda.current_device()
65
+ gpu_properties = torch.cuda.get_device_properties(device)
66
+ total_memory = gpu_properties.total_memory / (1024 ** 3)
67
+ allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)
68
+ available_memory = total_memory - allocated_memory
69
+ gpu_memory.append(available_memory)
70
+ return gpu_memory
71
+
72
+ def load_model(
73
+ model_path, device, num_gpus, max_gpu_memory=None, load_8bit=False, lora_weights=None
74
+ ):
75
+ if device == "cpu":
76
+ kwargs = {}
77
+ elif device == "cuda":
78
+ kwargs = {"torch_dtype": torch.float16}
79
+ if num_gpus == "auto":
80
+ kwargs["device_map"] = "auto"
81
+ else:
82
+ num_gpus = int(num_gpus)
83
+ if num_gpus != 1:
84
+ kwargs["device_map"] = "auto"
85
+ if max_gpu_memory is None:
86
+ kwargs[
87
+ "device_map"
88
+ ] = "sequential" # This is important for not the same VRAM sizes
89
+ available_gpu_memory = get_gpu_memory(num_gpus)
90
+ kwargs["max_memory"] = {
91
+ i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
92
+ for i in range(num_gpus)
93
+ }
94
+ else:
95
+ kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
96
+ else:
97
+ raise ValueError(f"Invalid device: {device}")
98
+
99
+ tokenizer = LlamaTokenizer.from_pretrained(
100
+ model_path, use_fast=False)
101
+
102
+ if lora_weights is None:
103
+ model = HuskyForConditionalGeneration.from_pretrained(
104
+ model_path, low_cpu_mem_usage=True, **kwargs
105
+ )
106
+ else:
107
+ kwargs["device_map"] = "auto"
108
+ model = HuskyForConditionalGeneration.from_pretrained(
109
+ model_path, low_cpu_mem_usage=True, **kwargs
110
+ )
111
+ model.language_model = PeftModel.from_pretrained(
112
+ model.language_model,
113
+ lora_weights,
114
+ **kwargs
115
+ )
116
+
117
+ if load_8bit:
118
+ compress_module(model, device)
119
+
120
+ if (device == "cuda" and num_gpus == 1) or device == "mps":
121
+ model.to(device)
122
+
123
+ model = model.eval()
124
+ return model, tokenizer
125
+
126
+ def load_image(image_file, input_size=224):
127
+ if image_file.startswith('http') or image_file.startswith('https'):
128
+ response = requests.get(image_file)
129
+ image = Image.open(BytesIO(response.content)).convert('RGB')
130
+ else:
131
+ image = Image.open(image_file).convert('RGB')
132
+
133
+ crop_pct = 224 / 256
134
+ size = int(input_size / crop_pct)
135
+ transform = T.Compose([
136
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
137
+ T.Resize(size, interpolation=InterpolationMode.BICUBIC),
138
+ T.CenterCrop(input_size),
139
+ T.ToTensor(),
140
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
141
+ ])
142
+ image = transform(image)
143
+ return image
144
+
145
+ def load_video(video_path, num_segments=8):
146
+ vr = VideoReader(video_path, ctx=cpu(0))
147
+ num_frames = len(vr)
148
+ frame_indices = get_index(num_frames, num_segments)
149
+
150
+ # transform
151
+ crop_size = 224
152
+ scale_size = 224
153
+ input_mean = [0.48145466, 0.4578275, 0.40821073]
154
+ input_std = [0.26862954, 0.26130258, 0.27577711]
155
+
156
+ transform = T.Compose([
157
+ GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
158
+ GroupCenterCrop(crop_size),
159
+ Stack(),
160
+ ToTorchFormatTensor(),
161
+ GroupNormalize(input_mean, input_std)
162
+ ])
163
+
164
+ images_group = list()
165
+ for frame_index in frame_indices:
166
+ img = Image.fromarray(vr[frame_index].asnumpy())
167
+ images_group.append(img)
168
+ video = transform(images_group)
169
+ return video
170
+
171
+ class StoppingCriteriaSub(StoppingCriteria):
172
+
173
+ def __init__(self, stops, encounters=1):
174
+ super().__init__()
175
+ self.stops = stops
176
+
177
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
178
+ for stop in self.stops:
179
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
180
+ return True
181
+
182
+ return False
183
+
184
+ @torch.inference_mode()
185
+ def generate_stream(
186
+ model, tokenizer, image_processor, params, device
187
+ ):
188
+ prompt = params["prompt"]
189
+ images = params.get("images", None)
190
+ videos = params.get("videos", None)
191
+ temperature = float(params.get("temperature", 0.7))
192
+ max_new_tokens = int(params.get("max_new_tokens", 1024))
193
+
194
+ num_queries = model.config.num_query_tokens
195
+
196
+ stop_words = ["Human: ", "Assistant: ", "###", "\n\n"]
197
+ stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
198
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
199
+
200
+ generation_config = GenerationConfig(
201
+ bos_token_id=1,
202
+ do_sample=True,
203
+ temperature=temperature,
204
+ max_new_tokens=max_new_tokens,
205
+ stopping_criteria=stopping_criteria
206
+ )
207
+
208
+ pixel_values = None
209
+ if images is not None:
210
+ pixel_values = load_image(images).to(device) # only support one image
211
+ image_query = DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN
212
+ prompt = prompt.replace("<image>", image_query)
213
+
214
+ elif videos is not None:
215
+ pixel_values = load_video(videos).to(device)
216
+ video_query = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN
217
+ prompt = prompt.replace("<video>", video_query)
218
+
219
+ model_inputs = tokenizer([prompt], return_tensors="pt")
220
+ model_inputs.pop("token_type_ids", None)
221
+
222
+ if pixel_values is not None:
223
+ model_inputs["pixel_values"] = pixel_values
224
+
225
+ generation_output = model.generate(
226
+ **model_inputs,
227
+ generation_config=generation_config,
228
+ return_dict_in_generate=True,
229
+ output_scores=True
230
+ )
231
+ else:
232
+ generation_output = model.language_model.generate(
233
+ **model_inputs,
234
+ generation_config=generation_config,
235
+ return_dict_in_generate=True,
236
+ output_scores=True
237
+ )
238
+
239
+ preds = generation_output.sequences
240
+ outputs = tokenizer.batch_decode(preds, skip_special_tokens=True)
241
+ return outputs
242
+
243
+ class Chat:
244
+ def __init__(
245
+ self,
246
+ model_path,
247
+ device,
248
+ num_gpus=1,
249
+ load_8bit=False,
250
+ temperature=0.7,
251
+ max_new_tokens=512,
252
+ lora_path=None,
253
+ ):
254
+ model, tokenizer = load_model(
255
+ model_path, device, num_gpus, load_8bit=load_8bit, lora_weights=lora_path
256
+ )
257
+
258
+ self.model = model
259
+ # self.model.language_model = deepspeed.init_inference(
260
+ # self.model.language_model, mp_size=1, dtype=torch.float16, checkpoint=None, replace_with_kernel_inject=True)
261
+ self.tokenizer = tokenizer
262
+ num_queries = model.config.num_query_tokens
263
+
264
+ self.device = device
265
+ self.dtype = model.dtype
266
+
267
+ stop_words = ["Human: ", "Assistant: ", "###", "\n\n"]
268
+ stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
269
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
270
+
271
+ self.conv = get_conv_template("husky")
272
+
273
+ self.image_query = DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN
274
+ self.video_query = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN
275
+
276
+ self.generation_config = GenerationConfig(
277
+ bos_token_id=1,
278
+ do_sample=True,
279
+ top_k=20,
280
+ top_p=0.9,
281
+ temperature=temperature,
282
+ max_new_tokens=max_new_tokens,
283
+ stopping_criteria=stopping_criteria
284
+ )
285
+
286
+ def ask(self, text, conv, modal_type="image"):
287
+ assert modal_type in ["text", "image", "video"]
288
+ conversations = []
289
+
290
+ if len(conv.messages) > 0 or modal_type == "text":
291
+ conv.append_message(conv.roles[0], text)
292
+ elif modal_type == "image":
293
+ conv.append_message(conv.roles[0], self.image_query + "\n" + text)
294
+ else:
295
+ conv.append_message(conv.roles[0], self.video_query + "\n" + text)
296
+
297
+ conv.append_message(conv.roles[1], None)
298
+ conversations.append(conv.get_prompt())
299
+ return conversations
300
+
301
+ @torch.no_grad()
302
+ def get_image_embedding(self, image_file):
303
+ pixel_values = load_image(image_file)
304
+ pixel_values = pixel_values.unsqueeze(0).to(self.device, dtype=self.dtype)
305
+ language_model_inputs = self.model.extract_feature(pixel_values)
306
+ return language_model_inputs
307
+
308
+ @torch.no_grad()
309
+ def get_video_embedding(self, video_file):
310
+ pixel_values = load_video(video_file)
311
+ TC, H, W = pixel_values.shape
312
+ pixel_values = pixel_values.reshape(TC // 3, 3, H, W).transpose(0, 1) # [C, T, H, W]
313
+ pixel_values = pixel_values.unsqueeze(0).to(self.device, dtype=self.dtype)
314
+ assert len(pixel_values.shape) == 5
315
+ language_model_inputs = self.model.extract_feature(pixel_values)
316
+ return language_model_inputs
317
+
318
+ @torch.no_grad()
319
+ def answer(self, conversations, language_model_inputs, modal_type="image"):
320
+ model_inputs = self.tokenizer(
321
+ conversations,
322
+ return_tensors="pt",
323
+ )
324
+ model_inputs.pop("token_type_ids", None)
325
+
326
+ input_ids = model_inputs["input_ids"].to(self.device)
327
+ attention_mask = model_inputs["attention_mask"].to(self.device)
328
+
329
+ if modal_type == "text":
330
+ generation_output = self.model.language_model.generate(
331
+ input_ids=input_ids,
332
+ attention_mask=attention_mask,
333
+ generation_config=self.generation_config,
334
+ return_dict_in_generate=True,
335
+ output_scores=True
336
+ )
337
+ else:
338
+ pixel_values = model_inputs.pop("pixel_values", None)
339
+ if pixel_values is not None:
340
+ pixel_values = pixel_values.to(self.device)
341
+
342
+ generation_output = self.model.generate(
343
+ pixel_values=pixel_values,
344
+ input_ids=input_ids,
345
+ attention_mask=attention_mask,
346
+ language_model_inputs=language_model_inputs,
347
+ generation_config=self.generation_config,
348
+ return_dict_in_generate=True,
349
+ output_scores=True
350
+ )
351
+
352
+ preds = generation_output.sequences
353
+ outputs = self.tokenizer.batch_decode(preds, skip_special_tokens=True)[0]
354
+
355
+ if modal_type == "text":
356
+ skip_echo_len = len(conversations[0]) - conversations[0].count("</s>") * 3
357
+ outputs = outputs[skip_echo_len:].strip()
358
+
359
+ return outputs
360
+
361
+ if __name__ == '__main__':
362
+ # model_path = "/mnt/petrelfs/zhangqinglong/Documents/Husky/work_dirs/husky_v3/EmbodiedGPT/pretrain_0727"
363
+ model_path = "/mnt/petrelfs/share_data/gvembodied/workdirs/align_new_myyf"
364
+ device = "cuda" if torch.cuda.is_available() else "cpu"
365
+ chat = Chat(model_path, device=device, num_gpus=1, max_new_tokens=1024, load_8bit=False)
366
+
367
+ vision_feature = None
368
+ image_state = False
369
+ video_state = False
370
+
371
+ while True:
372
+ query = input("\n")
373
+ if query.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
374
+ if os.path.exists(query):
375
+ print("received.")
376
+ vision_feature = chat.get_image_embedding(query)
377
+ chat.conv = get_conv_template("husky").copy()
378
+ image_state = True
379
+ continue
380
+ if query.lower().endswith(('.mp4', '.mkv', '.avi', '.wmv', '.iso', ".webm")):
381
+ if os.path.exists(query):
382
+ print("received.")
383
+ vision_feature = chat.get_video_embedding(query)
384
+ chat.conv = get_conv_template("husky").copy()
385
+ video_state = True
386
+ continue
387
+
388
+ if query == "stop":
389
+ break
390
+ if query == "clear" or query == "" or query == "\n":
391
+ chat.conv = get_conv_template("husky").copy()
392
+ image_state = False
393
+ video_state = False
394
+ os.system("clear")
395
+ print("欢迎使用 husky-13b-zh 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
396
+ continue
397
+
398
+ if image_state:
399
+ modal_type = "image"
400
+ elif video_state:
401
+ modal_type = "video"
402
+ else:
403
+ modal_type = "text"
404
+
405
+ # image_test = "assets/husky.jpg"
406
+ # image_test = "assets/yoga.mp4"
407
+ # video_test = "assets/pretty_girl.mp4"
408
+ # video_test = "assets/stock-footage-billiards-concentrated-young-woman-playing-in-club.webm"
409
+ # video_test = "assets/stock-footage-kherson-ukraine-may-open-free-rock-music-festival-crowd-partying-at-a-rock-concert.webm"
410
+ conversations = chat.ask(text=query, conv=chat.conv, modal_type=modal_type)
411
+ outputs = chat.answer(conversations, vision_feature, modal_type=modal_type)
412
+ # NOTE: strip is important to align with the training data.
413
+ chat.conv.messages[-1][1] = outputs.strip()
414
+
415
+ print(f"Husky: \n{outputs}")
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ torchaudio==2.0.2
4
+ transformers==4.34.1
5
+ decord
6
+ peft
7
+ huggingface_hub
8
+ Pillow
9
+ einops
10
+ scipy
11
+ numpy
12
+ tqdm
13
+ flash-attn
robohusky/.DS_Store ADDED
Binary file (6.15 kB). View file
 
robohusky/base_dataset.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ from typing import Dict, Optional, Sequence
5
+ from PIL import PngImagePlugin, Image, ImageFile
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+ import torchvision.transforms as T
10
+ from torchvision.transforms.functional import InterpolationMode
11
+
12
+ from robohusky.train.tcsloader import TCSLoader
13
+ from robohusky.conversation import get_conv_template
14
+
15
+ IGNORE_INDEX = -100
16
+
17
+ Image.MAX_IMAGE_PIXELS = None
18
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
19
+ MaximumDecompressedSize = 1024
20
+ MegaByte = 2 ** 20
21
+ PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
22
+
23
+ DEFAULT_IMG_START_TOKEN = "<img>"
24
+ DEFAULT_IMG_END_TOKEN = "</img>"
25
+
26
+ DEFAULT_VIDEO_START_TOKEN = "<vid>"
27
+ DEFAULT_VIDEO_END_TOKEN = "</vid>"
28
+
29
+ def is_image(image_file):
30
+ if image_file.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
31
+ return True
32
+ else:
33
+ return False
34
+
35
+ def is_video(image_file):
36
+ if image_file.lower().endswith(('.mp4', '.mkv', '.avi', '.wmv', '.iso', ".webm")):
37
+ return True
38
+ else:
39
+ return False
40
+
41
+ def build_transform(input_size):
42
+ transform = T.Compose([
43
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
44
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
45
+ T.ToTensor(),
46
+ T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
47
+ ])
48
+ return transform
49
+
50
+ def format_inputs(sources):
51
+ # Apply prompt templates
52
+ conv = get_conv_template("husky").copy()
53
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
54
+ conversations = []
55
+
56
+ for i, source in enumerate(sources):
57
+ if roles[source[0]["from"]] != conv.roles[0]:
58
+ # Skip the first one if it is not from human
59
+ source = source[1:]
60
+
61
+ conv.messages = []
62
+ for j, sentence in enumerate(source):
63
+ role = roles[sentence["from"]]
64
+ assert role == conv.roles[j % 2], f"{i}"
65
+ # vision is only supported for the human input
66
+ if role == conv.roles[0]:
67
+ value = sentence["value"]
68
+ if "<image>" in value:
69
+ if value.endswith("\n<image>"):
70
+ value = "<image>\n" + value.replace("\n<image>", "")
71
+ image_query = DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN
72
+ sentence["value"] = value.replace("<image>", image_query)
73
+
74
+ elif "<video>" in value:
75
+ if value.endswith("\n<video>"):
76
+ value = "<video>\n" + value.replace("\n<video>", "")
77
+ video_query = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN
78
+ sentence["value"] = value.replace("<video>", video_query)
79
+
80
+ conv.append_message(role, sentence["value"])
81
+ conversations.append(conv.get_prompt())
82
+
83
+ return conversations, conv
84
+
85
+ def process_func(examples, tokenizer, max_seq_length):
86
+ conversations, conv = format_inputs(examples['conversations'])
87
+ model_inputs = tokenizer(
88
+ conversations,
89
+ max_length=max_seq_length,
90
+ padding="max_length",
91
+ truncation=True,
92
+ return_tensors="pt",
93
+ )
94
+
95
+ model_inputs.pop("token_type_ids", None)
96
+ # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
97
+ # padding in the loss.
98
+ targets = model_inputs["input_ids"].clone()
99
+
100
+ # Mask targets
101
+ sep = conv.sep + conv.roles[1] + ": "
102
+ for conversation, target in zip(conversations, targets):
103
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
104
+
105
+ turns = conversation.split(conv.sep2)
106
+ cur_len = 1
107
+ target[:cur_len] = IGNORE_INDEX
108
+ for i, turn in enumerate(turns):
109
+ if turn == "":
110
+ break
111
+ turn_len = len(tokenizer(turn).input_ids)
112
+
113
+ parts = turn.split(sep)
114
+ if len(parts) != 2:
115
+ break
116
+ parts[0] += sep
117
+
118
+ # "-2" is hardcoded for the Llama tokenizer to make the offset correct.
119
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
120
+
121
+ if i != 0 and not tokenizer.legacy:
122
+ # The legacy and non-legacy modes handle special tokens differently
123
+ instruction_len -= 1
124
+
125
+ # Ignore the user instructions
126
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
127
+ cur_len += turn_len
128
+
129
+ if i != 0 and not tokenizer.legacy:
130
+ # The legacy and non-legacy modes handle special tokens differently
131
+ cur_len -= 1
132
+
133
+ target[cur_len:] = IGNORE_INDEX
134
+
135
+ if cur_len < tokenizer.model_max_length:
136
+ if cur_len != total_len:
137
+ target[:] = IGNORE_INDEX
138
+
139
+ model_inputs["labels"] = targets
140
+ return model_inputs
141
+
142
+ class BaseDataset(Dataset):
143
+ def __init__(self, dataset, processor, image_path="", input_size=224):
144
+ super(BaseDataset, self).__init__()
145
+ self.dataset = dataset
146
+ self.image_path = image_path
147
+
148
+ self.transform = build_transform(input_size)
149
+ self.husky_processor = processor
150
+
151
+ self.cached_data_dict = {}
152
+
153
+ def __len__(self):
154
+ return len(self.dataset)
155
+
156
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
157
+ if i in self.cached_data_dict:
158
+ return self.cached_data_dict[i]
159
+
160
+ data = self.dataset[i]
161
+ image_file = data.pop("image", None)
162
+
163
+ if self.image_path != "":
164
+ image_file = os.path.join(self.image_path, image_file)
165
+ if not os.path.exists(image_file):
166
+ return self.__getitem__((i + 1) % len(self.dataset))
167
+ image = Image.open(image_file)
168
+ else:
169
+ image = Image.open(image_file)
170
+
171
+ for k, v in data.items():
172
+ data[k] = [v]
173
+ ret = self.husky_processor(data)
174
+ for k, v in ret.items():
175
+ ret[k] = v[0]
176
+
177
+ pixel_values = self.transform(image)
178
+ ret["pixel_values"] = pixel_values
179
+
180
+ self.cached_data_dict[i] = ret
181
+ return ret
182
+
183
+ class CephDataset(Dataset):
184
+ def __init__(self, dataset, processor, input_size=224):
185
+ super(CephDataset, self).__init__()
186
+ self.dataset = dataset
187
+
188
+ self.transform = build_transform(input_size)
189
+ self.husky_processor = processor
190
+
191
+ conf_path = "./petrelf.conf"
192
+ self.conf_path = os.path.abspath(conf_path)
193
+
194
+ self.initialized = False
195
+ self._init_memcached()
196
+
197
+ def _init_memcached(self):
198
+ if not self.initialized:
199
+ assert self.conf_path is not None
200
+ self.mt_loader = TCSLoader(self.conf_path)
201
+ self.initialized = True
202
+
203
+ def __len__(self):
204
+ return len(self.dataset)
205
+
206
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
207
+ data = self.dataset[i]
208
+ image_file = data.pop("image", None)
209
+
210
+ try:
211
+ image = self.mt_loader(image_file).convert('RGB')
212
+ except (AttributeError, OSError):
213
+ with open("error.txt", 'a') as f:
214
+ f.write(image_file + '\n')
215
+ i = random.randint(0, len(self.dataset))
216
+ return self.__getitem__(i % len(self.dataset))
217
+
218
+ for k, v in data.items():
219
+ data[k] = [v]
220
+
221
+ ret = self.husky_processor(data)
222
+ for k, v in ret.items():
223
+ ret[k] = v[0]
224
+ pixel_values = self.transform(image)
225
+ ret["pixel_values"] = pixel_values
226
+ return ret
robohusky/base_dataset_uni.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ from typing import Dict, Optional, Sequence, Iterator, List, Iterable, Union
5
+ from PIL import PngImagePlugin, Image, ImageFile, ImageOps
6
+
7
+ import numpy as np
8
+
9
+ import torch
10
+ from torch.utils.data import (
11
+ Dataset,
12
+ ConcatDataset,
13
+ Sampler,
14
+ WeightedRandomSampler
15
+ )
16
+ import torchvision.transforms as T
17
+ from torchvision.transforms.functional import InterpolationMode
18
+
19
+ from robohusky.train.tcsloader import TCSLoader
20
+
21
+ from decord import VideoReader, cpu
22
+ from robohusky.video_transformers import (
23
+ GroupNormalize,
24
+ GroupScale,
25
+ GroupCenterCrop,
26
+ Stack,
27
+ ToTorchFormatTensor,
28
+ get_index,
29
+ )
30
+
31
+ from robohusky.conversation import get_conv_template
32
+
33
+ IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
34
+ IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
35
+ IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5]
36
+ IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]
37
+ OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
38
+ OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
39
+
40
+ IGNORE_INDEX = -100
41
+
42
+ Image.MAX_IMAGE_PIXELS = None
43
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
44
+ MaximumDecompressedSize = 1024
45
+ MegaByte = 2 ** 20
46
+ PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
47
+
48
+ DEFAULT_IMG_START_TOKEN = "<img>"
49
+ DEFAULT_IMG_END_TOKEN = "</img>"
50
+
51
+ DEFAULT_VIDEO_START_TOKEN = "<vid>"
52
+ DEFAULT_VIDEO_END_TOKEN = "</vid>"
53
+
54
+ DEFAULT_EMBED_TOKEN = "<quad>"
55
+
56
+ conf_path = "/your path to/petrelf.conf"
57
+
58
+ def is_image(image_file):
59
+ if image_file.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
60
+ return True
61
+ else:
62
+ return False
63
+
64
+ def is_video(image_file):
65
+ if image_file.lower().endswith(('.mp4', '.mkv', '.avi', '.wmv', '.iso', ".webm")):
66
+ return True
67
+ else:
68
+ return False
69
+
70
+ def is_numpy(image_file):
71
+ if image_file.endswith(".npy"):
72
+ return True
73
+ else:
74
+ return False
75
+
76
+ def get_media_type(image_file):
77
+ if is_image(image_file):
78
+ return "image"
79
+ elif is_video(image_file):
80
+ return "video"
81
+ elif is_numpy(image_file):
82
+ return "numpy"
83
+ else:
84
+ return "text"
85
+
86
+ def build_transform(input_size, norm_type="openai", media_type="image"):
87
+ if norm_type == "openai":
88
+ mean = OPENAI_CLIP_MEAN
89
+ std = OPENAI_CLIP_STD
90
+ elif norm_type == "imagenet":
91
+ mean = IMAGENET_DEFAULT_MEAN
92
+ std = IMAGENET_DEFAULT_STD
93
+ else:
94
+ mean = IMAGENET_DEFAULT_MEAN
95
+ std = IMAGENET_DEFAULT_STD
96
+
97
+ if media_type == "image":
98
+ transform = T.Compose([
99
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
100
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
101
+ T.ToTensor(),
102
+ T.Normalize(mean=mean, std=std)
103
+ ])
104
+ elif media_type == "video":
105
+ transform = T.Compose([
106
+ GroupScale(int(input_size), interpolation=InterpolationMode.BICUBIC),
107
+ GroupCenterCrop(input_size),
108
+ Stack(),
109
+ ToTorchFormatTensor(),
110
+ GroupNormalize(mean=mean, std=std)
111
+ ])
112
+ else:
113
+ transform = None
114
+ return transform
115
+
116
+ def check_format(data):
117
+ if not ('id' in data and 'image' in data and 'conversations' in data and len(data['conversations']) % 2 == 0):
118
+ print(f"Lake field: {data}")
119
+ return False
120
+ for i, message in enumerate(data['conversations']):
121
+ if i == 0:
122
+ if not (message['value'].startswith("<image>\n") or message['value'].endswith("\n<image>")):
123
+ print(f"No <image>: {data}")
124
+ return False
125
+ if i % 2 == 0:
126
+ if not (message['from'] == 'human'):
127
+ print(f"Not from human: {data}")
128
+ return False
129
+ else:
130
+ if not (message['from'] == 'gpt'):
131
+ print(f"Not from gpt: {data}")
132
+ return False
133
+ if message['value'] is None or (len(message['value']) == 0):
134
+ print(f"No Message: {data}")
135
+ return False
136
+ return True
137
+
138
+ def format_inputs(sources, conv_tempt="husky", num_query_tokens=256):
139
+ # Apply prompt templates
140
+ conv = get_conv_template(conv_tempt).copy()
141
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
142
+ conversations = []
143
+
144
+ for i, source in enumerate(sources):
145
+ if roles[source[0]["from"]] != conv.roles[0]:
146
+ # Skip the first one if it is not from human
147
+ source = source[1:]
148
+
149
+ conv.messages = []
150
+ for j, sentence in enumerate(source):
151
+ role = roles[sentence["from"]]
152
+ assert role == conv.roles[j % 2], f"{i}"
153
+ # vision is only supported for the human input
154
+ if role == conv.roles[0]:
155
+ value = sentence["value"]
156
+ if "<image>" in value:
157
+ if value.endswith("\n<image>"):
158
+ value = "<image>\n" + value.replace("\n<image>", "")
159
+
160
+ image_query = DEFAULT_IMG_START_TOKEN + num_query_tokens * DEFAULT_EMBED_TOKEN + DEFAULT_IMG_END_TOKEN
161
+ sentence["value"] = value.replace("<image>", image_query)
162
+
163
+ elif "<video>" in value:
164
+ if value.endswith("\n<video>"):
165
+ value = "<video>\n" + value.replace("\n<video>", "")
166
+
167
+ video_query = DEFAULT_VIDEO_START_TOKEN + num_query_tokens * DEFAULT_EMBED_TOKEN + DEFAULT_VIDEO_END_TOKEN
168
+ sentence["value"] = value.replace("<video>", video_query)
169
+
170
+ conv.append_message(role, sentence["value"])
171
+ conversations.append(conv.get_prompt())
172
+
173
+ return conversations, conv
174
+
175
+ def process_func(examples, tokenizer, max_seq_length=-1, conv_tempt="husky", num_query_tokens=256):
176
+ conversations, conv = format_inputs(examples['conversations'], conv_tempt, num_query_tokens)
177
+ if max_seq_length < 0:
178
+ model_inputs = tokenizer(
179
+ conversations,
180
+ return_tensors="pt",
181
+ max_length=tokenizer.model_max_length,
182
+ truncation=True,
183
+ )
184
+ else:
185
+ model_inputs = tokenizer(
186
+ conversations,
187
+ max_length=max_seq_length,
188
+ padding="max_length",
189
+ truncation=True,
190
+ return_tensors="pt",
191
+ )
192
+
193
+ model_inputs.pop("token_type_ids", None)
194
+ # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
195
+ # padding in the loss.
196
+ targets = model_inputs["input_ids"].clone()
197
+
198
+ # Mask targets
199
+ sep = conv.sep + conv.roles[1] + ": "
200
+ for conversation, target in zip(conversations, targets):
201
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
202
+
203
+ turns = conversation.split(conv.sep2)
204
+ cur_len = 1
205
+ target[:cur_len] = IGNORE_INDEX
206
+ for i, turn in enumerate(turns):
207
+ if turn == "":
208
+ break
209
+ turn_len = len(tokenizer(turn).input_ids)
210
+
211
+ parts = turn.split(sep)
212
+ if len(parts) != 2:
213
+ break
214
+ parts[0] += sep
215
+
216
+ # "-2" is hardcoded for the Llama tokenizer to make the offset correct.
217
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
218
+
219
+ if i != 0 and not tokenizer.legacy:
220
+ # The legacy and non-legacy modes handle special tokens differently
221
+ instruction_len -= 1
222
+
223
+ # Ignore the user instructions
224
+ target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
225
+ cur_len += turn_len
226
+
227
+ if i != 0 and not tokenizer.legacy:
228
+ # The legacy and non-legacy modes handle special tokens differently
229
+ cur_len -= 1
230
+
231
+ target[cur_len:] = IGNORE_INDEX
232
+
233
+ if cur_len < tokenizer.model_max_length:
234
+ if cur_len != total_len:
235
+ target[:] = IGNORE_INDEX
236
+
237
+ model_inputs["labels"] = targets
238
+ return model_inputs
239
+
240
+ class BaseDataset(Dataset):
241
+ def __init__(
242
+ self,
243
+ dataset,
244
+ processor,
245
+ image_path="",
246
+ input_size=224,
247
+ num_segments=8,
248
+ norm_type="openai",
249
+ media_type="image"
250
+ ):
251
+ super(BaseDataset, self).__init__()
252
+ self.dataset = dataset
253
+ self.image_path = image_path
254
+ self.input_size = input_size
255
+ self.num_segments = num_segments
256
+
257
+ self.media_type = media_type
258
+ self.transform = build_transform(input_size, norm_type, media_type)
259
+ self.husky_processor = processor
260
+ self.tcs_loader = TCSLoader(os.path.abspath(conf_path), media_type=media_type)
261
+
262
+ self.cached_data_dict = {}
263
+
264
+ def __len__(self):
265
+ return len(self.dataset)
266
+
267
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
268
+ if i in self.cached_data_dict:
269
+ return self.cached_data_dict[i]
270
+
271
+ data = self.dataset[i]
272
+ image_file = data["image"] if "image" in data else data["video"]
273
+
274
+ if self.media_type == "llm" or image_file == "":
275
+ # Pseudo pixel_values
276
+ # pixel_values = torch.zeros(size=(3, self.input_size, self.input_size))
277
+ pixel_values = None
278
+ else:
279
+ if self.image_path != "":
280
+ image_file = os.path.join(self.image_path, image_file)
281
+ if "s3://" not in image_file and not os.path.exists(image_file):
282
+ i = random.randint(0, len(self.dataset))
283
+ return self.__getitem__(i % len(self.dataset))
284
+
285
+ try:
286
+ if self.media_type == "image":
287
+ # load from ceph
288
+ if "s3://" in image_file:
289
+ image = self.tcs_loader(image_file)
290
+ else:
291
+ image = Image.open(image_file).convert('RGB')
292
+
293
+ # process image with extreme aspect ratios
294
+ height, width = image.size
295
+ if height / width >= 1.8:
296
+ delta = height - width
297
+ padding = (0, delta // 2, 0, delta - delta // 2)
298
+ image = ImageOps.expand(image, padding)
299
+ elif height / width <= 0.56:
300
+ delta = width - height
301
+ padding = (delta // 2, 0, delta - delta // 2, 0)
302
+ image = ImageOps.expand(image, padding)
303
+ pixel_values = self.transform(image)
304
+ elif self.media_type == "video":
305
+ if "s3://" in image_file:
306
+ vr = self.tcs_loader(image_file)
307
+ else:
308
+ vr = VideoReader(image_file, ctx=cpu(0))
309
+
310
+ num_frames = len(vr)
311
+ frame_indices = get_index(num_frames, self.num_segments)
312
+ images_group = list()
313
+ for frame_index in frame_indices:
314
+ img = Image.fromarray(vr[frame_index].asnumpy())
315
+ images_group.append(img)
316
+ pixel_values = self.transform(images_group)
317
+ TC, H, W = pixel_values.shape
318
+ pixel_values = pixel_values.reshape(TC // 3, 3, H, W).transpose(0, 1) # [C, T, H, W]
319
+ else:
320
+ # load numpy
321
+ if "s3://" in image_file:
322
+ pixel_values = self.tcs_loader(image_file)
323
+ else:
324
+ pixel_values = np.load(image_file)
325
+ pixel_values = torch.tensor(pixel_values).transpose(0, 1)
326
+ except (AttributeError, OSError):
327
+ with open("error.txt", 'a') as f:
328
+ f.write(image_file + '\n')
329
+ i = random.randint(0, len(self.dataset))
330
+ return self.__getitem__(i % len(self.dataset))
331
+
332
+ for k, v in data.items():
333
+ data[k] = [v]
334
+ ret = self.husky_processor(data)
335
+ for k, v in ret.items():
336
+ ret[k] = v[0]
337
+
338
+ if pixel_values is not None:
339
+ ret["pixel_values"] = pixel_values
340
+
341
+ self.cached_data_dict[i] = ret
342
+ return ret
343
+
344
+ class WeightedConcatDataset(ConcatDataset):
345
+ def __init__(
346
+ self,
347
+ datasets: List[Dataset],
348
+ weights: Sequence[float] = None,
349
+ replacement: bool = True,
350
+ batch_size: int = -1,
351
+ generator=None
352
+ ) -> None:
353
+ super().__init__(datasets)
354
+ if weights is None:
355
+ weights = [1.0] * len(self.datasets)
356
+ weights_tensor = torch.as_tensor(weights, dtype=torch.double)
357
+ if len(weights_tensor.shape) != 1:
358
+ raise ValueError("weights should be a 1d sequence but given "
359
+ "weights have shape {}".format(tuple(weights_tensor.shape)))
360
+ self.weights = weights_tensor
361
+ self.batch_size = batch_size
362
+
363
+ self.replacement = replacement
364
+ self.generator = generator
365
+
366
+ if self.batch_size <= 0:
367
+ self.num_samples = sum([len(d) for d in datasets])
368
+ self.sampler = WeightedRandomSampler(
369
+ weights=self.weights,
370
+ num_samples=self.num_samples,
371
+ replacement=self.replacement
372
+ )
373
+ else:
374
+ self.task_batches = [len(d) // batch_size for d in datasets]
375
+ self.num_samples = sum(self.task_batches) * batch_size
376
+ self.sampler = WeightedBatchSampler(
377
+ weights=self.weights,
378
+ num_samples=self.num_samples,
379
+ batch_size=self.batch_size,
380
+ replacement=self.replacement
381
+ )
382
+
383
+ def __iter__(self) -> Iterator[int]:
384
+ return iter(self.sampler)
385
+
386
+ def __len__(self) -> int:
387
+ return self.num_samples
388
+
389
+ class WeightedBatchSampler(Sampler[int]):
390
+ weights: torch.Tensor
391
+ num_samples: int
392
+ batch_size: int
393
+ replacement: bool
394
+
395
+ def __init__(
396
+ self,
397
+ weights: Sequence[float],
398
+ num_samples: int,
399
+ batch_size: int,
400
+ replacement: bool = True,
401
+ generator=None
402
+ ) -> None:
403
+ if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
404
+ batch_size <= 0:
405
+ raise ValueError("batch_size should be a positive integer value, "
406
+ "but got batch_size={}".format(batch_size))
407
+ if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
408
+ num_samples <= 0:
409
+ raise ValueError("num_samples should be a positive integer "
410
+ "value, but got num_samples={}".format(num_samples))
411
+ if not isinstance(replacement, bool):
412
+ raise ValueError("replacement should be a boolean value, but got "
413
+ "replacement={}".format(replacement))
414
+
415
+ weights_tensor = torch.as_tensor(weights, dtype=torch.double)
416
+ if len(weights_tensor.shape) != 1:
417
+ raise ValueError("weights should be a 1d sequence but given "
418
+ "weights have shape {}".format(tuple(weights_tensor.shape)))
419
+
420
+ self.weights = weights_tensor
421
+ self.num_samples = num_samples
422
+ self.batch_size = batch_size
423
+ self.num_batches = num_samples // batch_size
424
+ self.replacement = replacement
425
+ self.generator = generator
426
+
427
+ def __iter__(self) -> Iterator[int]:
428
+ rand_tensor = torch.multinomial(self.weights, self.num_batches, self.replacement, generator=self.generator)
429
+ rand_tensor = rand_tensor.repeat_interleave(self.batch_size)
430
+
431
+ yield from iter(rand_tensor.tolist())
432
+
433
+ def __len__(self) -> int:
434
+ return self.num_samples
robohusky/compression.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import gc
3
+ import glob
4
+ import os
5
+
6
+ from accelerate import init_empty_weights
7
+ from accelerate.utils import set_module_tensor_to_device
8
+ import torch
9
+ from torch import Tensor
10
+ from torch.nn import functional as F
11
+ import torch.nn as nn
12
+ from tqdm import tqdm
13
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class CompressionConfig:
18
+ """Group-wise quantization."""
19
+
20
+ num_bits: int
21
+ group_size: int
22
+ group_dim: int
23
+ symmetric: bool
24
+ enabled: bool = True
25
+
26
+
27
+ default_compression_config = CompressionConfig(
28
+ num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True
29
+ )
30
+
31
+
32
+ class CLinear(nn.Module):
33
+ """Compressed Linear Layer."""
34
+
35
+ def __init__(self, weight=None, bias=None, device=None):
36
+ super().__init__()
37
+ if weight is None:
38
+ self.weight = None
39
+ elif isinstance(weight, Tensor):
40
+ self.weight = compress(weight.data.to(device), default_compression_config)
41
+ else:
42
+ self.weight = weight
43
+ self.bias = bias
44
+
45
+ def forward(self, input: Tensor) -> Tensor:
46
+ weight = decompress(self.weight, default_compression_config)
47
+ if self.bias is None:
48
+ return F.linear(input.to(weight.dtype), weight)
49
+ return F.linear(input.to(weight.dtype), weight, self.bias.to(weight.dtype))
50
+
51
+
52
+ def compress_module(module, target_device):
53
+ for attr_str in dir(module):
54
+ target_attr = getattr(module, attr_str)
55
+ if type(target_attr) == torch.nn.Linear:
56
+ setattr(
57
+ module,
58
+ attr_str,
59
+ CLinear(target_attr.weight, target_attr.bias, target_device),
60
+ )
61
+ for name, child in module.named_children():
62
+ compress_module(child, target_device)
63
+
64
+
65
+ def get_compressed_list(module, prefix=""):
66
+ compressed_list = []
67
+ for attr_str in dir(module):
68
+ target_attr = getattr(module, attr_str)
69
+ if type(target_attr) == torch.nn.Linear:
70
+ full_name = (
71
+ f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight"
72
+ )
73
+ compressed_list.append(full_name)
74
+ for name, child in module.named_children():
75
+ child_prefix = f"{prefix}.{name}" if prefix else name
76
+ for each in get_compressed_list(child, child_prefix):
77
+ compressed_list.append(each)
78
+ return compressed_list
79
+
80
+
81
+ def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""):
82
+ for attr_str in dir(module):
83
+ target_attr = getattr(module, attr_str)
84
+ if type(target_attr) == torch.nn.Linear:
85
+ full_name = (
86
+ f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight"
87
+ )
88
+ setattr(
89
+ module,
90
+ attr_str,
91
+ CLinear(
92
+ compressed_state_dict[full_name], target_attr.bias, target_device
93
+ ),
94
+ )
95
+ for name, child in module.named_children():
96
+ child_prefix = f"{prefix}.{name}" if prefix else name
97
+ apply_compressed_weight(
98
+ child, compressed_state_dict, target_device, child_prefix
99
+ )
100
+
101
+
102
+ def load_compress_model(model_path, device, torch_dtype, use_fast=False):
103
+ # partially load model
104
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast)
105
+ base_pattern = os.path.join(model_path, "pytorch_model*.bin")
106
+ files = glob.glob(base_pattern)
107
+
108
+ with init_empty_weights():
109
+ config = AutoConfig.from_pretrained(
110
+ model_path, low_cpu_mem_usage=True, torch_dtype=torch_dtype
111
+ )
112
+ model = AutoModelForCausalLM.from_config(config)
113
+ linear_weights = get_compressed_list(model)
114
+
115
+ compressed_state_dict = {}
116
+
117
+ for filename in tqdm(files):
118
+ tmp_state_dict = torch.load(filename)
119
+ for name in tmp_state_dict:
120
+ if name in linear_weights:
121
+ tensor = tmp_state_dict[name].to(device).data.to(torch_dtype)
122
+ compressed_state_dict[name] = compress(
123
+ tensor, default_compression_config
124
+ )
125
+ else:
126
+ compressed_state_dict[name] = tmp_state_dict[name].to(device)
127
+ tmp_state_dict[name] = None
128
+ tensor = None
129
+ gc.collect()
130
+ torch.cuda.empty_cache()
131
+
132
+ for name in model.state_dict():
133
+ if name not in linear_weights:
134
+ set_module_tensor_to_device(
135
+ model, name, device, value=compressed_state_dict[name]
136
+ )
137
+ apply_compressed_weight(model, compressed_state_dict, device)
138
+
139
+ model.to(device)
140
+
141
+ return model, tokenizer
142
+
143
+
144
+ def compress(tensor, config):
145
+ """Simulate group-wise quantization."""
146
+ if not config.enabled:
147
+ return tensor
148
+
149
+ group_size, num_bits, group_dim, symmetric = (
150
+ config.group_size,
151
+ config.num_bits,
152
+ config.group_dim,
153
+ config.symmetric,
154
+ )
155
+ assert num_bits <= 8
156
+
157
+ original_shape = tensor.shape
158
+ num_groups = (original_shape[group_dim] + group_size - 1) // group_size
159
+ new_shape = (
160
+ original_shape[:group_dim]
161
+ + (num_groups, group_size)
162
+ + original_shape[group_dim + 1 :]
163
+ )
164
+
165
+ # Pad
166
+ pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
167
+ if pad_len != 0:
168
+ pad_shape = (
169
+ original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :]
170
+ )
171
+ tensor = torch.cat(
172
+ [tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
173
+ dim=group_dim,
174
+ )
175
+ data = tensor.view(new_shape)
176
+
177
+ # Quantize
178
+ if symmetric:
179
+ B = 2 ** (num_bits - 1) - 1
180
+ scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0]
181
+ data = data * scale
182
+ data = data.clamp_(-B, B).round_().to(torch.int8)
183
+ return data, scale, original_shape
184
+ else:
185
+ B = 2**num_bits - 1
186
+ mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
187
+ mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]
188
+
189
+ scale = B / (mx - mn)
190
+ data = data - mn
191
+ data.mul_(scale)
192
+
193
+ data = data.clamp_(0, B).round_().to(torch.uint8)
194
+ return data, mn, scale, original_shape
195
+
196
+
197
+ def decompress(packed_data, config):
198
+ """Simulate group-wise dequantization."""
199
+ if not config.enabled:
200
+ return packed_data
201
+
202
+ group_size, num_bits, group_dim, symmetric = (
203
+ config.group_size,
204
+ config.num_bits,
205
+ config.group_dim,
206
+ config.symmetric,
207
+ )
208
+
209
+ # Dequantize
210
+ if symmetric:
211
+ data, scale, original_shape = packed_data
212
+ data = data / scale
213
+ else:
214
+ data, mn, scale, original_shape = packed_data
215
+ data = data / scale
216
+ data.add_(mn)
217
+
218
+ # Unpad
219
+ pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
220
+ if pad_len:
221
+ padded_original_shape = (
222
+ original_shape[:group_dim]
223
+ + (original_shape[group_dim] + pad_len,)
224
+ + original_shape[group_dim + 1 :]
225
+ )
226
+ data = data.reshape(padded_original_shape)
227
+ indices = [slice(0, x) for x in original_shape]
228
+ return data[indices].contiguous()
229
+ else:
230
+ return data.view(original_shape)
robohusky/configuration_husky.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Husky model configuration"""
16
+
17
+ import copy
18
+ import os
19
+ from typing import Union
20
+
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
23
+ from transformers.utils import logging
24
+ from transformers.models.auto import CONFIG_MAPPING
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+ HUSKY_PRETRAINED_CONFIG_ARCHIVE_MAP = {
29
+ "wofmanaf/husky-7b": "https://huggingface.co/wofmanaf/husky-7b/resolve/main/config.json",
30
+ }
31
+
32
+ class HuskyVisionConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a [`HuskyVisionModel`]. It is used to
35
+ instantiate a Husky vision encoder according to the specified arguments, defining the model architecture.
36
+ Instantiating a configuration defaults will yield a similar configuration to that of the Husky architecture.
37
+
38
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
39
+ documentation from [`PretrainedConfig`] for more information.
40
+
41
+ Args:
42
+ hidden_size (`int`, *optional*, defaults to 1408):
43
+ Dimensionality of the encoder layers and the pooler layer.
44
+ intermediate_size (`int`, *optional*, defaults to 6144):
45
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
46
+ num_hidden_layers (`int`, *optional*, defaults to 39):
47
+ Number of hidden layers in the Transformer encoder.
48
+ num_attention_heads (`int`, *optional*, defaults to 16):
49
+ Number of attention heads for each attention layer in the Transformer encoder.
50
+ image_size (`int`, *optional*, defaults to 224):
51
+ The size (resolution) of each image.
52
+ patch_size (`int`, *optional*, defaults to 14):
53
+ The size (resolution) of each patch.
54
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
55
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
56
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults
57
+ to 1e-5): The epsilon used by the layer normalization layers.
58
+ dropout (`float`, *optional*, defaults to 0.0):
59
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
60
+ attention_dropout (`float`, *optional*, defaults to 0.0):
61
+ The dropout ratio for the attention probabilities.
62
+ initializer_range (`float`, *optional*, defaults to 0.02):
63
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
64
+ initializer_factor (`float``, *optional*, defaults to 1):
65
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
66
+ testing).
67
+ qkv_bias (`bool`, *optional*, defaults to `True`):
68
+ Whether to add a bias to the queries and values in the self-attention layers.
69
+ """
70
+
71
+ model_type = "husky_vision_model"
72
+
73
+ def __init__(
74
+ self,
75
+ hidden_size=1408,
76
+ intermediate_size=6144,
77
+ projection_dim=512,
78
+ num_hidden_layers=39,
79
+ num_attention_heads=16,
80
+ num_channels=3,
81
+ image_size=224,
82
+ patch_size=14,
83
+ hidden_act="gelu",
84
+ layer_norm_eps=0.00001,
85
+ dropout=0.0,
86
+ attention_dropout=0.0,
87
+ initializer_range=1e-10,
88
+ initializer_factor=1.0,
89
+ qkv_bias=True,
90
+ **kwargs,
91
+ ):
92
+ super().__init__(**kwargs)
93
+
94
+ self.hidden_size = hidden_size
95
+ self.intermediate_size = intermediate_size
96
+ self.projection_dim = projection_dim
97
+ self.dropout = dropout
98
+ self.num_hidden_layers = num_hidden_layers
99
+ self.num_attention_heads = num_attention_heads
100
+ self.num_channels = num_channels
101
+ self.patch_size = patch_size
102
+ self.image_size = image_size
103
+ self.initializer_range = initializer_range
104
+ self.initializer_factor = initializer_factor
105
+ self.attention_dropout = attention_dropout
106
+ self.layer_norm_eps = layer_norm_eps
107
+ self.hidden_act = hidden_act
108
+ self.qkv_bias = qkv_bias
109
+
110
+ @classmethod
111
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
112
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
113
+
114
+ # get the vision config dict if we are loading from HuskyConfig
115
+ if config_dict.get("model_type") == "husky":
116
+ config_dict = config_dict["vision_config"]
117
+
118
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
119
+ logger.warning(
120
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
121
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
122
+ )
123
+
124
+ return cls.from_dict(config_dict, **kwargs)
125
+
126
+ class HuskyQFormerConfig(PretrainedConfig):
127
+ r"""
128
+ This is the configuration class to store the configuration of a [`HuskyQFormerModel`]. It is used to
129
+ instantiate a Husky Querying Transformer (Q-Former) model according to the specified arguments, defining the
130
+ model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
131
+ the Husky [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5)
132
+ architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
133
+ Read the documentation from [`PretrainedConfig`] for more information.
134
+
135
+ Note that [`HuskyQFormerModel`] is very similar to [`BertLMHeadModel`] with interleaved cross-attention.
136
+
137
+ Args:
138
+ vocab_size (`int`, *optional*, defaults to 30522):
139
+ Vocabulary size of the Q-Former model. Defines the number of different tokens that can be represented by
140
+ the `inputs_ids` passed when calling the model.
141
+ hidden_size (`int`, *optional*, defaults to 768):
142
+ Dimensionality of the encoder layers and the pooler layer.
143
+ num_hidden_layers (`int`, *optional*, defaults to 12):
144
+ Number of hidden layers in the Transformer encoder.
145
+ num_attention_heads (`int`, *optional*, defaults to 12):
146
+ Number of attention heads for each attention layer in the Transformer encoder.
147
+ intermediate_size (`int`, *optional*, defaults to 3072):
148
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
149
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
150
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
151
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
152
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
153
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
154
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
155
+ The dropout ratio for the attention probabilities.
156
+ max_position_embeddings (`int`, *optional*, defaults to 512):
157
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
158
+ just in case (e.g., 512 or 1024 or 2048).
159
+ initializer_range (`float`, *optional*, defaults to 0.02):
160
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
161
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
162
+ The epsilon used by the layer normalization layers.
163
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
164
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
165
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
166
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
167
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
168
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
169
+ classifier_dropout (`float`, *optional*):
170
+ The dropout ratio for the classification head.
171
+ cross_attention_frequency (`int`, *optional*, defaults to 2):
172
+ The frequency of adding cross-attention to the Transformer layers.
173
+ encoder_hidden_size (`int`, *optional*, defaults to 1408):
174
+ The hidden size of the hidden states for cross-attention.
175
+ """
176
+ model_type = "husky_qformer"
177
+
178
+ def __init__(
179
+ self,
180
+ vocab_size=30522,
181
+ hidden_size=768,
182
+ num_hidden_layers=12,
183
+ num_attention_heads=12,
184
+ intermediate_size=3072,
185
+ hidden_act="gelu",
186
+ hidden_dropout_prob=0.1,
187
+ attention_probs_dropout_prob=0.1,
188
+ max_position_embeddings=512,
189
+ initializer_range=0.02,
190
+ layer_norm_eps=1e-12,
191
+ pad_token_id=0,
192
+ position_embedding_type="absolute",
193
+ classifier_dropout=None,
194
+ cross_attention_frequency=2,
195
+ encoder_hidden_size=1408,
196
+ **kwargs,
197
+ ):
198
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
199
+
200
+ self.vocab_size = vocab_size
201
+ self.hidden_size = hidden_size
202
+ self.num_hidden_layers = num_hidden_layers
203
+ self.num_attention_heads = num_attention_heads
204
+ self.hidden_act = hidden_act
205
+ self.intermediate_size = intermediate_size
206
+ self.hidden_dropout_prob = hidden_dropout_prob
207
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
208
+ self.max_position_embeddings = max_position_embeddings
209
+ self.initializer_range = initializer_range
210
+ self.layer_norm_eps = layer_norm_eps
211
+ self.position_embedding_type = position_embedding_type
212
+ self.classifier_dropout = classifier_dropout
213
+ self.cross_attention_frequency = cross_attention_frequency
214
+ self.encoder_hidden_size = encoder_hidden_size
215
+
216
+ @classmethod
217
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
218
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
219
+ # get the qformer config dict if we are loading from HuskyConfig
220
+ if config_dict.get("model_type") == "husky":
221
+ config_dict = config_dict["qformer_config"]
222
+
223
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
224
+ logger.warning(
225
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
226
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
227
+ )
228
+
229
+ return cls.from_dict(config_dict, **kwargs)
230
+
231
+ class HuskyConfig(PretrainedConfig):
232
+ r"""
233
+ [`HuskyConfig`] is the configuration class to store the configuration of a
234
+ [`HuskyForConditionalGeneration`]. It is used to instantiate a Husky model according to the specified
235
+ arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
236
+ the defaults will yield a similar configuration to that of the Husky
237
+ [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
238
+
239
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
240
+ documentation from [`PretrainedConfig`] for more information.
241
+
242
+ Args:
243
+ vision_config (`dict`, *optional*):
244
+ Dictionary of configuration options used to initialize [`HuskyVisionConfig`].
245
+ qformer_config (`dict`, *optional*):
246
+ Dictionary of configuration options used to initialize [`HuskyQFormerConfig`].
247
+ text_config (`dict`, *optional*):
248
+ Dictionary of configuration options used to initialize any [`PretrainedConfig`].
249
+ num_query_tokens (`int`, *optional*, defaults to 32):
250
+ The number of query tokens passed through the Transformer.
251
+
252
+ kwargs (*optional*):
253
+ Dictionary of keyword arguments.
254
+ """
255
+
256
+ model_type = "husky"
257
+ is_composition = True
258
+
259
+ def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
260
+ super().__init__(**kwargs)
261
+
262
+ if vision_config is None:
263
+ vision_config = {}
264
+ logger.info("vision_config is None. initializing the HuskyVisionConfig with default values.")
265
+
266
+ if qformer_config is None:
267
+ qformer_config = {}
268
+ logger.info("qformer_config is None. Initializing the HuskyQFormerConfig with default values.")
269
+
270
+ if text_config is None:
271
+ text_config = {}
272
+ logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
273
+
274
+ self.vision_config = HuskyVisionConfig(**vision_config)
275
+ self.qformer_config = HuskyQFormerConfig(**qformer_config)
276
+ text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
277
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
278
+
279
+ self.tie_word_embeddings = self.text_config.tie_word_embeddings
280
+ self.is_encoder_decoder = self.text_config.is_encoder_decoder
281
+
282
+ self.num_query_tokens = num_query_tokens
283
+ self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
284
+ self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
285
+ self.initializer_factor = 1.0
286
+ self.initializer_range = 0.02
287
+
288
+ @classmethod
289
+ def from_vision_qformer_text_configs(
290
+ cls,
291
+ vision_config: HuskyVisionConfig,
292
+ qformer_config: HuskyQFormerConfig,
293
+ text_config: PretrainedConfig,
294
+ **kwargs,
295
+ ):
296
+ r"""
297
+ Instantiate a [`HuskyConfig`] (or a derived class) from a Husky vision model, Q-Former and
298
+ language model configurations.
299
+
300
+ Returns:
301
+ [`HuskyConfig`]: An instance of a configuration object
302
+ """
303
+
304
+ return cls(
305
+ vision_config=vision_config.to_dict(),
306
+ qformer_config=qformer_config.to_dict(),
307
+ text_config=text_config.to_dict(),
308
+ **kwargs,
309
+ )
310
+
311
+ def to_dict(self):
312
+ """
313
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
314
+
315
+ Returns:
316
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
317
+ """
318
+ output = copy.deepcopy(self.__dict__)
319
+ output["vision_config"] = self.vision_config.to_dict()
320
+ output["qformer_config"] = self.qformer_config.to_dict()
321
+ output["text_config"] = self.text_config.to_dict()
322
+ output["model_type"] = self.__class__.model_type
323
+ return output
324
+
325
+ if __name__ == '__main__':
326
+ config = HuskyConfig.from_pretrain
robohusky/constants.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import IntEnum
2
+ import os
3
+
4
+ # For the gradio web server
5
+ SERVER_ERROR_MSG = (
6
+ "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
7
+ )
8
+ MODERATION_MSG = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE FIX YOUR INPUT AND TRY AGAIN."
9
+ CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
10
+ INPUT_CHAR_LEN_LIMIT = 2560
11
+ CONVERSATION_LEN_LIMIT = 50
12
+ LOGDIR = "."
13
+
14
+ # For the controller and workers(could be overwritten through ENV variables.)
15
+ CONTROLLER_HEART_BEAT_EXPIRATION = int(
16
+ os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90)
17
+ )
18
+ WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 30))
19
+ WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100))
20
+ WORKER_API_EMBEDDING_BATCH_SIZE = int(os.getenv("WORKER_API_EMBEDDING_BATCH_SIZE", 4))
21
+
22
+
23
+ class ErrorCode(IntEnum):
24
+ """
25
+ https://platform.openai.com/docs/guides/error-codes/api-errors
26
+ """
27
+
28
+ VALIDATION_TYPE_ERROR = 40001
29
+
30
+ INVALID_AUTH_KEY = 40101
31
+ INCORRECT_AUTH_KEY = 40102
32
+ NO_PERMISSION = 40103
33
+
34
+ INVALID_MODEL = 40301
35
+ PARAM_OUT_OF_RANGE = 40302
36
+ CONTEXT_OVERFLOW = 40303
37
+
38
+ RATE_LIMIT = 42901
39
+ QUOTA_EXCEEDED = 42902
40
+ ENGINE_OVERLOADED = 42903
41
+
42
+ INTERNAL_ERROR = 50001
43
+ CUDA_OUT_OF_MEMORY = 50002
44
+ GRADIO_REQUEST_ERROR = 50003
45
+ GRADIO_STREAM_UNKNOWN_ERROR = 50004
46
+ CONTROLLER_NO_WORKER = 50005
47
+ CONTROLLER_WORKER_TIMEOUT = 50006
robohusky/conversation.py ADDED
@@ -0,0 +1,511 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Conversation prompt templates.
3
+ """
4
+
5
+ import dataclasses
6
+ from enum import auto, Enum
7
+ from typing import List, Any, Dict
8
+
9
+
10
+ class SeparatorStyle(Enum):
11
+ """Separator styles."""
12
+
13
+ ADD_COLON_SINGLE = auto()
14
+ ADD_COLON_TWO = auto()
15
+ ADD_COLON_SPACE_SINGLE = auto()
16
+ NO_COLON_SINGLE = auto()
17
+ ADD_NEW_LINE_SINGLE = auto()
18
+ DOLLY = auto()
19
+ RWKV = auto()
20
+ PHOENIX = auto()
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class Conversation:
25
+ """A class that keeps all conversation history."""
26
+
27
+ # The name of this template
28
+ name: str
29
+ # The system prompt
30
+ system: str
31
+ # Two roles
32
+ roles: List[str]
33
+ # All messages. Each item is (role, message).
34
+ messages: List[List[str]]
35
+ # The number of few shot examples
36
+ offset: int
37
+ # Separators
38
+ sep_style: SeparatorStyle
39
+ sep: str
40
+ sep2: str = None
41
+ # Stop criteria (the default one is EOS token)
42
+ stop_str: str = None
43
+ # Stops generation if meeting any token in this list
44
+ stop_token_ids: List[int] = None
45
+
46
+ def get_prompt(self) -> str:
47
+ """Get the prompt for generation."""
48
+ if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
49
+ ret = self.system + self.sep
50
+ for role, message in self.messages:
51
+ if message:
52
+ ret += role + ": " + message + self.sep
53
+ else:
54
+ ret += role + ":"
55
+ return ret
56
+ elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
57
+ seps = [self.sep, self.sep2]
58
+ ret = self.system + seps[0]
59
+ for i, (role, message) in enumerate(self.messages):
60
+ if message:
61
+ ret += role + ": " + message + seps[i % 2]
62
+ else:
63
+ ret += role + ":"
64
+ return ret
65
+ elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
66
+ ret = self.system + self.sep
67
+ for role, message in self.messages:
68
+ if message:
69
+ ret += role + ": " + message + self.sep
70
+ else:
71
+ ret += role + ": " # must be end with a space
72
+ return ret
73
+ elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
74
+ ret = self.system
75
+ for role, message in self.messages:
76
+ if message:
77
+ ret += role + message + self.sep
78
+ else:
79
+ ret += role
80
+ return ret
81
+ elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
82
+ ret = self.system + self.sep
83
+ for role, message in self.messages:
84
+ if message:
85
+ ret += role + "\n" + message + self.sep
86
+ else:
87
+ ret += role + "\n"
88
+ return ret
89
+ elif self.sep_style == SeparatorStyle.DOLLY:
90
+ seps = [self.sep, self.sep2]
91
+ ret = self.system
92
+ for i, (role, message) in enumerate(self.messages):
93
+ if message:
94
+ ret += role + ":\n" + message + seps[i % 2]
95
+ if i % 2 == 1:
96
+ ret += "\n\n"
97
+ else:
98
+ ret += role + ":\n"
99
+ return ret
100
+ elif self.sep_style == SeparatorStyle.RWKV:
101
+ ret = self.system
102
+ for i, (role, message) in enumerate(self.messages):
103
+ if message:
104
+ ret += (
105
+ role
106
+ + ": "
107
+ + message.replace("\r\n", "\n").replace("\n\n", "\n")
108
+ )
109
+ ret += "\n\n"
110
+ else:
111
+ ret += role + ":"
112
+ return ret
113
+ elif self.sep_style == SeparatorStyle.PHOENIX:
114
+ ret = self.system
115
+ for role, message in self.messages:
116
+ if message:
117
+ ret += role + ": " + "<s>" + message + "</s>"
118
+ else:
119
+ ret += role + ": " + "<s>"
120
+ return ret
121
+ else:
122
+ raise ValueError(f"Invalid style: {self.sep_style}")
123
+
124
+ def append_message(self, role: str, message: str):
125
+ """Append a new message."""
126
+ self.messages.append([role, message])
127
+
128
+ def update_last_message(self, message: str):
129
+ """Update the last output.
130
+
131
+ The last message is typically set to be None when constructing the prompt,
132
+ so we need to update it in-place after getting the response from a model.
133
+ """
134
+ self.messages[-1][1] = message
135
+
136
+ def to_gradio_chatbot(self):
137
+ """Convert the conversation to gradio chatbot format"""
138
+ ret = []
139
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
140
+ if i % 2 == 0:
141
+ ret.append([msg, None])
142
+ else:
143
+ ret[-1][-1] = msg
144
+ return ret
145
+
146
+ def to_openai_api_messages(self):
147
+ """Convert the conversation to OpenAI chat completion format."""
148
+ ret = [{"role": "system", "content": self.system}]
149
+
150
+ for i, (_, msg) in enumerate(self.messages[self.offset:]):
151
+ if i % 2 == 0:
152
+ ret.append({"role": "user", "content": msg})
153
+ else:
154
+ if msg is not None:
155
+ ret.append({"role": "assistant", "content": msg})
156
+ return ret
157
+
158
+ def copy(self):
159
+ return Conversation(
160
+ name=self.name,
161
+ system=self.system,
162
+ roles=self.roles,
163
+ messages=[[x, y] for x, y in self.messages],
164
+ offset=self.offset,
165
+ sep_style=self.sep_style,
166
+ sep=self.sep,
167
+ sep2=self.sep2,
168
+ stop_str=self.stop_str,
169
+ stop_token_ids=self.stop_token_ids,
170
+ )
171
+
172
+ def dict(self):
173
+ return {
174
+ "name": self.name,
175
+ "system": self.system,
176
+ "roles": self.roles,
177
+ "messages": self.messages,
178
+ "offset": self.offset,
179
+ }
180
+
181
+
182
+ # A global registry for all conversation templates
183
+ conv_templates: Dict[str, Conversation] = {}
184
+
185
+
186
+ def register_conv_template(template: Conversation, override: bool = False):
187
+ """Register a new conversation template."""
188
+ if not override:
189
+ assert template.name not in conv_templates, f"{template.name} has been registered."
190
+ conv_templates[template.name] = template
191
+
192
+
193
+ def get_conv_template(name: str) -> Conversation:
194
+ """Get a conversation template."""
195
+ return conv_templates[name].copy()
196
+
197
+
198
+ # A template with one conversation example
199
+ register_conv_template(
200
+ Conversation(
201
+ name="one_shot",
202
+ system="A chat between a curious human and an artificial intelligence assistant. "
203
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
204
+ roles=("Human", "Assistant"),
205
+ messages=(
206
+ (
207
+ "Human",
208
+ "Got any creative ideas for a 10 year old’s birthday?",
209
+ ),
210
+ (
211
+ "Assistant",
212
+ """Of course! Here are some creative ideas for a 10-year-old's birthday party:
213
+ 1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises.
214
+ 2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions.
215
+ 3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars.
216
+ 4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors.
217
+ 5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants.
218
+ 6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen.
219
+ 7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges.
220
+ 8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors.
221
+ Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""",
222
+ ),
223
+ ),
224
+ offset=2,
225
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
226
+ sep="\n### ",
227
+ stop_str="###",
228
+ )
229
+ )
230
+
231
+ # Vicuna v1.1 template
232
+ register_conv_template(
233
+ Conversation(
234
+ name="vicuna_v1.1",
235
+ system="A chat between a curious user and an artificial intelligence assistant. "
236
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
237
+ roles=("USER", "ASSISTANT"),
238
+ messages=(),
239
+ offset=0,
240
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
241
+ sep=" ",
242
+ sep2="</s>",
243
+ )
244
+ )
245
+
246
+ # Husky template
247
+ register_conv_template(
248
+ Conversation(
249
+ name="husky",
250
+ system="",
251
+ roles=("Human", "Assistant"),
252
+ messages=(),
253
+ offset=0,
254
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
255
+ sep=" ",
256
+ sep2="</s>",
257
+ )
258
+ )
259
+
260
+ # Koala default template
261
+ register_conv_template(
262
+ Conversation(
263
+ name="koala_v1",
264
+ system="BEGINNING OF CONVERSATION:",
265
+ roles=("USER", "GPT"),
266
+ messages=(),
267
+ offset=0,
268
+ sep_style=SeparatorStyle.ADD_COLON_TWO,
269
+ sep=" ",
270
+ sep2="</s>",
271
+ )
272
+ )
273
+
274
+ # Alpaca default template
275
+ register_conv_template(
276
+ Conversation(
277
+ name="alpaca",
278
+ system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
279
+ roles=("### Instruction:", "### Response:"),
280
+ messages=(),
281
+ offset=0,
282
+ sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
283
+ sep="\n\n",
284
+ )
285
+ )
286
+
287
+ # Dolly V2 default template
288
+ register_conv_template(
289
+ Conversation(
290
+ name="dolly_v2",
291
+ system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
292
+ roles=("### Instruction", "### Response"),
293
+ messages=(),
294
+ offset=0,
295
+ sep_style=SeparatorStyle.DOLLY,
296
+ sep="\n\n",
297
+ sep2="### End",
298
+ )
299
+ )
300
+
301
+ # OpenAssistant Pythia default template
302
+ register_conv_template(
303
+ Conversation(
304
+ name="oasst_pythia",
305
+ system="",
306
+ roles=("<|prompter|>", "<|assistant|>"),
307
+ messages=(),
308
+ offset=0,
309
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
310
+ sep="<|endoftext|>",
311
+ )
312
+ )
313
+
314
+ # StableLM Alpha default template
315
+ register_conv_template(
316
+ Conversation(
317
+ name="stablelm",
318
+ system="""<|SYSTEM|># StableLM Tuned (Alpha version)
319
+ - StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
320
+ - StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
321
+ - StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
322
+ - StableLM will refuse to participate in anything that could harm a human.
323
+ """,
324
+ roles=("<|USER|>", "<|ASSISTANT|>"),
325
+ messages=(),
326
+ offset=0,
327
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
328
+ sep="",
329
+ stop_token_ids=[50278, 50279, 50277, 1, 0],
330
+ )
331
+ )
332
+
333
+ # Baize default template
334
+ register_conv_template(
335
+ Conversation(
336
+ name="baize",
337
+ system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n",
338
+ roles=("[|Human|]", "[|AI|]"),
339
+ messages=(
340
+ ("[|Human|]", "Hello!"),
341
+ ("[|AI|]", "Hi!"),
342
+ ),
343
+ offset=2,
344
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
345
+ sep="\n",
346
+ stop_str="[|Human|]",
347
+ )
348
+ )
349
+
350
+ # RWKV-4-Raven default template
351
+ register_conv_template(
352
+ Conversation(
353
+ name="rwkv",
354
+ system="",
355
+ roles=("Bob", "Alice"),
356
+ messages=(
357
+ ("Bob", "hi"),
358
+ (
359
+ "Alice",
360
+ "Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.",
361
+ ),
362
+ ),
363
+ offset=2,
364
+ sep_style=SeparatorStyle.RWKV,
365
+ sep="",
366
+ stop_str="\n\n",
367
+ )
368
+ )
369
+
370
+ # Buddy default template
371
+ register_conv_template(
372
+ Conversation(
373
+ name="openbuddy",
374
+ system="""Consider a conversation between User (a human) and Assistant (named Buddy).
375
+ Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy
376
+ Buddy cannot access the Internet.
377
+ Buddy can fluently speak the user's language (e.g. English, Chinese).
378
+ Buddy can generate poems, stories, code, essays, songs, parodies, and more.
379
+ Buddy possesses vast knowledge about the world, history, and culture.
380
+ Buddy's responses are always safe, creative, high-quality, human-like, and interesting.
381
+ Buddy strictly refuses to discuss political, NSFW, or other unsafe topics.
382
+
383
+ User: Hi.
384
+ Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""",
385
+ roles=("User", "Assistant"),
386
+ messages=(),
387
+ offset=0,
388
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
389
+ sep="\n",
390
+ )
391
+ )
392
+
393
+ # Phoenix default template
394
+ register_conv_template(
395
+ Conversation(
396
+ name="phoenix",
397
+ system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
398
+ roles=("Human", "Assistant"),
399
+ messages=(),
400
+ offset=0,
401
+ sep_style=SeparatorStyle.PHOENIX,
402
+ sep="</s>",
403
+ )
404
+ )
405
+
406
+ # ChatGPT default template
407
+ register_conv_template(
408
+ Conversation(
409
+ name="chatgpt",
410
+ system="You are a helpful assistant.",
411
+ roles=("user", "assistant"),
412
+ messages=(),
413
+ offset=0,
414
+ sep_style=None,
415
+ sep=None,
416
+ )
417
+ )
418
+
419
+ # Claude default template
420
+ register_conv_template(
421
+ Conversation(
422
+ name="claude",
423
+ system="",
424
+ roles=("Human", "Assistant"),
425
+ messages=(),
426
+ offset=0,
427
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
428
+ sep="\n\n",
429
+ )
430
+ )
431
+
432
+ # MPT default template
433
+ register_conv_template(
434
+ Conversation(
435
+ name="mpt",
436
+ system="""<|im_start|>system
437
+ - You are a helpful assistant chatbot trained by MosaicML.
438
+ - You answer questions.
439
+ - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
440
+ - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.
441
+ """,
442
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
443
+ messages=(),
444
+ offset=0,
445
+ sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
446
+ sep="<|im_end|>",
447
+ stop_token_ids=[50278, 0],
448
+ )
449
+ )
450
+
451
+ # Bard default template
452
+ register_conv_template(
453
+ Conversation(
454
+ name="bard",
455
+ system="",
456
+ roles=("0", "1"),
457
+ messages=(),
458
+ offset=0,
459
+ sep_style=None,
460
+ sep=None,
461
+ )
462
+ )
463
+
464
+ # BiLLa default template
465
+ register_conv_template(
466
+ Conversation(
467
+ name="billa",
468
+ system="",
469
+ roles=("Human", "Assistant"),
470
+ messages=(),
471
+ offset=0,
472
+ sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE,
473
+ sep="\n",
474
+ stop_str="Human:",
475
+ )
476
+ )
477
+
478
+ # RedPajama INCITE default template
479
+ register_conv_template(
480
+ Conversation(
481
+ name="redpajama-incite",
482
+ system="",
483
+ roles=("<human>", "<bot>"),
484
+ messages=(),
485
+ offset=0,
486
+ sep_style=SeparatorStyle.ADD_COLON_SINGLE,
487
+ sep="\n",
488
+ stop_str="<human>",
489
+ )
490
+ )
491
+
492
+ # h2oGPT default template
493
+ register_conv_template(
494
+ Conversation(
495
+ name="h2ogpt",
496
+ system="",
497
+ roles=("<|prompt|>", "<|answer|>"),
498
+ messages=(),
499
+ offset=0,
500
+ sep_style=SeparatorStyle.NO_COLON_SINGLE,
501
+ sep="</s>",
502
+ )
503
+ )
504
+
505
+ if __name__ == "__main__":
506
+ conv = get_conv_template("husky")
507
+ conv.append_message(conv.roles[0], "Hello!")
508
+ conv.append_message(conv.roles[1], "Hi!")
509
+ conv.append_message(conv.roles[0], "How are you?")
510
+ conv.append_message(conv.roles[1], None)
511
+ print(conv.get_prompt())
robohusky/convert_fp16.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ srun -p INTERN2 --job-name='convert_2_fp16' --gres=gpu:0 --cpus-per-task=8 --quotatype="auto" python -u husky/convert_fp16.py --in-checkpoint work_dirs/llm/husky-13b/zh_bell/checkpoint-9500 --out-checkpoint work_dirs/llm/husky-13b/zh_bell/
4
+ """
5
+ import argparse
6
+ import os.path
7
+
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM
9
+ import torch
10
+
11
+ def convert_fp16(in_checkpoint, out_checkpoint):
12
+ tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False)
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=False
15
+ )
16
+ if not os.path.exists(out_checkpoint):
17
+ os.mkdir(out_checkpoint)
18
+ model.save_pretrained(out_checkpoint)
19
+ tokenizer.save_pretrained(out_checkpoint)
20
+
21
+ if __name__ == "__main__":
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--in-checkpoint", type=str, help="Path to the model")
24
+ parser.add_argument("--out-checkpoint", type=str, help="Path to the output model")
25
+ args = parser.parse_args()
26
+
27
+ convert_fp16(args.in_checkpoint, args.out_checkpoint)
robohusky/convert_husky_fp16.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ srun -p INTERN2 --job-name='convert_2_fp16' --gres=gpu:0 --cpus-per-task=8 --quotatype="auto" python -u husky/convert_husky_fp16.py --in-checkpoint work_dirs/husky_v3/multi_align/checkpoint-48000 --out-checkpoint work_dirs/husky_v3/multi_align_fp16
4
+ """
5
+ import argparse
6
+ import os.path
7
+
8
+ from transformers import AutoTokenizer
9
+ from husky.model.modeling_husky_multi import HuskyForConditionalGeneration
10
+ import torch
11
+
12
+ def convert_fp16(in_checkpoint, out_checkpoint):
13
+ tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False)
14
+ model = HuskyForConditionalGeneration.from_pretrained(
15
+ in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=False
16
+ )
17
+ if not os.path.exists(out_checkpoint):
18
+ os.mkdir(out_checkpoint)
19
+ model.save_pretrained(out_checkpoint)
20
+ tokenizer.save_pretrained(out_checkpoint)
21
+
22
+ if __name__ == "__main__":
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--in-checkpoint", type=str, help="Path to the model")
25
+ parser.add_argument("--out-checkpoint", type=str, help="Path to the output model")
26
+ args = parser.parse_args()
27
+
28
+ convert_fp16(args.in_checkpoint, args.out_checkpoint)
robohusky/convert_reward_fp16.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Usage:
3
+ srun -p INTERN2 --job-name='convert_2_fp16' --gres=gpu:0 --cpus-per-task=8 --quotatype="auto" python -u husky/convert_reward_fp16.py --in-checkpoint work_dirs/llm/Ziya-LLaMA-7B-Reward --out-checkpoint work_dirs/llm/reward_model
4
+ """
5
+ import argparse
6
+ import os.path
7
+
8
+ from transformers import LlamaTokenizer, AutoModelForSequenceClassification
9
+ import torch
10
+
11
+ def convert_fp16(in_checkpoint, out_checkpoint):
12
+ tokenizer = LlamaTokenizer.from_pretrained(in_checkpoint, use_fast=False)
13
+ model = AutoModelForSequenceClassification.from_pretrained(
14
+ in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=False
15
+ )
16
+ if not os.path.exists(out_checkpoint):
17
+ os.mkdir(out_checkpoint)
18
+ model.save_pretrained(out_checkpoint)
19
+ tokenizer.save_pretrained(out_checkpoint)
20
+
21
+ if __name__ == "__main__":
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--in-checkpoint", type=str, help="Path to the model")
24
+ parser.add_argument("--out-checkpoint", type=str, help="Path to the output model")
25
+ args = parser.parse_args()
26
+
27
+ convert_fp16(args.in_checkpoint, args.out_checkpoint)
robohusky/dist_utils.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import socket
4
+ import subprocess
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.multiprocessing as mp
9
+ from torch import distributed as dist
10
+
11
+
12
+ def _find_free_port():
13
+ # Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
14
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
15
+ # Binding to port 0 will cause the OS to find an available port for us
16
+ sock.bind(('', 0))
17
+ port = sock.getsockname()[1]
18
+ sock.close()
19
+ # NOTE: there is still a chance the port could be taken by other processes.
20
+ return port
21
+
22
+
23
+ def _is_free_port(port):
24
+ ips = socket.gethostbyname_ex(socket.gethostname())[-1]
25
+ ips.append('localhost')
26
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
27
+ return all(s.connect_ex((ip, port)) != 0 for ip in ips)
28
+
29
+
30
+ def init_dist(launcher, backend='nccl', **kwargs):
31
+ if mp.get_start_method(allow_none=True) is None:
32
+ mp.set_start_method('spawn')
33
+ if launcher == 'pytorch':
34
+ _init_dist_pytorch(backend, **kwargs)
35
+ elif launcher == 'mpi':
36
+ _init_dist_mpi(backend, **kwargs)
37
+ elif launcher == 'slurm':
38
+ _init_dist_slurm(backend, **kwargs)
39
+ else:
40
+ raise ValueError(f'Invalid launcher type: {launcher}')
41
+
42
+
43
+ def _init_dist_pytorch(backend, **kwargs):
44
+ # TODO: use local_rank instead of rank % num_gpus
45
+ rank = int(os.environ['RANK'])
46
+ num_gpus = torch.cuda.device_count()
47
+ torch.cuda.set_device(rank % num_gpus)
48
+ dist.init_process_group(backend=backend, **kwargs)
49
+
50
+
51
+ def _init_dist_mpi(backend, **kwargs):
52
+ local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
53
+ torch.cuda.set_device(local_rank)
54
+ if 'MASTER_PORT' not in os.environ:
55
+ # 29500 is torch.distributed default port
56
+ os.environ['MASTER_PORT'] = '29500'
57
+ if 'MASTER_ADDR' not in os.environ:
58
+ raise KeyError('The environment variable MASTER_ADDR is not set')
59
+ os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
60
+ os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
61
+ dist.init_process_group(backend=backend, **kwargs)
62
+
63
+
64
+ def _init_dist_slurm(backend, port=None):
65
+ """Initialize slurm distributed training environment.
66
+
67
+ If argument ``port`` is not specified, then the master port will be system
68
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
69
+ environment variable, then a default port ``29500`` will be used.
70
+
71
+ Args:
72
+ backend (str): Backend of torch.distributed.
73
+ port (int, optional): Master port. Defaults to None.
74
+ """
75
+ proc_id = int(os.environ['SLURM_PROCID'])
76
+ ntasks = int(os.environ['SLURM_NTASKS'])
77
+ node_list = os.environ['SLURM_NODELIST']
78
+ num_gpus = torch.cuda.device_count()
79
+ torch.cuda.set_device(proc_id % num_gpus)
80
+ addr = subprocess.getoutput(
81
+ f'scontrol show hostname {node_list} | head -n1')
82
+ # specify master port
83
+ if port is not None:
84
+ os.environ['MASTER_PORT'] = str(port)
85
+ elif 'MASTER_PORT' in os.environ:
86
+ pass # use MASTER_PORT in the environment variable
87
+ else:
88
+ # if torch.distributed default port(29500) is available
89
+ # then use it, else find a free port
90
+ if _is_free_port(29500):
91
+ os.environ['MASTER_PORT'] = '29500'
92
+ else:
93
+ os.environ['MASTER_PORT'] = str(_find_free_port())
94
+ # use MASTER_ADDR in the environment variable if it already exists
95
+ if 'MASTER_ADDR' not in os.environ:
96
+ os.environ['MASTER_ADDR'] = addr
97
+ os.environ['WORLD_SIZE'] = str(ntasks)
98
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
99
+ os.environ['RANK'] = str(proc_id)
100
+ dist.init_process_group(backend=backend)
robohusky/llama2_flash_attn_monkey_patch.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from flash_attn import __version__ as flash_attn_version
6
+ from flash_attn.bert_padding import pad_input, unpad_input
7
+ from flash_attn.flash_attn_interface import (
8
+ flash_attn_func,
9
+ flash_attn_varlen_kvpacked_func,
10
+ )
11
+ from transformers.models.llama.modeling_llama import (
12
+ LlamaAttention,
13
+ LlamaModel,
14
+ rotate_half,
15
+ )
16
+
17
+ def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
18
+ gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
19
+ gather_indices = gather_indices.repeat(
20
+ 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
21
+ )
22
+ bsz = gather_indices.shape[0]
23
+ cos, sin = (
24
+ torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
25
+ for x in cos_sin
26
+ )
27
+ q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
28
+ return q, k
29
+
30
+ def forward(
31
+ self,
32
+ hidden_states: torch.Tensor,
33
+ attention_mask: Optional[torch.Tensor] = None,
34
+ position_ids: Optional[torch.Tensor] = None,
35
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
36
+ output_attentions: bool = False,
37
+ use_cache: bool = False,
38
+ padding_mask: Optional[torch.Tensor] = None,
39
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
40
+ if output_attentions:
41
+ warnings.warn(
42
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
43
+ )
44
+
45
+ bsz, q_len, _ = hidden_states.size()
46
+ kv_heads = getattr(self, "num_key_value_heads", self.num_heads)
47
+
48
+ q, k, v = (
49
+ op(hidden_states).view(bsz, q_len, nh, self.head_dim)
50
+ for op, nh in (
51
+ (self.q_proj, self.num_heads),
52
+ (self.k_proj, kv_heads),
53
+ (self.v_proj, kv_heads),
54
+ )
55
+ )
56
+ # shape: (b, s, num_heads, head_dim)
57
+
58
+ kv_seq_len = k.shape[1]
59
+ past_kv_len = 0
60
+ if past_key_value is not None:
61
+ past_kv_len = past_key_value[0].shape[2]
62
+ kv_seq_len += past_kv_len
63
+
64
+ cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
65
+ q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)
66
+
67
+ if past_key_value is not None:
68
+ assert (
69
+ flash_attn_version >= "2.1.0"
70
+ ), "past_key_value support requires flash-attn >= 2.1.0"
71
+ # reuse k, v
72
+ k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
73
+ v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
74
+
75
+ past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
76
+
77
+ if attention_mask is None:
78
+ output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
79
+ bsz, q_len, -1
80
+ )
81
+ else:
82
+ q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
83
+ # We can skip concat and call unpad twice but seems better to call unpad only once.
84
+ kv, _, cu_k_lens, max_k = unpad_input(
85
+ torch.stack((k, v), dim=2), attention_mask
86
+ )
87
+ output_unpad = flash_attn_varlen_kvpacked_func(
88
+ q,
89
+ kv,
90
+ cu_q_lens,
91
+ cu_k_lens,
92
+ max_s,
93
+ max_k,
94
+ 0.0,
95
+ softmax_scale=None,
96
+ causal=True,
97
+ )
98
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
99
+ output = pad_input(output_unpad, indices, bsz, q_len)
100
+
101
+ return self.o_proj(output), None, past_key_value
102
+
103
+ # Disable the transformation of the attention mask in LlamaModel as flash attention
104
+ # takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
105
+ def _prepare_decoder_attention_mask(
106
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
107
+ ):
108
+ # [bsz, seq_len]
109
+ if past_key_values_length > 0 and attention_mask is not None:
110
+ attention_mask = torch.cat(
111
+ (
112
+ torch.full(
113
+ (input_shape[0], past_key_values_length),
114
+ True,
115
+ dtype=attention_mask.dtype,
116
+ device=attention_mask.device,
117
+ ),
118
+ attention_mask,
119
+ ),
120
+ dim=-1,
121
+ )
122
+
123
+ if attention_mask is not None and torch.all(attention_mask):
124
+ return None # This uses the faster call when training with full samples
125
+
126
+ return attention_mask
127
+
128
+ def replace_llama_attn_with_flash_attn():
129
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
130
+ if cuda_major < 8:
131
+ warnings.warn(
132
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
133
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
134
+ )
135
+
136
+ LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
137
+ LlamaAttention.forward = forward
138
+
139
+ def test():
140
+ from robohusky.train.llama_flash_attn_monkey_patch import forward as fastchat_forward
141
+ from transformers.models.llama.configuration_llama import LlamaConfig
142
+
143
+ config = LlamaConfig(
144
+ hidden_size=1024,
145
+ intermediate_size=128,
146
+ num_hidden_layers=1,
147
+ num_attention_heads=8,
148
+ max_position_embeddings=16,
149
+ )
150
+ device = torch.device("cuda")
151
+ model = LlamaModel(config)
152
+ attn = LlamaAttention(config).to(device).half()
153
+ bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings
154
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view(
155
+ -1, seqlen
156
+ )
157
+
158
+ mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
159
+ for i in range(4):
160
+ hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
161
+ if i:
162
+ mask[0, -i:] = False
163
+ mask[1, :i] = False
164
+
165
+ lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0)
166
+ ref, _, _ = attn.forward(
167
+ hidden, attention_mask=lmask, position_ids=position_ids
168
+ )
169
+
170
+ fast, _, _ = fastchat_forward(
171
+ attn, hidden, attention_mask=mask, position_ids=position_ids
172
+ )
173
+
174
+ lmask = _prepare_decoder_attention_mask(
175
+ model, mask, hidden.shape[:2], hidden, 0
176
+ )
177
+ test, _, _ = forward(
178
+ attn, hidden, attention_mask=lmask, position_ids=position_ids
179
+ )
180
+
181
+ print(f"Mean(abs(ref)) = {torch.mean(torch.abs(ref))}")
182
+ print(f"Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}")
183
+ print(f"Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}")
184
+ print(f"Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}")
185
+ print(f"allclose(fast, test) = {torch.allclose(fast, test)}")
186
+
187
+ with torch.no_grad():
188
+ # Also check that past_kv is handled properly
189
+ hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
190
+ part_len = seqlen // 4
191
+ assert part_len * 4 == seqlen
192
+ mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
193
+ mask[0, -2:] = False
194
+ lmask = _prepare_decoder_attention_mask(
195
+ model, mask, hidden.shape[:2], hidden, 0
196
+ )
197
+ oneshot, _, _ = forward(
198
+ attn, hidden, attention_mask=lmask, position_ids=position_ids
199
+ )
200
+ parts = []
201
+ past_kv, past_kv_len = None, 0
202
+ for i in range(4):
203
+ start = part_len * i
204
+ end = start + part_len
205
+ hidden_part = hidden[:, start:end, ...]
206
+ lmask = _prepare_decoder_attention_mask(
207
+ model,
208
+ mask[:, start:end],
209
+ hidden_part.shape[:2],
210
+ hidden_part,
211
+ past_kv_len,
212
+ )
213
+ part, _, past_kv = forward(
214
+ attn,
215
+ hidden_part.clone(),
216
+ attention_mask=lmask,
217
+ position_ids=position_ids[:, start:end],
218
+ past_key_value=past_kv,
219
+ use_cache=True,
220
+ )
221
+ parts.append(part)
222
+ past_kv_len = past_kv[0].shape[2]
223
+
224
+ print(
225
+ f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}"
226
+ )
227
+ print(
228
+ f"allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}"
229
+ )
230
+
231
+ if __name__ == "__main__":
232
+ test()
robohusky/model/__init__.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import TYPE_CHECKING
15
+
16
+ from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
17
+
18
+ _import_structure = {
19
+ "configuration_husky": [
20
+ "HUSKY_PRETRAINED_CONFIG_ARCHIVE_MAP",
21
+ "HuskyConfig",
22
+ "HuskyQFormerConfig",
23
+ "HuskyVisionConfig",
24
+ ],
25
+ "processing_husky": ["HuskyProcessor"],
26
+ }
27
+
28
+ try:
29
+ if not is_torch_available():
30
+ raise OptionalDependencyNotAvailable()
31
+ except OptionalDependencyNotAvailable:
32
+ pass
33
+ else:
34
+ _import_structure["modeling_husky"] = [
35
+ "HUSKY_PRETRAINED_MODEL_ARCHIVE_LIST",
36
+ "HuskyModel",
37
+ "HuskyQFormerModel",
38
+ "HuskyPreTrainedModel",
39
+ "HuskyForConditionalGeneration",
40
+ "HuskyVisionModel",
41
+ ]
42
+
43
+ if TYPE_CHECKING:
44
+ from .configuration_husky import (
45
+ HUSKY_PRETRAINED_CONFIG_ARCHIVE_MAP,
46
+ HuskyConfig,
47
+ HuskyVisionConfig,
48
+ HuskyQFormerConfig
49
+ )
50
+ from .processing_husky import HuskyProcessor
51
+
52
+ try:
53
+ if not is_torch_available():
54
+ raise OptionalDependencyNotAvailable()
55
+ except OptionalDependencyNotAvailable:
56
+ pass
57
+ else:
58
+ from .modeling_husky import (
59
+ HUSKY_PRETRAINED_MODEL_ARCHIVE_LIST,
60
+ HuskyForConditionalGeneration,
61
+ HuskyModel,
62
+ HuskyPreTrainedModel,
63
+ HuskyQFormerModel,
64
+ HuskyVisionModel,
65
+ )
66
+
67
+ else:
68
+ import sys
69
+
70
+ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
robohusky/model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.14 kB). View file
 
robohusky/model/__pycache__/configuration_husky.cpython-38.pyc ADDED
Binary file (13.4 kB). View file
 
robohusky/model/__pycache__/modeling_husky_embody2.cpython-38.pyc ADDED
Binary file (54.6 kB). View file
 
robohusky/model/compression.py ADDED
File without changes
robohusky/model/configuration_husky.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Husky model configuration"""
16
+
17
+ import copy
18
+ import os
19
+ from typing import Union
20
+
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
23
+ from transformers.utils import logging
24
+
25
+ from transformers.models.auto import CONFIG_MAPPING
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ HUSKY_PRETRAINED_CONFIG_ARCHIVE_MAP = {
30
+ "wofmanaf/husky-7b": "https://huggingface.co/wofmanaf/husky-7b/resolve/main/config.json",
31
+ }
32
+
33
+ class HuskyVisionConfig(PretrainedConfig):
34
+ r"""
35
+ This is the configuration class to store the configuration of a [`HuskyVisionModel`]. It is used to
36
+ instantiate a Husky vision encoder according to the specified arguments, defining the model architecture.
37
+ Instantiating a configuration defaults will yield a similar configuration to that of the Husky architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+ Args:
43
+ hidden_size (`int`, *optional*, defaults to 1408):
44
+ Dimensionality of the encoder layers and the pooler layer.
45
+ intermediate_size (`int`, *optional*, defaults to 6144):
46
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
47
+ num_hidden_layers (`int`, *optional*, defaults to 39):
48
+ Number of hidden layers in the Transformer encoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 16):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ image_size (`int`, *optional*, defaults to 224):
52
+ The size (resolution) of each image.
53
+ patch_size (`int`, *optional*, defaults to 14):
54
+ The size (resolution) of each patch.
55
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
56
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
57
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults
58
+ to 1e-5): The epsilon used by the layer normalization layers.
59
+ dropout (`float`, *optional*, defaults to 0.0):
60
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
61
+ attention_dropout (`float`, *optional*, defaults to 0.0):
62
+ The dropout ratio for the attention probabilities.
63
+ initializer_range (`float`, *optional*, defaults to 0.02):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ initializer_factor (`float``, *optional*, defaults to 1):
66
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
67
+ testing).
68
+ qkv_bias (`bool`, *optional*, defaults to `True`):
69
+ Whether to add a bias to the queries and values in the self-attention layers.
70
+ """
71
+
72
+ model_type = "husky_vision_model"
73
+
74
+ def __init__(
75
+ self,
76
+ hidden_size=1408,
77
+ intermediate_size=6144,
78
+ projection_dim=512,
79
+ num_hidden_layers=39,
80
+ num_attention_heads=16,
81
+ num_channels=3,
82
+ image_size=224,
83
+ patch_size=14,
84
+ hidden_act="gelu",
85
+ layer_norm_eps=0.00001,
86
+ dropout=0.0,
87
+ attention_dropout=0.0,
88
+ initializer_range=1e-10,
89
+ initializer_factor=1.0,
90
+ qkv_bias=True,
91
+ _flash_attn_2_enabled=True,
92
+ **kwargs,
93
+ ):
94
+ super().__init__(**kwargs)
95
+
96
+ self.hidden_size = hidden_size
97
+ self.intermediate_size = intermediate_size
98
+ self.projection_dim = projection_dim
99
+ self.dropout = dropout
100
+ self.num_hidden_layers = num_hidden_layers
101
+ self.num_attention_heads = num_attention_heads
102
+ self.num_channels = num_channels
103
+ self.patch_size = patch_size
104
+ self.image_size = image_size
105
+ self.initializer_range = initializer_range
106
+ self.initializer_factor = initializer_factor
107
+ self.attention_dropout = attention_dropout
108
+ self.layer_norm_eps = layer_norm_eps
109
+ self.hidden_act = hidden_act
110
+ self.qkv_bias = qkv_bias
111
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
112
+
113
+ @classmethod
114
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
115
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
116
+
117
+ # get the vision config dict if we are loading from HuskyConfig
118
+ if config_dict.get("model_type") == "husky":
119
+ config_dict = config_dict["vision_config"]
120
+
121
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
122
+ logger.warning(
123
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
124
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
125
+ )
126
+
127
+ return cls.from_dict(config_dict, **kwargs)
128
+
129
+ class HuskyQFormerConfig(PretrainedConfig):
130
+ r"""
131
+ This is the configuration class to store the configuration of a [`HuskyQFormerModel`]. It is used to
132
+ instantiate a Husky Querying Transformer (Q-Former) model according to the specified arguments, defining the
133
+ model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
134
+ the Husky [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5)
135
+ architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
136
+ Read the documentation from [`PretrainedConfig`] for more information.
137
+
138
+ Note that [`HuskyQFormerModel`] is very similar to [`BertLMHeadModel`] with interleaved cross-attention.
139
+
140
+ Args:
141
+ vocab_size (`int`, *optional*, defaults to 30522):
142
+ Vocabulary size of the Q-Former model. Defines the number of different tokens that can be represented by
143
+ the `inputs_ids` passed when calling the model.
144
+ hidden_size (`int`, *optional*, defaults to 768):
145
+ Dimensionality of the encoder layers and the pooler layer.
146
+ num_hidden_layers (`int`, *optional*, defaults to 12):
147
+ Number of hidden layers in the Transformer encoder.
148
+ num_attention_heads (`int`, *optional*, defaults to 12):
149
+ Number of attention heads for each attention layer in the Transformer encoder.
150
+ intermediate_size (`int`, *optional*, defaults to 3072):
151
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
152
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
153
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
154
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
155
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
156
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
157
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
158
+ The dropout ratio for the attention probabilities.
159
+ max_position_embeddings (`int`, *optional*, defaults to 512):
160
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
161
+ just in case (e.g., 512 or 1024 or 2048).
162
+ initializer_range (`float`, *optional*, defaults to 0.02):
163
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
164
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
165
+ The epsilon used by the layer normalization layers.
166
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
167
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
168
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
169
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
170
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
171
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
172
+ classifier_dropout (`float`, *optional*):
173
+ The dropout ratio for the classification head.
174
+ cross_attention_frequency (`int`, *optional*, defaults to 2):
175
+ The frequency of adding cross-attention to the Transformer layers.
176
+ encoder_hidden_size (`int`, *optional*, defaults to 1408):
177
+ The hidden size of the hidden states for cross-attention.
178
+ """
179
+ model_type = "husky_qformer"
180
+
181
+ def __init__(
182
+ self,
183
+ vocab_size=30522,
184
+ hidden_size=768,
185
+ num_hidden_layers=12,
186
+ num_attention_heads=12,
187
+ intermediate_size=3072,
188
+ hidden_act="gelu",
189
+ hidden_dropout_prob=0.1,
190
+ attention_probs_dropout_prob=0.1,
191
+ max_position_embeddings=512,
192
+ initializer_range=0.02,
193
+ layer_norm_eps=1e-12,
194
+ pad_token_id=0,
195
+ position_embedding_type="absolute",
196
+ classifier_dropout=None,
197
+ cross_attention_frequency=2,
198
+ encoder_hidden_size=1408,
199
+ _flash_attn_2_enabled=True,
200
+ **kwargs,
201
+ ):
202
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
203
+
204
+ self.vocab_size = vocab_size
205
+ self.hidden_size = hidden_size
206
+ self.num_hidden_layers = num_hidden_layers
207
+ self.num_attention_heads = num_attention_heads
208
+ self.hidden_act = hidden_act
209
+ self.intermediate_size = intermediate_size
210
+ self.hidden_dropout_prob = hidden_dropout_prob
211
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
212
+ self.max_position_embeddings = max_position_embeddings
213
+ self.initializer_range = initializer_range
214
+ self.layer_norm_eps = layer_norm_eps
215
+ self.position_embedding_type = position_embedding_type
216
+ self.classifier_dropout = classifier_dropout
217
+ self.cross_attention_frequency = cross_attention_frequency
218
+ self.encoder_hidden_size = encoder_hidden_size
219
+ self._flash_attn_2_enabled = _flash_attn_2_enabled
220
+
221
+ @classmethod
222
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
223
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
224
+ # get the qformer config dict if we are loading from HuskyConfig
225
+ if config_dict.get("model_type") == "husky":
226
+ config_dict = config_dict["qformer_config"]
227
+
228
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
229
+ logger.warning(
230
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
231
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
232
+ )
233
+
234
+ return cls.from_dict(config_dict, **kwargs)
235
+
236
+ class HuskyConfig(PretrainedConfig):
237
+ r"""
238
+ [`HuskyConfig`] is the configuration class to store the configuration of a
239
+ [`HuskyForConditionalGeneration`]. It is used to instantiate a Husky model according to the specified
240
+ arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
241
+ the defaults will yield a similar configuration to that of the Husky
242
+ [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
243
+
244
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
245
+ documentation from [`PretrainedConfig`] for more information.
246
+
247
+ Args:
248
+ vision_config (`dict`, *optional*):
249
+ Dictionary of configuration options used to initialize [`HuskyVisionConfig`].
250
+ qformer_config (`dict`, *optional*):
251
+ Dictionary of configuration options used to initialize [`HuskyQFormerConfig`].
252
+ text_config (`dict`, *optional*):
253
+ Dictionary of configuration options used to initialize any [`PretrainedConfig`].
254
+ num_query_tokens (`int`, *optional*, defaults to 32):
255
+ The number of query tokens passed through the Transformer.
256
+
257
+ kwargs (*optional*):
258
+ Dictionary of keyword arguments.
259
+ """
260
+
261
+ model_type = "husky"
262
+ is_composition = True
263
+
264
+ def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
265
+ super().__init__(**kwargs)
266
+
267
+ if vision_config is None:
268
+ vision_config = {}
269
+ logger.info("vision_config is None. initializing the HuskyVisionConfig with default values.")
270
+
271
+ if qformer_config is None:
272
+ qformer_config = {}
273
+ logger.info("qformer_config is None. Initializing the HuskyQFormerConfig with default values.")
274
+
275
+ if text_config is None:
276
+ text_config = {}
277
+ logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
278
+
279
+ self.vision_config = HuskyVisionConfig(**vision_config)
280
+ self.qformer_config = HuskyQFormerConfig(**qformer_config)
281
+ text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
282
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
283
+
284
+ self.tie_word_embeddings = self.text_config.tie_word_embeddings
285
+ self.is_encoder_decoder = self.text_config.is_encoder_decoder
286
+
287
+ self.num_query_tokens = num_query_tokens
288
+ self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
289
+ self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
290
+ self.initializer_factor = 1.0
291
+ self.initializer_range = 0.02
292
+
293
+ @classmethod
294
+ def from_vision_qformer_text_configs(
295
+ cls,
296
+ vision_config: HuskyVisionConfig,
297
+ qformer_config: HuskyQFormerConfig,
298
+ text_config: PretrainedConfig,
299
+ **kwargs,
300
+ ):
301
+ r"""
302
+ Instantiate a [`HuskyConfig`] (or a derived class) from a Husky vision model, Q-Former and
303
+ language model configurations.
304
+
305
+ Returns:
306
+ [`HuskyConfig`]: An instance of a configuration object
307
+ """
308
+
309
+ return cls(
310
+ vision_config=vision_config.to_dict(),
311
+ qformer_config=qformer_config.to_dict(),
312
+ text_config=text_config.to_dict(),
313
+ **kwargs,
314
+ )
315
+
316
+ def to_dict(self):
317
+ """
318
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
319
+
320
+ Returns:
321
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
322
+ """
323
+ output = copy.deepcopy(self.__dict__)
324
+ output["vision_config"] = self.vision_config.to_dict()
325
+ output["qformer_config"] = self.qformer_config.to_dict()
326
+ output["text_config"] = self.text_config.to_dict()
327
+ output["model_type"] = self.__class__.model_type
328
+ return output
329
+
330
+ if __name__ == '__main__':
331
+ config = HuskyConfig.from_pretrain
robohusky/model/configuration_husky_ori.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Husky model configuration"""
16
+
17
+ import copy
18
+ import os
19
+ from typing import Union
20
+
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
23
+ from transformers.utils import logging
24
+
25
+ from transformers.models.auto import CONFIG_MAPPING
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ HUSKY_PRETRAINED_CONFIG_ARCHIVE_MAP = {
30
+ "wofmanaf/husky-7b": "https://huggingface.co/wofmanaf/husky-7b/resolve/main/config.json",
31
+ }
32
+
33
+ class HuskyVisionConfig(PretrainedConfig):
34
+ r"""
35
+ This is the configuration class to store the configuration of a [`HuskyVisionModel`]. It is used to
36
+ instantiate a Husky vision encoder according to the specified arguments, defining the model architecture.
37
+ Instantiating a configuration defaults will yield a similar configuration to that of the Husky architecture.
38
+
39
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40
+ documentation from [`PretrainedConfig`] for more information.
41
+
42
+ Args:
43
+ hidden_size (`int`, *optional*, defaults to 1408):
44
+ Dimensionality of the encoder layers and the pooler layer.
45
+ intermediate_size (`int`, *optional*, defaults to 6144):
46
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
47
+ num_hidden_layers (`int`, *optional*, defaults to 39):
48
+ Number of hidden layers in the Transformer encoder.
49
+ num_attention_heads (`int`, *optional*, defaults to 16):
50
+ Number of attention heads for each attention layer in the Transformer encoder.
51
+ image_size (`int`, *optional*, defaults to 224):
52
+ The size (resolution) of each image.
53
+ patch_size (`int`, *optional*, defaults to 14):
54
+ The size (resolution) of each patch.
55
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
56
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
57
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults
58
+ to 1e-5): The epsilon used by the layer normalization layers.
59
+ dropout (`float`, *optional*, defaults to 0.0):
60
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
61
+ attention_dropout (`float`, *optional*, defaults to 0.0):
62
+ The dropout ratio for the attention probabilities.
63
+ initializer_range (`float`, *optional*, defaults to 0.02):
64
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
65
+ initializer_factor (`float``, *optional*, defaults to 1):
66
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
67
+ testing).
68
+ qkv_bias (`bool`, *optional*, defaults to `True`):
69
+ Whether to add a bias to the queries and values in the self-attention layers.
70
+ """
71
+
72
+ model_type = "husky_vision_model"
73
+
74
+ def __init__(
75
+ self,
76
+ hidden_size=1408,
77
+ intermediate_size=6144,
78
+ projection_dim=512,
79
+ num_hidden_layers=39,
80
+ num_attention_heads=16,
81
+ num_channels=3,
82
+ image_size=224,
83
+ patch_size=14,
84
+ hidden_act="gelu",
85
+ layer_norm_eps=0.00001,
86
+ dropout=0.0,
87
+ attention_dropout=0.0,
88
+ initializer_range=1e-10,
89
+ initializer_factor=1.0,
90
+ qkv_bias=True,
91
+ **kwargs,
92
+ ):
93
+ super().__init__(**kwargs)
94
+
95
+ self.hidden_size = hidden_size
96
+ self.intermediate_size = intermediate_size
97
+ self.projection_dim = projection_dim
98
+ self.dropout = dropout
99
+ self.num_hidden_layers = num_hidden_layers
100
+ self.num_attention_heads = num_attention_heads
101
+ self.num_channels = num_channels
102
+ self.patch_size = patch_size
103
+ self.image_size = image_size
104
+ self.initializer_range = initializer_range
105
+ self.initializer_factor = initializer_factor
106
+ self.attention_dropout = attention_dropout
107
+ self.layer_norm_eps = layer_norm_eps
108
+ self.hidden_act = hidden_act
109
+ self.qkv_bias = qkv_bias
110
+
111
+ @classmethod
112
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
113
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
114
+
115
+ # get the vision config dict if we are loading from HuskyConfig
116
+ if config_dict.get("model_type") == "husky":
117
+ config_dict = config_dict["vision_config"]
118
+
119
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
120
+ logger.warning(
121
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
122
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
123
+ )
124
+
125
+ return cls.from_dict(config_dict, **kwargs)
126
+
127
+ class HuskyQFormerConfig(PretrainedConfig):
128
+ r"""
129
+ This is the configuration class to store the configuration of a [`HuskyQFormerModel`]. It is used to
130
+ instantiate a Husky Querying Transformer (Q-Former) model according to the specified arguments, defining the
131
+ model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
132
+ the Husky [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5)
133
+ architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
134
+ Read the documentation from [`PretrainedConfig`] for more information.
135
+
136
+ Note that [`HuskyQFormerModel`] is very similar to [`BertLMHeadModel`] with interleaved cross-attention.
137
+
138
+ Args:
139
+ vocab_size (`int`, *optional*, defaults to 30522):
140
+ Vocabulary size of the Q-Former model. Defines the number of different tokens that can be represented by
141
+ the `inputs_ids` passed when calling the model.
142
+ hidden_size (`int`, *optional*, defaults to 768):
143
+ Dimensionality of the encoder layers and the pooler layer.
144
+ num_hidden_layers (`int`, *optional*, defaults to 12):
145
+ Number of hidden layers in the Transformer encoder.
146
+ num_attention_heads (`int`, *optional*, defaults to 12):
147
+ Number of attention heads for each attention layer in the Transformer encoder.
148
+ intermediate_size (`int`, *optional*, defaults to 3072):
149
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
150
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
151
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
152
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
153
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
154
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
155
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
156
+ The dropout ratio for the attention probabilities.
157
+ max_position_embeddings (`int`, *optional*, defaults to 512):
158
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
159
+ just in case (e.g., 512 or 1024 or 2048).
160
+ initializer_range (`float`, *optional*, defaults to 0.02):
161
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
162
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
163
+ The epsilon used by the layer normalization layers.
164
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
165
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
166
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
167
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
168
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
169
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
170
+ classifier_dropout (`float`, *optional*):
171
+ The dropout ratio for the classification head.
172
+ cross_attention_frequency (`int`, *optional*, defaults to 2):
173
+ The frequency of adding cross-attention to the Transformer layers.
174
+ encoder_hidden_size (`int`, *optional*, defaults to 1408):
175
+ The hidden size of the hidden states for cross-attention.
176
+ """
177
+ model_type = "husky_qformer"
178
+
179
+ def __init__(
180
+ self,
181
+ vocab_size=30522,
182
+ hidden_size=768,
183
+ num_hidden_layers=12,
184
+ num_attention_heads=12,
185
+ intermediate_size=3072,
186
+ hidden_act="gelu",
187
+ hidden_dropout_prob=0.1,
188
+ attention_probs_dropout_prob=0.1,
189
+ max_position_embeddings=512,
190
+ initializer_range=0.02,
191
+ layer_norm_eps=1e-12,
192
+ pad_token_id=0,
193
+ position_embedding_type="absolute",
194
+ classifier_dropout=None,
195
+ cross_attention_frequency=2,
196
+ encoder_hidden_size=1408,
197
+ **kwargs,
198
+ ):
199
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
200
+
201
+ self.vocab_size = vocab_size
202
+ self.hidden_size = hidden_size
203
+ self.num_hidden_layers = num_hidden_layers
204
+ self.num_attention_heads = num_attention_heads
205
+ self.hidden_act = hidden_act
206
+ self.intermediate_size = intermediate_size
207
+ self.hidden_dropout_prob = hidden_dropout_prob
208
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
209
+ self.max_position_embeddings = max_position_embeddings
210
+ self.initializer_range = initializer_range
211
+ self.layer_norm_eps = layer_norm_eps
212
+ self.position_embedding_type = position_embedding_type
213
+ self.classifier_dropout = classifier_dropout
214
+ self.cross_attention_frequency = cross_attention_frequency
215
+ self.encoder_hidden_size = encoder_hidden_size
216
+
217
+ @classmethod
218
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
219
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
220
+ # get the qformer config dict if we are loading from HuskyConfig
221
+ if config_dict.get("model_type") == "husky":
222
+ config_dict = config_dict["qformer_config"]
223
+
224
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
225
+ logger.warning(
226
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
227
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
228
+ )
229
+
230
+ return cls.from_dict(config_dict, **kwargs)
231
+
232
+ class HuskyConfig(PretrainedConfig):
233
+ r"""
234
+ [`HuskyConfig`] is the configuration class to store the configuration of a
235
+ [`HuskyForConditionalGeneration`]. It is used to instantiate a Husky model according to the specified
236
+ arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
237
+ the defaults will yield a similar configuration to that of the Husky
238
+ [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
239
+
240
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
241
+ documentation from [`PretrainedConfig`] for more information.
242
+
243
+ Args:
244
+ vision_config (`dict`, *optional*):
245
+ Dictionary of configuration options used to initialize [`HuskyVisionConfig`].
246
+ qformer_config (`dict`, *optional*):
247
+ Dictionary of configuration options used to initialize [`HuskyQFormerConfig`].
248
+ text_config (`dict`, *optional*):
249
+ Dictionary of configuration options used to initialize any [`PretrainedConfig`].
250
+ num_query_tokens (`int`, *optional*, defaults to 32):
251
+ The number of query tokens passed through the Transformer.
252
+
253
+ kwargs (*optional*):
254
+ Dictionary of keyword arguments.
255
+ """
256
+
257
+ model_type = "husky"
258
+ is_composition = True
259
+
260
+ def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
261
+ super().__init__(**kwargs)
262
+
263
+ if vision_config is None:
264
+ vision_config = {}
265
+ logger.info("vision_config is None. initializing the HuskyVisionConfig with default values.")
266
+
267
+ if qformer_config is None:
268
+ qformer_config = {}
269
+ logger.info("qformer_config is None. Initializing the HuskyQFormerConfig with default values.")
270
+
271
+ if text_config is None:
272
+ text_config = {}
273
+ logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
274
+
275
+ self.vision_config = HuskyVisionConfig(**vision_config)
276
+ self.qformer_config = HuskyQFormerConfig(**qformer_config)
277
+ text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
278
+ self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
279
+
280
+ self.tie_word_embeddings = self.text_config.tie_word_embeddings
281
+ self.is_encoder_decoder = self.text_config.is_encoder_decoder
282
+
283
+ self.num_query_tokens = num_query_tokens
284
+ self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
285
+ self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
286
+ self.initializer_factor = 1.0
287
+ self.initializer_range = 0.02
288
+
289
+ @classmethod
290
+ def from_vision_qformer_text_configs(
291
+ cls,
292
+ vision_config: HuskyVisionConfig,
293
+ qformer_config: HuskyQFormerConfig,
294
+ text_config: PretrainedConfig,
295
+ **kwargs,
296
+ ):
297
+ r"""
298
+ Instantiate a [`HuskyConfig`] (or a derived class) from a Husky vision model, Q-Former and
299
+ language model configurations.
300
+
301
+ Returns:
302
+ [`HuskyConfig`]: An instance of a configuration object
303
+ """
304
+
305
+ return cls(
306
+ vision_config=vision_config.to_dict(),
307
+ qformer_config=qformer_config.to_dict(),
308
+ text_config=text_config.to_dict(),
309
+ **kwargs,
310
+ )
311
+
312
+ def to_dict(self):
313
+ """
314
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
315
+
316
+ Returns:
317
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
318
+ """
319
+ output = copy.deepcopy(self.__dict__)
320
+ output["vision_config"] = self.vision_config.to_dict()
321
+ output["qformer_config"] = self.qformer_config.to_dict()
322
+ output["text_config"] = self.text_config.to_dict()
323
+ output["model_type"] = self.__class__.model_type
324
+ return output
325
+
326
+ if __name__ == '__main__':
327
+ config = HuskyConfig.from_pretrain
robohusky/model/modeling_husky.py ADDED
@@ -0,0 +1,1820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Husky model."""
16
+
17
+ import contextlib
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Any, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutput,
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ BaseModelOutputWithPooling,
32
+ BaseModelOutputWithPoolingAndCrossAttentions,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
36
+ from transformers.utils import (
37
+ ModelOutput,
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ logging,
41
+ replace_return_docstrings,
42
+ )
43
+ from transformers import AutoModelForCausalLM, GenerationConfig
44
+
45
+ from .configuration_husky import HuskyConfig, HuskyQFormerConfig, HuskyVisionConfig
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ _CHECKPOINT_FOR_DOC = "wofmanaf/husky-7b"
50
+
51
+ HUSKY_PRETRAINED_MODEL_ARCHIVE_LIST = [
52
+ "wofmanaf/husky-7b",
53
+ ]
54
+
55
+ @dataclass
56
+ class HuskyForConditionalGenerationModelOutput(ModelOutput):
57
+ """
58
+ Class defining the outputs of [`HuskyForConditionalGeneration`].
59
+
60
+ Args:
61
+ loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
62
+ Language modeling loss from the language model.
63
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
64
+ Prediction scores of the language modeling head of the language model.
65
+ vision_outputs (`BaseModelOutputWithPooling`):
66
+ Outputs of the vision encoder.
67
+ qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
68
+ Outputs of the Q-Former (Querying Transformer).
69
+ language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
70
+ Outputs of the language model.
71
+ """
72
+
73
+ loss: Optional[Tuple[torch.FloatTensor]] = None
74
+ logits: Optional[Tuple[torch.FloatTensor]] = None
75
+ vision_outputs: Optional[torch.FloatTensor] = None
76
+ qformer_outputs: Optional[Tuple[torch.FloatTensor]] = None
77
+ language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None
78
+
79
+ def to_tuple(self) -> Tuple[Any]:
80
+ return tuple(
81
+ self[k]
82
+ if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
83
+ else getattr(self, k).to_tuple()
84
+ for k in self.keys()
85
+ )
86
+
87
+ # Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Husky
88
+ class HuskyVisionEmbeddings(nn.Module):
89
+ def __init__(self, config: HuskyVisionConfig):
90
+ super().__init__()
91
+ self.config = config
92
+ self.embed_dim = config.hidden_size
93
+ self.image_size = config.image_size
94
+ self.patch_size = config.patch_size
95
+
96
+ self.class_embedding = nn.Parameter(
97
+ torch.randn(1, 1, self.embed_dim),
98
+ )
99
+
100
+ self.patch_embedding = nn.Conv2d(
101
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
102
+ )
103
+
104
+ self.num_patches = (self.image_size // self.patch_size) ** 2
105
+ self.num_positions = self.num_patches + 1
106
+
107
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
108
+
109
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
110
+ batch_size = pixel_values.shape[0]
111
+ target_dtype = self.patch_embedding.weight.dtype
112
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
113
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
114
+
115
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
116
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
117
+ embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
118
+ return embeddings
119
+
120
+ class HuskyVideoEmbeddings(nn.Module):
121
+ def __init__(self, config: HuskyVisionConfig):
122
+ super().__init__()
123
+ self.config = config
124
+ self.embed_dim = config.hidden_size
125
+ self.image_size = config.image_size
126
+ self.patch_size = config.patch_size
127
+ self.num_frames = getattr(self.config, "num_frames", 8)
128
+ self.frame_stride = getattr(self.config, "frame_stride", 2)
129
+
130
+ self.class_embedding = nn.Parameter(
131
+ torch.randn(1, 1, self.embed_dim),
132
+ )
133
+
134
+ self.patch_embedding = nn.Conv3d(
135
+ in_channels=3, out_channels=self.embed_dim,
136
+ kernel_size=(self.frame_stride, self.patch_size, self.patch_size),
137
+ stride=(self.frame_stride, self.patch_size, self.patch_size)
138
+ )
139
+
140
+ self.num_patches = int(self.num_frames // self.frame_stride) * (self.image_size // self.patch_size) ** 2
141
+ self.num_positions = self.num_patches + 1
142
+
143
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
144
+
145
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
146
+ batch_size = pixel_values.shape[0]
147
+ target_dtype = self.patch_embedding.weight.dtype
148
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
149
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
150
+
151
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
152
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
153
+ embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
154
+ return embeddings
155
+
156
+ class HuskyAttention(nn.Module):
157
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
158
+
159
+ def __init__(self, config):
160
+ super().__init__()
161
+ self.config = config
162
+ self.embed_dim = config.hidden_size
163
+ self.num_heads = config.num_attention_heads
164
+ self.head_dim = self.embed_dim // self.num_heads
165
+ if self.head_dim * self.num_heads != self.embed_dim:
166
+ raise ValueError(
167
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
168
+ f" {self.num_heads})."
169
+ )
170
+ self.scale = self.head_dim ** -0.5
171
+ self.dropout = nn.Dropout(config.attention_dropout)
172
+
173
+ # small tweak here compared to CLIP, no bias here
174
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
175
+
176
+ if config.qkv_bias:
177
+ q_bias = nn.Parameter(torch.zeros(self.embed_dim))
178
+ v_bias = nn.Parameter(torch.zeros(self.embed_dim))
179
+ else:
180
+ q_bias = None
181
+ v_bias = None
182
+
183
+ if q_bias is not None:
184
+ qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
185
+ self.qkv.bias = nn.Parameter(qkv_bias)
186
+
187
+ self.projection = nn.Linear(self.embed_dim, self.embed_dim)
188
+
189
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
190
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
191
+
192
+ def forward(
193
+ self,
194
+ hidden_states: torch.Tensor,
195
+ head_mask: Optional[torch.Tensor] = None,
196
+ output_attentions: Optional[bool] = False,
197
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
198
+ """Input shape: Batch x Time x Channel"""
199
+
200
+ bsz, tgt_len, embed_dim = hidden_states.size()
201
+
202
+ mixed_qkv = self.qkv(hidden_states)
203
+
204
+ mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
205
+ 2, 0, 3, 1, 4
206
+ )
207
+ query_states, key_states, value_states = (
208
+ mixed_qkv[0],
209
+ mixed_qkv[1],
210
+ mixed_qkv[2],
211
+ )
212
+
213
+ # Take the dot product between "query" and "key" to get the raw attention scores.
214
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
215
+
216
+ attention_scores = attention_scores * self.scale
217
+
218
+ # Normalize the attention scores to probabilities.
219
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
220
+
221
+ # This is actually dropping out entire tokens to attend to, which might
222
+ # seem a bit unusual, but is taken from the original Transformer paper.
223
+ attention_probs = self.dropout(attention_probs)
224
+
225
+ # Mask heads if we want to
226
+ if head_mask is not None:
227
+ attention_probs = attention_probs * head_mask
228
+
229
+ context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
230
+
231
+ new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
232
+ context_layer = context_layer.reshape(new_context_layer_shape)
233
+
234
+ output = self.projection(context_layer)
235
+
236
+ outputs = (output, attention_probs) if output_attentions else (output, None)
237
+
238
+ return outputs
239
+
240
+ # Copied from transformers.models.blip.modeling_blip.BlipMLP
241
+ class HuskyMLP(nn.Module):
242
+ def __init__(self, config):
243
+ super().__init__()
244
+ self.config = config
245
+ self.activation_fn = ACT2FN[config.hidden_act]
246
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
247
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
248
+
249
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
250
+ hidden_states = self.fc1(hidden_states)
251
+ hidden_states = self.activation_fn(hidden_states)
252
+ hidden_states = self.fc2(hidden_states)
253
+ return hidden_states
254
+
255
+ # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Husky
256
+ class HuskyEncoderLayer(nn.Module):
257
+ def __init__(self, config: HuskyConfig):
258
+ super().__init__()
259
+ self.embed_dim = config.hidden_size
260
+ self.self_attn = HuskyAttention(config)
261
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
262
+ self.mlp = HuskyMLP(config)
263
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
264
+
265
+ def forward(
266
+ self,
267
+ hidden_states: torch.Tensor,
268
+ attention_mask: torch.Tensor,
269
+ output_attentions: Optional[bool] = False,
270
+ ) -> Tuple[torch.FloatTensor]:
271
+ """
272
+ Args:
273
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
274
+ attention_mask (`torch.FloatTensor`): attention mask of size
275
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
276
+ `(config.encoder_attention_heads,)`.
277
+ output_attentions (`bool`, *optional*):
278
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
279
+ returned tensors for more detail.
280
+ """
281
+ residual = hidden_states
282
+
283
+ hidden_states = self.layer_norm1(hidden_states)
284
+ hidden_states, attn_weights = self.self_attn(
285
+ hidden_states=hidden_states,
286
+ head_mask=attention_mask,
287
+ output_attentions=output_attentions,
288
+ )
289
+ hidden_states = hidden_states + residual
290
+ residual = hidden_states
291
+ hidden_states = self.layer_norm2(hidden_states)
292
+ hidden_states = self.mlp(hidden_states)
293
+
294
+ hidden_states = hidden_states + residual
295
+
296
+ outputs = (hidden_states,)
297
+
298
+ if output_attentions:
299
+ outputs += (attn_weights,)
300
+
301
+ return outputs
302
+
303
+ class HuskyPreTrainedModel(PreTrainedModel):
304
+ """
305
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
306
+ models.
307
+ """
308
+
309
+ config_class = HuskyConfig
310
+ base_model_prefix = "husky"
311
+ supports_gradient_checkpointing = True
312
+ _keys_to_ignore_on_load_missing = [
313
+ r"position_ids",
314
+ r"language_model.encoder.embed_tokens.weight",
315
+ r"language_model.decoder.embed_tokens.weight",
316
+ r"language_model.lm_head.weight",
317
+ ]
318
+ _no_split_modules = ["HuskyAttention", "LlamaDecoderLayer", "LlamaForCausalLM"]
319
+ _skip_keys_device_placement = "past_key_values"
320
+ _keep_in_fp32_modules = ["wo"]
321
+
322
+ def _init_weights(self, module):
323
+ """Initialize the weights"""
324
+ factor = self.config.initializer_range
325
+ if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
326
+ module.weight.data.normal_(mean=0.0, std=factor)
327
+ if hasattr(module, "bias") and module.bias is not None:
328
+ module.bias.data.zero_()
329
+
330
+ if isinstance(module, HuskyVisionEmbeddings):
331
+ if hasattr(self.config, "vision_config"):
332
+ factor = self.config.vision_config.initializer_range
333
+ nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
334
+ nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
335
+
336
+ elif isinstance(module, nn.LayerNorm):
337
+ module.bias.data.zero_()
338
+ module.weight.data.fill_(1.0)
339
+ elif isinstance(module, nn.Linear) and module.bias is not None:
340
+ module.bias.data.zero_()
341
+
342
+ def _set_gradient_checkpointing(self, module, value=False):
343
+ if isinstance(module, HuskyEncoder):
344
+ module.gradient_checkpointing = value
345
+
346
+ Husky_START_DOCSTRING = r"""
347
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
348
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
349
+ etc.)
350
+
351
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
352
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
353
+ and behavior.
354
+
355
+ Parameters:
356
+ config ([`HuskyConfig`]): Model configuration class with all the parameters of the model.
357
+ Initializing with a config file does not load the weights associated with the model, only the
358
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
359
+ """
360
+
361
+ Husky_VISION_INPUTS_DOCSTRING = r"""
362
+ Args:
363
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
364
+ Pixel values. Pixel values can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for
365
+ details.
366
+ output_attentions (`bool`, *optional*):
367
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
368
+ tensors for more detail.
369
+ output_hidden_states (`bool`, *optional*):
370
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
371
+ more detail.
372
+ return_dict (`bool`, *optional*):
373
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
374
+ """
375
+
376
+ Husky_TEXT_INPUTS_DOCSTRING = r"""
377
+ Args:
378
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
379
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
380
+ it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
381
+ [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
382
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
383
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
384
+ - 1 for tokens that are **not masked**,
385
+ - 0 for tokens that are **masked**.
386
+ [What are attention masks?](../glossary#attention-mask)
387
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
388
+ Indices of decoder input sequence tokens in the vocabulary.
389
+
390
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
391
+ [`PreTrainedTokenizer.__call__`] for details.
392
+
393
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
394
+
395
+ T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
396
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
397
+
398
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
399
+ Training](./t5#training).
400
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
401
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
402
+ be used by default.
403
+ output_attentions (`bool`, *optional*):
404
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
405
+ tensors for more detail.
406
+ output_hidden_states (`bool`, *optional*):
407
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
408
+ more detail.
409
+ return_dict (`bool`, *optional*):
410
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
411
+ """
412
+
413
+ Husky_INPUTS_DOCSTRING = r"""
414
+ Args:
415
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
416
+ Pixel values. Pixel values can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for
417
+ details.
418
+
419
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
420
+ Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
421
+ provided to serve as text prompt, which the language model can continue.
422
+
423
+ Indices can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for details.
424
+
425
+ [What are input IDs?](../glossary#input-ids)
426
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
427
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
428
+
429
+ - 1 for tokens that are **not masked**,
430
+ - 0 for tokens that are **masked**.
431
+
432
+ [What are attention masks?](../glossary#attention-mask)
433
+
434
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
435
+ Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
436
+ encoder-decoder language model (like T5) is used.
437
+
438
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
439
+ [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
440
+
441
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
442
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
443
+ be used by default.
444
+
445
+ Only relevant in case an encoder-decoder language model (like T5) is used.
446
+
447
+ output_attentions (`bool`, *optional*):
448
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
449
+ tensors for more detail.
450
+ output_hidden_states (`bool`, *optional*):
451
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
452
+ more detail.
453
+ return_dict (`bool`, *optional*):
454
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
455
+ """
456
+
457
+ # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Husky
458
+ class HuskyEncoder(nn.Module):
459
+ """
460
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
461
+ [`HuskyEncoderLayer`].
462
+
463
+ Args:
464
+ config (`HuskyConfig`):
465
+ The corresponding vision configuration for the `HuskyEncoder`.
466
+ """
467
+
468
+ def __init__(self, config: HuskyConfig):
469
+ super().__init__()
470
+ self.config = config
471
+ self.layers = nn.ModuleList([HuskyEncoderLayer(config) for _ in range(config.num_hidden_layers)])
472
+ self.gradient_checkpointing = False
473
+
474
+ def forward(
475
+ self,
476
+ inputs_embeds,
477
+ attention_mask: Optional[torch.Tensor] = None,
478
+ output_attentions: Optional[bool] = None,
479
+ output_hidden_states: Optional[bool] = None,
480
+ return_dict: Optional[bool] = None,
481
+ ) -> Union[Tuple, BaseModelOutput]:
482
+ r"""
483
+ Args:
484
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
485
+ Embedded representation of the inputs. Should be float, not int tokens.
486
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
487
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
488
+
489
+ - 1 for tokens that are **not masked**,
490
+ - 0 for tokens that are **masked**.
491
+
492
+ [What are attention masks?](../glossary#attention-mask)
493
+ output_attentions (`bool`, *optional*):
494
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
495
+ returned tensors for more detail.
496
+ output_hidden_states (`bool`, *optional*):
497
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
498
+ for more detail.
499
+ return_dict (`bool`, *optional*):
500
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
501
+ """
502
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
503
+ output_hidden_states = (
504
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
505
+ )
506
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
507
+
508
+ encoder_states = () if output_hidden_states else None
509
+ all_attentions = () if output_attentions else None
510
+
511
+ hidden_states = inputs_embeds
512
+ for idx, encoder_layer in enumerate(self.layers):
513
+ if output_hidden_states:
514
+ encoder_states = encoder_states + (hidden_states,)
515
+ if self.gradient_checkpointing and self.training:
516
+
517
+ def create_custom_forward(module):
518
+ def custom_forward(*inputs):
519
+ return module(*inputs, output_attentions)
520
+
521
+ return custom_forward
522
+
523
+ layer_outputs = torch.utils.checkpoint.checkpoint(
524
+ create_custom_forward(encoder_layer),
525
+ hidden_states,
526
+ attention_mask,
527
+ )
528
+ else:
529
+ layer_outputs = encoder_layer(
530
+ hidden_states,
531
+ attention_mask,
532
+ output_attentions=output_attentions,
533
+ )
534
+
535
+ hidden_states = layer_outputs[0]
536
+
537
+ if output_attentions:
538
+ all_attentions = all_attentions + (layer_outputs[1],)
539
+
540
+ if output_hidden_states:
541
+ encoder_states = encoder_states + (hidden_states,)
542
+
543
+ if not return_dict:
544
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
545
+ return BaseModelOutput(
546
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
547
+ )
548
+
549
+ # Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Husky, BLIP->Husky
550
+ class HuskyVisionModel(HuskyPreTrainedModel):
551
+ main_input_name = "pixel_values"
552
+ config_class = HuskyVisionConfig
553
+
554
+ def __init__(self, config: HuskyVisionConfig):
555
+ super().__init__(config)
556
+ self.config = config
557
+ embed_dim = config.hidden_size
558
+
559
+ self.embeddings = HuskyVisionEmbeddings(config)
560
+ self.video_embeddings = HuskyVideoEmbeddings(config)
561
+
562
+ self.encoder = HuskyEncoder(config)
563
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
564
+
565
+ self.post_init()
566
+
567
+ @add_start_docstrings_to_model_forward(Husky_VISION_INPUTS_DOCSTRING)
568
+ # @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=HuskyVisionConfig)
569
+ def forward(
570
+ self,
571
+ pixel_values: Optional[torch.FloatTensor] = None,
572
+ output_attentions: Optional[bool] = None,
573
+ output_hidden_states: Optional[bool] = None,
574
+ return_dict: Optional[bool] = None,
575
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
576
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
577
+ output_hidden_states = (
578
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
579
+ )
580
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
581
+
582
+ if pixel_values is None:
583
+ raise ValueError("You have to specify pixel_values")
584
+
585
+ if len(pixel_values.shape) == 4:
586
+ hidden_states = self.embeddings(pixel_values)
587
+ elif len(pixel_values.shape) == 5:
588
+ hidden_states = self.video_embeddings(pixel_values)
589
+ else:
590
+ raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
591
+
592
+ encoder_outputs = self.encoder(
593
+ inputs_embeds=hidden_states,
594
+ output_attentions=output_attentions,
595
+ output_hidden_states=output_hidden_states,
596
+ return_dict=return_dict,
597
+ )
598
+
599
+ last_hidden_state = encoder_outputs[0]
600
+ last_hidden_state = self.post_layernorm(last_hidden_state)
601
+
602
+ pooled_output = last_hidden_state[:, 0, :]
603
+ pooled_output = self.post_layernorm(pooled_output)
604
+
605
+ if not return_dict:
606
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
607
+
608
+ return BaseModelOutputWithPooling(
609
+ last_hidden_state=last_hidden_state,
610
+ pooler_output=pooled_output,
611
+ hidden_states=encoder_outputs.hidden_states,
612
+ attentions=encoder_outputs.attentions,
613
+ )
614
+
615
+ def get_input_embeddings(self):
616
+ return self.embeddings
617
+
618
+ def get_video_embeddings(self):
619
+ return self.video_embeddings
620
+
621
+ class HuskyQFormerMultiHeadAttention(nn.Module):
622
+ def __init__(self, config, is_cross_attention=False):
623
+ super().__init__()
624
+ self.config = config
625
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
626
+ raise ValueError(
627
+ "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
628
+ % (config.hidden_size, config.num_attention_heads)
629
+ )
630
+
631
+ self.num_attention_heads = config.num_attention_heads
632
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
633
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
634
+
635
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
636
+ if is_cross_attention:
637
+ self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
638
+ self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
639
+ else:
640
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
641
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
642
+
643
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
644
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
645
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
646
+ self.max_position_embeddings = config.max_position_embeddings
647
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
648
+ self.save_attention = False
649
+
650
+ def save_attn_gradients(self, attn_gradients):
651
+ self.attn_gradients = attn_gradients
652
+
653
+ def get_attn_gradients(self):
654
+ return self.attn_gradients
655
+
656
+ def save_attention_map(self, attention_map):
657
+ self.attention_map = attention_map
658
+
659
+ def get_attention_map(self):
660
+ return self.attention_map
661
+
662
+ def transpose_for_scores(self, x):
663
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
664
+ x = x.view(*new_x_shape)
665
+ return x.permute(0, 2, 1, 3)
666
+
667
+ def forward(
668
+ self,
669
+ hidden_states,
670
+ attention_mask=None,
671
+ head_mask=None,
672
+ encoder_hidden_states=None,
673
+ encoder_attention_mask=None,
674
+ past_key_value=None,
675
+ output_attentions=False,
676
+ ):
677
+ # If this is instantiated as a cross-attention module, the keys
678
+ # and values come from an encoder; the attention mask needs to be
679
+ # such that the encoder's padding tokens are not attended to.
680
+ is_cross_attention = encoder_hidden_states is not None
681
+
682
+ if is_cross_attention:
683
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
684
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
685
+ attention_mask = encoder_attention_mask
686
+ elif past_key_value is not None:
687
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
688
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
689
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
690
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
691
+ else:
692
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
693
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
694
+
695
+ mixed_query_layer = self.query(hidden_states)
696
+
697
+ query_layer = self.transpose_for_scores(mixed_query_layer)
698
+
699
+ past_key_value = (key_layer, value_layer)
700
+
701
+ # Take the dot product between "query" and "key" to get the raw attention scores.
702
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
703
+
704
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
705
+ seq_length = hidden_states.size()[1]
706
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
707
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
708
+ distance = position_ids_l - position_ids_r
709
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
710
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
711
+
712
+ if self.position_embedding_type == "relative_key":
713
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
714
+ attention_scores = attention_scores + relative_position_scores
715
+ elif self.position_embedding_type == "relative_key_query":
716
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
717
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
718
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
719
+
720
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
721
+
722
+ if attention_mask is not None:
723
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
724
+ attention_scores = attention_scores + attention_mask
725
+
726
+ # Normalize the attention scores to probabilities.
727
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
728
+
729
+ if is_cross_attention and self.save_attention:
730
+ self.save_attention_map(attention_probs)
731
+ attention_probs.register_hook(self.save_attn_gradients)
732
+
733
+ # This is actually dropping out entire tokens to attend to, which might
734
+ # seem a bit unusual, but is taken from the original Transformer paper.
735
+ attention_probs_dropped = self.dropout(attention_probs)
736
+
737
+ # Mask heads if we want to
738
+ if head_mask is not None:
739
+ attention_probs_dropped = attention_probs_dropped * head_mask
740
+
741
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
742
+
743
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
744
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
745
+ context_layer = context_layer.view(*new_context_layer_shape)
746
+
747
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
748
+
749
+ outputs = outputs + (past_key_value,)
750
+ return outputs
751
+
752
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->HuskyQFormer
753
+ class HuskyQFormerSelfOutput(nn.Module):
754
+ def __init__(self, config):
755
+ super().__init__()
756
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
757
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
758
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
759
+
760
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
761
+ hidden_states = self.dense(hidden_states)
762
+ hidden_states = self.dropout(hidden_states)
763
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
764
+ return hidden_states
765
+
766
+ class HuskyQFormerAttention(nn.Module):
767
+ def __init__(self, config, is_cross_attention=False):
768
+ super().__init__()
769
+ self.attention = HuskyQFormerMultiHeadAttention(config, is_cross_attention)
770
+ self.output = HuskyQFormerSelfOutput(config)
771
+ self.pruned_heads = set()
772
+
773
+ def prune_heads(self, heads):
774
+ if len(heads) == 0:
775
+ return
776
+ heads, index = find_pruneable_heads_and_indices(
777
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
778
+ )
779
+
780
+ # Prune linear layers
781
+ self.attention.query = prune_linear_layer(self.attention.query, index)
782
+ self.attention.key = prune_linear_layer(self.attention.key, index)
783
+ self.attention.value = prune_linear_layer(self.attention.value, index)
784
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
785
+
786
+ # Update hyper params and store pruned heads
787
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
788
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
789
+ self.pruned_heads = self.pruned_heads.union(heads)
790
+
791
+ def forward(
792
+ self,
793
+ hidden_states: torch.Tensor,
794
+ attention_mask: Optional[torch.FloatTensor] = None,
795
+ head_mask: Optional[torch.FloatTensor] = None,
796
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
797
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
798
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
799
+ output_attentions: Optional[bool] = False,
800
+ ) -> Tuple[torch.Tensor]:
801
+ self_outputs = self.attention(
802
+ hidden_states,
803
+ attention_mask,
804
+ head_mask,
805
+ encoder_hidden_states,
806
+ encoder_attention_mask,
807
+ past_key_value,
808
+ output_attentions,
809
+ )
810
+ attention_output = self.output(self_outputs[0], hidden_states)
811
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
812
+ return outputs
813
+
814
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->HuskyQFormer
815
+ class HuskyQFormerIntermediate(nn.Module):
816
+ def __init__(self, config):
817
+ super().__init__()
818
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
819
+ if isinstance(config.hidden_act, str):
820
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
821
+ else:
822
+ self.intermediate_act_fn = config.hidden_act
823
+
824
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
825
+ hidden_states = self.dense(hidden_states)
826
+ hidden_states = self.intermediate_act_fn(hidden_states)
827
+ return hidden_states
828
+
829
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->HuskyQFormer
830
+ class HuskyQFormerOutput(nn.Module):
831
+ def __init__(self, config):
832
+ super().__init__()
833
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
834
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
835
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
836
+
837
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
838
+ hidden_states = self.dense(hidden_states)
839
+ hidden_states = self.dropout(hidden_states)
840
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
841
+ return hidden_states
842
+
843
+ class HuskyQFormerLayer(nn.Module):
844
+ def __init__(self, config, layer_idx):
845
+ super().__init__()
846
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
847
+ self.seq_len_dim = 1
848
+ self.attention = HuskyQFormerAttention(config)
849
+
850
+ self.layer_idx = layer_idx
851
+
852
+ if layer_idx % config.cross_attention_frequency == 0:
853
+ self.crossattention = HuskyQFormerAttention(config, is_cross_attention=True)
854
+ self.has_cross_attention = True
855
+ else:
856
+ self.has_cross_attention = False
857
+
858
+ self.intermediate_query = HuskyQFormerIntermediate(config)
859
+ self.output_query = HuskyQFormerOutput(config)
860
+
861
+ def forward(
862
+ self,
863
+ hidden_states,
864
+ attention_mask=None,
865
+ head_mask=None,
866
+ encoder_hidden_states=None,
867
+ encoder_attention_mask=None,
868
+ past_key_value=None,
869
+ output_attentions=False,
870
+ query_length=0,
871
+ ):
872
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
873
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
874
+ self_attention_outputs = self.attention(
875
+ hidden_states,
876
+ attention_mask,
877
+ head_mask,
878
+ output_attentions=output_attentions,
879
+ past_key_value=self_attn_past_key_value,
880
+ )
881
+ attention_output = self_attention_outputs[0]
882
+ outputs = self_attention_outputs[1:-1]
883
+
884
+ present_key_value = self_attention_outputs[-1]
885
+
886
+ if query_length > 0:
887
+ query_attention_output = attention_output[:, :query_length, :]
888
+
889
+ if self.has_cross_attention:
890
+ if encoder_hidden_states is None:
891
+ raise ValueError("encoder_hidden_states must be given for cross-attention layers")
892
+ cross_attention_outputs = self.crossattention(
893
+ query_attention_output,
894
+ attention_mask,
895
+ head_mask,
896
+ encoder_hidden_states,
897
+ encoder_attention_mask,
898
+ output_attentions=output_attentions,
899
+ )
900
+ query_attention_output = cross_attention_outputs[0]
901
+ # add cross attentions if we output attention weights
902
+ outputs = outputs + cross_attention_outputs[1:-1]
903
+
904
+ layer_output = apply_chunking_to_forward(
905
+ self.feed_forward_chunk_query,
906
+ self.chunk_size_feed_forward,
907
+ self.seq_len_dim,
908
+ query_attention_output,
909
+ )
910
+
911
+ if attention_output.shape[1] > query_length:
912
+ layer_output_text = apply_chunking_to_forward(
913
+ self.feed_forward_chunk,
914
+ self.chunk_size_feed_forward,
915
+ self.seq_len_dim,
916
+ attention_output[:, query_length:, :],
917
+ )
918
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
919
+ else:
920
+ layer_output = apply_chunking_to_forward(
921
+ self.feed_forward_chunk,
922
+ self.chunk_size_feed_forward,
923
+ self.seq_len_dim,
924
+ attention_output,
925
+ )
926
+ outputs = (layer_output,) + outputs
927
+
928
+ outputs = outputs + (present_key_value,)
929
+
930
+ return outputs
931
+
932
+ def feed_forward_chunk(self, attention_output):
933
+ intermediate_output = self.intermediate(attention_output)
934
+ layer_output = self.output(intermediate_output, attention_output)
935
+ return layer_output
936
+
937
+ def feed_forward_chunk_query(self, attention_output):
938
+ intermediate_output = self.intermediate_query(attention_output)
939
+ layer_output = self.output_query(intermediate_output, attention_output)
940
+ return layer_output
941
+
942
+ class HuskyQFormerEncoder(nn.Module):
943
+ def __init__(self, config):
944
+ super().__init__()
945
+ self.config = config
946
+ self.layer = nn.ModuleList(
947
+ [HuskyQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
948
+ )
949
+ self.gradient_checkpointing = False
950
+
951
+ def forward(
952
+ self,
953
+ hidden_states,
954
+ attention_mask=None,
955
+ head_mask=None,
956
+ encoder_hidden_states=None,
957
+ encoder_attention_mask=None,
958
+ past_key_values=None,
959
+ use_cache=None,
960
+ output_attentions=False,
961
+ output_hidden_states=False,
962
+ return_dict=True,
963
+ query_length=0,
964
+ ):
965
+ all_hidden_states = () if output_hidden_states else None
966
+ all_self_attentions = () if output_attentions else None
967
+ all_cross_attentions = () if output_attentions else None
968
+
969
+ next_decoder_cache = () if use_cache else None
970
+
971
+ for i in range(self.config.num_hidden_layers):
972
+ layer_module = self.layer[i]
973
+ if output_hidden_states:
974
+ all_hidden_states = all_hidden_states + (hidden_states,)
975
+
976
+ layer_head_mask = head_mask[i] if head_mask is not None else None
977
+ past_key_value = past_key_values[i] if past_key_values is not None else None
978
+
979
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
980
+ if use_cache:
981
+ logger.warn(
982
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
983
+ )
984
+ use_cache = False
985
+
986
+ def create_custom_forward(module):
987
+ def custom_forward(*inputs):
988
+ return module(*inputs, past_key_value, output_attentions, query_length)
989
+
990
+ return custom_forward
991
+
992
+ layer_outputs = torch.utils.checkpoint.checkpoint(
993
+ create_custom_forward(layer_module),
994
+ hidden_states,
995
+ attention_mask,
996
+ layer_head_mask,
997
+ encoder_hidden_states,
998
+ encoder_attention_mask,
999
+ )
1000
+ else:
1001
+ layer_outputs = layer_module(
1002
+ hidden_states,
1003
+ attention_mask,
1004
+ layer_head_mask,
1005
+ encoder_hidden_states,
1006
+ encoder_attention_mask,
1007
+ past_key_value,
1008
+ output_attentions,
1009
+ query_length,
1010
+ )
1011
+
1012
+ hidden_states = layer_outputs[0]
1013
+ if use_cache:
1014
+ next_decoder_cache += (layer_outputs[-1],)
1015
+ if output_attentions:
1016
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
1017
+ if layer_module.has_cross_attention:
1018
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
1019
+
1020
+ if output_hidden_states:
1021
+ all_hidden_states = all_hidden_states + (hidden_states,)
1022
+
1023
+ if not return_dict:
1024
+ return tuple(
1025
+ v
1026
+ for v in [
1027
+ hidden_states,
1028
+ next_decoder_cache,
1029
+ all_hidden_states,
1030
+ all_self_attentions,
1031
+ all_cross_attentions,
1032
+ ]
1033
+ if v is not None
1034
+ )
1035
+ return BaseModelOutputWithPastAndCrossAttentions(
1036
+ last_hidden_state=hidden_states,
1037
+ past_key_values=next_decoder_cache,
1038
+ hidden_states=all_hidden_states,
1039
+ attentions=all_self_attentions,
1040
+ cross_attentions=all_cross_attentions,
1041
+ )
1042
+
1043
+ class HuskyQFormerModel(HuskyPreTrainedModel):
1044
+ """
1045
+ Querying Transformer (Q-Former), used in Husky.
1046
+ """
1047
+
1048
+ def __init__(self, config: HuskyQFormerConfig):
1049
+ super().__init__(config)
1050
+ self.config = config
1051
+
1052
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1053
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1054
+
1055
+ self.encoder = HuskyQFormerEncoder(config)
1056
+
1057
+ self.post_init()
1058
+
1059
+ def get_input_embeddings(self):
1060
+ return self.embeddings.word_embeddings
1061
+
1062
+ def set_input_embeddings(self, value):
1063
+ self.embeddings.word_embeddings = value
1064
+
1065
+ def _prune_heads(self, heads_to_prune):
1066
+ """
1067
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1068
+ class PreTrainedModel
1069
+ """
1070
+ for layer, heads in heads_to_prune.items():
1071
+ self.encoder.layer[layer].attention.prune_heads(heads)
1072
+
1073
+ def get_extended_attention_mask(
1074
+ self,
1075
+ attention_mask: torch.Tensor,
1076
+ input_shape: Tuple[int],
1077
+ device: torch.device,
1078
+ has_query: bool = False,
1079
+ ) -> torch.Tensor:
1080
+ """
1081
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
1082
+
1083
+ Arguments:
1084
+ attention_mask (`torch.Tensor`):
1085
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
1086
+ input_shape (`Tuple[int]`):
1087
+ The shape of the input to the model.
1088
+ device (`torch.device`):
1089
+ The device of the input to the model.
1090
+
1091
+ Returns:
1092
+ `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
1093
+ """
1094
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1095
+ # ourselves in which case we just need to make it broadcastable to all heads.
1096
+ if attention_mask.dim() == 3:
1097
+ extended_attention_mask = attention_mask[:, None, :, :]
1098
+ elif attention_mask.dim() == 2:
1099
+ # Provided a padding mask of dimensions [batch_size, seq_length]
1100
+ # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
1101
+ extended_attention_mask = attention_mask[:, None, None, :]
1102
+ else:
1103
+ raise ValueError(
1104
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
1105
+ input_shape, attention_mask.shape
1106
+ )
1107
+ )
1108
+
1109
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1110
+ # masked positions, this operation will create a tensor which is 0.0 for
1111
+ # positions we want to attend and -10000.0 for masked positions.
1112
+ # Since we are adding it to the raw scores before the softmax, this is
1113
+ # effectively the same as removing these entirely.
1114
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
1115
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
1116
+ return extended_attention_mask
1117
+
1118
+ def forward(
1119
+ self,
1120
+ query_embeds,
1121
+ attention_mask=None,
1122
+ head_mask=None,
1123
+ encoder_hidden_states=None,
1124
+ encoder_attention_mask=None,
1125
+ past_key_values=None,
1126
+ use_cache=None,
1127
+ output_attentions=None,
1128
+ output_hidden_states=None,
1129
+ return_dict=None,
1130
+ ):
1131
+ r"""
1132
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
1133
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1134
+ the model is configured as a decoder.
1135
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
1136
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1137
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1138
+ - 1 for tokens that are **not masked**,
1139
+ - 0 for tokens that are **masked**.
1140
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
1141
+ shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
1142
+ value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
1143
+ used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
1144
+ value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
1145
+ `(batch_size, sequence_length)`.
1146
+ use_cache (`bool`, `optional`):
1147
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1148
+ `past_key_values`).
1149
+ """
1150
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1151
+ output_hidden_states = (
1152
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1153
+ )
1154
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1155
+
1156
+ # past_key_values_length
1157
+ past_key_values_length = (
1158
+ past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
1159
+ )
1160
+
1161
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
1162
+
1163
+ embedding_output = self.layernorm(query_embeds)
1164
+ embedding_output = self.dropout(embedding_output)
1165
+
1166
+ input_shape = embedding_output.size()[:-1]
1167
+ batch_size, seq_length = input_shape
1168
+ device = embedding_output.device
1169
+
1170
+ if attention_mask is None:
1171
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
1172
+
1173
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1174
+ # ourselves in which case we just need to make it broadcastable to all heads.
1175
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
1176
+
1177
+ # If a 2D or 3D attention mask is provided for the cross-attention
1178
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1179
+ if encoder_hidden_states is not None:
1180
+ if type(encoder_hidden_states) == list:
1181
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
1182
+ else:
1183
+ (
1184
+ encoder_batch_size,
1185
+ encoder_sequence_length,
1186
+ _,
1187
+ ) = encoder_hidden_states.size()
1188
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1189
+
1190
+ if type(encoder_attention_mask) == list:
1191
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
1192
+ elif encoder_attention_mask is None:
1193
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1194
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1195
+ else:
1196
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1197
+ else:
1198
+ encoder_extended_attention_mask = None
1199
+
1200
+ # Prepare head mask if needed
1201
+ # 1.0 in head_mask indicate we keep the head
1202
+ # attention_probs has shape bsz x n_heads x N x N
1203
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1204
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1205
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1206
+
1207
+ encoder_outputs = self.encoder(
1208
+ embedding_output,
1209
+ attention_mask=extended_attention_mask,
1210
+ head_mask=head_mask,
1211
+ encoder_hidden_states=encoder_hidden_states,
1212
+ encoder_attention_mask=encoder_extended_attention_mask,
1213
+ past_key_values=past_key_values,
1214
+ use_cache=use_cache,
1215
+ output_attentions=output_attentions,
1216
+ output_hidden_states=output_hidden_states,
1217
+ return_dict=return_dict,
1218
+ query_length=query_length,
1219
+ )
1220
+ sequence_output = encoder_outputs[0]
1221
+ pooled_output = sequence_output[:, 0, :]
1222
+
1223
+ if not return_dict:
1224
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1225
+
1226
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1227
+ last_hidden_state=sequence_output,
1228
+ pooler_output=pooled_output,
1229
+ past_key_values=encoder_outputs.past_key_values,
1230
+ hidden_states=encoder_outputs.hidden_states,
1231
+ attentions=encoder_outputs.attentions,
1232
+ cross_attentions=encoder_outputs.cross_attentions,
1233
+ )
1234
+
1235
+ class AdapterMLP(nn.Module):
1236
+ def __init__(self, config):
1237
+ super().__init__()
1238
+ self.config = config
1239
+ self.activation_fn = ACT2FN["silu"]
1240
+ hidden_size = config.vision_config.hidden_size
1241
+ intermediate_size = hidden_size // 4
1242
+ output_size = config.qformer_config.hidden_size
1243
+
1244
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
1245
+ self.fc2 = nn.Linear(intermediate_size, output_size)
1246
+
1247
+ # nn.init.trunc_normal_(self.fc1.weight, std=0.02)
1248
+ # nn.init.trunc_normal_(self.fc2.weight, std=0.02)
1249
+ # nn.init.constant_(self.fc1.bias, 0)
1250
+ # nn.init.constant_(self.fc2.bias, 0)
1251
+
1252
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1253
+ hidden_states = self.fc1(hidden_states)
1254
+ hidden_states = self.activation_fn(hidden_states)
1255
+ hidden_states = self.fc2(hidden_states)
1256
+ return hidden_states
1257
+
1258
+ @add_start_docstrings(
1259
+ """
1260
+ Husky Model for generating text and image features. The model consists of a vision encoder, Querying Transformer
1261
+ (Q-Former) and a language model.
1262
+ """,
1263
+ Husky_START_DOCSTRING,
1264
+ )
1265
+ class HuskyModel(HuskyPreTrainedModel):
1266
+ config_class = HuskyConfig
1267
+ main_input_name = "pixel_values"
1268
+
1269
+ def __init__(self, config: HuskyConfig):
1270
+ super().__init__(config)
1271
+
1272
+ self.vision_model = HuskyVisionModel(config.vision_config)
1273
+
1274
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
1275
+ self.qformer = HuskyQFormerModel(config.qformer_config)
1276
+
1277
+ self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
1278
+ self.language_model = AutoModelForCausalLM.from_config(config.text_config)
1279
+
1280
+ self.config.hidden_size = config.text_config.hidden_size
1281
+ self.num_queries = config.num_query_tokens
1282
+ self.offset = 5
1283
+
1284
+ # Initialize weights and apply final processing
1285
+ self.post_init()
1286
+
1287
+ def get_input_embeddings(self):
1288
+ return self.language_model.get_input_embeddings()
1289
+
1290
+ def set_input_embeddings(self, value):
1291
+ self.language_model.set_input_embeddings(value)
1292
+
1293
+ def set_output_embeddings(self, new_embeddings):
1294
+ self.language_model.set_output_embeddings(new_embeddings)
1295
+
1296
+ def get_output_embeddings(self) -> nn.Module:
1297
+ return self.language_model.get_output_embeddings()
1298
+
1299
+ def get_encoder(self):
1300
+ return self.language_model.get_encoder()
1301
+
1302
+ def get_decoder(self):
1303
+ return self.language_model.get_decoder()
1304
+
1305
+ def _tie_weights(self):
1306
+ if not self.config.use_decoder_only_language_model:
1307
+ self.language_model.encoder.embed_tokens = self.language_model.shared
1308
+ self.language_model.decoder.embed_tokens = self.language_model.shared
1309
+
1310
+ @add_start_docstrings_to_model_forward(Husky_TEXT_INPUTS_DOCSTRING)
1311
+ def get_text_features(
1312
+ self,
1313
+ input_ids: Optional[torch.Tensor] = None,
1314
+ attention_mask: Optional[torch.Tensor] = None,
1315
+ output_attentions: Optional[bool] = None,
1316
+ output_hidden_states: Optional[bool] = None,
1317
+ return_dict: Optional[bool] = None,
1318
+ ):
1319
+ r"""
1320
+ Returns:
1321
+ text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`):
1322
+ The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that
1323
+ contains the language model logits, the past key values and the hidden states if
1324
+ `output_hidden_states=True`.
1325
+ ```"""
1326
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1327
+ output_hidden_states = (
1328
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1329
+ )
1330
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1331
+
1332
+ text_outputs = self.language_model(
1333
+ input_ids=input_ids,
1334
+ attention_mask=attention_mask,
1335
+ output_attentions=output_attentions,
1336
+ output_hidden_states=output_hidden_states,
1337
+ return_dict=return_dict,
1338
+ )
1339
+
1340
+ return text_outputs
1341
+
1342
+ @add_start_docstrings_to_model_forward(Husky_VISION_INPUTS_DOCSTRING)
1343
+ def get_image_features(
1344
+ self,
1345
+ pixel_values: Optional[torch.FloatTensor] = None,
1346
+ output_attentions: Optional[bool] = None,
1347
+ output_hidden_states: Optional[bool] = None,
1348
+ return_dict: Optional[bool] = None,
1349
+ ):
1350
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1351
+ output_hidden_states = (
1352
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1353
+ )
1354
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1355
+
1356
+ vision_outputs = self.vision_model(
1357
+ pixel_values=pixel_values,
1358
+ output_attentions=output_attentions,
1359
+ output_hidden_states=output_hidden_states,
1360
+ return_dict=return_dict,
1361
+ )
1362
+
1363
+ return vision_outputs
1364
+
1365
+ @add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
1366
+ def get_qformer_features(
1367
+ self,
1368
+ pixel_values: Optional[torch.FloatTensor] = None,
1369
+ output_attentions: Optional[bool] = None,
1370
+ output_hidden_states: Optional[bool] = None,
1371
+ return_dict: Optional[bool] = None,
1372
+ ):
1373
+ r"""
1374
+ Returns:
1375
+ vision_outputs (`BaseModelOutputWithPooling` or tuple of `torch.FloatTensor`):
1376
+ The vision model outputs. If `return_dict=True`, the output is a [`BaseModelOutputWithPooling`] that
1377
+ contains the image features, the pooled image features and the hidden states if
1378
+ `output_hidden_states=True`.
1379
+ """
1380
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1381
+ output_hidden_states = (
1382
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1383
+ )
1384
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1385
+
1386
+ vision_outputs = self.vision_model(
1387
+ pixel_values=pixel_values,
1388
+ output_attentions=output_attentions,
1389
+ output_hidden_states=output_hidden_states,
1390
+ return_dict=return_dict,
1391
+ )
1392
+
1393
+ image_embeds = vision_outputs[0]
1394
+
1395
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1396
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1397
+
1398
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1399
+ query_outputs = self.qformer(
1400
+ query_embeds=query_tokens,
1401
+ encoder_hidden_states=image_embeds,
1402
+ encoder_attention_mask=image_attention_mask,
1403
+ output_attentions=output_attentions,
1404
+ output_hidden_states=output_hidden_states,
1405
+ return_dict=return_dict,
1406
+ )
1407
+
1408
+ return query_outputs
1409
+
1410
+ @add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
1411
+ # @replace_return_docstrings(output_type=HuskyForConditionalGenerationModelOutput, config_class=HuskyVisionConfig)
1412
+ def forward(
1413
+ self,
1414
+ pixel_values: torch.FloatTensor,
1415
+ input_ids: torch.FloatTensor,
1416
+ attention_mask: Optional[torch.LongTensor] = None,
1417
+ output_attentions: Optional[bool] = None,
1418
+ output_hidden_states: Optional[bool] = None,
1419
+ labels: Optional[torch.LongTensor] = None,
1420
+ return_dict: Optional[bool] = None,
1421
+ ) -> Union[Tuple, HuskyForConditionalGenerationModelOutput]:
1422
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1423
+
1424
+ # step 1: forward the images through the vision encoder,
1425
+ # to get image embeddings of shape (batch_size, seq_len, hidden_size)
1426
+ vision_outputs = self.vision_model(
1427
+ pixel_values=pixel_values,
1428
+ output_attentions=output_attentions,
1429
+ output_hidden_states=output_hidden_states,
1430
+ return_dict=return_dict,
1431
+ )
1432
+ image_embeds = vision_outputs[0]
1433
+
1434
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1435
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1436
+
1437
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1438
+ query_outputs = self.qformer(
1439
+ query_embeds=query_tokens,
1440
+ encoder_hidden_states=image_embeds,
1441
+ encoder_attention_mask=image_attention_mask,
1442
+ output_attentions=output_attentions,
1443
+ output_hidden_states=output_hidden_states,
1444
+ return_dict=return_dict,
1445
+ )
1446
+ query_output = query_outputs[0]
1447
+
1448
+ # step 3: use the language model, conditioned on the query outputs and the prompt
1449
+ language_model_inputs = self.language_projection(query_output)
1450
+ assert language_model_inputs.shape[1] == self.num_queries
1451
+
1452
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
1453
+ # Human: <img><IMAGE></img>. Give the describe Assistant:
1454
+ # position of <image>: [offset: offset+num_queries]
1455
+
1456
+ inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs
1457
+ if attention_mask is None:
1458
+ attention_mask = torch.ones_like(
1459
+ input_ids, dtype=torch.long, device=language_model_inputs.device)
1460
+
1461
+ outputs = self.language_model(
1462
+ inputs_embeds=inputs_embeds,
1463
+ attention_mask=attention_mask,
1464
+ output_attentions=output_attentions,
1465
+ output_hidden_states=output_hidden_states,
1466
+ return_dict=return_dict,
1467
+ )
1468
+ logits = outputs.logits if return_dict else outputs[0]
1469
+ loss = None
1470
+ # we compute the loss here since we need to take into account the sequence length of the query embeds
1471
+ if labels is not None:
1472
+ labels = labels.to(logits.device)
1473
+ logits = logits[:, -labels.size(1):, :]
1474
+ # Shift so that tokens < n predict n
1475
+ shift_logits = logits[..., :-1, :].contiguous()
1476
+ shift_labels = labels[..., 1:].contiguous().to(logits.device)
1477
+
1478
+ # Flatten the tokens
1479
+ loss_fct = CrossEntropyLoss(reduction="mean")
1480
+
1481
+ loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
1482
+
1483
+ if not return_dict:
1484
+ output = (logits, vision_outputs, query_outputs, outputs)
1485
+ return ((loss,) + output) if loss is not None else output
1486
+
1487
+ return HuskyForConditionalGenerationModelOutput(
1488
+ loss=loss,
1489
+ logits=logits,
1490
+ vision_outputs=vision_outputs,
1491
+ qformer_outputs=query_outputs,
1492
+ language_model_outputs=outputs,
1493
+ )
1494
+
1495
+ @add_start_docstrings(
1496
+ """
1497
+ Husky Model for generating text given an image and an optional text prompt. The model consists of a vision
1498
+ encoder, Querying Transformer (Q-Former) and a language model.
1499
+
1500
+ One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
1501
+ the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
1502
+ """,
1503
+ Husky_START_DOCSTRING,
1504
+ )
1505
+ class HuskyForConditionalGeneration(HuskyPreTrainedModel):
1506
+ config_class = HuskyConfig
1507
+ main_input_name = "pixel_values"
1508
+
1509
+ def __init__(self, config: HuskyConfig):
1510
+ super().__init__(config)
1511
+
1512
+ self.vision_model = HuskyVisionModel(config.vision_config)
1513
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
1514
+ self.qformer = HuskyQFormerModel(config.qformer_config)
1515
+
1516
+ self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
1517
+ self.language_model = AutoModelForCausalLM.from_config(config.text_config)
1518
+
1519
+ self.config.hidden_size = config.text_config.hidden_size
1520
+ self.num_queries = config.num_query_tokens
1521
+ self.offset = 5
1522
+
1523
+ self.vision_adapter = AdapterMLP(config)
1524
+ self.layer_norms = nn.ModuleList()
1525
+ for i in range(4):
1526
+ self.layer_norms.append(
1527
+ nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
1528
+ )
1529
+
1530
+ # Initialize weights and apply final processing
1531
+ self.post_init()
1532
+
1533
+ def get_input_embeddings(self):
1534
+ return self.language_model.get_input_embeddings()
1535
+
1536
+ def set_input_embeddings(self, value):
1537
+ self.language_model.set_input_embeddings(value)
1538
+
1539
+ def set_output_embeddings(self, new_embeddings):
1540
+ self.language_model.set_output_embeddings(new_embeddings)
1541
+
1542
+ def get_output_embeddings(self) -> nn.Module:
1543
+ return self.language_model.get_output_embeddings()
1544
+
1545
+ def get_encoder(self):
1546
+ return self.language_model.get_encoder()
1547
+
1548
+ def get_decoder(self):
1549
+ return self.language_model.get_decoder()
1550
+
1551
+ def extract_feature(
1552
+ self,
1553
+ pixel_values: torch.FloatTensor,
1554
+ ):
1555
+ vision_outputs = self.vision_model(
1556
+ pixel_values=pixel_values,
1557
+ output_hidden_states=True,
1558
+ )
1559
+ image_embeds = vision_outputs[0]
1560
+
1561
+ depth = len(vision_outputs[2])
1562
+ indices = range(depth // 4 - 1, depth, depth // 4)
1563
+ pooled_outputs = []
1564
+ for idx, layer_norm in zip(indices, self.layer_norms):
1565
+ pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
1566
+ pool_output = layer_norm(pool_output)
1567
+ pooled_outputs.append(pool_output)
1568
+
1569
+ pooled_outputs = torch.cat(pooled_outputs, dim=1)
1570
+ pooled_outputs = self.vision_adapter(pooled_outputs)
1571
+
1572
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1573
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1574
+
1575
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1576
+ query_outputs = self.qformer(
1577
+ query_embeds=query_tokens,
1578
+ encoder_hidden_states=image_embeds,
1579
+ encoder_attention_mask=image_attention_mask
1580
+ )
1581
+ query_output = query_outputs[0]
1582
+ query_output = torch.cat([query_output, pooled_outputs], dim=1)
1583
+ language_model_inputs = self.language_projection(query_output)
1584
+
1585
+ return language_model_inputs
1586
+
1587
+ def _tie_weights(self):
1588
+ if not self.config.use_decoder_only_language_model:
1589
+ self.language_model.encoder.embed_tokens = self.language_model.shared
1590
+ self.language_model.decoder.embed_tokens = self.language_model.shared
1591
+
1592
+ def _preprocess_accelerate(self):
1593
+ r"""
1594
+ Some pre-processing hacks to make the model `accelerate` compatible. Check
1595
+ https://github.com/huggingface/transformers/pull/21707 for more details.
1596
+ """
1597
+ hf_device_map = self.hf_device_map
1598
+
1599
+ if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
1600
+ # warn users about unexpected behavior when using multi-GPU + Husky + `accelerate`.
1601
+ logger.warning(
1602
+ "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
1603
+ " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
1604
+ " Please pass a `device_map` that contains `language_model` to remove this warning."
1605
+ " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
1606
+ " more details on creating a `device_map` for large models.",
1607
+ )
1608
+
1609
+ if hasattr(self.language_model, "_hf_hook"):
1610
+ self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
1611
+
1612
+ @add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
1613
+ # @replace_return_docstrings(output_type=HuskyForConditionalGenerationModelOutput, config_class=HuskyVisionConfig)
1614
+ def forward(
1615
+ self,
1616
+ pixel_values: torch.FloatTensor,
1617
+ input_ids: torch.FloatTensor,
1618
+ attention_mask: Optional[torch.LongTensor] = None,
1619
+ output_attentions: Optional[bool] = None,
1620
+ output_hidden_states: Optional[bool] = None,
1621
+ labels: Optional[torch.LongTensor] = None,
1622
+ return_dict: Optional[bool] = None,
1623
+ ) -> Union[Tuple, HuskyForConditionalGenerationModelOutput]:
1624
+
1625
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1626
+
1627
+ # step 1: forward the images through the vision encoder,
1628
+ # to get image embeddings of shape (batch_size, seq_len, hidden_size)
1629
+ batch_size = input_ids.shape[0]
1630
+ vision_outputs = self.vision_model(
1631
+ pixel_values=pixel_values,
1632
+ output_attentions=output_attentions,
1633
+ output_hidden_states=True,
1634
+ return_dict=return_dict,
1635
+ )
1636
+ image_embeds = vision_outputs[0]
1637
+
1638
+ depth = len(vision_outputs[2])
1639
+ indices = range(depth // 4 - 1, depth, depth // 4)
1640
+ pooled_outputs = []
1641
+ for idx, layer_norm in zip(indices, self.layer_norms):
1642
+ pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
1643
+ pool_output = layer_norm(pool_output)
1644
+ pooled_outputs.append(pool_output)
1645
+
1646
+ pooled_outputs = torch.cat(pooled_outputs, dim=1)
1647
+ pooled_outputs = self.vision_adapter(pooled_outputs)
1648
+
1649
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1650
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1651
+
1652
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1653
+ query_outputs = self.qformer(
1654
+ query_embeds=query_tokens,
1655
+ encoder_hidden_states=image_embeds,
1656
+ encoder_attention_mask=image_attention_mask,
1657
+ output_attentions=output_attentions,
1658
+ output_hidden_states=output_hidden_states,
1659
+ return_dict=return_dict,
1660
+ )
1661
+ query_output = query_outputs[0]
1662
+ query_output = torch.cat([query_output, pooled_outputs], dim=1)
1663
+
1664
+ # step 3: use the language model, conditioned on the query outputs and the prompt
1665
+ language_model_inputs = self.language_projection(query_output)
1666
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
1667
+ # Human: <img></img>. Give the describe Assistant:
1668
+ # position of <image>: [offset: offset+num_queries]
1669
+
1670
+ # inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs
1671
+ prefix_embeds = inputs_embeds[:, :self.offset, :]
1672
+ postfix_embeds = inputs_embeds[:, self.offset:, :]
1673
+ inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)
1674
+
1675
+ if attention_mask is None:
1676
+ attention_mask = torch.ones_like(
1677
+ inputs_embeds, dtype=torch.long, device=language_model_inputs.device)
1678
+ else:
1679
+ prefix_mask = attention_mask[:, :self.offset]
1680
+ postfix_mask = attention_mask[:, self.offset:]
1681
+ vision_mask = torch.ones(size=(batch_size, self.num_queries + 4), dtype=torch.long,
1682
+ device=attention_mask.device)
1683
+ attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)
1684
+
1685
+ outputs = self.language_model(
1686
+ inputs_embeds=inputs_embeds,
1687
+ attention_mask=attention_mask,
1688
+ output_attentions=output_attentions,
1689
+ output_hidden_states=output_hidden_states,
1690
+ return_dict=return_dict,
1691
+ )
1692
+ logits = outputs.logits if return_dict else outputs[0]
1693
+ loss = None
1694
+ # we compute the loss here since we need to take into account the sequence length of the query embeds
1695
+ if labels is not None:
1696
+ labels = labels.to(logits.device)
1697
+ logits = logits[:, -labels.size(1):, :]
1698
+ # Shift so that tokens < n predict n
1699
+ shift_logits = logits[..., :-1, :].contiguous()
1700
+ shift_labels = labels[..., 1:].contiguous().to(logits.device)
1701
+
1702
+ # Flatten the tokens
1703
+ loss_fct = CrossEntropyLoss(reduction="mean")
1704
+
1705
+ loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
1706
+
1707
+ if not return_dict:
1708
+ output = (logits, vision_outputs, query_outputs, outputs)
1709
+ return ((loss,) + output) if loss is not None else output
1710
+
1711
+ return HuskyForConditionalGenerationModelOutput(
1712
+ loss=loss,
1713
+ logits=logits,
1714
+ vision_outputs=vision_outputs,
1715
+ qformer_outputs=query_outputs,
1716
+ language_model_outputs=outputs,
1717
+ )
1718
+
1719
+ @torch.no_grad()
1720
+ def generate(
1721
+ self,
1722
+ pixel_values: Optional[torch.FloatTensor] = None,
1723
+ input_ids: Optional[torch.LongTensor] = None,
1724
+ attention_mask: Optional[torch.LongTensor] = None,
1725
+ language_model_inputs: Optional[torch.FloatTensor] = None,
1726
+ generation_config: Optional[GenerationConfig] = None,
1727
+ **generate_kwargs,
1728
+ ) -> torch.LongTensor:
1729
+ """
1730
+ Overrides `generate` function to be able to use the model as a conditional generator.
1731
+
1732
+ Args:
1733
+ pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
1734
+ Input images to be processed.
1735
+ input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
1736
+ The sequence used as a prompt for the generation.
1737
+ attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
1738
+ Mask to avoid performing attention on padding token indices
1739
+ language_model_inputs (`torch.LongTensor` of shape (batch_size, sequence_length, num_channel), *optional*):
1740
+ The sequence used as the input for the generation
1741
+ language_model_inputs (`torch.LongTensor` of shape (batch_size, sequence_length, num_channel), *optional*):
1742
+ The sequence used as the input for the generation
1743
+ generation_config (`~generation.GenerationConfig`, *optional*):
1744
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
1745
+ passed to generate matching the attributes of `generation_config` will override them. If
1746
+ `generation_config` is not provided, the default will be used, which had the following loading
1747
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
1748
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
1749
+ default values, whose documentation should be checked to parameterize generation.
1750
+
1751
+ Returns:
1752
+ captions (list): A list of strings of length batch_size * num_captions.
1753
+ """
1754
+ if hasattr(self, "hf_device_map"):
1755
+ # preprocess for `accelerate`
1756
+ self._preprocess_accelerate()
1757
+
1758
+ if language_model_inputs is None:
1759
+ vision_outputs = self.vision_model(
1760
+ pixel_values=pixel_values,
1761
+ output_hidden_states=True,
1762
+ )
1763
+ image_embeds = vision_outputs[0]
1764
+
1765
+ depth = len(vision_outputs[2])
1766
+ indices = range(depth // 4 - 1, depth, depth // 4)
1767
+ pooled_outputs = []
1768
+ for idx, layer_norm in zip(indices, self.layer_norms):
1769
+ pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
1770
+ pool_output = layer_norm(pool_output)
1771
+ pooled_outputs.append(pool_output)
1772
+
1773
+ pooled_outputs = torch.cat(pooled_outputs, dim=1)
1774
+ pooled_outputs = self.vision_adapter(pooled_outputs)
1775
+
1776
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1777
+
1778
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1779
+ query_outputs = self.qformer(
1780
+ query_embeds=query_tokens,
1781
+ encoder_hidden_states=image_embeds,
1782
+ encoder_attention_mask=image_attention_mask,
1783
+ )
1784
+ query_output = query_outputs[0]
1785
+ query_output = torch.cat([query_output, pooled_outputs], dim=1)
1786
+
1787
+ language_model_inputs = self.language_projection(query_output)
1788
+
1789
+ batch_size = language_model_inputs.shape[0]
1790
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
1791
+
1792
+ prefix_embeds = inputs_embeds[:, :self.offset, :]
1793
+ postfix_embeds = inputs_embeds[:, self.offset:, :]
1794
+ inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)
1795
+
1796
+ if input_ids is None:
1797
+ input_ids = (
1798
+ torch.LongTensor([[self.config.text_config.bos_token_id]])
1799
+ .repeat(batch_size, 1)
1800
+ .to(inputs_embeds.device)
1801
+ )
1802
+
1803
+ if attention_mask is None:
1804
+ attention_mask = torch.ones_like(
1805
+ input_ids, dtype=torch.long, device=language_model_inputs.device)
1806
+ else:
1807
+ prefix_mask = attention_mask[:, :self.offset]
1808
+ postfix_mask = attention_mask[:, self.offset:]
1809
+ vision_mask = torch.ones(size=(batch_size, self.num_queries + 4), dtype=torch.long,
1810
+ device=attention_mask.device)
1811
+ attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)
1812
+
1813
+ outputs = self.language_model.generate(
1814
+ inputs_embeds=inputs_embeds,
1815
+ attention_mask=attention_mask,
1816
+ generation_config=generation_config,
1817
+ **generate_kwargs,
1818
+ )
1819
+
1820
+ return outputs
robohusky/model/modeling_husky_embody2.py ADDED
@@ -0,0 +1,1962 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Husky model."""
16
+
17
+ import contextlib
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Any, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutput,
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ BaseModelOutputWithPooling,
32
+ BaseModelOutputWithPoolingAndCrossAttentions,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
36
+ from transformers.utils import (
37
+ ModelOutput,
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ logging,
41
+ replace_return_docstrings,
42
+ is_flash_attn_available
43
+ )
44
+ from transformers import AutoModelForCausalLM, GenerationConfig
45
+
46
+ from .configuration_husky import HuskyConfig, HuskyQFormerConfig, HuskyVisionConfig
47
+
48
+ if is_flash_attn_available():
49
+ from flash_attn import flash_attn_func
50
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
51
+
52
+ try:
53
+ from apex.normalization import FusedLayerNorm as LayerNorm
54
+ except ImportError:
55
+ from torch.nn import LayerNorm as LayerNorm
56
+
57
+ logger = logging.get_logger(__name__)
58
+
59
+ _CHECKPOINT_FOR_DOC = "wofmanaf/husky-7b"
60
+
61
+ HUSKY_PRETRAINED_MODEL_ARCHIVE_LIST = [
62
+ "wofmanaf/husky-7b",
63
+ ]
64
+
65
+ @dataclass
66
+ class HuskyForConditionalGenerationModelOutput(ModelOutput):
67
+ """
68
+ Class defining the outputs of [`HuskyForConditionalGeneration`].
69
+
70
+ Args:
71
+ loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
72
+ Language modeling loss from the language model.
73
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
74
+ Prediction scores of the language modeling head of the language model.
75
+ vision_outputs (`BaseModelOutputWithPooling`):
76
+ Outputs of the vision encoder.
77
+ qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
78
+ Outputs of the Q-Former (Querying Transformer).
79
+ language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
80
+ Outputs of the language model.
81
+ """
82
+
83
+ loss: Optional[Tuple[torch.FloatTensor]] = None
84
+ logits: Optional[Tuple[torch.FloatTensor]] = None
85
+ vision_outputs: Optional[torch.FloatTensor] = None
86
+ qformer_outputs: Optional[Tuple[torch.FloatTensor]] = None
87
+ language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None
88
+
89
+ def to_tuple(self) -> Tuple[Any]:
90
+ return tuple(
91
+ self[k]
92
+ if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
93
+ else getattr(self, k).to_tuple()
94
+ for k in self.keys()
95
+ )
96
+
97
+ # Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Husky
98
+ class HuskyVisionEmbeddings(nn.Module):
99
+ def __init__(self, config: HuskyVisionConfig):
100
+ super().__init__()
101
+ self.config = config
102
+ self.embed_dim = config.hidden_size
103
+ self.image_size = config.image_size
104
+ self.patch_size = config.patch_size
105
+
106
+ self.class_embedding = nn.Parameter(
107
+ torch.randn(1, 1, self.embed_dim),
108
+ )
109
+
110
+ self.patch_embedding = nn.Conv2d(
111
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
112
+ )
113
+
114
+ self.num_patches = (self.image_size // self.patch_size) ** 2
115
+ self.num_positions = self.num_patches + 1
116
+
117
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
118
+
119
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
120
+ batch_size = pixel_values.shape[0]
121
+ target_dtype = self.patch_embedding.weight.dtype
122
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
123
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
124
+
125
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
126
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
127
+ embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
128
+ return embeddings
129
+
130
+ class HuskyVideoEmbeddings(nn.Module):
131
+ def __init__(self, config: HuskyVisionConfig):
132
+ super().__init__()
133
+ self.config = config
134
+ self.embed_dim = config.hidden_size
135
+ self.image_size = config.image_size
136
+ self.patch_size = config.patch_size
137
+ self.num_frames = getattr(self.config, "num_frames", 8)
138
+ self.frame_stride = getattr(self.config, "frame_stride", 2)
139
+
140
+ self.class_embedding = nn.Parameter(
141
+ torch.randn(1, 1, self.embed_dim),
142
+ )
143
+
144
+ self.patch_embedding = nn.Conv3d(
145
+ in_channels=3, out_channels=self.embed_dim,
146
+ kernel_size=(self.frame_stride, self.patch_size, self.patch_size),
147
+ stride=(self.frame_stride, self.patch_size, self.patch_size)
148
+ )
149
+
150
+ self.num_patches = int(self.num_frames // self.frame_stride) * (self.image_size // self.patch_size) ** 2
151
+ self.num_positions = self.num_patches + 1
152
+
153
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
154
+
155
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
156
+ batch_size = pixel_values.shape[0]
157
+ target_dtype = self.patch_embedding.weight.dtype
158
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
159
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
160
+
161
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
162
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
163
+ embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
164
+ return embeddings
165
+
166
+ class HuskyAttention(nn.Module):
167
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
168
+
169
+ def __init__(self, config):
170
+ super().__init__()
171
+ self.config = config
172
+ self.embed_dim = config.hidden_size
173
+ self.num_heads = config.num_attention_heads
174
+ self.head_dim = self.embed_dim // self.num_heads
175
+ if self.head_dim * self.num_heads != self.embed_dim:
176
+ raise ValueError(
177
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
178
+ f" {self.num_heads})."
179
+ )
180
+ self.scale = self.head_dim ** -0.5
181
+ self.dropout = nn.Dropout(config.attention_dropout)
182
+
183
+ # small tweak here compared to CLIP, no bias here
184
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
185
+
186
+ if config.qkv_bias:
187
+ q_bias = nn.Parameter(torch.zeros(self.embed_dim))
188
+ v_bias = nn.Parameter(torch.zeros(self.embed_dim))
189
+ else:
190
+ q_bias = None
191
+ v_bias = None
192
+
193
+ if q_bias is not None:
194
+ qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
195
+ self.qkv.bias = nn.Parameter(qkv_bias)
196
+
197
+ self.projection = nn.Linear(self.embed_dim, self.embed_dim)
198
+
199
+ def _shape(self, tensor: torch.Tensor, bsz: int, seq_len: int):
200
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
201
+
202
+ def forward(
203
+ self,
204
+ hidden_states: torch.Tensor,
205
+ head_mask: Optional[torch.Tensor] = None,
206
+ output_attentions: Optional[bool] = False,
207
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
208
+ """Input shape: Batch x Time x Channel"""
209
+
210
+ bsz, tgt_len, embed_dim = hidden_states.size()
211
+
212
+ mixed_qkv = self.qkv(hidden_states)
213
+
214
+ mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
215
+ 2, 0, 3, 1, 4
216
+ )
217
+ query_states, key_states, value_states = (
218
+ mixed_qkv[0],
219
+ mixed_qkv[1],
220
+ mixed_qkv[2],
221
+ )
222
+
223
+ # Take the dot product between "query" and "key" to get the raw attention scores.
224
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
225
+
226
+ attention_scores = attention_scores * self.scale
227
+
228
+ # Normalize the attention scores to probabilities.
229
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
230
+
231
+ # This is actually dropping out entire tokens to attend to, which might
232
+ # seem a bit unusual, but is taken from the original Transformer paper.
233
+ attention_probs = self.dropout(attention_probs)
234
+
235
+ # Mask heads if we want to
236
+ if head_mask is not None:
237
+ attention_probs = attention_probs * head_mask
238
+
239
+ context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
240
+
241
+ new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
242
+ context_layer = context_layer.reshape(new_context_layer_shape)
243
+
244
+ output = self.projection(context_layer)
245
+
246
+ outputs = (output, attention_probs) if output_attentions else (output, None)
247
+
248
+ return outputs
249
+
250
+ class HuskyFlashAttention2(HuskyAttention):
251
+ """
252
+ Husky flash attention module. This module inherits from `HuskyAttention` as the weights of the module stays
253
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
254
+ flash attention and deal with padding tokens in case the input contains any of them.
255
+ """
256
+
257
+ def forward(
258
+ self,
259
+ hidden_states: torch.Tensor,
260
+ head_mask: Optional[torch.Tensor] = None,
261
+ output_attentions: Optional[bool] = False,
262
+ ) -> tuple[Any, None]:
263
+ # HuskyFlashAttention2 does not support output_attentions
264
+ assert output_attentions is False
265
+
266
+ bsz, tgt_len, embed_dim = hidden_states.size()
267
+
268
+ mixed_qkv = self.qkv(hidden_states)
269
+
270
+ # Flash attention requires the input to have the shape batch_size x seq_len x num_heads x head_dim
271
+ # therefore we just need to keep the original shape
272
+ mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
273
+ 2, 0, 1, 3, 4
274
+ )
275
+
276
+ query_states, key_states, value_states = (
277
+ mixed_qkv[0],
278
+ mixed_qkv[1],
279
+ mixed_qkv[2],
280
+ )
281
+
282
+ input_dtype = query_states.dtype
283
+ if input_dtype == torch.float32:
284
+ # Handle the case where the model is quantized
285
+ if hasattr(self.config, "_pre_quantization_dtype"):
286
+ target_dtype = self.config._pre_quantization_dtype
287
+ else:
288
+ target_dtype = self.qkv.weight.dtype
289
+
290
+ query_states = query_states.to(target_dtype)
291
+ key_states = key_states.to(target_dtype)
292
+ value_states = value_states.to(target_dtype)
293
+
294
+ attn_output = flash_attn_func(
295
+ query_states, key_states, value_states
296
+ )
297
+
298
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim).contiguous()
299
+ output = self.projection(attn_output)
300
+
301
+ outputs = (output, None)
302
+ return outputs
303
+
304
+ class HuskyMLP(nn.Module):
305
+ def __init__(self, config):
306
+ super().__init__()
307
+ self.config = config
308
+ self.activation_fn = ACT2FN[config.hidden_act]
309
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
310
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
311
+
312
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
313
+ hidden_states = self.fc1(hidden_states)
314
+ hidden_states = self.activation_fn(hidden_states)
315
+ hidden_states = self.fc2(hidden_states)
316
+ return hidden_states
317
+
318
+ # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Husky
319
+ class HuskyEncoderLayer(nn.Module):
320
+ def __init__(self, config: HuskyConfig):
321
+ super().__init__()
322
+ self.embed_dim = config.hidden_size
323
+ self.self_attn = (
324
+ HuskyAttention(config=config)
325
+ if not getattr(config, "_flash_attn_2_enabled", False)
326
+ else HuskyFlashAttention2(config=config)
327
+ )
328
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
329
+ self.mlp = HuskyMLP(config)
330
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
331
+
332
+ def forward(
333
+ self,
334
+ hidden_states: torch.Tensor,
335
+ attention_mask: torch.Tensor,
336
+ output_attentions: Optional[bool] = False,
337
+ ) -> Tuple[torch.FloatTensor]:
338
+ """
339
+ Args:
340
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
341
+ attention_mask (`torch.FloatTensor`): attention mask of size
342
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
343
+ `(config.encoder_attention_heads,)`.
344
+ output_attentions (`bool`, *optional*):
345
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
346
+ returned tensors for more detail.
347
+ """
348
+ residual = hidden_states
349
+
350
+ hidden_states = self.layer_norm1(hidden_states)
351
+ hidden_states, attn_weights = self.self_attn(
352
+ hidden_states=hidden_states,
353
+ head_mask=attention_mask,
354
+ output_attentions=output_attentions,
355
+ )
356
+ hidden_states = hidden_states + residual
357
+ residual = hidden_states
358
+ hidden_states = self.layer_norm2(hidden_states)
359
+ hidden_states = self.mlp(hidden_states)
360
+
361
+ hidden_states = hidden_states + residual
362
+
363
+ outputs = (hidden_states,)
364
+
365
+ if output_attentions:
366
+ outputs += (attn_weights,)
367
+
368
+ return outputs
369
+
370
+ class HuskyPreTrainedModel(PreTrainedModel):
371
+ """
372
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
373
+ models.
374
+ """
375
+
376
+ config_class = HuskyConfig
377
+ base_model_prefix = "husky"
378
+ supports_gradient_checkpointing = True
379
+ _keys_to_ignore_on_load_missing = [
380
+ r"position_ids",
381
+ r"language_model.encoder.embed_tokens.weight",
382
+ r"language_model.decoder.embed_tokens.weight",
383
+ r"language_model.lm_head.weight",
384
+ ]
385
+ _no_split_modules = [
386
+ "HuskyAttention",
387
+ "HuskyFlashAttention2",
388
+ "LlamaDecoderLayer",
389
+ ]
390
+ _skip_keys_device_placement = "past_key_values"
391
+ _supports_flash_attn_2 = True
392
+ _keep_in_fp32_modules = ["wo"]
393
+
394
+ def _init_weights(self, module):
395
+ """Initialize the weights"""
396
+ factor = self.config.initializer_range
397
+ if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
398
+ module.weight.data.normal_(mean=0.0, std=factor)
399
+ if hasattr(module, "bias") and module.bias is not None:
400
+ module.bias.data.zero_()
401
+
402
+ if isinstance(module, HuskyVisionEmbeddings):
403
+ if hasattr(self.config, "vision_config"):
404
+ factor = self.config.vision_config.initializer_range
405
+ nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
406
+ nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
407
+
408
+ elif isinstance(module, nn.LayerNorm):
409
+ module.bias.data.zero_()
410
+ module.weight.data.fill_(1.0)
411
+ elif isinstance(module, nn.Linear) and module.bias is not None:
412
+ module.bias.data.zero_()
413
+
414
+ def _set_gradient_checkpointing(self, module, value=False):
415
+ if isinstance(module, HuskyEncoder):
416
+ module.gradient_checkpointing = value
417
+
418
+ Husky_START_DOCSTRING = r"""
419
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
420
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
421
+ etc.)
422
+
423
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
424
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
425
+ and behavior.
426
+
427
+ Parameters:
428
+ config ([`HuskyConfig`]): Model configuration class with all the parameters of the model.
429
+ Initializing with a config file does not load the weights associated with the model, only the
430
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
431
+ """
432
+
433
+ Husky_VISION_INPUTS_DOCSTRING = r"""
434
+ Args:
435
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
436
+ Pixel values. Pixel values can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for
437
+ details.
438
+ output_attentions (`bool`, *optional*):
439
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
440
+ tensors for more detail.
441
+ output_hidden_states (`bool`, *optional*):
442
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
443
+ more detail.
444
+ return_dict (`bool`, *optional*):
445
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
446
+ """
447
+
448
+ Husky_TEXT_INPUTS_DOCSTRING = r"""
449
+ Args:
450
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
451
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
452
+ it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
453
+ [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
454
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
455
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
456
+ - 1 for tokens that are **not masked**,
457
+ - 0 for tokens that are **masked**.
458
+ [What are attention masks?](../glossary#attention-mask)
459
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
460
+ Indices of decoder input sequence tokens in the vocabulary.
461
+
462
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
463
+ [`PreTrainedTokenizer.__call__`] for details.
464
+
465
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
466
+
467
+ T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
468
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
469
+
470
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
471
+ Training](./t5#training).
472
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
473
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
474
+ be used by default.
475
+ output_attentions (`bool`, *optional*):
476
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
477
+ tensors for more detail.
478
+ output_hidden_states (`bool`, *optional*):
479
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
480
+ more detail.
481
+ return_dict (`bool`, *optional*):
482
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
483
+ """
484
+
485
+ Husky_INPUTS_DOCSTRING = r"""
486
+ Args:
487
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
488
+ Pixel values. Pixel values can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for
489
+ details.
490
+
491
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
492
+ Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
493
+ provided to serve as text prompt, which the language model can continue.
494
+
495
+ Indices can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for details.
496
+
497
+ [What are input IDs?](../glossary#input-ids)
498
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
499
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
500
+
501
+ - 1 for tokens that are **not masked**,
502
+ - 0 for tokens that are **masked**.
503
+
504
+ [What are attention masks?](../glossary#attention-mask)
505
+
506
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
507
+ Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
508
+ encoder-decoder language model (like T5) is used.
509
+
510
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
511
+ [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
512
+
513
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
514
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
515
+ be used by default.
516
+
517
+ Only relevant in case an encoder-decoder language model (like T5) is used.
518
+
519
+ output_attentions (`bool`, *optional*):
520
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
521
+ tensors for more detail.
522
+ output_hidden_states (`bool`, *optional*):
523
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
524
+ more detail.
525
+ return_dict (`bool`, *optional*):
526
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
527
+ """
528
+
529
+ # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Husky
530
+ class HuskyEncoder(nn.Module):
531
+ """
532
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
533
+ [`HuskyEncoderLayer`].
534
+
535
+ Args:
536
+ config (`HuskyConfig`):
537
+ The corresponding vision configuration for the `HuskyEncoder`.
538
+ """
539
+
540
+ def __init__(self, config: HuskyConfig):
541
+ super().__init__()
542
+ self.config = config
543
+ self.layers = nn.ModuleList([HuskyEncoderLayer(config) for _ in range(config.num_hidden_layers)])
544
+ self.gradient_checkpointing = False
545
+
546
+ def forward(
547
+ self,
548
+ inputs_embeds,
549
+ attention_mask: Optional[torch.Tensor] = None,
550
+ output_attentions: Optional[bool] = None,
551
+ output_hidden_states: Optional[bool] = None,
552
+ return_dict: Optional[bool] = None,
553
+ ) -> Union[Tuple, BaseModelOutput]:
554
+ r"""
555
+ Args:
556
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
557
+ Embedded representation of the inputs. Should be float, not int tokens.
558
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
559
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
560
+
561
+ - 1 for tokens that are **not masked**,
562
+ - 0 for tokens that are **masked**.
563
+
564
+ [What are attention masks?](../glossary#attention-mask)
565
+ output_attentions (`bool`, *optional*):
566
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
567
+ returned tensors for more detail.
568
+ output_hidden_states (`bool`, *optional*):
569
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
570
+ for more detail.
571
+ return_dict (`bool`, *optional*):
572
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
573
+ """
574
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
575
+ output_hidden_states = (
576
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
577
+ )
578
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
579
+
580
+ encoder_states = () if output_hidden_states else None
581
+ all_attentions = () if output_attentions else None
582
+
583
+ hidden_states = inputs_embeds
584
+ for idx, encoder_layer in enumerate(self.layers):
585
+ if output_hidden_states:
586
+ encoder_states = encoder_states + (hidden_states,)
587
+ if self.gradient_checkpointing and self.training:
588
+
589
+ def create_custom_forward(module):
590
+ def custom_forward(*inputs):
591
+ return module(*inputs, output_attentions)
592
+
593
+ return custom_forward
594
+
595
+ layer_outputs = torch.utils.checkpoint.checkpoint(
596
+ create_custom_forward(encoder_layer),
597
+ hidden_states,
598
+ attention_mask,
599
+ )
600
+ else:
601
+ layer_outputs = encoder_layer(
602
+ hidden_states,
603
+ attention_mask,
604
+ output_attentions=output_attentions,
605
+ )
606
+
607
+ hidden_states = layer_outputs[0]
608
+
609
+ if output_attentions:
610
+ all_attentions = all_attentions + (layer_outputs[1],)
611
+
612
+ if output_hidden_states:
613
+ encoder_states = encoder_states + (hidden_states,)
614
+
615
+ if not return_dict:
616
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
617
+ return BaseModelOutput(
618
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
619
+ )
620
+
621
+ # Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Husky, BLIP->Husky
622
+ class HuskyVisionModel(HuskyPreTrainedModel):
623
+ main_input_name = "pixel_values"
624
+ config_class = HuskyVisionConfig
625
+
626
+ def __init__(self, config: HuskyVisionConfig):
627
+ super().__init__(config)
628
+ self.config = config
629
+ embed_dim = config.hidden_size
630
+
631
+ self.embeddings = HuskyVisionEmbeddings(config)
632
+ self.video_embeddings = HuskyVideoEmbeddings(config)
633
+
634
+ self.encoder = HuskyEncoder(config)
635
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
636
+
637
+ self.post_init()
638
+
639
+ @add_start_docstrings_to_model_forward(Husky_VISION_INPUTS_DOCSTRING)
640
+ # @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=HuskyVisionConfig)
641
+ def forward(
642
+ self,
643
+ pixel_values: Optional[torch.FloatTensor] = None,
644
+ output_attentions: Optional[bool] = None,
645
+ output_hidden_states: Optional[bool] = None,
646
+ return_dict: Optional[bool] = None,
647
+ pixel_embeds: Optional[torch.FloatTensor] = None,
648
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
649
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
650
+ output_hidden_states = (
651
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
652
+ )
653
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
654
+
655
+ if pixel_values is None and pixel_embeds is None:
656
+ raise ValueError("You have to specify pixel_values or pixel_embeds")
657
+
658
+ if pixel_embeds is not None:
659
+ hidden_states = pixel_embeds
660
+ else:
661
+ if len(pixel_values.shape) == 4:
662
+ hidden_states = self.embeddings(pixel_values)
663
+ elif len(pixel_values.shape) == 5:
664
+ hidden_states = self.video_embeddings(pixel_values)
665
+ else:
666
+ raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
667
+
668
+ encoder_outputs = self.encoder(
669
+ inputs_embeds=hidden_states,
670
+ output_attentions=output_attentions,
671
+ output_hidden_states=output_hidden_states,
672
+ return_dict=return_dict,
673
+ )
674
+
675
+ last_hidden_state = encoder_outputs[0]
676
+ last_hidden_state = self.post_layernorm(last_hidden_state)
677
+
678
+ pooled_output = last_hidden_state[:, 0, :]
679
+ pooled_output = self.post_layernorm(pooled_output)
680
+
681
+ if not return_dict:
682
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
683
+
684
+ return BaseModelOutputWithPooling(
685
+ last_hidden_state=last_hidden_state,
686
+ pooler_output=pooled_output,
687
+ hidden_states=encoder_outputs.hidden_states,
688
+ attentions=encoder_outputs.attentions,
689
+ )
690
+
691
+ def get_input_embeddings(self):
692
+ return self.embeddings
693
+
694
+ def get_video_embeddings(self):
695
+ return self.video_embeddings
696
+
697
+ class HuskyQFormerMultiHeadAttention(nn.Module):
698
+ def __init__(self, config, is_cross_attention=False):
699
+ super().__init__()
700
+ self.config = config
701
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
702
+ raise ValueError(
703
+ "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
704
+ % (config.hidden_size, config.num_attention_heads)
705
+ )
706
+
707
+ self.num_attention_heads = config.num_attention_heads
708
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
709
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
710
+
711
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
712
+ if is_cross_attention:
713
+ self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
714
+ self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
715
+ else:
716
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
717
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
718
+
719
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
720
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
721
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
722
+ self.max_position_embeddings = config.max_position_embeddings
723
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
724
+ self.save_attention = False
725
+
726
+ def save_attn_gradients(self, attn_gradients):
727
+ self.attn_gradients = attn_gradients
728
+
729
+ def get_attn_gradients(self):
730
+ return self.attn_gradients
731
+
732
+ def save_attention_map(self, attention_map):
733
+ self.attention_map = attention_map
734
+
735
+ def get_attention_map(self):
736
+ return self.attention_map
737
+
738
+ def transpose_for_scores(self, x):
739
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
740
+ x = x.view(*new_x_shape)
741
+ return x.permute(0, 2, 1, 3)
742
+
743
+ def forward(
744
+ self,
745
+ hidden_states,
746
+ attention_mask=None,
747
+ head_mask=None,
748
+ encoder_hidden_states=None,
749
+ encoder_attention_mask=None,
750
+ past_key_value=None,
751
+ output_attentions=False,
752
+ ):
753
+ # If this is instantiated as a cross-attention module, the keys
754
+ # and values come from an encoder; the attention mask needs to be
755
+ # such that the encoder's padding tokens are not attended to.
756
+ is_cross_attention = encoder_hidden_states is not None
757
+
758
+ if is_cross_attention:
759
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
760
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
761
+ attention_mask = encoder_attention_mask
762
+ elif past_key_value is not None:
763
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
764
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
765
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
766
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
767
+ else:
768
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
769
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
770
+
771
+ mixed_query_layer = self.query(hidden_states)
772
+
773
+ query_layer = self.transpose_for_scores(mixed_query_layer)
774
+
775
+ past_key_value = (key_layer, value_layer)
776
+
777
+ # Take the dot product between "query" and "key" to get the raw attention scores.
778
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
779
+
780
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
781
+ seq_length = hidden_states.size()[1]
782
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
783
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
784
+ distance = position_ids_l - position_ids_r
785
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
786
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
787
+
788
+ if self.position_embedding_type == "relative_key":
789
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
790
+ attention_scores = attention_scores + relative_position_scores
791
+ elif self.position_embedding_type == "relative_key_query":
792
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
793
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
794
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
795
+
796
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
797
+
798
+ if attention_mask is not None:
799
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
800
+ attention_scores = attention_scores + attention_mask
801
+
802
+ # Normalize the attention scores to probabilities.
803
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
804
+
805
+ if is_cross_attention and self.save_attention:
806
+ self.save_attention_map(attention_probs)
807
+ attention_probs.register_hook(self.save_attn_gradients)
808
+
809
+ # This is actually dropping out entire tokens to attend to, which might
810
+ # seem a bit unusual, but is taken from the original Transformer paper.
811
+ attention_probs_dropped = self.dropout(attention_probs)
812
+
813
+ # Mask heads if we want to
814
+ if head_mask is not None:
815
+ attention_probs_dropped = attention_probs_dropped * head_mask
816
+
817
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
818
+
819
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
820
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
821
+ context_layer = context_layer.view(*new_context_layer_shape)
822
+
823
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
824
+
825
+ outputs = outputs + (past_key_value,)
826
+ return outputs
827
+
828
+ class HuskyQFormerFlashAttention2(HuskyQFormerMultiHeadAttention):
829
+ def forward(
830
+ self,
831
+ hidden_states,
832
+ attention_mask=None,
833
+ head_mask=None,
834
+ encoder_hidden_states=None,
835
+ encoder_attention_mask=None,
836
+ past_key_value=None,
837
+ output_attentions=False,
838
+ ):
839
+ # If this is instantiated as a cross-attention module, the keys
840
+ # and values come from an encoder; the attention mask needs to be
841
+ # such that the encoder's padding tokens are not attended to.
842
+
843
+ bsz, tgt_len, embed_dim = hidden_states.size()
844
+ is_cross_attention = encoder_hidden_states is not None
845
+
846
+ if is_cross_attention:
847
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
848
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
849
+ attention_mask = encoder_attention_mask
850
+ elif past_key_value is not None:
851
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
852
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
853
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
854
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
855
+ else:
856
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
857
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
858
+
859
+ mixed_query_layer = self.query(hidden_states)
860
+
861
+ query_layer = self.transpose_for_scores(mixed_query_layer)
862
+
863
+ past_key_value = (key_layer, value_layer)
864
+
865
+ # original key shape: [batch_size, num_heads, seq_len, head_dim]
866
+ # flash_attn size: [batch_size, seq_len, num_heads, head_dim]
867
+
868
+ query_layer = query_layer.transpose(1, 2)
869
+ key_layer = key_layer.transpose(1, 2)
870
+ value_layer = value_layer.transpose(1, 2)
871
+
872
+ dropout_rate = self.dropout if self.training else 0
873
+ input_dtype = query_layer.dtype
874
+ if input_dtype == torch.float32:
875
+ if torch.is_autocast_enabled():
876
+ target_dtype = torch.get_autocast_gpu_dtype()
877
+ else:
878
+ target_dtype = self.query.weight.dtype
879
+
880
+ query_layer = query_layer.to(target_dtype)
881
+ key_layer = key_layer.to(target_dtype)
882
+ value_layer = value_layer.to(target_dtype)
883
+
884
+ attn_output = flash_attn_func(
885
+ query_layer, key_layer, value_layer, causal=False
886
+ )
887
+
888
+ context_layer = attn_output.reshape(bsz, tgt_len, self.embed_size).contiguous()
889
+ outputs = (context_layer,)
890
+
891
+ outputs = outputs + (past_key_value,)
892
+ return outputs
893
+
894
+ class HuskyQFormerSelfOutput(nn.Module):
895
+ def __init__(self, config):
896
+ super().__init__()
897
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
898
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
899
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
900
+
901
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
902
+ hidden_states = self.dense(hidden_states)
903
+ hidden_states = self.dropout(hidden_states)
904
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
905
+ return hidden_states
906
+
907
+ class HuskyQFormerAttention(nn.Module):
908
+ def __init__(self, config, is_cross_attention=False):
909
+ super().__init__()
910
+ self.attention = (
911
+ HuskyQFormerMultiHeadAttention(config, is_cross_attention)
912
+ if not getattr(config, "_flash_attn_2_enabled", False)
913
+ else HuskyQFormerFlashAttention2(config, is_cross_attention)
914
+ )
915
+ self.output = HuskyQFormerSelfOutput(config)
916
+ self.pruned_heads = set()
917
+
918
+ def prune_heads(self, heads):
919
+ if len(heads) == 0:
920
+ return
921
+ heads, index = find_pruneable_heads_and_indices(
922
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
923
+ )
924
+
925
+ # Prune linear layers
926
+ self.attention.query = prune_linear_layer(self.attention.query, index)
927
+ self.attention.key = prune_linear_layer(self.attention.key, index)
928
+ self.attention.value = prune_linear_layer(self.attention.value, index)
929
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
930
+
931
+ # Update hyper params and store pruned heads
932
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
933
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
934
+ self.pruned_heads = self.pruned_heads.union(heads)
935
+
936
+ def forward(
937
+ self,
938
+ hidden_states: torch.Tensor,
939
+ attention_mask: Optional[torch.FloatTensor] = None,
940
+ head_mask: Optional[torch.FloatTensor] = None,
941
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
942
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
943
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
944
+ output_attentions: Optional[bool] = False,
945
+ ) -> Tuple[torch.Tensor]:
946
+ self_outputs = self.attention(
947
+ hidden_states,
948
+ attention_mask,
949
+ head_mask,
950
+ encoder_hidden_states,
951
+ encoder_attention_mask,
952
+ past_key_value,
953
+ output_attentions,
954
+ )
955
+ attention_output = self.output(self_outputs[0], hidden_states)
956
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
957
+ return outputs
958
+
959
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->HuskyQFormer
960
+ class HuskyQFormerIntermediate(nn.Module):
961
+ def __init__(self, config):
962
+ super().__init__()
963
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
964
+ if isinstance(config.hidden_act, str):
965
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
966
+ else:
967
+ self.intermediate_act_fn = config.hidden_act
968
+
969
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
970
+ hidden_states = self.dense(hidden_states)
971
+ hidden_states = self.intermediate_act_fn(hidden_states)
972
+ return hidden_states
973
+
974
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->HuskyQFormer
975
+ class HuskyQFormerOutput(nn.Module):
976
+ def __init__(self, config):
977
+ super().__init__()
978
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
979
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
980
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
981
+
982
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
983
+ hidden_states = self.dense(hidden_states)
984
+ hidden_states = self.dropout(hidden_states)
985
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
986
+ return hidden_states
987
+
988
+ class HuskyQFormerLayer(nn.Module):
989
+ def __init__(self, config, layer_idx):
990
+ super().__init__()
991
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
992
+ self.seq_len_dim = 1
993
+ self.attention = HuskyQFormerAttention(config)
994
+
995
+ self.layer_idx = layer_idx
996
+
997
+ if layer_idx % config.cross_attention_frequency == 0:
998
+ self.crossattention = HuskyQFormerAttention(config, is_cross_attention=True)
999
+ self.has_cross_attention = True
1000
+ else:
1001
+ self.has_cross_attention = False
1002
+
1003
+ self.intermediate_query = HuskyQFormerIntermediate(config)
1004
+ self.output_query = HuskyQFormerOutput(config)
1005
+
1006
+ def forward(
1007
+ self,
1008
+ hidden_states,
1009
+ attention_mask=None,
1010
+ head_mask=None,
1011
+ encoder_hidden_states=None,
1012
+ encoder_attention_mask=None,
1013
+ past_key_value=None,
1014
+ output_attentions=False,
1015
+ query_length=0,
1016
+ ):
1017
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
1018
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
1019
+ self_attention_outputs = self.attention(
1020
+ hidden_states,
1021
+ attention_mask,
1022
+ head_mask,
1023
+ output_attentions=output_attentions,
1024
+ past_key_value=self_attn_past_key_value,
1025
+ )
1026
+ attention_output = self_attention_outputs[0]
1027
+ outputs = self_attention_outputs[1:-1]
1028
+
1029
+ present_key_value = self_attention_outputs[-1]
1030
+
1031
+ if query_length > 0:
1032
+ query_attention_output = attention_output[:, :query_length, :]
1033
+
1034
+ if self.has_cross_attention:
1035
+ if encoder_hidden_states is None:
1036
+ raise ValueError("encoder_hidden_states must be given for cross-attention layers")
1037
+ cross_attention_outputs = self.crossattention(
1038
+ query_attention_output,
1039
+ attention_mask,
1040
+ head_mask,
1041
+ encoder_hidden_states,
1042
+ encoder_attention_mask,
1043
+ output_attentions=output_attentions,
1044
+ )
1045
+ query_attention_output = cross_attention_outputs[0]
1046
+ # add cross attentions if we output attention weights
1047
+ outputs = outputs + cross_attention_outputs[1:-1]
1048
+
1049
+ layer_output = apply_chunking_to_forward(
1050
+ self.feed_forward_chunk_query,
1051
+ self.chunk_size_feed_forward,
1052
+ self.seq_len_dim,
1053
+ query_attention_output,
1054
+ )
1055
+
1056
+ if attention_output.shape[1] > query_length:
1057
+ layer_output_text = apply_chunking_to_forward(
1058
+ self.feed_forward_chunk,
1059
+ self.chunk_size_feed_forward,
1060
+ self.seq_len_dim,
1061
+ attention_output[:, query_length:, :],
1062
+ )
1063
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
1064
+ else:
1065
+ layer_output = apply_chunking_to_forward(
1066
+ self.feed_forward_chunk,
1067
+ self.chunk_size_feed_forward,
1068
+ self.seq_len_dim,
1069
+ attention_output,
1070
+ )
1071
+ outputs = (layer_output,) + outputs
1072
+
1073
+ outputs = outputs + (present_key_value,)
1074
+
1075
+ return outputs
1076
+
1077
+ def feed_forward_chunk(self, attention_output):
1078
+ intermediate_output = self.intermediate(attention_output)
1079
+ layer_output = self.output(intermediate_output, attention_output)
1080
+ return layer_output
1081
+
1082
+ def feed_forward_chunk_query(self, attention_output):
1083
+ intermediate_output = self.intermediate_query(attention_output)
1084
+ layer_output = self.output_query(intermediate_output, attention_output)
1085
+ return layer_output
1086
+
1087
+ class HuskyQFormerEncoder(nn.Module):
1088
+ def __init__(self, config):
1089
+ super().__init__()
1090
+ self.config = config
1091
+ self.layer = nn.ModuleList(
1092
+ [HuskyQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
1093
+ )
1094
+ self.gradient_checkpointing = False
1095
+
1096
+ def forward(
1097
+ self,
1098
+ hidden_states,
1099
+ attention_mask=None,
1100
+ head_mask=None,
1101
+ encoder_hidden_states=None,
1102
+ encoder_attention_mask=None,
1103
+ past_key_values=None,
1104
+ use_cache=None,
1105
+ output_attentions=False,
1106
+ output_hidden_states=False,
1107
+ return_dict=True,
1108
+ query_length=0,
1109
+ ):
1110
+ all_hidden_states = () if output_hidden_states else None
1111
+ all_self_attentions = () if output_attentions else None
1112
+ all_cross_attentions = () if output_attentions else None
1113
+
1114
+ next_decoder_cache = () if use_cache else None
1115
+
1116
+ for i in range(self.config.num_hidden_layers):
1117
+ layer_module = self.layer[i]
1118
+ if output_hidden_states:
1119
+ all_hidden_states = all_hidden_states + (hidden_states,)
1120
+
1121
+ layer_head_mask = head_mask[i] if head_mask is not None else None
1122
+ past_key_value = past_key_values[i] if past_key_values is not None else None
1123
+
1124
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
1125
+ if use_cache:
1126
+ logger.warn(
1127
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1128
+ )
1129
+ use_cache = False
1130
+
1131
+ def create_custom_forward(module):
1132
+ def custom_forward(*inputs):
1133
+ return module(*inputs, past_key_value, output_attentions, query_length)
1134
+
1135
+ return custom_forward
1136
+
1137
+ layer_outputs = torch.utils.checkpoint.checkpoint(
1138
+ create_custom_forward(layer_module),
1139
+ hidden_states,
1140
+ attention_mask,
1141
+ layer_head_mask,
1142
+ encoder_hidden_states,
1143
+ encoder_attention_mask,
1144
+ )
1145
+ else:
1146
+ layer_outputs = layer_module(
1147
+ hidden_states,
1148
+ attention_mask,
1149
+ layer_head_mask,
1150
+ encoder_hidden_states,
1151
+ encoder_attention_mask,
1152
+ past_key_value,
1153
+ output_attentions,
1154
+ query_length,
1155
+ )
1156
+
1157
+ hidden_states = layer_outputs[0]
1158
+ if use_cache:
1159
+ next_decoder_cache += (layer_outputs[-1],)
1160
+ if output_attentions:
1161
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
1162
+ if layer_module.has_cross_attention:
1163
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
1164
+
1165
+ if output_hidden_states:
1166
+ all_hidden_states = all_hidden_states + (hidden_states,)
1167
+
1168
+ if not return_dict:
1169
+ return tuple(
1170
+ v
1171
+ for v in [
1172
+ hidden_states,
1173
+ next_decoder_cache,
1174
+ all_hidden_states,
1175
+ all_self_attentions,
1176
+ all_cross_attentions,
1177
+ ]
1178
+ if v is not None
1179
+ )
1180
+ return BaseModelOutputWithPastAndCrossAttentions(
1181
+ last_hidden_state=hidden_states,
1182
+ past_key_values=next_decoder_cache,
1183
+ hidden_states=all_hidden_states,
1184
+ attentions=all_self_attentions,
1185
+ cross_attentions=all_cross_attentions,
1186
+ )
1187
+
1188
+ class HuskyQFormerModel(HuskyPreTrainedModel):
1189
+ """
1190
+ Querying Transformer (Q-Former), used in Husky.
1191
+ """
1192
+
1193
+ def __init__(self, config: HuskyQFormerConfig):
1194
+ super().__init__(config)
1195
+ self.config = config
1196
+
1197
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1198
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1199
+
1200
+ self.encoder = HuskyQFormerEncoder(config)
1201
+
1202
+ self.post_init()
1203
+
1204
+ def get_input_embeddings(self):
1205
+ return self.embeddings.word_embeddings
1206
+
1207
+ def set_input_embeddings(self, value):
1208
+ self.embeddings.word_embeddings = value
1209
+
1210
+ def _prune_heads(self, heads_to_prune):
1211
+ """
1212
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1213
+ class PreTrainedModel
1214
+ """
1215
+ for layer, heads in heads_to_prune.items():
1216
+ self.encoder.layer[layer].attention.prune_heads(heads)
1217
+
1218
+ def get_extended_attention_mask(
1219
+ self,
1220
+ attention_mask: torch.Tensor,
1221
+ input_shape: Tuple[int],
1222
+ device: torch.device,
1223
+ has_query: bool = False,
1224
+ ) -> torch.Tensor:
1225
+ """
1226
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
1227
+
1228
+ Arguments:
1229
+ attention_mask (`torch.Tensor`):
1230
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
1231
+ input_shape (`Tuple[int]`):
1232
+ The shape of the input to the model.
1233
+ device (`torch.device`):
1234
+ The device of the input to the model.
1235
+
1236
+ Returns:
1237
+ `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
1238
+ """
1239
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1240
+ # ourselves in which case we just need to make it broadcastable to all heads.
1241
+ if attention_mask.dim() == 3:
1242
+ extended_attention_mask = attention_mask[:, None, :, :]
1243
+ elif attention_mask.dim() == 2:
1244
+ # Provided a padding mask of dimensions [batch_size, seq_length]
1245
+ # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
1246
+ extended_attention_mask = attention_mask[:, None, None, :]
1247
+ else:
1248
+ raise ValueError(
1249
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
1250
+ input_shape, attention_mask.shape
1251
+ )
1252
+ )
1253
+
1254
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1255
+ # masked positions, this operation will create a tensor which is 0.0 for
1256
+ # positions we want to attend and -10000.0 for masked positions.
1257
+ # Since we are adding it to the raw scores before the softmax, this is
1258
+ # effectively the same as removing these entirely.
1259
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
1260
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
1261
+ return extended_attention_mask
1262
+
1263
+ def forward(
1264
+ self,
1265
+ query_embeds,
1266
+ attention_mask=None,
1267
+ head_mask=None,
1268
+ encoder_hidden_states=None,
1269
+ encoder_attention_mask=None,
1270
+ past_key_values=None,
1271
+ use_cache=None,
1272
+ output_attentions=None,
1273
+ output_hidden_states=None,
1274
+ return_dict=None,
1275
+ ):
1276
+ r"""
1277
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
1278
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1279
+ the model is configured as a decoder.
1280
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
1281
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1282
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1283
+ - 1 for tokens that are **not masked**,
1284
+ - 0 for tokens that are **masked**.
1285
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
1286
+ shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
1287
+ value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
1288
+ used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
1289
+ value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
1290
+ `(batch_size, sequence_length)`.
1291
+ use_cache (`bool`, `optional`):
1292
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1293
+ `past_key_values`).
1294
+ """
1295
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1296
+ output_hidden_states = (
1297
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1298
+ )
1299
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1300
+
1301
+ # past_key_values_length
1302
+ past_key_values_length = (
1303
+ past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
1304
+ )
1305
+
1306
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
1307
+
1308
+ embedding_output = self.layernorm(query_embeds)
1309
+ embedding_output = self.dropout(embedding_output)
1310
+
1311
+ input_shape = embedding_output.size()[:-1]
1312
+ batch_size, seq_length = input_shape
1313
+ device = embedding_output.device
1314
+
1315
+ if attention_mask is None:
1316
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
1317
+
1318
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1319
+ # ourselves in which case we just need to make it broadcastable to all heads.
1320
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
1321
+
1322
+ # If a 2D or 3D attention mask is provided for the cross-attention
1323
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1324
+ if encoder_hidden_states is not None:
1325
+ if type(encoder_hidden_states) == list:
1326
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
1327
+ else:
1328
+ (
1329
+ encoder_batch_size,
1330
+ encoder_sequence_length,
1331
+ _,
1332
+ ) = encoder_hidden_states.size()
1333
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1334
+
1335
+ if type(encoder_attention_mask) == list:
1336
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
1337
+ elif encoder_attention_mask is None:
1338
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1339
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1340
+ else:
1341
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1342
+ else:
1343
+ encoder_extended_attention_mask = None
1344
+
1345
+ # Prepare head mask if needed
1346
+ # 1.0 in head_mask indicate we keep the head
1347
+ # attention_probs has shape bsz x n_heads x N x N
1348
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1349
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1350
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1351
+
1352
+ encoder_outputs = self.encoder(
1353
+ embedding_output,
1354
+ attention_mask=extended_attention_mask,
1355
+ head_mask=head_mask,
1356
+ encoder_hidden_states=encoder_hidden_states,
1357
+ encoder_attention_mask=encoder_extended_attention_mask,
1358
+ past_key_values=past_key_values,
1359
+ use_cache=use_cache,
1360
+ output_attentions=output_attentions,
1361
+ output_hidden_states=output_hidden_states,
1362
+ return_dict=return_dict,
1363
+ query_length=query_length,
1364
+ )
1365
+ sequence_output = encoder_outputs[0]
1366
+ pooled_output = sequence_output[:, 0, :]
1367
+
1368
+ if not return_dict:
1369
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1370
+
1371
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1372
+ last_hidden_state=sequence_output,
1373
+ pooler_output=pooled_output,
1374
+ past_key_values=encoder_outputs.past_key_values,
1375
+ hidden_states=encoder_outputs.hidden_states,
1376
+ attentions=encoder_outputs.attentions,
1377
+ cross_attentions=encoder_outputs.cross_attentions,
1378
+ )
1379
+
1380
+ class AdapterMLP(nn.Module):
1381
+ def __init__(self, config):
1382
+ super().__init__()
1383
+ self.config = config
1384
+ self.activation_fn = ACT2FN["silu"]
1385
+ hidden_size = config.vision_config.hidden_size
1386
+ intermediate_size = hidden_size // 4
1387
+ output_size = config.qformer_config.hidden_size
1388
+
1389
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
1390
+ self.fc2 = nn.Linear(intermediate_size, output_size)
1391
+ self.layernorm = nn.LayerNorm(output_size, eps=config.vision_config.layer_norm_eps)
1392
+
1393
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1394
+ hidden_states = self.fc1(hidden_states)
1395
+ hidden_states = self.activation_fn(hidden_states)
1396
+ hidden_states = self.fc2(hidden_states)
1397
+ hidden_states = self.layernorm(hidden_states)
1398
+ return hidden_states
1399
+
1400
+ @add_start_docstrings(
1401
+ """
1402
+ Husky Model for generating text and image features. The model consists of a vision encoder, Querying Transformer
1403
+ (Q-Former) and a language model.
1404
+ """,
1405
+ Husky_START_DOCSTRING,
1406
+ )
1407
+ class HuskyModel(HuskyPreTrainedModel):
1408
+ config_class = HuskyConfig
1409
+ main_input_name = "pixel_values"
1410
+
1411
+ def __init__(self, config: HuskyConfig):
1412
+ super().__init__(config)
1413
+
1414
+ self.vision_model = HuskyVisionModel(config.vision_config)
1415
+
1416
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
1417
+ self.qformer = HuskyQFormerModel(config.qformer_config)
1418
+
1419
+ self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
1420
+ self.language_model = AutoModelForCausalLM.from_config(config.text_config)
1421
+
1422
+ self.config.hidden_size = config.text_config.hidden_size
1423
+ self.num_queries = config.num_query_tokens
1424
+ self.offset = 5
1425
+
1426
+ # Initialize weights and apply final processing
1427
+ self.post_init()
1428
+
1429
+ def get_input_embeddings(self):
1430
+ return self.language_model.get_input_embeddings()
1431
+
1432
+ def set_input_embeddings(self, value):
1433
+ self.language_model.set_input_embeddings(value)
1434
+
1435
+ def set_output_embeddings(self, new_embeddings):
1436
+ self.language_model.set_output_embeddings(new_embeddings)
1437
+
1438
+ def get_output_embeddings(self) -> nn.Module:
1439
+ return self.language_model.get_output_embeddings()
1440
+
1441
+ def get_encoder(self):
1442
+ return self.language_model.get_encoder()
1443
+
1444
+ def get_decoder(self):
1445
+ return self.language_model.get_decoder()
1446
+
1447
+ def _tie_weights(self):
1448
+ if not self.config.use_decoder_only_language_model:
1449
+ self.language_model.encoder.embed_tokens = self.language_model.shared
1450
+ self.language_model.decoder.embed_tokens = self.language_model.shared
1451
+
1452
+ @add_start_docstrings_to_model_forward(Husky_TEXT_INPUTS_DOCSTRING)
1453
+ def get_text_features(
1454
+ self,
1455
+ input_ids: Optional[torch.Tensor] = None,
1456
+ attention_mask: Optional[torch.Tensor] = None,
1457
+ output_attentions: Optional[bool] = None,
1458
+ output_hidden_states: Optional[bool] = None,
1459
+ return_dict: Optional[bool] = None,
1460
+ ):
1461
+ r"""
1462
+ Returns:
1463
+ text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`):
1464
+ The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that
1465
+ contains the language model logits, the past key values and the hidden states if
1466
+ `output_hidden_states=True`.
1467
+ ```"""
1468
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1469
+ output_hidden_states = (
1470
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1471
+ )
1472
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1473
+
1474
+ text_outputs = self.language_model(
1475
+ input_ids=input_ids,
1476
+ attention_mask=attention_mask,
1477
+ output_attentions=output_attentions,
1478
+ output_hidden_states=output_hidden_states,
1479
+ return_dict=return_dict,
1480
+ )
1481
+
1482
+ return text_outputs
1483
+
1484
+ @add_start_docstrings_to_model_forward(Husky_VISION_INPUTS_DOCSTRING)
1485
+ def get_image_features(
1486
+ self,
1487
+ pixel_values: Optional[torch.FloatTensor] = None,
1488
+ output_attentions: Optional[bool] = None,
1489
+ output_hidden_states: Optional[bool] = None,
1490
+ return_dict: Optional[bool] = None,
1491
+ ):
1492
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1493
+ output_hidden_states = (
1494
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1495
+ )
1496
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1497
+
1498
+ vision_outputs = self.vision_model(
1499
+ pixel_values=pixel_values,
1500
+ output_attentions=output_attentions,
1501
+ output_hidden_states=output_hidden_states,
1502
+ return_dict=return_dict,
1503
+ )
1504
+
1505
+ return vision_outputs
1506
+
1507
+ @add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
1508
+ def get_qformer_features(
1509
+ self,
1510
+ pixel_values: Optional[torch.FloatTensor] = None,
1511
+ output_attentions: Optional[bool] = None,
1512
+ output_hidden_states: Optional[bool] = None,
1513
+ return_dict: Optional[bool] = None,
1514
+ ):
1515
+ r"""
1516
+ Returns:
1517
+ vision_outputs (`BaseModelOutputWithPooling` or tuple of `torch.FloatTensor`):
1518
+ The vision model outputs. If `return_dict=True`, the output is a [`BaseModelOutputWithPooling`] that
1519
+ contains the image features, the pooled image features and the hidden states if
1520
+ `output_hidden_states=True`.
1521
+ """
1522
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1523
+ output_hidden_states = (
1524
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1525
+ )
1526
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1527
+
1528
+ vision_outputs = self.vision_model(
1529
+ pixel_values=pixel_values,
1530
+ output_attentions=output_attentions,
1531
+ output_hidden_states=output_hidden_states,
1532
+ return_dict=return_dict,
1533
+ )
1534
+
1535
+ image_embeds = vision_outputs[0]
1536
+
1537
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1538
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1539
+
1540
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1541
+ query_outputs = self.qformer(
1542
+ query_embeds=query_tokens,
1543
+ encoder_hidden_states=image_embeds,
1544
+ encoder_attention_mask=image_attention_mask,
1545
+ output_attentions=output_attentions,
1546
+ output_hidden_states=output_hidden_states,
1547
+ return_dict=return_dict,
1548
+ )
1549
+
1550
+ return query_outputs
1551
+
1552
+ @add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
1553
+ # @replace_return_docstrings(output_type=HuskyForConditionalGenerationModelOutput, config_class=HuskyVisionConfig)
1554
+ def forward(
1555
+ self,
1556
+ pixel_values: torch.FloatTensor,
1557
+ input_ids: torch.FloatTensor,
1558
+ attention_mask: Optional[torch.LongTensor] = None,
1559
+ output_attentions: Optional[bool] = None,
1560
+ output_hidden_states: Optional[bool] = None,
1561
+ labels: Optional[torch.LongTensor] = None,
1562
+ return_dict: Optional[bool] = None,
1563
+ ) -> Union[Tuple, HuskyForConditionalGenerationModelOutput]:
1564
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1565
+
1566
+ # step 1: forward the images through the vision encoder,
1567
+ # to get image embeddings of shape (batch_size, seq_len, hidden_size)
1568
+ vision_outputs = self.vision_model(
1569
+ pixel_values=pixel_values,
1570
+ output_attentions=output_attentions,
1571
+ output_hidden_states=output_hidden_states,
1572
+ return_dict=return_dict,
1573
+ )
1574
+ image_embeds = vision_outputs[0]
1575
+
1576
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1577
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1578
+
1579
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1580
+ query_outputs = self.qformer(
1581
+ query_embeds=query_tokens,
1582
+ encoder_hidden_states=image_embeds,
1583
+ encoder_attention_mask=image_attention_mask,
1584
+ output_attentions=output_attentions,
1585
+ output_hidden_states=output_hidden_states,
1586
+ return_dict=return_dict,
1587
+ )
1588
+ query_output = query_outputs[0]
1589
+
1590
+ # step 3: use the language model, conditioned on the query outputs and the prompt
1591
+ language_model_inputs = self.language_projection(query_output)
1592
+ assert language_model_inputs.shape[1] == self.num_queries
1593
+
1594
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
1595
+ # Human: <img><IMAGE></img>. Give the describe Assistant:
1596
+ # position of <image>: [offset: offset+num_queries]
1597
+
1598
+ inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs
1599
+ if attention_mask is None:
1600
+ attention_mask = torch.ones_like(
1601
+ input_ids, dtype=torch.long, device=language_model_inputs.device)
1602
+
1603
+ outputs = self.language_model(
1604
+ inputs_embeds=inputs_embeds,
1605
+ attention_mask=attention_mask,
1606
+ output_attentions=output_attentions,
1607
+ output_hidden_states=output_hidden_states,
1608
+ return_dict=return_dict,
1609
+ )
1610
+ logits = outputs.logits if return_dict else outputs[0]
1611
+ loss = None
1612
+ # we compute the loss here since we need to take into account the sequence length of the query embeds
1613
+ if labels is not None:
1614
+ labels = labels.to(logits.device)
1615
+ logits = logits[:, -labels.size(1):, :]
1616
+ # Shift so that tokens < n predict n
1617
+ shift_logits = logits[..., :-1, :].contiguous()
1618
+ shift_labels = labels[..., 1:].contiguous().to(logits.device)
1619
+
1620
+ # Flatten the tokens
1621
+ loss_fct = CrossEntropyLoss(reduction="mean")
1622
+
1623
+ loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
1624
+
1625
+ if not return_dict:
1626
+ output = (logits, vision_outputs, query_outputs, outputs)
1627
+ return ((loss,) + output) if loss is not None else output
1628
+
1629
+ return HuskyForConditionalGenerationModelOutput(
1630
+ loss=loss,
1631
+ logits=logits,
1632
+ vision_outputs=vision_outputs,
1633
+ qformer_outputs=query_outputs,
1634
+ language_model_outputs=outputs,
1635
+ )
1636
+
1637
+ @add_start_docstrings(
1638
+ """
1639
+ Husky Model for generating text given an image and an optional text prompt. The model consists of a vision
1640
+ encoder, Querying Transformer (Q-Former) and a language model.
1641
+
1642
+ One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
1643
+ the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
1644
+ """,
1645
+ Husky_START_DOCSTRING,
1646
+ )
1647
+ class HuskyForConditionalGeneration(HuskyPreTrainedModel):
1648
+ config_class = HuskyConfig
1649
+ main_input_name = "pixel_values"
1650
+
1651
+ def __init__(self, config: HuskyConfig):
1652
+ super().__init__(config)
1653
+
1654
+ self.vision_model = HuskyVisionModel(config.vision_config)
1655
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
1656
+ self.qformer = HuskyQFormerModel(config.qformer_config)
1657
+
1658
+ self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
1659
+ self.language_model = AutoModelForCausalLM.from_config(config.text_config)
1660
+
1661
+ self.config.hidden_size = config.text_config.hidden_size
1662
+ self.num_queries = config.num_query_tokens
1663
+ self.offset = 5
1664
+
1665
+ self.vision_adapter = AdapterMLP(config)
1666
+ self.layer_norms = nn.ModuleList()
1667
+ for i in range(4):
1668
+ self.layer_norms.append(
1669
+ nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
1670
+ )
1671
+
1672
+ # Initialize weights and apply final processing
1673
+ self.post_init()
1674
+
1675
+ def get_input_embeddings(self):
1676
+ return self.language_model.get_input_embeddings()
1677
+
1678
+ def set_input_embeddings(self, value):
1679
+ self.language_model.set_input_embeddings(value)
1680
+
1681
+ def set_output_embeddings(self, new_embeddings):
1682
+ self.language_model.set_output_embeddings(new_embeddings)
1683
+
1684
+ def get_output_embeddings(self) -> nn.Module:
1685
+ return self.language_model.get_output_embeddings()
1686
+
1687
+ def get_encoder(self):
1688
+ return self.language_model.get_encoder()
1689
+
1690
+ def get_decoder(self):
1691
+ return self.language_model.get_decoder()
1692
+
1693
+ def extract_feature(
1694
+ self,
1695
+ pixel_values: torch.FloatTensor,
1696
+ ):
1697
+ vision_outputs = self.vision_model(
1698
+ pixel_values=pixel_values,
1699
+ output_hidden_states=True,
1700
+ )
1701
+ image_embeds = vision_outputs[0]
1702
+
1703
+ depth = len(vision_outputs[2])
1704
+ indices = range(depth // 4 - 1, depth, depth // 4)
1705
+ pooled_outputs = []
1706
+ for idx, layer_norm in zip(indices, self.layer_norms):
1707
+ pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
1708
+ pool_output = layer_norm(pool_output)
1709
+ pooled_outputs.append(pool_output)
1710
+
1711
+ pooled_outputs = torch.cat(pooled_outputs, dim=1)
1712
+ pooled_outputs = self.vision_adapter(pooled_outputs)
1713
+
1714
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1715
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1716
+
1717
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1718
+ query_outputs = self.qformer(
1719
+ query_embeds=query_tokens,
1720
+ encoder_hidden_states=image_embeds,
1721
+ encoder_attention_mask=image_attention_mask
1722
+ )
1723
+ query_output = query_outputs[0]
1724
+ # soft_prompting
1725
+ query_output = torch.cat([pooled_outputs, query_output], dim=1)
1726
+ language_model_inputs = self.language_projection(query_output)
1727
+
1728
+ return language_model_inputs
1729
+
1730
+ def _tie_weights(self):
1731
+ if not self.config.use_decoder_only_language_model:
1732
+ self.language_model.encoder.embed_tokens = self.language_model.shared
1733
+ self.language_model.decoder.embed_tokens = self.language_model.shared
1734
+
1735
+ def _preprocess_accelerate(self):
1736
+ r"""
1737
+ Some pre-processing hacks to make the model `accelerate` compatible. Check
1738
+ https://github.com/huggingface/transformers/pull/21707 for more details.
1739
+ """
1740
+ hf_device_map = self.hf_device_map
1741
+
1742
+ if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
1743
+ # warn users about unexpected behavior when using multi-GPU + Husky + `accelerate`.
1744
+ logger.warning(
1745
+ "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
1746
+ " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
1747
+ " Please pass a `device_map` that contains `language_model` to remove this warning."
1748
+ " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
1749
+ " more details on creating a `device_map` for large models.",
1750
+ )
1751
+
1752
+ if hasattr(self.language_model, "_hf_hook"):
1753
+ self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
1754
+
1755
+ @add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
1756
+ # @replace_return_docstrings(output_type=HuskyForConditionalGenerationModelOutput, config_class=HuskyVisionConfig)
1757
+ def forward(
1758
+ self,
1759
+ pixel_values: Optional[torch.FloatTensor] = None,
1760
+ input_ids: Optional[torch.FloatTensor] = None,
1761
+ attention_mask: Optional[torch.LongTensor] = None,
1762
+ output_attentions: Optional[bool] = None,
1763
+ output_hidden_states: Optional[bool] = None,
1764
+ labels: Optional[torch.LongTensor] = None,
1765
+ return_dict: Optional[bool] = None,
1766
+ pixel_embeds: Optional[torch.FloatTensor] = None,
1767
+ ) -> Union[Tuple, HuskyForConditionalGenerationModelOutput]:
1768
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1769
+
1770
+ # step 1: forward the images through the vision encoder,
1771
+ # to get image embeddings of shape (batch_size, seq_len, hidden_size)
1772
+ batch_size = input_ids.shape[0]
1773
+ vision_outputs = self.vision_model(
1774
+ pixel_values=pixel_values,
1775
+ output_attentions=output_attentions,
1776
+ output_hidden_states=True,
1777
+ return_dict=return_dict,
1778
+ pixel_embeds=pixel_embeds,
1779
+ )
1780
+ image_embeds = vision_outputs[0]
1781
+ depth = len(vision_outputs[2])
1782
+ indices = range(depth // 4 - 1, depth, depth // 4)
1783
+ pooled_outputs = []
1784
+ for idx, layer_norm in zip(indices, self.layer_norms):
1785
+ pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
1786
+ pool_output = layer_norm(pool_output)
1787
+ pooled_outputs.append(pool_output)
1788
+
1789
+ pooled_outputs = torch.cat(pooled_outputs, dim=1)
1790
+ pooled_outputs = self.vision_adapter(pooled_outputs)
1791
+
1792
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1793
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1794
+
1795
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1796
+ query_outputs = self.qformer(
1797
+ query_embeds=query_tokens,
1798
+ encoder_hidden_states=image_embeds,
1799
+ encoder_attention_mask=image_attention_mask,
1800
+ output_attentions=output_attentions,
1801
+ output_hidden_states=output_hidden_states,
1802
+ return_dict=return_dict,
1803
+ )
1804
+ query_output = query_outputs[0]
1805
+ query_output = torch.cat([pooled_outputs, query_output], dim=1) # 36 token
1806
+
1807
+ # step 3: use the language model, conditioned on the query outputs and the prompt
1808
+ language_model_inputs = self.language_projection(query_output)
1809
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
1810
+
1811
+ # Human: <img></img>. Give the describe Assistant:
1812
+ # position of <image>: [offset: offset+num_queries]
1813
+ prefix_embeds = inputs_embeds[:, :self.offset, :]
1814
+ postfix_embeds = inputs_embeds[:, self.offset:, :]
1815
+ inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)
1816
+ if attention_mask is None:
1817
+ attention_mask = torch.ones_like(
1818
+ inputs_embeds, dtype=torch.long, device=language_model_inputs.device)
1819
+ else:
1820
+ prefix_mask = attention_mask[:, :self.offset]
1821
+ postfix_mask = attention_mask[:, self.offset:]
1822
+ vision_mask = torch.ones(size=(batch_size, self.num_queries + 4), dtype=torch.long,
1823
+ device=attention_mask.device)
1824
+ attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)
1825
+
1826
+ outputs = self.language_model(
1827
+ inputs_embeds=inputs_embeds,
1828
+ attention_mask=attention_mask,
1829
+ output_attentions=output_attentions,
1830
+ output_hidden_states=output_hidden_states,
1831
+ return_dict=return_dict,
1832
+ )
1833
+ logits = outputs.logits if return_dict else outputs[0]
1834
+ loss = None
1835
+ # we compute the loss here since we need to take into account the sequence length of the query embeds
1836
+ if labels is not None:
1837
+ labels = labels.to(logits.device)
1838
+ logits = logits[:, -labels.size(1):, :]
1839
+ # Shift so that tokens < n predict n
1840
+ shift_logits = logits[..., :-1, :].contiguous()
1841
+ shift_labels = labels[..., 1:].contiguous().to(logits.device)
1842
+
1843
+ # Flatten the tokens
1844
+ loss_fct = CrossEntropyLoss(reduction="mean")
1845
+
1846
+ loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
1847
+
1848
+ if not return_dict:
1849
+ output = (logits, vision_outputs, query_outputs, outputs)
1850
+ return ((loss,) + output) if loss is not None else output
1851
+
1852
+ return HuskyForConditionalGenerationModelOutput(
1853
+ loss=loss,
1854
+ logits=logits,
1855
+ vision_outputs=vision_outputs,
1856
+ qformer_outputs=query_outputs,
1857
+ language_model_outputs=outputs,
1858
+ )
1859
+
1860
+ @torch.no_grad()
1861
+ def generate(
1862
+ self,
1863
+ pixel_values: Optional[torch.FloatTensor] = None,
1864
+ input_ids: Optional[torch.LongTensor] = None,
1865
+ attention_mask: Optional[torch.LongTensor] = None,
1866
+ language_model_inputs: Optional[torch.FloatTensor] = None,
1867
+ generation_config: Optional[GenerationConfig] = None,
1868
+ **generate_kwargs,
1869
+ ) -> torch.LongTensor:
1870
+ """
1871
+ Overrides `generate` function to be able to use the model as a conditional generator.
1872
+
1873
+ Args:
1874
+ pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
1875
+ Input images to be processed.
1876
+ input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
1877
+ The sequence used as a prompt for the generation.
1878
+ attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
1879
+ Mask to avoid performing attention on padding token indices
1880
+ language_model_inputs (`torch.LongTensor` of shape (batch_size, sequence_length, num_channel), *optional*):
1881
+ The sequence used as the input for the generation
1882
+ language_model_inputs (`torch.LongTensor` of shape (batch_size, sequence_length, num_channel), *optional*):
1883
+ The sequence used as the input for the generation
1884
+ generation_config (`~generation.GenerationConfig`, *optional*):
1885
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
1886
+ passed to generate matching the attributes of `generation_config` will override them. If
1887
+ `generation_config` is not provided, the default will be used, which had the following loading
1888
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
1889
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
1890
+ default values, whose documentation should be checked to parameterize generation.
1891
+
1892
+ Returns:
1893
+ captions (list): A list of strings of length batch_size * num_captions.
1894
+ """
1895
+
1896
+ if hasattr(self, "hf_device_map"):
1897
+ # preprocess for `accelerate`
1898
+ self._preprocess_accelerate()
1899
+
1900
+ if language_model_inputs is None:
1901
+ vision_outputs = self.vision_model(
1902
+ pixel_values=pixel_values,
1903
+ output_hidden_states=True,
1904
+ )
1905
+ image_embeds = vision_outputs[0]
1906
+
1907
+ depth = len(vision_outputs[2])
1908
+ indices = range(depth // 4 - 1, depth, depth // 4)
1909
+ pooled_outputs = []
1910
+ for idx, layer_norm in zip(indices, self.layer_norms):
1911
+ pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
1912
+ pool_output = layer_norm(pool_output)
1913
+ pooled_outputs.append(pool_output)
1914
+
1915
+ pooled_outputs = torch.cat(pooled_outputs, dim=1)
1916
+ pooled_outputs = self.vision_adapter(pooled_outputs)
1917
+
1918
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1919
+
1920
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1921
+ query_outputs = self.qformer(
1922
+ query_embeds=query_tokens,
1923
+ encoder_hidden_states=image_embeds,
1924
+ encoder_attention_mask=image_attention_mask,
1925
+ )
1926
+ query_output = query_outputs[0]
1927
+ query_output = torch.cat([pooled_outputs, query_output], dim=1)
1928
+
1929
+ language_model_inputs = self.language_projection(query_output)
1930
+
1931
+ batch_size = language_model_inputs.shape[0]
1932
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
1933
+
1934
+ prefix_embeds = inputs_embeds[:, :self.offset, :]
1935
+ postfix_embeds = inputs_embeds[:, self.offset:, :]
1936
+ inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)
1937
+
1938
+ if input_ids is None:
1939
+ input_ids = (
1940
+ torch.LongTensor([[self.config.text_config.bos_token_id]])
1941
+ .repeat(batch_size, 1)
1942
+ .to(inputs_embeds.device)
1943
+ )
1944
+
1945
+ if attention_mask is None:
1946
+ attention_mask = torch.ones_like(
1947
+ input_ids, dtype=torch.long, device=language_model_inputs.device)
1948
+ else:
1949
+ prefix_mask = attention_mask[:, :self.offset]
1950
+ postfix_mask = attention_mask[:, self.offset:]
1951
+ vision_mask = torch.ones(size=(batch_size, self.num_queries + 4), dtype=torch.long,
1952
+ device=attention_mask.device)
1953
+ attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)
1954
+
1955
+ outputs = self.language_model.generate(
1956
+ inputs_embeds=inputs_embeds,
1957
+ attention_mask=attention_mask,
1958
+ generation_config=generation_config,
1959
+ **generate_kwargs,
1960
+ )
1961
+
1962
+ return outputs
robohusky/model/modeling_husky_embody2_ori.py ADDED
@@ -0,0 +1,1821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch Husky model."""
16
+
17
+ import contextlib
18
+ import math
19
+ from dataclasses import dataclass
20
+ from typing import Any, Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.utils.checkpoint
24
+ from torch import nn
25
+ from torch.nn import CrossEntropyLoss
26
+
27
+ from transformers.activations import ACT2FN
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutput,
30
+ BaseModelOutputWithPastAndCrossAttentions,
31
+ BaseModelOutputWithPooling,
32
+ BaseModelOutputWithPoolingAndCrossAttentions,
33
+ )
34
+ from transformers.modeling_utils import PreTrainedModel
35
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
36
+ from transformers.utils import (
37
+ ModelOutput,
38
+ add_start_docstrings,
39
+ add_start_docstrings_to_model_forward,
40
+ logging,
41
+ replace_return_docstrings,
42
+ )
43
+ from transformers import AutoModelForCausalLM, GenerationConfig
44
+
45
+ from .configuration_husky import HuskyConfig, HuskyQFormerConfig, HuskyVisionConfig
46
+
47
+ logger = logging.get_logger(__name__)
48
+
49
+ _CHECKPOINT_FOR_DOC = "wofmanaf/husky-7b"
50
+
51
+ HUSKY_PRETRAINED_MODEL_ARCHIVE_LIST = [
52
+ "wofmanaf/husky-7b",
53
+ ]
54
+
55
+ @dataclass
56
+ class HuskyForConditionalGenerationModelOutput(ModelOutput):
57
+ """
58
+ Class defining the outputs of [`HuskyForConditionalGeneration`].
59
+
60
+ Args:
61
+ loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
62
+ Language modeling loss from the language model.
63
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
64
+ Prediction scores of the language modeling head of the language model.
65
+ vision_outputs (`BaseModelOutputWithPooling`):
66
+ Outputs of the vision encoder.
67
+ qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
68
+ Outputs of the Q-Former (Querying Transformer).
69
+ language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
70
+ Outputs of the language model.
71
+ """
72
+
73
+ loss: Optional[Tuple[torch.FloatTensor]] = None
74
+ logits: Optional[Tuple[torch.FloatTensor]] = None
75
+ vision_outputs: Optional[torch.FloatTensor] = None
76
+ qformer_outputs: Optional[Tuple[torch.FloatTensor]] = None
77
+ language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None
78
+
79
+ def to_tuple(self) -> Tuple[Any]:
80
+ return tuple(
81
+ self[k]
82
+ if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
83
+ else getattr(self, k).to_tuple()
84
+ for k in self.keys()
85
+ )
86
+
87
+ # Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Husky
88
+ class HuskyVisionEmbeddings(nn.Module):
89
+ def __init__(self, config: HuskyVisionConfig):
90
+ super().__init__()
91
+ self.config = config
92
+ self.embed_dim = config.hidden_size
93
+ self.image_size = config.image_size
94
+ self.patch_size = config.patch_size
95
+
96
+ self.class_embedding = nn.Parameter(
97
+ torch.randn(1, 1, self.embed_dim),
98
+ )
99
+
100
+ self.patch_embedding = nn.Conv2d(
101
+ in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
102
+ )
103
+
104
+ self.num_patches = (self.image_size // self.patch_size) ** 2
105
+ self.num_positions = self.num_patches + 1
106
+
107
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
108
+
109
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
110
+ batch_size = pixel_values.shape[0]
111
+ target_dtype = self.patch_embedding.weight.dtype
112
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
113
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
114
+
115
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
116
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
117
+ embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
118
+ return embeddings
119
+
120
+ class HuskyVideoEmbeddings(nn.Module):
121
+ def __init__(self, config: HuskyVisionConfig):
122
+ super().__init__()
123
+ self.config = config
124
+ self.embed_dim = config.hidden_size
125
+ self.image_size = config.image_size
126
+ self.patch_size = config.patch_size
127
+ self.num_frames = getattr(self.config, "num_frames", 8)
128
+ self.frame_stride = getattr(self.config, "frame_stride", 2)
129
+
130
+ self.class_embedding = nn.Parameter(
131
+ torch.randn(1, 1, self.embed_dim),
132
+ )
133
+
134
+ self.patch_embedding = nn.Conv3d(
135
+ in_channels=3, out_channels=self.embed_dim,
136
+ kernel_size=(self.frame_stride, self.patch_size, self.patch_size),
137
+ stride=(self.frame_stride, self.patch_size, self.patch_size)
138
+ )
139
+
140
+ self.num_patches = int(self.num_frames // self.frame_stride) * (self.image_size // self.patch_size) ** 2
141
+ self.num_positions = self.num_patches + 1
142
+
143
+ self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
144
+
145
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
146
+ batch_size = pixel_values.shape[0]
147
+ target_dtype = self.patch_embedding.weight.dtype
148
+ patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
149
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
150
+
151
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
152
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
153
+ embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
154
+ return embeddings
155
+
156
+ class HuskyAttention(nn.Module):
157
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
158
+
159
+ def __init__(self, config):
160
+ super().__init__()
161
+ self.config = config
162
+ self.embed_dim = config.hidden_size
163
+ self.num_heads = config.num_attention_heads
164
+ self.head_dim = self.embed_dim // self.num_heads
165
+ if self.head_dim * self.num_heads != self.embed_dim:
166
+ raise ValueError(
167
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
168
+ f" {self.num_heads})."
169
+ )
170
+ self.scale = self.head_dim ** -0.5
171
+ self.dropout = nn.Dropout(config.attention_dropout)
172
+
173
+ # small tweak here compared to CLIP, no bias here
174
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
175
+
176
+ if config.qkv_bias:
177
+ q_bias = nn.Parameter(torch.zeros(self.embed_dim))
178
+ v_bias = nn.Parameter(torch.zeros(self.embed_dim))
179
+ else:
180
+ q_bias = None
181
+ v_bias = None
182
+
183
+ if q_bias is not None:
184
+ qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
185
+ self.qkv.bias = nn.Parameter(qkv_bias)
186
+
187
+ self.projection = nn.Linear(self.embed_dim, self.embed_dim)
188
+
189
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
190
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
191
+
192
+ def forward(
193
+ self,
194
+ hidden_states: torch.Tensor,
195
+ head_mask: Optional[torch.Tensor] = None,
196
+ output_attentions: Optional[bool] = False,
197
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
198
+ """Input shape: Batch x Time x Channel"""
199
+
200
+ bsz, tgt_len, embed_dim = hidden_states.size()
201
+
202
+ mixed_qkv = self.qkv(hidden_states)
203
+
204
+ mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
205
+ 2, 0, 3, 1, 4
206
+ )
207
+ query_states, key_states, value_states = (
208
+ mixed_qkv[0],
209
+ mixed_qkv[1],
210
+ mixed_qkv[2],
211
+ )
212
+
213
+ # Take the dot product between "query" and "key" to get the raw attention scores.
214
+ attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
215
+
216
+ attention_scores = attention_scores * self.scale
217
+
218
+ # Normalize the attention scores to probabilities.
219
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
220
+
221
+ # This is actually dropping out entire tokens to attend to, which might
222
+ # seem a bit unusual, but is taken from the original Transformer paper.
223
+ attention_probs = self.dropout(attention_probs)
224
+
225
+ # Mask heads if we want to
226
+ if head_mask is not None:
227
+ attention_probs = attention_probs * head_mask
228
+
229
+ context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
230
+
231
+ new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
232
+ context_layer = context_layer.reshape(new_context_layer_shape)
233
+
234
+ output = self.projection(context_layer)
235
+
236
+ outputs = (output, attention_probs) if output_attentions else (output, None)
237
+
238
+ return outputs
239
+
240
+ # Copied from transformers.models.blip.modeling_blip.BlipMLP
241
+ class HuskyMLP(nn.Module):
242
+ def __init__(self, config):
243
+ super().__init__()
244
+ self.config = config
245
+ self.activation_fn = ACT2FN[config.hidden_act]
246
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
247
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
248
+
249
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
250
+ hidden_states = self.fc1(hidden_states)
251
+ hidden_states = self.activation_fn(hidden_states)
252
+ hidden_states = self.fc2(hidden_states)
253
+ return hidden_states
254
+
255
+ # Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Husky
256
+ class HuskyEncoderLayer(nn.Module):
257
+ def __init__(self, config: HuskyConfig):
258
+ super().__init__()
259
+ self.embed_dim = config.hidden_size
260
+ self.self_attn = HuskyAttention(config)
261
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
262
+ self.mlp = HuskyMLP(config)
263
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
264
+
265
+ def forward(
266
+ self,
267
+ hidden_states: torch.Tensor,
268
+ attention_mask: torch.Tensor,
269
+ output_attentions: Optional[bool] = False,
270
+ ) -> Tuple[torch.FloatTensor]:
271
+ """
272
+ Args:
273
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
274
+ attention_mask (`torch.FloatTensor`): attention mask of size
275
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
276
+ `(config.encoder_attention_heads,)`.
277
+ output_attentions (`bool`, *optional*):
278
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
279
+ returned tensors for more detail.
280
+ """
281
+ residual = hidden_states
282
+
283
+ hidden_states = self.layer_norm1(hidden_states)
284
+ hidden_states, attn_weights = self.self_attn(
285
+ hidden_states=hidden_states,
286
+ head_mask=attention_mask,
287
+ output_attentions=output_attentions,
288
+ )
289
+ hidden_states = hidden_states + residual
290
+ residual = hidden_states
291
+ hidden_states = self.layer_norm2(hidden_states)
292
+ hidden_states = self.mlp(hidden_states)
293
+
294
+ hidden_states = hidden_states + residual
295
+
296
+ outputs = (hidden_states,)
297
+
298
+ if output_attentions:
299
+ outputs += (attn_weights,)
300
+
301
+ return outputs
302
+
303
+ class HuskyPreTrainedModel(PreTrainedModel):
304
+ """
305
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
306
+ models.
307
+ """
308
+
309
+ config_class = HuskyConfig
310
+ base_model_prefix = "husky"
311
+ supports_gradient_checkpointing = True
312
+ _keys_to_ignore_on_load_missing = [
313
+ r"position_ids",
314
+ r"language_model.encoder.embed_tokens.weight",
315
+ r"language_model.decoder.embed_tokens.weight",
316
+ r"language_model.lm_head.weight",
317
+ ]
318
+ _no_split_modules = ["HuskyAttention", "LlamaDecoderLayer", "LlamaForCausalLM"]
319
+ _skip_keys_device_placement = "past_key_values"
320
+ _keep_in_fp32_modules = ["wo"]
321
+
322
+ def _init_weights(self, module):
323
+ """Initialize the weights"""
324
+ factor = self.config.initializer_range
325
+ if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
326
+ module.weight.data.normal_(mean=0.0, std=factor)
327
+ if hasattr(module, "bias") and module.bias is not None:
328
+ module.bias.data.zero_()
329
+
330
+ if isinstance(module, HuskyVisionEmbeddings):
331
+ if hasattr(self.config, "vision_config"):
332
+ factor = self.config.vision_config.initializer_range
333
+ nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
334
+ nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
335
+
336
+ elif isinstance(module, nn.LayerNorm):
337
+ module.bias.data.zero_()
338
+ module.weight.data.fill_(1.0)
339
+ elif isinstance(module, nn.Linear) and module.bias is not None:
340
+ module.bias.data.zero_()
341
+
342
+ def _set_gradient_checkpointing(self, module, value=False):
343
+ if isinstance(module, HuskyEncoder):
344
+ module.gradient_checkpointing = value
345
+
346
+ Husky_START_DOCSTRING = r"""
347
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
348
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
349
+ etc.)
350
+
351
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
352
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
353
+ and behavior.
354
+
355
+ Parameters:
356
+ config ([`HuskyConfig`]): Model configuration class with all the parameters of the model.
357
+ Initializing with a config file does not load the weights associated with the model, only the
358
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
359
+ """
360
+
361
+ Husky_VISION_INPUTS_DOCSTRING = r"""
362
+ Args:
363
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
364
+ Pixel values. Pixel values can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for
365
+ details.
366
+ output_attentions (`bool`, *optional*):
367
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
368
+ tensors for more detail.
369
+ output_hidden_states (`bool`, *optional*):
370
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
371
+ more detail.
372
+ return_dict (`bool`, *optional*):
373
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
374
+ """
375
+
376
+ Husky_TEXT_INPUTS_DOCSTRING = r"""
377
+ Args:
378
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
379
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
380
+ it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
381
+ [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
382
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
383
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
384
+ - 1 for tokens that are **not masked**,
385
+ - 0 for tokens that are **masked**.
386
+ [What are attention masks?](../glossary#attention-mask)
387
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
388
+ Indices of decoder input sequence tokens in the vocabulary.
389
+
390
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
391
+ [`PreTrainedTokenizer.__call__`] for details.
392
+
393
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
394
+
395
+ T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
396
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
397
+
398
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
399
+ Training](./t5#training).
400
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
401
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
402
+ be used by default.
403
+ output_attentions (`bool`, *optional*):
404
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
405
+ tensors for more detail.
406
+ output_hidden_states (`bool`, *optional*):
407
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
408
+ more detail.
409
+ return_dict (`bool`, *optional*):
410
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
411
+ """
412
+
413
+ Husky_INPUTS_DOCSTRING = r"""
414
+ Args:
415
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
416
+ Pixel values. Pixel values can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for
417
+ details.
418
+
419
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
420
+ Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
421
+ provided to serve as text prompt, which the language model can continue.
422
+
423
+ Indices can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for details.
424
+
425
+ [What are input IDs?](../glossary#input-ids)
426
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
427
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
428
+
429
+ - 1 for tokens that are **not masked**,
430
+ - 0 for tokens that are **masked**.
431
+
432
+ [What are attention masks?](../glossary#attention-mask)
433
+
434
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
435
+ Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
436
+ encoder-decoder language model (like T5) is used.
437
+
438
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
439
+ [`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
440
+
441
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
442
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
443
+ be used by default.
444
+
445
+ Only relevant in case an encoder-decoder language model (like T5) is used.
446
+
447
+ output_attentions (`bool`, *optional*):
448
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
449
+ tensors for more detail.
450
+ output_hidden_states (`bool`, *optional*):
451
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
452
+ more detail.
453
+ return_dict (`bool`, *optional*):
454
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
455
+ """
456
+
457
+ # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Husky
458
+ class HuskyEncoder(nn.Module):
459
+ """
460
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
461
+ [`HuskyEncoderLayer`].
462
+
463
+ Args:
464
+ config (`HuskyConfig`):
465
+ The corresponding vision configuration for the `HuskyEncoder`.
466
+ """
467
+
468
+ def __init__(self, config: HuskyConfig):
469
+ super().__init__()
470
+ self.config = config
471
+ self.layers = nn.ModuleList([HuskyEncoderLayer(config) for _ in range(config.num_hidden_layers)])
472
+ self.gradient_checkpointing = False
473
+
474
+ def forward(
475
+ self,
476
+ inputs_embeds,
477
+ attention_mask: Optional[torch.Tensor] = None,
478
+ output_attentions: Optional[bool] = None,
479
+ output_hidden_states: Optional[bool] = None,
480
+ return_dict: Optional[bool] = None,
481
+ ) -> Union[Tuple, BaseModelOutput]:
482
+ r"""
483
+ Args:
484
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
485
+ Embedded representation of the inputs. Should be float, not int tokens.
486
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
487
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
488
+
489
+ - 1 for tokens that are **not masked**,
490
+ - 0 for tokens that are **masked**.
491
+
492
+ [What are attention masks?](../glossary#attention-mask)
493
+ output_attentions (`bool`, *optional*):
494
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
495
+ returned tensors for more detail.
496
+ output_hidden_states (`bool`, *optional*):
497
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
498
+ for more detail.
499
+ return_dict (`bool`, *optional*):
500
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
501
+ """
502
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
503
+ output_hidden_states = (
504
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
505
+ )
506
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
507
+
508
+ encoder_states = () if output_hidden_states else None
509
+ all_attentions = () if output_attentions else None
510
+
511
+ hidden_states = inputs_embeds
512
+ for idx, encoder_layer in enumerate(self.layers):
513
+ if output_hidden_states:
514
+ encoder_states = encoder_states + (hidden_states,)
515
+ if self.gradient_checkpointing and self.training:
516
+
517
+ def create_custom_forward(module):
518
+ def custom_forward(*inputs):
519
+ return module(*inputs, output_attentions)
520
+
521
+ return custom_forward
522
+
523
+ layer_outputs = torch.utils.checkpoint.checkpoint(
524
+ create_custom_forward(encoder_layer),
525
+ hidden_states,
526
+ attention_mask,
527
+ )
528
+ else:
529
+ layer_outputs = encoder_layer(
530
+ hidden_states,
531
+ attention_mask,
532
+ output_attentions=output_attentions,
533
+ )
534
+
535
+ hidden_states = layer_outputs[0]
536
+
537
+ if output_attentions:
538
+ all_attentions = all_attentions + (layer_outputs[1],)
539
+
540
+ if output_hidden_states:
541
+ encoder_states = encoder_states + (hidden_states,)
542
+
543
+ if not return_dict:
544
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
545
+ return BaseModelOutput(
546
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
547
+ )
548
+
549
+ # Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Husky, BLIP->Husky
550
+ class HuskyVisionModel(HuskyPreTrainedModel):
551
+ main_input_name = "pixel_values"
552
+ config_class = HuskyVisionConfig
553
+
554
+ def __init__(self, config: HuskyVisionConfig):
555
+ super().__init__(config)
556
+ self.config = config
557
+ embed_dim = config.hidden_size
558
+
559
+ self.embeddings = HuskyVisionEmbeddings(config)
560
+ self.video_embeddings = HuskyVideoEmbeddings(config)
561
+
562
+ self.encoder = HuskyEncoder(config)
563
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
564
+
565
+ self.post_init()
566
+
567
+ @add_start_docstrings_to_model_forward(Husky_VISION_INPUTS_DOCSTRING)
568
+ # @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=HuskyVisionConfig)
569
+ def forward(
570
+ self,
571
+ pixel_values: Optional[torch.FloatTensor] = None,
572
+ output_attentions: Optional[bool] = None,
573
+ output_hidden_states: Optional[bool] = None,
574
+ return_dict: Optional[bool] = None,
575
+ pixel_embeds: Optional[torch.FloatTensor] = None,
576
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
577
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
578
+ output_hidden_states = (
579
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
580
+ )
581
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
582
+
583
+ if pixel_values is None and pixel_embeds is None:
584
+ raise ValueError("You have to specify pixel_values or pixel_embeds")
585
+
586
+ if pixel_embeds is not None:
587
+ hidden_states = pixel_embeds
588
+ else:
589
+ if len(pixel_values.shape) == 4:
590
+ hidden_states = self.embeddings(pixel_values)
591
+ elif len(pixel_values.shape) == 5:
592
+ hidden_states = self.video_embeddings(pixel_values)
593
+ else:
594
+ raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
595
+
596
+ encoder_outputs = self.encoder(
597
+ inputs_embeds=hidden_states,
598
+ output_attentions=output_attentions,
599
+ output_hidden_states=output_hidden_states,
600
+ return_dict=return_dict,
601
+ )
602
+
603
+ last_hidden_state = encoder_outputs[0]
604
+ last_hidden_state = self.post_layernorm(last_hidden_state)
605
+
606
+ pooled_output = last_hidden_state[:, 0, :]
607
+ pooled_output = self.post_layernorm(pooled_output)
608
+
609
+ if not return_dict:
610
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
611
+
612
+ return BaseModelOutputWithPooling(
613
+ last_hidden_state=last_hidden_state,
614
+ pooler_output=pooled_output,
615
+ hidden_states=encoder_outputs.hidden_states,
616
+ attentions=encoder_outputs.attentions,
617
+ )
618
+
619
+ def get_input_embeddings(self):
620
+ return self.embeddings
621
+
622
+ def get_video_embeddings(self):
623
+ return self.video_embeddings
624
+
625
+ class HuskyQFormerMultiHeadAttention(nn.Module):
626
+ def __init__(self, config, is_cross_attention=False):
627
+ super().__init__()
628
+ self.config = config
629
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
630
+ raise ValueError(
631
+ "The hidden size (%d) is not a multiple of the number of attention heads (%d)"
632
+ % (config.hidden_size, config.num_attention_heads)
633
+ )
634
+
635
+ self.num_attention_heads = config.num_attention_heads
636
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
637
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
638
+
639
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
640
+ if is_cross_attention:
641
+ self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
642
+ self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
643
+ else:
644
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
645
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
646
+
647
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
648
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
649
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
650
+ self.max_position_embeddings = config.max_position_embeddings
651
+ self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
652
+ self.save_attention = False
653
+
654
+ def save_attn_gradients(self, attn_gradients):
655
+ self.attn_gradients = attn_gradients
656
+
657
+ def get_attn_gradients(self):
658
+ return self.attn_gradients
659
+
660
+ def save_attention_map(self, attention_map):
661
+ self.attention_map = attention_map
662
+
663
+ def get_attention_map(self):
664
+ return self.attention_map
665
+
666
+ def transpose_for_scores(self, x):
667
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
668
+ x = x.view(*new_x_shape)
669
+ return x.permute(0, 2, 1, 3)
670
+
671
+ def forward(
672
+ self,
673
+ hidden_states,
674
+ attention_mask=None,
675
+ head_mask=None,
676
+ encoder_hidden_states=None,
677
+ encoder_attention_mask=None,
678
+ past_key_value=None,
679
+ output_attentions=False,
680
+ ):
681
+ # If this is instantiated as a cross-attention module, the keys
682
+ # and values come from an encoder; the attention mask needs to be
683
+ # such that the encoder's padding tokens are not attended to.
684
+ is_cross_attention = encoder_hidden_states is not None
685
+
686
+ if is_cross_attention:
687
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
688
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
689
+ attention_mask = encoder_attention_mask
690
+ elif past_key_value is not None:
691
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
692
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
693
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
694
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
695
+ else:
696
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
697
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
698
+
699
+ mixed_query_layer = self.query(hidden_states)
700
+
701
+ query_layer = self.transpose_for_scores(mixed_query_layer)
702
+
703
+ past_key_value = (key_layer, value_layer)
704
+
705
+ # Take the dot product between "query" and "key" to get the raw attention scores.
706
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
707
+
708
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
709
+ seq_length = hidden_states.size()[1]
710
+ position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
711
+ position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
712
+ distance = position_ids_l - position_ids_r
713
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
714
+ positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
715
+
716
+ if self.position_embedding_type == "relative_key":
717
+ relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
718
+ attention_scores = attention_scores + relative_position_scores
719
+ elif self.position_embedding_type == "relative_key_query":
720
+ relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
721
+ relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
722
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
723
+
724
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
725
+
726
+ if attention_mask is not None:
727
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
728
+ attention_scores = attention_scores + attention_mask
729
+
730
+ # Normalize the attention scores to probabilities.
731
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
732
+
733
+ if is_cross_attention and self.save_attention:
734
+ self.save_attention_map(attention_probs)
735
+ attention_probs.register_hook(self.save_attn_gradients)
736
+
737
+ # This is actually dropping out entire tokens to attend to, which might
738
+ # seem a bit unusual, but is taken from the original Transformer paper.
739
+ attention_probs_dropped = self.dropout(attention_probs)
740
+
741
+ # Mask heads if we want to
742
+ if head_mask is not None:
743
+ attention_probs_dropped = attention_probs_dropped * head_mask
744
+
745
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
746
+
747
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
748
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
749
+ context_layer = context_layer.view(*new_context_layer_shape)
750
+
751
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
752
+
753
+ outputs = outputs + (past_key_value,)
754
+ return outputs
755
+
756
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->HuskyQFormer
757
+ class HuskyQFormerSelfOutput(nn.Module):
758
+ def __init__(self, config):
759
+ super().__init__()
760
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
761
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
762
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
763
+
764
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
765
+ hidden_states = self.dense(hidden_states)
766
+ hidden_states = self.dropout(hidden_states)
767
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
768
+ return hidden_states
769
+
770
+ class HuskyQFormerAttention(nn.Module):
771
+ def __init__(self, config, is_cross_attention=False):
772
+ super().__init__()
773
+ self.attention = HuskyQFormerMultiHeadAttention(config, is_cross_attention)
774
+ self.output = HuskyQFormerSelfOutput(config)
775
+ self.pruned_heads = set()
776
+
777
+ def prune_heads(self, heads):
778
+ if len(heads) == 0:
779
+ return
780
+ heads, index = find_pruneable_heads_and_indices(
781
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
782
+ )
783
+
784
+ # Prune linear layers
785
+ self.attention.query = prune_linear_layer(self.attention.query, index)
786
+ self.attention.key = prune_linear_layer(self.attention.key, index)
787
+ self.attention.value = prune_linear_layer(self.attention.value, index)
788
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
789
+
790
+ # Update hyper params and store pruned heads
791
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
792
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
793
+ self.pruned_heads = self.pruned_heads.union(heads)
794
+
795
+ def forward(
796
+ self,
797
+ hidden_states: torch.Tensor,
798
+ attention_mask: Optional[torch.FloatTensor] = None,
799
+ head_mask: Optional[torch.FloatTensor] = None,
800
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
801
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
802
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
803
+ output_attentions: Optional[bool] = False,
804
+ ) -> Tuple[torch.Tensor]:
805
+ self_outputs = self.attention(
806
+ hidden_states,
807
+ attention_mask,
808
+ head_mask,
809
+ encoder_hidden_states,
810
+ encoder_attention_mask,
811
+ past_key_value,
812
+ output_attentions,
813
+ )
814
+ attention_output = self.output(self_outputs[0], hidden_states)
815
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
816
+ return outputs
817
+
818
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->HuskyQFormer
819
+ class HuskyQFormerIntermediate(nn.Module):
820
+ def __init__(self, config):
821
+ super().__init__()
822
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
823
+ if isinstance(config.hidden_act, str):
824
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
825
+ else:
826
+ self.intermediate_act_fn = config.hidden_act
827
+
828
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
829
+ hidden_states = self.dense(hidden_states)
830
+ hidden_states = self.intermediate_act_fn(hidden_states)
831
+ return hidden_states
832
+
833
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->HuskyQFormer
834
+ class HuskyQFormerOutput(nn.Module):
835
+ def __init__(self, config):
836
+ super().__init__()
837
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
838
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
839
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
840
+
841
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
842
+ hidden_states = self.dense(hidden_states)
843
+ hidden_states = self.dropout(hidden_states)
844
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
845
+ return hidden_states
846
+
847
+ class HuskyQFormerLayer(nn.Module):
848
+ def __init__(self, config, layer_idx):
849
+ super().__init__()
850
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
851
+ self.seq_len_dim = 1
852
+ self.attention = HuskyQFormerAttention(config)
853
+
854
+ self.layer_idx = layer_idx
855
+
856
+ if layer_idx % config.cross_attention_frequency == 0:
857
+ self.crossattention = HuskyQFormerAttention(config, is_cross_attention=True)
858
+ self.has_cross_attention = True
859
+ else:
860
+ self.has_cross_attention = False
861
+
862
+ self.intermediate_query = HuskyQFormerIntermediate(config)
863
+ self.output_query = HuskyQFormerOutput(config)
864
+
865
+ def forward(
866
+ self,
867
+ hidden_states,
868
+ attention_mask=None,
869
+ head_mask=None,
870
+ encoder_hidden_states=None,
871
+ encoder_attention_mask=None,
872
+ past_key_value=None,
873
+ output_attentions=False,
874
+ query_length=0,
875
+ ):
876
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
877
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
878
+ self_attention_outputs = self.attention(
879
+ hidden_states,
880
+ attention_mask,
881
+ head_mask,
882
+ output_attentions=output_attentions,
883
+ past_key_value=self_attn_past_key_value,
884
+ )
885
+ attention_output = self_attention_outputs[0]
886
+ outputs = self_attention_outputs[1:-1]
887
+
888
+ present_key_value = self_attention_outputs[-1]
889
+
890
+ if query_length > 0:
891
+ query_attention_output = attention_output[:, :query_length, :]
892
+
893
+ if self.has_cross_attention:
894
+ if encoder_hidden_states is None:
895
+ raise ValueError("encoder_hidden_states must be given for cross-attention layers")
896
+ cross_attention_outputs = self.crossattention(
897
+ query_attention_output,
898
+ attention_mask,
899
+ head_mask,
900
+ encoder_hidden_states,
901
+ encoder_attention_mask,
902
+ output_attentions=output_attentions,
903
+ )
904
+ query_attention_output = cross_attention_outputs[0]
905
+ # add cross attentions if we output attention weights
906
+ outputs = outputs + cross_attention_outputs[1:-1]
907
+
908
+ layer_output = apply_chunking_to_forward(
909
+ self.feed_forward_chunk_query,
910
+ self.chunk_size_feed_forward,
911
+ self.seq_len_dim,
912
+ query_attention_output,
913
+ )
914
+
915
+ if attention_output.shape[1] > query_length:
916
+ layer_output_text = apply_chunking_to_forward(
917
+ self.feed_forward_chunk,
918
+ self.chunk_size_feed_forward,
919
+ self.seq_len_dim,
920
+ attention_output[:, query_length:, :],
921
+ )
922
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
923
+ else:
924
+ layer_output = apply_chunking_to_forward(
925
+ self.feed_forward_chunk,
926
+ self.chunk_size_feed_forward,
927
+ self.seq_len_dim,
928
+ attention_output,
929
+ )
930
+ outputs = (layer_output,) + outputs
931
+
932
+ outputs = outputs + (present_key_value,)
933
+
934
+ return outputs
935
+
936
+ def feed_forward_chunk(self, attention_output):
937
+ intermediate_output = self.intermediate(attention_output)
938
+ layer_output = self.output(intermediate_output, attention_output)
939
+ return layer_output
940
+
941
+ def feed_forward_chunk_query(self, attention_output):
942
+ intermediate_output = self.intermediate_query(attention_output)
943
+ layer_output = self.output_query(intermediate_output, attention_output)
944
+ return layer_output
945
+
946
+ class HuskyQFormerEncoder(nn.Module):
947
+ def __init__(self, config):
948
+ super().__init__()
949
+ self.config = config
950
+ self.layer = nn.ModuleList(
951
+ [HuskyQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
952
+ )
953
+ self.gradient_checkpointing = False
954
+
955
+ def forward(
956
+ self,
957
+ hidden_states,
958
+ attention_mask=None,
959
+ head_mask=None,
960
+ encoder_hidden_states=None,
961
+ encoder_attention_mask=None,
962
+ past_key_values=None,
963
+ use_cache=None,
964
+ output_attentions=False,
965
+ output_hidden_states=False,
966
+ return_dict=True,
967
+ query_length=0,
968
+ ):
969
+ all_hidden_states = () if output_hidden_states else None
970
+ all_self_attentions = () if output_attentions else None
971
+ all_cross_attentions = () if output_attentions else None
972
+
973
+ next_decoder_cache = () if use_cache else None
974
+
975
+ for i in range(self.config.num_hidden_layers):
976
+ layer_module = self.layer[i]
977
+ if output_hidden_states:
978
+ all_hidden_states = all_hidden_states + (hidden_states,)
979
+
980
+ layer_head_mask = head_mask[i] if head_mask is not None else None
981
+ past_key_value = past_key_values[i] if past_key_values is not None else None
982
+
983
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
984
+ if use_cache:
985
+ logger.warn(
986
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
987
+ )
988
+ use_cache = False
989
+
990
+ def create_custom_forward(module):
991
+ def custom_forward(*inputs):
992
+ return module(*inputs, past_key_value, output_attentions, query_length)
993
+
994
+ return custom_forward
995
+
996
+ layer_outputs = torch.utils.checkpoint.checkpoint(
997
+ create_custom_forward(layer_module),
998
+ hidden_states,
999
+ attention_mask,
1000
+ layer_head_mask,
1001
+ encoder_hidden_states,
1002
+ encoder_attention_mask,
1003
+ )
1004
+ else:
1005
+ layer_outputs = layer_module(
1006
+ hidden_states,
1007
+ attention_mask,
1008
+ layer_head_mask,
1009
+ encoder_hidden_states,
1010
+ encoder_attention_mask,
1011
+ past_key_value,
1012
+ output_attentions,
1013
+ query_length,
1014
+ )
1015
+
1016
+ hidden_states = layer_outputs[0]
1017
+ if use_cache:
1018
+ next_decoder_cache += (layer_outputs[-1],)
1019
+ if output_attentions:
1020
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
1021
+ if layer_module.has_cross_attention:
1022
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
1023
+
1024
+ if output_hidden_states:
1025
+ all_hidden_states = all_hidden_states + (hidden_states,)
1026
+
1027
+ if not return_dict:
1028
+ return tuple(
1029
+ v
1030
+ for v in [
1031
+ hidden_states,
1032
+ next_decoder_cache,
1033
+ all_hidden_states,
1034
+ all_self_attentions,
1035
+ all_cross_attentions,
1036
+ ]
1037
+ if v is not None
1038
+ )
1039
+ return BaseModelOutputWithPastAndCrossAttentions(
1040
+ last_hidden_state=hidden_states,
1041
+ past_key_values=next_decoder_cache,
1042
+ hidden_states=all_hidden_states,
1043
+ attentions=all_self_attentions,
1044
+ cross_attentions=all_cross_attentions,
1045
+ )
1046
+
1047
+ class HuskyQFormerModel(HuskyPreTrainedModel):
1048
+ """
1049
+ Querying Transformer (Q-Former), used in Husky.
1050
+ """
1051
+
1052
+ def __init__(self, config: HuskyQFormerConfig):
1053
+ super().__init__(config)
1054
+ self.config = config
1055
+
1056
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1057
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1058
+
1059
+ self.encoder = HuskyQFormerEncoder(config)
1060
+
1061
+ self.post_init()
1062
+
1063
+ def get_input_embeddings(self):
1064
+ return self.embeddings.word_embeddings
1065
+
1066
+ def set_input_embeddings(self, value):
1067
+ self.embeddings.word_embeddings = value
1068
+
1069
+ def _prune_heads(self, heads_to_prune):
1070
+ """
1071
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1072
+ class PreTrainedModel
1073
+ """
1074
+ for layer, heads in heads_to_prune.items():
1075
+ self.encoder.layer[layer].attention.prune_heads(heads)
1076
+
1077
+ def get_extended_attention_mask(
1078
+ self,
1079
+ attention_mask: torch.Tensor,
1080
+ input_shape: Tuple[int],
1081
+ device: torch.device,
1082
+ has_query: bool = False,
1083
+ ) -> torch.Tensor:
1084
+ """
1085
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
1086
+
1087
+ Arguments:
1088
+ attention_mask (`torch.Tensor`):
1089
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
1090
+ input_shape (`Tuple[int]`):
1091
+ The shape of the input to the model.
1092
+ device (`torch.device`):
1093
+ The device of the input to the model.
1094
+
1095
+ Returns:
1096
+ `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
1097
+ """
1098
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1099
+ # ourselves in which case we just need to make it broadcastable to all heads.
1100
+ if attention_mask.dim() == 3:
1101
+ extended_attention_mask = attention_mask[:, None, :, :]
1102
+ elif attention_mask.dim() == 2:
1103
+ # Provided a padding mask of dimensions [batch_size, seq_length]
1104
+ # - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
1105
+ extended_attention_mask = attention_mask[:, None, None, :]
1106
+ else:
1107
+ raise ValueError(
1108
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
1109
+ input_shape, attention_mask.shape
1110
+ )
1111
+ )
1112
+
1113
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1114
+ # masked positions, this operation will create a tensor which is 0.0 for
1115
+ # positions we want to attend and -10000.0 for masked positions.
1116
+ # Since we are adding it to the raw scores before the softmax, this is
1117
+ # effectively the same as removing these entirely.
1118
+ extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
1119
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
1120
+ return extended_attention_mask
1121
+
1122
+ def forward(
1123
+ self,
1124
+ query_embeds,
1125
+ attention_mask=None,
1126
+ head_mask=None,
1127
+ encoder_hidden_states=None,
1128
+ encoder_attention_mask=None,
1129
+ past_key_values=None,
1130
+ use_cache=None,
1131
+ output_attentions=None,
1132
+ output_hidden_states=None,
1133
+ return_dict=None,
1134
+ ):
1135
+ r"""
1136
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
1137
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1138
+ the model is configured as a decoder.
1139
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
1140
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1141
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1142
+ - 1 for tokens that are **not masked**,
1143
+ - 0 for tokens that are **masked**.
1144
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
1145
+ shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
1146
+ value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
1147
+ used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
1148
+ value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
1149
+ `(batch_size, sequence_length)`.
1150
+ use_cache (`bool`, `optional`):
1151
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1152
+ `past_key_values`).
1153
+ """
1154
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1155
+ output_hidden_states = (
1156
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1157
+ )
1158
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1159
+
1160
+ # past_key_values_length
1161
+ past_key_values_length = (
1162
+ past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
1163
+ )
1164
+
1165
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
1166
+
1167
+ embedding_output = self.layernorm(query_embeds)
1168
+ embedding_output = self.dropout(embedding_output)
1169
+
1170
+ input_shape = embedding_output.size()[:-1]
1171
+ batch_size, seq_length = input_shape
1172
+ device = embedding_output.device
1173
+
1174
+ if attention_mask is None:
1175
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
1176
+
1177
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1178
+ # ourselves in which case we just need to make it broadcastable to all heads.
1179
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
1180
+
1181
+ # If a 2D or 3D attention mask is provided for the cross-attention
1182
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1183
+ if encoder_hidden_states is not None:
1184
+ if type(encoder_hidden_states) == list:
1185
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
1186
+ else:
1187
+ (
1188
+ encoder_batch_size,
1189
+ encoder_sequence_length,
1190
+ _,
1191
+ ) = encoder_hidden_states.size()
1192
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1193
+
1194
+ if type(encoder_attention_mask) == list:
1195
+ encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
1196
+ elif encoder_attention_mask is None:
1197
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1198
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1199
+ else:
1200
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1201
+ else:
1202
+ encoder_extended_attention_mask = None
1203
+
1204
+ # Prepare head mask if needed
1205
+ # 1.0 in head_mask indicate we keep the head
1206
+ # attention_probs has shape bsz x n_heads x N x N
1207
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1208
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1209
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1210
+
1211
+ encoder_outputs = self.encoder(
1212
+ embedding_output,
1213
+ attention_mask=extended_attention_mask,
1214
+ head_mask=head_mask,
1215
+ encoder_hidden_states=encoder_hidden_states,
1216
+ encoder_attention_mask=encoder_extended_attention_mask,
1217
+ past_key_values=past_key_values,
1218
+ use_cache=use_cache,
1219
+ output_attentions=output_attentions,
1220
+ output_hidden_states=output_hidden_states,
1221
+ return_dict=return_dict,
1222
+ query_length=query_length,
1223
+ )
1224
+ sequence_output = encoder_outputs[0]
1225
+ pooled_output = sequence_output[:, 0, :]
1226
+
1227
+ if not return_dict:
1228
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1229
+
1230
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1231
+ last_hidden_state=sequence_output,
1232
+ pooler_output=pooled_output,
1233
+ past_key_values=encoder_outputs.past_key_values,
1234
+ hidden_states=encoder_outputs.hidden_states,
1235
+ attentions=encoder_outputs.attentions,
1236
+ cross_attentions=encoder_outputs.cross_attentions,
1237
+ )
1238
+
1239
+ class AdapterMLP(nn.Module):
1240
+ def __init__(self, config):
1241
+ super().__init__()
1242
+ self.config = config
1243
+ self.activation_fn = ACT2FN["silu"]
1244
+ hidden_size = config.vision_config.hidden_size
1245
+ intermediate_size = hidden_size // 4
1246
+ output_size = config.qformer_config.hidden_size
1247
+
1248
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
1249
+ self.fc2 = nn.Linear(intermediate_size, output_size)
1250
+ self.layernorm = nn.LayerNorm(output_size, eps=config.vision_config.layer_norm_eps)
1251
+
1252
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1253
+ hidden_states = self.fc1(hidden_states)
1254
+ hidden_states = self.activation_fn(hidden_states)
1255
+ hidden_states = self.fc2(hidden_states)
1256
+ hidden_states = self.layernorm(hidden_states)
1257
+ return hidden_states
1258
+
1259
+ @add_start_docstrings(
1260
+ """
1261
+ Husky Model for generating text and image features. The model consists of a vision encoder, Querying Transformer
1262
+ (Q-Former) and a language model.
1263
+ """,
1264
+ Husky_START_DOCSTRING,
1265
+ )
1266
+ class HuskyModel(HuskyPreTrainedModel):
1267
+ config_class = HuskyConfig
1268
+ main_input_name = "pixel_values"
1269
+
1270
+ def __init__(self, config: HuskyConfig):
1271
+ super().__init__(config)
1272
+
1273
+ self.vision_model = HuskyVisionModel(config.vision_config)
1274
+
1275
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
1276
+ self.qformer = HuskyQFormerModel(config.qformer_config)
1277
+
1278
+ self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
1279
+ self.language_model = AutoModelForCausalLM.from_config(config.text_config)
1280
+
1281
+ self.config.hidden_size = config.text_config.hidden_size
1282
+ self.num_queries = config.num_query_tokens
1283
+ self.offset = 5
1284
+
1285
+ # Initialize weights and apply final processing
1286
+ self.post_init()
1287
+
1288
+ def get_input_embeddings(self):
1289
+ return self.language_model.get_input_embeddings()
1290
+
1291
+ def set_input_embeddings(self, value):
1292
+ self.language_model.set_input_embeddings(value)
1293
+
1294
+ def set_output_embeddings(self, new_embeddings):
1295
+ self.language_model.set_output_embeddings(new_embeddings)
1296
+
1297
+ def get_output_embeddings(self) -> nn.Module:
1298
+ return self.language_model.get_output_embeddings()
1299
+
1300
+ def get_encoder(self):
1301
+ return self.language_model.get_encoder()
1302
+
1303
+ def get_decoder(self):
1304
+ return self.language_model.get_decoder()
1305
+
1306
+ def _tie_weights(self):
1307
+ if not self.config.use_decoder_only_language_model:
1308
+ self.language_model.encoder.embed_tokens = self.language_model.shared
1309
+ self.language_model.decoder.embed_tokens = self.language_model.shared
1310
+
1311
+ @add_start_docstrings_to_model_forward(Husky_TEXT_INPUTS_DOCSTRING)
1312
+ def get_text_features(
1313
+ self,
1314
+ input_ids: Optional[torch.Tensor] = None,
1315
+ attention_mask: Optional[torch.Tensor] = None,
1316
+ output_attentions: Optional[bool] = None,
1317
+ output_hidden_states: Optional[bool] = None,
1318
+ return_dict: Optional[bool] = None,
1319
+ ):
1320
+ r"""
1321
+ Returns:
1322
+ text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`):
1323
+ The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that
1324
+ contains the language model logits, the past key values and the hidden states if
1325
+ `output_hidden_states=True`.
1326
+ ```"""
1327
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1328
+ output_hidden_states = (
1329
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1330
+ )
1331
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1332
+
1333
+ text_outputs = self.language_model(
1334
+ input_ids=input_ids,
1335
+ attention_mask=attention_mask,
1336
+ output_attentions=output_attentions,
1337
+ output_hidden_states=output_hidden_states,
1338
+ return_dict=return_dict,
1339
+ )
1340
+
1341
+ return text_outputs
1342
+
1343
+ @add_start_docstrings_to_model_forward(Husky_VISION_INPUTS_DOCSTRING)
1344
+ def get_image_features(
1345
+ self,
1346
+ pixel_values: Optional[torch.FloatTensor] = None,
1347
+ output_attentions: Optional[bool] = None,
1348
+ output_hidden_states: Optional[bool] = None,
1349
+ return_dict: Optional[bool] = None,
1350
+ ):
1351
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1352
+ output_hidden_states = (
1353
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1354
+ )
1355
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1356
+
1357
+ vision_outputs = self.vision_model(
1358
+ pixel_values=pixel_values,
1359
+ output_attentions=output_attentions,
1360
+ output_hidden_states=output_hidden_states,
1361
+ return_dict=return_dict,
1362
+ )
1363
+
1364
+ return vision_outputs
1365
+
1366
+ @add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
1367
+ def get_qformer_features(
1368
+ self,
1369
+ pixel_values: Optional[torch.FloatTensor] = None,
1370
+ output_attentions: Optional[bool] = None,
1371
+ output_hidden_states: Optional[bool] = None,
1372
+ return_dict: Optional[bool] = None,
1373
+ ):
1374
+ r"""
1375
+ Returns:
1376
+ vision_outputs (`BaseModelOutputWithPooling` or tuple of `torch.FloatTensor`):
1377
+ The vision model outputs. If `return_dict=True`, the output is a [`BaseModelOutputWithPooling`] that
1378
+ contains the image features, the pooled image features and the hidden states if
1379
+ `output_hidden_states=True`.
1380
+ """
1381
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1382
+ output_hidden_states = (
1383
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1384
+ )
1385
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1386
+
1387
+ vision_outputs = self.vision_model(
1388
+ pixel_values=pixel_values,
1389
+ output_attentions=output_attentions,
1390
+ output_hidden_states=output_hidden_states,
1391
+ return_dict=return_dict,
1392
+ )
1393
+
1394
+ image_embeds = vision_outputs[0]
1395
+
1396
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1397
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1398
+
1399
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1400
+ query_outputs = self.qformer(
1401
+ query_embeds=query_tokens,
1402
+ encoder_hidden_states=image_embeds,
1403
+ encoder_attention_mask=image_attention_mask,
1404
+ output_attentions=output_attentions,
1405
+ output_hidden_states=output_hidden_states,
1406
+ return_dict=return_dict,
1407
+ )
1408
+
1409
+ return query_outputs
1410
+
1411
+ @add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
1412
+ # @replace_return_docstrings(output_type=HuskyForConditionalGenerationModelOutput, config_class=HuskyVisionConfig)
1413
+ def forward(
1414
+ self,
1415
+ pixel_values: torch.FloatTensor,
1416
+ input_ids: torch.FloatTensor,
1417
+ attention_mask: Optional[torch.LongTensor] = None,
1418
+ output_attentions: Optional[bool] = None,
1419
+ output_hidden_states: Optional[bool] = None,
1420
+ labels: Optional[torch.LongTensor] = None,
1421
+ return_dict: Optional[bool] = None,
1422
+ ) -> Union[Tuple, HuskyForConditionalGenerationModelOutput]:
1423
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1424
+
1425
+ # step 1: forward the images through the vision encoder,
1426
+ # to get image embeddings of shape (batch_size, seq_len, hidden_size)
1427
+ vision_outputs = self.vision_model(
1428
+ pixel_values=pixel_values,
1429
+ output_attentions=output_attentions,
1430
+ output_hidden_states=output_hidden_states,
1431
+ return_dict=return_dict,
1432
+ )
1433
+ image_embeds = vision_outputs[0]
1434
+
1435
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1436
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1437
+
1438
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1439
+ query_outputs = self.qformer(
1440
+ query_embeds=query_tokens,
1441
+ encoder_hidden_states=image_embeds,
1442
+ encoder_attention_mask=image_attention_mask,
1443
+ output_attentions=output_attentions,
1444
+ output_hidden_states=output_hidden_states,
1445
+ return_dict=return_dict,
1446
+ )
1447
+ query_output = query_outputs[0]
1448
+
1449
+ # step 3: use the language model, conditioned on the query outputs and the prompt
1450
+ language_model_inputs = self.language_projection(query_output)
1451
+ assert language_model_inputs.shape[1] == self.num_queries
1452
+
1453
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
1454
+ # Human: <img><IMAGE></img>. Give the describe Assistant:
1455
+ # position of <image>: [offset: offset+num_queries]
1456
+
1457
+ inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs
1458
+ if attention_mask is None:
1459
+ attention_mask = torch.ones_like(
1460
+ input_ids, dtype=torch.long, device=language_model_inputs.device)
1461
+
1462
+ outputs = self.language_model(
1463
+ inputs_embeds=inputs_embeds,
1464
+ attention_mask=attention_mask,
1465
+ output_attentions=output_attentions,
1466
+ output_hidden_states=output_hidden_states,
1467
+ return_dict=return_dict,
1468
+ )
1469
+ logits = outputs.logits if return_dict else outputs[0]
1470
+ loss = None
1471
+ # we compute the loss here since we need to take into account the sequence length of the query embeds
1472
+ if labels is not None:
1473
+ labels = labels.to(logits.device)
1474
+ logits = logits[:, -labels.size(1):, :]
1475
+ # Shift so that tokens < n predict n
1476
+ shift_logits = logits[..., :-1, :].contiguous()
1477
+ shift_labels = labels[..., 1:].contiguous().to(logits.device)
1478
+
1479
+ # Flatten the tokens
1480
+ loss_fct = CrossEntropyLoss(reduction="mean")
1481
+
1482
+ loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
1483
+
1484
+ if not return_dict:
1485
+ output = (logits, vision_outputs, query_outputs, outputs)
1486
+ return ((loss,) + output) if loss is not None else output
1487
+
1488
+ return HuskyForConditionalGenerationModelOutput(
1489
+ loss=loss,
1490
+ logits=logits,
1491
+ vision_outputs=vision_outputs,
1492
+ qformer_outputs=query_outputs,
1493
+ language_model_outputs=outputs,
1494
+ )
1495
+
1496
+ @add_start_docstrings(
1497
+ """
1498
+ Husky Model for generating text given an image and an optional text prompt. The model consists of a vision
1499
+ encoder, Querying Transformer (Q-Former) and a language model.
1500
+
1501
+ One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
1502
+ the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
1503
+ """,
1504
+ Husky_START_DOCSTRING,
1505
+ )
1506
+ class HuskyForConditionalGeneration(HuskyPreTrainedModel):
1507
+ config_class = HuskyConfig
1508
+ main_input_name = "pixel_values"
1509
+
1510
+ def __init__(self, config: HuskyConfig):
1511
+ super().__init__(config)
1512
+
1513
+ self.vision_model = HuskyVisionModel(config.vision_config)
1514
+ self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
1515
+ self.qformer = HuskyQFormerModel(config.qformer_config)
1516
+
1517
+ self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
1518
+ self.language_model = AutoModelForCausalLM.from_config(config.text_config)
1519
+
1520
+ self.config.hidden_size = config.text_config.hidden_size
1521
+ self.num_queries = config.num_query_tokens
1522
+ self.offset = 5
1523
+
1524
+ self.vision_adapter = AdapterMLP(config)
1525
+ self.layer_norms = nn.ModuleList()
1526
+ for i in range(4):
1527
+ self.layer_norms.append(
1528
+ nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
1529
+ )
1530
+
1531
+ # Initialize weights and apply final processing
1532
+ self.post_init()
1533
+
1534
+ def get_input_embeddings(self):
1535
+ return self.language_model.get_input_embeddings()
1536
+
1537
+ def set_input_embeddings(self, value):
1538
+ self.language_model.set_input_embeddings(value)
1539
+
1540
+ def set_output_embeddings(self, new_embeddings):
1541
+ self.language_model.set_output_embeddings(new_embeddings)
1542
+
1543
+ def get_output_embeddings(self) -> nn.Module:
1544
+ return self.language_model.get_output_embeddings()
1545
+
1546
+ def get_encoder(self):
1547
+ return self.language_model.get_encoder()
1548
+
1549
+ def get_decoder(self):
1550
+ return self.language_model.get_decoder()
1551
+
1552
+ def extract_feature(
1553
+ self,
1554
+ pixel_values: torch.FloatTensor,
1555
+ ):
1556
+ vision_outputs = self.vision_model(
1557
+ pixel_values=pixel_values,
1558
+ output_hidden_states=True,
1559
+ )
1560
+ image_embeds = vision_outputs[0]
1561
+
1562
+ depth = len(vision_outputs[2])
1563
+ indices = range(depth // 4 - 1, depth, depth // 4)
1564
+ pooled_outputs = []
1565
+ for idx, layer_norm in zip(indices, self.layer_norms):
1566
+ pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
1567
+ pool_output = layer_norm(pool_output)
1568
+ pooled_outputs.append(pool_output)
1569
+
1570
+ pooled_outputs = torch.cat(pooled_outputs, dim=1)
1571
+ pooled_outputs = self.vision_adapter(pooled_outputs)
1572
+
1573
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1574
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1575
+
1576
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1577
+ query_outputs = self.qformer(
1578
+ query_embeds=query_tokens,
1579
+ encoder_hidden_states=image_embeds,
1580
+ encoder_attention_mask=image_attention_mask
1581
+ )
1582
+ query_output = query_outputs[0]
1583
+ # soft_prompting
1584
+ query_output = torch.cat([pooled_outputs, query_output], dim=1)
1585
+ language_model_inputs = self.language_projection(query_output)
1586
+
1587
+ return language_model_inputs
1588
+
1589
+ def _tie_weights(self):
1590
+ if not self.config.use_decoder_only_language_model:
1591
+ self.language_model.encoder.embed_tokens = self.language_model.shared
1592
+ self.language_model.decoder.embed_tokens = self.language_model.shared
1593
+
1594
+ def _preprocess_accelerate(self):
1595
+ r"""
1596
+ Some pre-processing hacks to make the model `accelerate` compatible. Check
1597
+ https://github.com/huggingface/transformers/pull/21707 for more details.
1598
+ """
1599
+ hf_device_map = self.hf_device_map
1600
+
1601
+ if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
1602
+ # warn users about unexpected behavior when using multi-GPU + Husky + `accelerate`.
1603
+ logger.warning(
1604
+ "The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
1605
+ " in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
1606
+ " Please pass a `device_map` that contains `language_model` to remove this warning."
1607
+ " Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
1608
+ " more details on creating a `device_map` for large models.",
1609
+ )
1610
+
1611
+ if hasattr(self.language_model, "_hf_hook"):
1612
+ self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
1613
+
1614
+ @add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
1615
+ # @replace_return_docstrings(output_type=HuskyForConditionalGenerationModelOutput, config_class=HuskyVisionConfig)
1616
+ def forward(
1617
+ self,
1618
+ pixel_values: Optional[torch.FloatTensor] = None,
1619
+ input_ids: Optional[torch.FloatTensor] = None,
1620
+ attention_mask: Optional[torch.LongTensor] = None,
1621
+ output_attentions: Optional[bool] = None,
1622
+ output_hidden_states: Optional[bool] = None,
1623
+ labels: Optional[torch.LongTensor] = None,
1624
+ return_dict: Optional[bool] = None,
1625
+ pixel_embeds: Optional[torch.FloatTensor] = None,
1626
+ ) -> Union[Tuple, HuskyForConditionalGenerationModelOutput]:
1627
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1628
+
1629
+ # step 1: forward the images through the vision encoder,
1630
+ # to get image embeddings of shape (batch_size, seq_len, hidden_size)
1631
+ batch_size = input_ids.shape[0]
1632
+ vision_outputs = self.vision_model(
1633
+ pixel_values=pixel_values,
1634
+ output_attentions=output_attentions,
1635
+ output_hidden_states=True,
1636
+ return_dict=return_dict,
1637
+ pixel_embeds=pixel_embeds,
1638
+ )
1639
+ image_embeds = vision_outputs[0]
1640
+ depth = len(vision_outputs[2])
1641
+ indices = range(depth // 4 - 1, depth, depth // 4)
1642
+ pooled_outputs = []
1643
+ for idx, layer_norm in zip(indices, self.layer_norms):
1644
+ pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
1645
+ pool_output = layer_norm(pool_output)
1646
+ pooled_outputs.append(pool_output)
1647
+
1648
+ pooled_outputs = torch.cat(pooled_outputs, dim=1)
1649
+ pooled_outputs = self.vision_adapter(pooled_outputs)
1650
+
1651
+ # step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
1652
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1653
+
1654
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1655
+ query_outputs = self.qformer(
1656
+ query_embeds=query_tokens,
1657
+ encoder_hidden_states=image_embeds,
1658
+ encoder_attention_mask=image_attention_mask,
1659
+ output_attentions=output_attentions,
1660
+ output_hidden_states=output_hidden_states,
1661
+ return_dict=return_dict,
1662
+ )
1663
+ query_output = query_outputs[0]
1664
+ query_output = torch.cat([pooled_outputs, query_output], dim=1) # 36 token
1665
+
1666
+ # step 3: use the language model, conditioned on the query outputs and the prompt
1667
+ language_model_inputs = self.language_projection(query_output)
1668
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
1669
+
1670
+ # Human: <img></img>. Give the describe Assistant:
1671
+ # position of <image>: [offset: offset+num_queries]
1672
+ prefix_embeds = inputs_embeds[:, :self.offset, :]
1673
+ postfix_embeds = inputs_embeds[:, self.offset:, :]
1674
+ inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)
1675
+ if attention_mask is None:
1676
+ attention_mask = torch.ones_like(
1677
+ inputs_embeds, dtype=torch.long, device=language_model_inputs.device)
1678
+ else:
1679
+ prefix_mask = attention_mask[:, :self.offset]
1680
+ postfix_mask = attention_mask[:, self.offset:]
1681
+ vision_mask = torch.ones(size=(batch_size, self.num_queries + 4), dtype=torch.long,
1682
+ device=attention_mask.device)
1683
+ attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)
1684
+
1685
+ outputs = self.language_model(
1686
+ inputs_embeds=inputs_embeds,
1687
+ attention_mask=attention_mask,
1688
+ output_attentions=output_attentions,
1689
+ output_hidden_states=output_hidden_states,
1690
+ return_dict=return_dict,
1691
+ )
1692
+ logits = outputs.logits if return_dict else outputs[0]
1693
+ loss = None
1694
+ # we compute the loss here since we need to take into account the sequence length of the query embeds
1695
+ if labels is not None:
1696
+ labels = labels.to(logits.device)
1697
+ logits = logits[:, -labels.size(1):, :]
1698
+ # Shift so that tokens < n predict n
1699
+ shift_logits = logits[..., :-1, :].contiguous()
1700
+ shift_labels = labels[..., 1:].contiguous().to(logits.device)
1701
+
1702
+ # Flatten the tokens
1703
+ loss_fct = CrossEntropyLoss(reduction="mean")
1704
+
1705
+ loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
1706
+
1707
+ if not return_dict:
1708
+ output = (logits, vision_outputs, query_outputs, outputs)
1709
+ return ((loss,) + output) if loss is not None else output
1710
+
1711
+ return HuskyForConditionalGenerationModelOutput(
1712
+ loss=loss,
1713
+ logits=logits,
1714
+ vision_outputs=vision_outputs,
1715
+ qformer_outputs=query_outputs,
1716
+ language_model_outputs=outputs,
1717
+ )
1718
+
1719
+ @torch.no_grad()
1720
+ def generate(
1721
+ self,
1722
+ pixel_values: Optional[torch.FloatTensor] = None,
1723
+ input_ids: Optional[torch.LongTensor] = None,
1724
+ attention_mask: Optional[torch.LongTensor] = None,
1725
+ language_model_inputs: Optional[torch.FloatTensor] = None,
1726
+ generation_config: Optional[GenerationConfig] = None,
1727
+ **generate_kwargs,
1728
+ ) -> torch.LongTensor:
1729
+ """
1730
+ Overrides `generate` function to be able to use the model as a conditional generator.
1731
+
1732
+ Args:
1733
+ pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
1734
+ Input images to be processed.
1735
+ input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
1736
+ The sequence used as a prompt for the generation.
1737
+ attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
1738
+ Mask to avoid performing attention on padding token indices
1739
+ language_model_inputs (`torch.LongTensor` of shape (batch_size, sequence_length, num_channel), *optional*):
1740
+ The sequence used as the input for the generation
1741
+ language_model_inputs (`torch.LongTensor` of shape (batch_size, sequence_length, num_channel), *optional*):
1742
+ The sequence used as the input for the generation
1743
+ generation_config (`~generation.GenerationConfig`, *optional*):
1744
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
1745
+ passed to generate matching the attributes of `generation_config` will override them. If
1746
+ `generation_config` is not provided, the default will be used, which had the following loading
1747
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
1748
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
1749
+ default values, whose documentation should be checked to parameterize generation.
1750
+
1751
+ Returns:
1752
+ captions (list): A list of strings of length batch_size * num_captions.
1753
+ """
1754
+
1755
+ if hasattr(self, "hf_device_map"):
1756
+ # preprocess for `accelerate`
1757
+ self._preprocess_accelerate()
1758
+
1759
+ if language_model_inputs is None:
1760
+ vision_outputs = self.vision_model(
1761
+ pixel_values=pixel_values,
1762
+ output_hidden_states=True,
1763
+ )
1764
+ image_embeds = vision_outputs[0]
1765
+
1766
+ depth = len(vision_outputs[2])
1767
+ indices = range(depth // 4 - 1, depth, depth // 4)
1768
+ pooled_outputs = []
1769
+ for idx, layer_norm in zip(indices, self.layer_norms):
1770
+ pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
1771
+ pool_output = layer_norm(pool_output)
1772
+ pooled_outputs.append(pool_output)
1773
+
1774
+ pooled_outputs = torch.cat(pooled_outputs, dim=1)
1775
+ pooled_outputs = self.vision_adapter(pooled_outputs)
1776
+
1777
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
1778
+
1779
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
1780
+ query_outputs = self.qformer(
1781
+ query_embeds=query_tokens,
1782
+ encoder_hidden_states=image_embeds,
1783
+ encoder_attention_mask=image_attention_mask,
1784
+ )
1785
+ query_output = query_outputs[0]
1786
+ query_output = torch.cat([pooled_outputs, query_output], dim=1)
1787
+
1788
+ language_model_inputs = self.language_projection(query_output)
1789
+
1790
+ batch_size = language_model_inputs.shape[0]
1791
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
1792
+
1793
+ prefix_embeds = inputs_embeds[:, :self.offset, :]
1794
+ postfix_embeds = inputs_embeds[:, self.offset:, :]
1795
+ inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)
1796
+
1797
+ if input_ids is None:
1798
+ input_ids = (
1799
+ torch.LongTensor([[self.config.text_config.bos_token_id]])
1800
+ .repeat(batch_size, 1)
1801
+ .to(inputs_embeds.device)
1802
+ )
1803
+
1804
+ if attention_mask is None:
1805
+ attention_mask = torch.ones_like(
1806
+ input_ids, dtype=torch.long, device=language_model_inputs.device)
1807
+ else:
1808
+ prefix_mask = attention_mask[:, :self.offset]
1809
+ postfix_mask = attention_mask[:, self.offset:]
1810
+ vision_mask = torch.ones(size=(batch_size, self.num_queries + 4), dtype=torch.long,
1811
+ device=attention_mask.device)
1812
+ attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)
1813
+
1814
+ outputs = self.language_model.generate(
1815
+ inputs_embeds=inputs_embeds,
1816
+ attention_mask=attention_mask,
1817
+ generation_config=generation_config,
1818
+ **generate_kwargs,
1819
+ )
1820
+
1821
+ return outputs
robohusky/model/processing_husky.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for Husky. Largely copy of Blip2Processor with addition of a tokenizer for the Q-Former.
17
+ """
18
+
19
+ from typing import List, Optional, Union
20
+
21
+ from transformers.processing_utils import ProcessorMixin
22
+ from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, \
23
+ TruncationStrategy
24
+ from transformers.utils import TensorType
25
+ from transformers.models.auto import AutoTokenizer
26
+
27
+
28
+ class HuskyProcessor(ProcessorMixin):
29
+ r"""
30
+ Constructs an Husky processor which wraps a BLIP image processor and a LLaMa/T5 tokenizer into a single
31
+ processor.
32
+
33
+ [`HuskyProcessor`] offers all the functionalities of [`BlipImageProcessor`] and [`AutoTokenizer`]. See the
34
+ docstring of [`~BlipProcessor.__call__`] and [`~BlipProcessor.decode`] for more information.
35
+
36
+ Args:
37
+ image_processor (`BlipImageProcessor`):
38
+ An instance of [`BlipImageProcessor`]. The image processor is a required input.
39
+ tokenizer (`AutoTokenizer`):
40
+ An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
41
+ """
42
+ attributes = ["image_processor", "tokenizer"]
43
+ image_processor_class = "BlipImageProcessor"
44
+ tokenizer_class = "AutoTokenizer"
45
+
46
+ def __init__(self, image_processor, tokenizer):
47
+ super().__init__(image_processor, tokenizer)
48
+ self.current_processor = self.image_processor
49
+
50
+ # add QFormer tokenizer
51
+ self.qformer_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", truncation_side="left")
52
+ self.qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"})
53
+
54
+ def __call__(
55
+ self,
56
+ images=None,
57
+ text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
58
+ add_special_tokens: bool = True,
59
+ padding: Union[bool, str, PaddingStrategy] = False,
60
+ truncation: Union[bool, str, TruncationStrategy] = None,
61
+ max_length: Optional[int] = None,
62
+ stride: int = 0,
63
+ pad_to_multiple_of: Optional[int] = None,
64
+ return_attention_mask: Optional[bool] = None,
65
+ return_overflowing_tokens: bool = False,
66
+ return_special_tokens_mask: bool = False,
67
+ return_offsets_mapping: bool = False,
68
+ return_token_type_ids: bool = False,
69
+ return_length: bool = False,
70
+ verbose: bool = True,
71
+ return_tensors: Optional[Union[str, TensorType]] = None,
72
+ **kwargs,
73
+ ) -> BatchEncoding:
74
+ """
75
+ This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and
76
+ [`BertTokenizerFast.__call__`] to prepare text for the model.
77
+
78
+ Please refer to the docstring of the above two methods for more information.
79
+ """
80
+ if images is None and text is None:
81
+ raise ValueError("You have to specify either images or text.")
82
+
83
+ # Get only text
84
+ if images is None:
85
+ self.current_processor = self.tokenizer
86
+ text_encoding = self.tokenizer(
87
+ text=text,
88
+ add_special_tokens=add_special_tokens,
89
+ padding=padding,
90
+ truncation=truncation,
91
+ max_length=max_length,
92
+ stride=stride,
93
+ pad_to_multiple_of=pad_to_multiple_of,
94
+ return_attention_mask=return_attention_mask,
95
+ return_overflowing_tokens=return_overflowing_tokens,
96
+ return_special_tokens_mask=return_special_tokens_mask,
97
+ return_offsets_mapping=return_offsets_mapping,
98
+ return_token_type_ids=return_token_type_ids,
99
+ return_length=return_length,
100
+ verbose=verbose,
101
+ return_tensors=return_tensors,
102
+ **kwargs,
103
+ )
104
+ return text_encoding
105
+
106
+ # add pixel_values
107
+ encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
108
+
109
+ if text is not None:
110
+ text_encoding = self.tokenizer(
111
+ text=text,
112
+ add_special_tokens=add_special_tokens,
113
+ padding=padding,
114
+ truncation=truncation,
115
+ max_length=max_length,
116
+ stride=stride,
117
+ pad_to_multiple_of=pad_to_multiple_of,
118
+ return_attention_mask=return_attention_mask,
119
+ return_overflowing_tokens=return_overflowing_tokens,
120
+ return_special_tokens_mask=return_special_tokens_mask,
121
+ return_offsets_mapping=return_offsets_mapping,
122
+ return_token_type_ids=return_token_type_ids,
123
+ return_length=return_length,
124
+ verbose=verbose,
125
+ return_tensors=return_tensors,
126
+ **kwargs,
127
+ )
128
+ qformer_text_encoding = self.qformer_tokenizer(
129
+ text=text,
130
+ add_special_tokens=add_special_tokens,
131
+ padding=padding,
132
+ truncation=truncation,
133
+ max_length=max_length,
134
+ stride=stride,
135
+ pad_to_multiple_of=pad_to_multiple_of,
136
+ return_attention_mask=return_attention_mask,
137
+ return_overflowing_tokens=return_overflowing_tokens,
138
+ return_special_tokens_mask=return_special_tokens_mask,
139
+ return_offsets_mapping=return_offsets_mapping,
140
+ return_token_type_ids=return_token_type_ids,
141
+ return_length=return_length,
142
+ verbose=verbose,
143
+ return_tensors=return_tensors,
144
+ **kwargs,
145
+ )
146
+ qformer_text_encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
147
+ qformer_text_encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
148
+ text_encoding.update(qformer_text_encoding)
149
+ else:
150
+ text_encoding = None
151
+
152
+ if text_encoding is not None:
153
+ encoding_image_processor.update(text_encoding)
154
+
155
+ return encoding_image_processor
156
+
157
+ # Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
158
+ def batch_decode(self, *args, **kwargs):
159
+ """
160
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
161
+ refer to the docstring of this method for more information.
162
+ """
163
+ return self.tokenizer.batch_decode(*args, **kwargs)
164
+
165
+ # Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer
166
+ def decode(self, *args, **kwargs):
167
+ """
168
+ This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
169
+ to the docstring of this method for more information.
170
+ """
171
+ return self.tokenizer.decode(*args, **kwargs)
172
+
173
+ @property
174
+ # Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
175
+ def model_input_names(self):
176
+ tokenizer_input_names = self.tokenizer.model_input_names
177
+ image_processor_input_names = self.image_processor.model_input_names
178
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
robohusky/train/.DS_Store ADDED
Binary file (6.15 kB). View file
 
robohusky/train/llama_flash_attn_monkey_patch.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ from flash_attn import __version__ as flash_attn_version
6
+ from flash_attn.bert_padding import pad_input, unpad_input
7
+ from flash_attn.flash_attn_interface import (
8
+ flash_attn_func,
9
+ flash_attn_varlen_kvpacked_func,
10
+ )
11
+ from transformers.models.llama.modeling_llama import (
12
+ LlamaAttention,
13
+ LlamaModel,
14
+ rotate_half,
15
+ )
16
+
17
+ def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
18
+ gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
19
+ gather_indices = gather_indices.repeat(
20
+ 1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
21
+ )
22
+ bsz = gather_indices.shape[0]
23
+ cos, sin = (
24
+ torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
25
+ for x in cos_sin
26
+ )
27
+ q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
28
+ return q, k
29
+
30
+ def forward(
31
+ self,
32
+ hidden_states: torch.Tensor,
33
+ attention_mask: Optional[torch.Tensor] = None,
34
+ position_ids: Optional[torch.Tensor] = None,
35
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
36
+ output_attentions: bool = False,
37
+ use_cache: bool = False,
38
+ padding_mask: Optional[torch.Tensor] = None,
39
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
40
+ if output_attentions:
41
+ warnings.warn(
42
+ "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
43
+ )
44
+
45
+ bsz, q_len, _ = hidden_states.size()
46
+ kv_heads = getattr(self, "num_key_value_heads", self.num_heads)
47
+
48
+ q, k, v = (
49
+ op(hidden_states).view(bsz, q_len, nh, self.head_dim)
50
+ for op, nh in (
51
+ (self.q_proj, self.num_heads),
52
+ (self.k_proj, kv_heads),
53
+ (self.v_proj, kv_heads),
54
+ )
55
+ )
56
+ # shape: (b, s, num_heads, head_dim)
57
+
58
+ kv_seq_len = k.shape[1]
59
+ past_kv_len = 0
60
+ if past_key_value is not None:
61
+ past_kv_len = past_key_value[0].shape[2]
62
+ kv_seq_len += past_kv_len
63
+
64
+ cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
65
+ q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)
66
+
67
+ if past_key_value is not None:
68
+ assert (
69
+ flash_attn_version >= "2.1.0"
70
+ ), "past_key_value support requires flash-attn >= 2.1.0"
71
+ # reuse k, v
72
+ k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
73
+ v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
74
+
75
+ past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
76
+
77
+ if attention_mask is None:
78
+ output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
79
+ bsz, q_len, -1
80
+ )
81
+ else:
82
+ q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
83
+ # We can skip concat and call unpad twice but seems better to call unpad only once.
84
+ kv, _, cu_k_lens, max_k = unpad_input(
85
+ torch.stack((k, v), dim=2), attention_mask
86
+ )
87
+ output_unpad = flash_attn_varlen_kvpacked_func(
88
+ q,
89
+ kv,
90
+ cu_q_lens,
91
+ cu_k_lens,
92
+ max_s,
93
+ max_k,
94
+ 0.0,
95
+ softmax_scale=None,
96
+ causal=True,
97
+ )
98
+ output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
99
+ output = pad_input(output_unpad, indices, bsz, q_len)
100
+
101
+ return self.o_proj(output), None, past_key_value
102
+
103
+ # Disable the transformation of the attention mask in LlamaModel as flash attention
104
+ # takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
105
+ def _prepare_decoder_attention_mask(
106
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
107
+ ):
108
+ # [bsz, seq_len]
109
+ if past_key_values_length > 0 and attention_mask is not None:
110
+ attention_mask = torch.cat(
111
+ (
112
+ torch.full(
113
+ (input_shape[0], past_key_values_length),
114
+ True,
115
+ dtype=attention_mask.dtype,
116
+ device=attention_mask.device,
117
+ ),
118
+ attention_mask,
119
+ ),
120
+ dim=-1,
121
+ )
122
+
123
+ if attention_mask is not None and torch.all(attention_mask):
124
+ return None # This uses the faster call when training with full samples
125
+
126
+ return attention_mask
127
+
128
+ def replace_llama_attn_with_flash_attn():
129
+ cuda_major, cuda_minor = torch.cuda.get_device_capability()
130
+ if cuda_major < 8:
131
+ warnings.warn(
132
+ "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
133
+ "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
134
+ )
135
+
136
+ LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
137
+ LlamaAttention.forward = forward
138
+
139
+ def test():
140
+ from robohusky.train.llama_flash_attn_monkey_patch import forward as fastchat_forward
141
+ from transformers.models.llama.configuration_llama import LlamaConfig
142
+
143
+ config = LlamaConfig(
144
+ hidden_size=1024,
145
+ intermediate_size=128,
146
+ num_hidden_layers=1,
147
+ num_attention_heads=8,
148
+ max_position_embeddings=16,
149
+ )
150
+ device = torch.device("cuda")
151
+ model = LlamaModel(config)
152
+ attn = LlamaAttention(config).to(device).half()
153
+ bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings
154
+ position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view(
155
+ -1, seqlen
156
+ )
157
+
158
+ mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
159
+ for i in range(4):
160
+ hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
161
+ if i:
162
+ mask[0, -i:] = False
163
+ mask[1, :i] = False
164
+
165
+ lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0)
166
+ ref, _, _ = attn.forward(
167
+ hidden, attention_mask=lmask, position_ids=position_ids
168
+ )
169
+
170
+ fast, _, _ = fastchat_forward(
171
+ attn, hidden, attention_mask=mask, position_ids=position_ids
172
+ )
173
+
174
+ lmask = _prepare_decoder_attention_mask(
175
+ model, mask, hidden.shape[:2], hidden, 0
176
+ )
177
+ test, _, _ = forward(
178
+ attn, hidden, attention_mask=lmask, position_ids=position_ids
179
+ )
180
+
181
+ print(f"Mean(abs(ref)) = {torch.mean(torch.abs(ref))}")
182
+ print(f"Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}")
183
+ print(f"Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}")
184
+ print(f"Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}")
185
+ print(f"allclose(fast, test) = {torch.allclose(fast, test)}")
186
+
187
+ with torch.no_grad():
188
+ # Also check that past_kv is handled properly
189
+ hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
190
+ part_len = seqlen // 4
191
+ assert part_len * 4 == seqlen
192
+ mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
193
+ mask[0, -2:] = False
194
+ lmask = _prepare_decoder_attention_mask(
195
+ model, mask, hidden.shape[:2], hidden, 0
196
+ )
197
+ oneshot, _, _ = forward(
198
+ attn, hidden, attention_mask=lmask, position_ids=position_ids
199
+ )
200
+ parts = []
201
+ past_kv, past_kv_len = None, 0
202
+ for i in range(4):
203
+ start = part_len * i
204
+ end = start + part_len
205
+ hidden_part = hidden[:, start:end, ...]
206
+ lmask = _prepare_decoder_attention_mask(
207
+ model,
208
+ mask[:, start:end],
209
+ hidden_part.shape[:2],
210
+ hidden_part,
211
+ past_kv_len,
212
+ )
213
+ part, _, past_kv = forward(
214
+ attn,
215
+ hidden_part.clone(),
216
+ attention_mask=lmask,
217
+ position_ids=position_ids[:, start:end],
218
+ past_key_value=past_kv,
219
+ use_cache=True,
220
+ )
221
+ parts.append(part)
222
+ past_kv_len = past_kv[0].shape[2]
223
+
224
+ print(
225
+ f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}"
226
+ )
227
+ print(
228
+ f"allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}"
229
+ )
230
+
231
+ if __name__ == "__main__":
232
+ test()
robohusky/train/llama_rmsnorm_monkey_patch.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+
3
+ def replace_llama_rmsnorm_with_fused_rmsnorm():
4
+ try:
5
+ from apex.normalization import FusedRMSNorm
6
+ from functools import partial
7
+ LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
8
+ transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
9
+ print("Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm")
10
+ except ImportError:
11
+ # using the normal LlamaRMSNorm
12
+ pass
13
+ except Exception:
14
+ print("discovered apex but it failed to load, falling back to LlamaRMSNorm")
15
+ pass
robohusky/train/train.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright Qing-Long Zhang. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for sequence to sequence.
18
+ """
19
+ import json
20
+ import logging
21
+ import os
22
+ import sys
23
+ import warnings
24
+ from functools import partial
25
+
26
+ from multiprocessing import cpu_count
27
+
28
+ from typing import Optional
29
+ from dataclasses import dataclass, field
30
+
31
+ from torch.utils.data import Dataset, ConcatDataset
32
+ from datasets import load_dataset, load_from_disk
33
+
34
+ from robohusky.dist_utils import init_dist
35
+ from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
36
+
37
+ import transformers
38
+ from transformers import (
39
+ HfArgumentParser,
40
+ TrainingArguments,
41
+ LlamaTokenizer,
42
+ Trainer,
43
+ set_seed,
44
+ default_data_collator,
45
+ DataCollatorForSeq2Seq,
46
+ )
47
+
48
+ from peft import (
49
+ LoraConfig,
50
+ get_peft_model,
51
+ prepare_model_for_int8_training,
52
+ )
53
+
54
+ from robohusky.base_dataset import (
55
+ process_func,
56
+ BaseDataset,
57
+ CephDataset,
58
+ build_transform
59
+ )
60
+
61
+ from transformers.trainer_utils import get_last_checkpoint
62
+ from transformers.utils import check_min_version
63
+ from transformers.utils.versions import require_version
64
+
65
+ from transformers.utils.logging import (
66
+ set_verbosity_info,
67
+ set_verbosity,
68
+ enable_default_handler,
69
+ enable_explicit_format,
70
+ )
71
+ from robohusky.train.llama_flash_attn_monkey_patch import (
72
+ replace_llama_attn_with_flash_attn
73
+ )
74
+
75
+ from robohusky.train.llama_rmsnorm_monkey_patch import (
76
+ replace_llama_rmsnorm_with_fused_rmsnorm
77
+ )
78
+
79
+ replace_llama_attn_with_flash_attn()
80
+ replace_llama_rmsnorm_with_fused_rmsnorm()
81
+
82
+ IGNORE_INDEX = -100
83
+ DEFAULT_UNK_TOKEN = "<unk>"
84
+ DEFAULT_IMG_START_TOKEN = "<img>"
85
+ DEFAULT_IMG_END_TOKEN = "</img>"
86
+
87
+ DEFAULT_VIDEO_START_TOKEN = "<vid>"
88
+ DEFAULT_VIDEO_END_TOKEN = "</vid>"
89
+
90
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
91
+ check_min_version("4.32.0.dev0")
92
+ require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
93
+
94
+ warnings.filterwarnings('ignore')
95
+ logger = logging.getLogger(__name__)
96
+
97
+ os.environ["WANDB_DISABLED"] = "true"
98
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
99
+
100
+ @dataclass
101
+ class ModelArguments:
102
+ """
103
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
104
+ """
105
+
106
+ model_name_or_path: str = field(
107
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
108
+ )
109
+ config_name: Optional[str] = field(
110
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
111
+ )
112
+ tokenizer_name: Optional[str] = field(
113
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
114
+ )
115
+ cache_dir: Optional[str] = field(
116
+ default=None,
117
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
118
+ )
119
+ use_fast_tokenizer: bool = field(
120
+ default=False,
121
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
122
+ )
123
+ model_revision: str = field(
124
+ default="main",
125
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
126
+ )
127
+ use_auth_token: bool = field(
128
+ default=False,
129
+ metadata={
130
+ "help": (
131
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
132
+ "with private models)."
133
+ )
134
+ },
135
+ )
136
+ freeze_model: bool = field(
137
+ default=False,
138
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
139
+ )
140
+ freeze_vision_model: bool = field(
141
+ default=False,
142
+ metadata={"help": "Will enable to load a pretrained vision model whose head dimensions are different."},
143
+ )
144
+ freeze_vision_adapter: bool = field(
145
+ default=False,
146
+ metadata={"help": "Will enable to load a pretrained vision adapter whose head dimensions are different."},
147
+ )
148
+ freeze_text_model: bool = field(
149
+ default=False,
150
+ metadata={"help": "Will enable to load a pretrained text model whose head dimensions are different."},
151
+ )
152
+ freeze_qformer: bool = field(
153
+ default=False,
154
+ metadata={"help": "Will enable to load a pretrained qformer model whose head dimensions are different."},
155
+ )
156
+ un_freeze_vision_embedding: bool = field(
157
+ default=False,
158
+ metadata={"help": "Will enable to tuning image patch_embedding when vision_model are frozen"},
159
+ )
160
+ un_freeze_video_embedding: bool = field(
161
+ default=False,
162
+ metadata={"help": "Will enable to tuning video patch_embedding when vision_model are frozen"},
163
+ )
164
+ un_freeze_llm_head: bool = field(
165
+ default=False,
166
+ metadata={"help": "Will enable to tuning video patch_embedding when vision_model are frozen"},
167
+ )
168
+ use_lora: bool = field(
169
+ default=False, metadata={"help": "add the LoRA adapters to the base model"}
170
+ )
171
+
172
+ @dataclass
173
+ class DataTrainingArguments:
174
+ """
175
+ Arguments pertaining to what data we are going to input our model for training and eval.
176
+ """
177
+
178
+ dataset_name: Optional[str] = field(
179
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
180
+ )
181
+ dataset_config_name: Optional[str] = field(
182
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
183
+ )
184
+ data_dir: Optional[str] = field(
185
+ default=None, metadata={"help": "The data directory containing input files."})
186
+ train_file: Optional[str] = field(
187
+ default=None, metadata={"help": "The input training data file (a jsonlines)."})
188
+ validation_file: Optional[str] = field(
189
+ default=None,
190
+ metadata={
191
+ "help": "An optional input evaluation data file to evaluate the metrics (sacrebleu) on a jsonlines file."
192
+ },
193
+ )
194
+ train_val_split: Optional[float] = field(
195
+ default=0.0, metadata={"help": "Percent to split off of train for validation."}
196
+ )
197
+ test_file: Optional[str] = field(
198
+ default=None,
199
+ metadata={"help": "An optional input test data file to evaluate the metrics (sacrebleu) on a jsonlines file."},
200
+ )
201
+ image_path: Optional[str] = field(
202
+ default=None,
203
+ metadata={"help": "An optional image path"},
204
+ )
205
+ video_path: Optional[str] = field(
206
+ default=None,
207
+ metadata={"help": "An optional video path"},
208
+ )
209
+ input_size: Optional[int] = field(
210
+ default=224,
211
+ metadata={"help": "The input size of images."},
212
+ )
213
+ overwrite_cache: bool = field(
214
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
215
+ )
216
+ preprocessing_num_workers: Optional[int] = field(
217
+ default=None,
218
+ metadata={"help": "The number of processes to use for the preprocessing."},
219
+ )
220
+ max_seq_length: Optional[int] = field(
221
+ default=128,
222
+ metadata={
223
+ "help": (
224
+ "The maximum total input sequence length after tokenization. Sequences longer "
225
+ "than this will be truncated, sequences shorter will be padded."
226
+ )
227
+ },
228
+ )
229
+ pad_to_max_length: bool = field(
230
+ default=False,
231
+ metadata={
232
+ "help": (
233
+ "Whether to pad all samples to model maximum sentence length. "
234
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
235
+ "efficient on GPU but very bad for TPU."
236
+ )
237
+ },
238
+ )
239
+ val_max_length: Optional[int] = field(
240
+ default=None,
241
+ metadata={
242
+ "help": (
243
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
244
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
245
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
246
+ "during ``evaluate`` and ``predict``."
247
+ )
248
+ },
249
+ )
250
+ max_train_samples: Optional[int] = field(
251
+ default=None,
252
+ metadata={
253
+ "help": (
254
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
255
+ "value if set."
256
+ )
257
+ },
258
+ )
259
+ max_eval_samples: Optional[int] = field(
260
+ default=None,
261
+ metadata={
262
+ "help": (
263
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
264
+ "value if set."
265
+ )
266
+ },
267
+ )
268
+ max_predict_samples: Optional[int] = field(
269
+ default=None,
270
+ metadata={
271
+ "help": (
272
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
273
+ "value if set."
274
+ )
275
+ },
276
+ )
277
+ conv_style: Optional[str] = field(
278
+ default=None, metadata={"help": "prompt style for a conversation."}
279
+ )
280
+ save_data_path: Optional[str] = field(
281
+ default=None, metadata={"help": "prompt style for a conversation."}
282
+ )
283
+ num_beams: Optional[int] = field(
284
+ default=None,
285
+ metadata={
286
+ "help": (
287
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
288
+ "which is used during ``evaluate`` and ``predict``."
289
+ )
290
+ },
291
+ )
292
+ ignore_pad_token_for_loss: bool = field(
293
+ default=True,
294
+ metadata={
295
+ "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
296
+ },
297
+ )
298
+ source_prefix: Optional[str] = field(
299
+ default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
300
+ )
301
+ forced_bos_token: Optional[str] = field(
302
+ default=None,
303
+ metadata={
304
+ "help": (
305
+ "The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for"
306
+ " multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to"
307
+ " be the target language token.(Usually it is the target language token)"
308
+ )
309
+ },
310
+ )
311
+
312
+ def __post_init__(self):
313
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
314
+ raise ValueError("Need either a dataset name or a training/validation file.")
315
+ # accepting both json and jsonl file extensions, as
316
+ # many jsonlines files actually have a .json extension
317
+ else:
318
+ if self.train_file is not None:
319
+ extension = self.train_file.split(".")[-1]
320
+ assert extension in ["csv", "json", "jsonl", "parquet"], "`train_file` should be a csv or a json file."
321
+ if self.validation_file is not None:
322
+ extension = self.validation_file.split(".")[-1]
323
+ assert extension in ["csv", "json", "jsonl",
324
+ "parquet"], "`validation_file` should be a csv or a json file."
325
+ if self.test_file is not None:
326
+ extension = self.test_file.split(".")[-1]
327
+ assert extension == "json", "`test_file` should be a json file."
328
+
329
+ def main():
330
+ # 1. Parse input arguments
331
+ # See all possible arguments in src/transformers/training_args.py
332
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
333
+ init_dist(launcher='slurm', backend='nccl', port=29598)
334
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
335
+ # If we pass only one argument to the script, and it's the path to a json file,
336
+ # let's parse it to get our arguments.
337
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
338
+ else:
339
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
340
+
341
+ # 2. Setup logging
342
+ logging.basicConfig(
343
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
344
+ datefmt="%m/%d/%Y %H:%M:%S",
345
+ handlers=[logging.StreamHandler(sys.stdout)],
346
+ )
347
+
348
+ if training_args.should_log:
349
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
350
+ transformers.utils.logging.set_verbosity_info()
351
+
352
+ log_level = training_args.get_process_log_level()
353
+ logger.setLevel(log_level)
354
+ set_verbosity(log_level)
355
+ enable_default_handler()
356
+ enable_explicit_format()
357
+
358
+ # Log on each process the small summary:
359
+ logger.warning(
360
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
361
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
362
+ )
363
+ logger.info(f"Training/evaluation parameters {training_args}")
364
+
365
+ # 3. Detecting last checkpoint and eventually continue from last checkpoint.
366
+ last_checkpoint = None
367
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
368
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
369
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
370
+ raise ValueError(
371
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
372
+ "Use --overwrite_output_dir to overcome."
373
+ )
374
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
375
+ logger.info(
376
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
377
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
378
+ )
379
+
380
+ # Set seed before initializing model.
381
+ set_seed(training_args.seed)
382
+
383
+ # 4. Get the datasets
384
+ # you can either provide your own JSON training and evaluation files (see below)
385
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
386
+ # (the dataset will be downloaded automatically from the datasets Hub).
387
+ #
388
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
389
+ # download the dataset.
390
+
391
+ if data_args.dataset_name is not None:
392
+ # Downloading and loading a dataset from the hub.
393
+ ds = load_dataset(
394
+ data_args.dataset_name,
395
+ data_args.dataset_config_name,
396
+ data_dir=data_args.data_dir,
397
+ cache_dir=model_args.cache_dir,
398
+ use_auth_token=True if model_args.use_auth_token else None,
399
+ )
400
+ else:
401
+ data_files = {}
402
+ if data_args.train_file is not None:
403
+ data_files["train"] = data_args.train_file
404
+ extension = data_args.train_file.split(".")[-1]
405
+ if data_args.validation_file is not None:
406
+ data_files["validation"] = data_args.validation_file
407
+ extension = data_args.validation_file.split(".")[-1]
408
+ if data_args.test_file is not None:
409
+ data_files["test"] = data_args.test_file
410
+ extension = data_args.test_file.split(".")[-1]
411
+
412
+ # ds = load_dataset(
413
+ # "json" if extension == "jsonl" else extension,
414
+ # data_files=data_files,
415
+ # split="train"
416
+ # )
417
+ ds = json.load(open(data_args.train_file, "r"))
418
+
419
+ # 5. Load pretrained model, tokenizer, and image processor
420
+ #
421
+ # Distributed training: The .from_pretrained methods guarantee that only one local process can concurrently
422
+ # download model & vocab.
423
+ tokenizer = LlamaTokenizer.from_pretrained(
424
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
425
+ cache_dir=model_args.cache_dir,
426
+ use_fast=model_args.use_fast_tokenizer,
427
+ legacy=True,
428
+ )
429
+ # add special token
430
+ tokenizer.pad_token_id = 0
431
+ if tokenizer.unk_token is None:
432
+ tokenizer.add_special_tokens({"unk_token": DEFAULT_UNK_TOKEN})
433
+
434
+ tokens_list = [
435
+ DEFAULT_IMG_START_TOKEN, DEFAULT_IMG_END_TOKEN,
436
+ DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN
437
+ ]
438
+ tokenizer.add_tokens(tokens_list, special_tokens=True)
439
+
440
+ model = HuskyForConditionalGeneration.from_pretrained(
441
+ model_args.model_name_or_path, ignore_mismatched_sizes=True
442
+ )
443
+ embedding_size = model.language_model.get_input_embeddings().weight.shape[0]
444
+
445
+ # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
446
+ # on a small vocab and want a smaller embedding size, remove this test.
447
+
448
+ if len(tokenizer) > embedding_size:
449
+ model.resize_token_embeddings(len(tokenizer))
450
+ model.language_model.resize_token_embeddings(len(tokenizer))
451
+ model.config.text_config.vocab_size = len(tokenizer)
452
+
453
+ model.config.use_cache = False
454
+
455
+ def _freeze_params(module):
456
+ for param in module.parameters():
457
+ param.requires_grad = False
458
+
459
+ if model_args.freeze_model:
460
+ _freeze_params(model)
461
+ # only update language projection
462
+ model.language_projection.weight.requires_grad = True
463
+
464
+ if model_args.freeze_vision_model:
465
+ model.vision_model = model.vision_model.eval()
466
+ _freeze_params(model.vision_model)
467
+
468
+ if model_args.freeze_vision_adapter:
469
+ _freeze_params(model.vision_adapter)
470
+
471
+ if model_args.freeze_qformer:
472
+ model.qformer = model.qformer.eval()
473
+ _freeze_params(model.qformer)
474
+ model.query_tokens.requires_grad = False
475
+
476
+ if model_args.freeze_text_model:
477
+ _freeze_params(model.language_model)
478
+
479
+ if model_args.use_lora:
480
+ training_args.ddp_find_unused_parameters = False
481
+ _freeze_params(model)
482
+ lora_config = LoraConfig(
483
+ r=16,
484
+ target_modules=["q_proj", "v_proj"],
485
+ lora_alpha=32,
486
+ lora_dropout=0.05,
487
+ bias="none",
488
+ task_type="CAUSAL_LM",
489
+ )
490
+ model.language_model = get_peft_model(model.language_model, lora_config)
491
+ model.language_model.print_trainable_parameters()
492
+
493
+ if model_args.un_freeze_video_embedding:
494
+ _freeze_params(model)
495
+ model.vision_model.video_embeddings.patch_embedding.weight.requires_grad = True
496
+ model.vision_model.video_embeddings.class_embedding.requires_grad = True
497
+ model.vision_model.video_embeddings.position_embedding.requires_grad = True
498
+
499
+ if model_args.un_freeze_llm_head:
500
+ model.language_model.lm_head.weight.requires_grad = True
501
+
502
+ # set seed for torch dataloaders
503
+ set_seed(training_args.seed)
504
+
505
+ # 7. Preprocessing the datasets.
506
+ # We need to tokenize input captions and transform the images.
507
+
508
+ # set padding.
509
+ padding = "max_length" if data_args.pad_to_max_length else False
510
+
511
+ def husky_processor(examples):
512
+ processor = partial(
513
+ process_func,
514
+ tokenizer=tokenizer,
515
+ max_seq_length=data_args.max_seq_length,
516
+ )
517
+ model_inputs = processor(examples)
518
+ return model_inputs
519
+
520
+ # Data collator
521
+ label_pad_token_id = IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
522
+ if data_args.pad_to_max_length:
523
+ data_collator = default_data_collator
524
+ else:
525
+ data_collator = DataCollatorForSeq2Seq(
526
+ tokenizer,
527
+ model=model,
528
+ label_pad_token_id=label_pad_token_id,
529
+ pad_to_multiple_of=8 if training_args.fp16 else None,
530
+ )
531
+
532
+ concat_dataset = []
533
+ for data in ds:
534
+ data_file = data["text_file"]
535
+ extension = data_file.split(".")[-1]
536
+ extension = "json" if extension == "jsonl" else extension
537
+ logger.info(f"Loading dataset: {data['data_name']}")
538
+
539
+ raw_dataset = load_dataset(extension, data_files=data_file, num_proc=cpu_count(), split="train")
540
+ if data["data_type"] == "base":
541
+ temp = BaseDataset(
542
+ raw_dataset,
543
+ processor=husky_processor,
544
+ image_path=data["image_path"],
545
+ input_size=data_args.input_size
546
+ )
547
+ else:
548
+ temp = CephDataset(
549
+ raw_dataset,
550
+ processor=husky_processor,
551
+ input_size=data_args.input_size
552
+ )
553
+ concat_dataset.append(temp)
554
+
555
+ logger.info(f"All datasets have been loaded!")
556
+
557
+ if len(concat_dataset) > 1:
558
+ train_dataset = ConcatDataset(concat_dataset)
559
+ # train_dataset = train_dataset.shuffle(seed=42)
560
+ else:
561
+ train_dataset = concat_dataset[0]
562
+
563
+ # 8. Initialize our Trainer
564
+ trainer = Trainer(
565
+ model=model,
566
+ args=training_args,
567
+ train_dataset=train_dataset if training_args.do_train else None,
568
+ eval_dataset=None,
569
+ tokenizer=tokenizer,
570
+ data_collator=data_collator,
571
+ )
572
+
573
+ # 9. Training
574
+ if training_args.do_train:
575
+ checkpoint = None
576
+ if training_args.resume_from_checkpoint is not None:
577
+ checkpoint = training_args.resume_from_checkpoint
578
+ elif last_checkpoint is not None:
579
+ checkpoint = last_checkpoint
580
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
581
+ if model_args.use_lora:
582
+ model.language_model.save_pretrained(training_args.output_dir)
583
+ else:
584
+ trainer.save_model() # Saves the tokenizer too for easy upload
585
+
586
+ metrics = train_result.metrics
587
+ max_train_samples = (
588
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
589
+ )
590
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
591
+
592
+ trainer.log_metrics("train", metrics)
593
+ trainer.save_metrics("train", metrics)
594
+ trainer.save_state()
595
+
596
+ if __name__ == "__main__":
597
+ main()
robohusky/train/train_uni.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright Qing-Long Zhang. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for sequence to sequence.
18
+ """
19
+ import json
20
+ import logging
21
+ import os
22
+ import sys
23
+ import warnings
24
+ from functools import partial
25
+
26
+ from multiprocessing import cpu_count
27
+
28
+ from typing import Optional
29
+ from dataclasses import dataclass, field
30
+
31
+ from torch.utils.data import Dataset, ConcatDataset
32
+ from datasets import load_dataset, load_from_disk
33
+
34
+ from robohusky.dist_utils import init_dist
35
+ from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
36
+
37
+ import transformers
38
+ from transformers import (
39
+ HfArgumentParser,
40
+ TrainingArguments,
41
+ LlamaTokenizer,
42
+ Trainer,
43
+ set_seed,
44
+ default_data_collator,
45
+ DataCollatorForSeq2Seq,
46
+ )
47
+
48
+ from peft import (
49
+ LoraConfig,
50
+ get_peft_model,
51
+ prepare_model_for_int8_training,
52
+ )
53
+
54
+ from robohusky.base_dataset_uni import (
55
+ process_func,
56
+ BaseDataset,
57
+ WeightedConcatDataset
58
+ )
59
+
60
+ from transformers.trainer_utils import get_last_checkpoint
61
+ from transformers.utils import check_min_version
62
+ from transformers.utils.versions import require_version
63
+
64
+ from transformers.utils.logging import (
65
+ set_verbosity_info,
66
+ set_verbosity,
67
+ enable_default_handler,
68
+ enable_explicit_format,
69
+ )
70
+ from robohusky.train.llama_flash_attn_monkey_patch import (
71
+ replace_llama_attn_with_flash_attn
72
+ )
73
+
74
+ from robohusky.train.llama_rmsnorm_monkey_patch import (
75
+ replace_llama_rmsnorm_with_fused_rmsnorm
76
+ )
77
+
78
+ replace_llama_attn_with_flash_attn()
79
+ replace_llama_rmsnorm_with_fused_rmsnorm()
80
+
81
+ IGNORE_INDEX = -100
82
+ DEFAULT_UNK_TOKEN = "<unk>"
83
+ DEFAULT_IMG_START_TOKEN = "<img>"
84
+ DEFAULT_IMG_END_TOKEN = "</img>"
85
+
86
+ DEFAULT_VIDEO_START_TOKEN = "<vid>"
87
+ DEFAULT_VIDEO_END_TOKEN = "</vid>"
88
+
89
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
90
+ check_min_version("4.32.0.dev0")
91
+ require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
92
+
93
+ warnings.filterwarnings('ignore')
94
+ logger = logging.getLogger(__name__)
95
+
96
+ os.environ["WANDB_DISABLED"] = "true"
97
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
98
+
99
+ @dataclass
100
+ class ModelArguments:
101
+ """
102
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
103
+ """
104
+
105
+ model_name_or_path: str = field(
106
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
107
+ )
108
+ config_name: Optional[str] = field(
109
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
110
+ )
111
+ tokenizer_name: Optional[str] = field(
112
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
113
+ )
114
+ cache_dir: Optional[str] = field(
115
+ default=None,
116
+ metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
117
+ )
118
+ use_fast_tokenizer: bool = field(
119
+ default=False,
120
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
121
+ )
122
+ model_revision: str = field(
123
+ default="main",
124
+ metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
125
+ )
126
+ use_auth_token: bool = field(
127
+ default=False,
128
+ metadata={
129
+ "help": (
130
+ "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
131
+ "with private models)."
132
+ )
133
+ },
134
+ )
135
+ freeze_model: bool = field(
136
+ default=False,
137
+ metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
138
+ )
139
+ freeze_vision_model: bool = field(
140
+ default=False,
141
+ metadata={"help": "Will enable to load a pretrained vision model whose head dimensions are different."},
142
+ )
143
+ freeze_vision_adapter: bool = field(
144
+ default=False,
145
+ metadata={"help": "Will enable to load a pretrained vision adapter whose head dimensions are different."},
146
+ )
147
+ freeze_text_model: bool = field(
148
+ default=False,
149
+ metadata={"help": "Will enable to load a pretrained text model whose head dimensions are different."},
150
+ )
151
+ freeze_qformer: bool = field(
152
+ default=False,
153
+ metadata={"help": "Will enable to load a pretrained qformer model whose head dimensions are different."},
154
+ )
155
+ un_freeze_vision_embedding: bool = field(
156
+ default=False,
157
+ metadata={"help": "Will enable to tuning image patch_embedding when vision_model are frozen"},
158
+ )
159
+ un_freeze_video_embedding: bool = field(
160
+ default=False,
161
+ metadata={"help": "Will enable to tuning video patch_embedding when vision_model are frozen"},
162
+ )
163
+ un_freeze_llm_head: bool = field(
164
+ default=False,
165
+ metadata={"help": "Will enable to tuning video patch_embedding when vision_model are frozen"},
166
+ )
167
+ use_lora: bool = field(
168
+ default=False, metadata={"help": "add the LoRA adapters to the base model"}
169
+ )
170
+
171
+ @dataclass
172
+ class DataTrainingArguments:
173
+ """
174
+ Arguments pertaining to what data we are going to input our model for training and eval.
175
+ """
176
+
177
+ dataset_name: Optional[str] = field(
178
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
179
+ )
180
+ dataset_config_name: Optional[str] = field(
181
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
182
+ )
183
+ data_dir: Optional[str] = field(
184
+ default=None, metadata={"help": "The data directory containing input files."})
185
+ train_file: Optional[str] = field(
186
+ default=None, metadata={"help": "The input training data file (a jsonlines)."})
187
+ validation_file: Optional[str] = field(
188
+ default=None,
189
+ metadata={
190
+ "help": "An optional input evaluation data file to evaluate the metrics (sacrebleu) on a jsonlines file."
191
+ },
192
+ )
193
+ train_val_split: Optional[float] = field(
194
+ default=0.0, metadata={"help": "Percent to split off of train for validation."}
195
+ )
196
+ test_file: Optional[str] = field(
197
+ default=None,
198
+ metadata={"help": "An optional input test data file to evaluate the metrics (sacrebleu) on a jsonlines file."},
199
+ )
200
+ image_path: Optional[str] = field(
201
+ default=None,
202
+ metadata={"help": "An optional image path"},
203
+ )
204
+ video_path: Optional[str] = field(
205
+ default=None,
206
+ metadata={"help": "An optional video path"},
207
+ )
208
+ input_size: Optional[int] = field(
209
+ default=224,
210
+ metadata={"help": "The input size of images."},
211
+ )
212
+ overwrite_cache: bool = field(
213
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
214
+ )
215
+ preprocessing_num_workers: Optional[int] = field(
216
+ default=None,
217
+ metadata={"help": "The number of processes to use for the preprocessing."},
218
+ )
219
+ max_seq_length: Optional[int] = field(
220
+ default=128,
221
+ metadata={
222
+ "help": (
223
+ "The maximum total input sequence length after tokenization. Sequences longer "
224
+ "than this will be truncated, sequences shorter will be padded."
225
+ )
226
+ },
227
+ )
228
+ pad_to_max_length: bool = field(
229
+ default=False,
230
+ metadata={
231
+ "help": (
232
+ "Whether to pad all samples to model maximum sentence length. "
233
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
234
+ "efficient on GPU but very bad for TPU."
235
+ )
236
+ },
237
+ )
238
+ val_max_length: Optional[int] = field(
239
+ default=None,
240
+ metadata={
241
+ "help": (
242
+ "The maximum total sequence length for validation target text after tokenization. Sequences longer "
243
+ "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
244
+ "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
245
+ "during ``evaluate`` and ``predict``."
246
+ )
247
+ },
248
+ )
249
+ max_train_samples: Optional[int] = field(
250
+ default=None,
251
+ metadata={
252
+ "help": (
253
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
254
+ "value if set."
255
+ )
256
+ },
257
+ )
258
+ max_eval_samples: Optional[int] = field(
259
+ default=None,
260
+ metadata={
261
+ "help": (
262
+ "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
263
+ "value if set."
264
+ )
265
+ },
266
+ )
267
+ max_predict_samples: Optional[int] = field(
268
+ default=None,
269
+ metadata={
270
+ "help": (
271
+ "For debugging purposes or quicker training, truncate the number of prediction examples to this "
272
+ "value if set."
273
+ )
274
+ },
275
+ )
276
+ conv_style: Optional[str] = field(
277
+ default=None, metadata={"help": "prompt style for a conversation."}
278
+ )
279
+ save_data_path: Optional[str] = field(
280
+ default=None, metadata={"help": "prompt style for a conversation."}
281
+ )
282
+ num_beams: Optional[int] = field(
283
+ default=None,
284
+ metadata={
285
+ "help": (
286
+ "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
287
+ "which is used during ``evaluate`` and ``predict``."
288
+ )
289
+ },
290
+ )
291
+ ignore_pad_token_for_loss: bool = field(
292
+ default=True,
293
+ metadata={
294
+ "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
295
+ },
296
+ )
297
+ source_prefix: Optional[str] = field(
298
+ default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
299
+ )
300
+ forced_bos_token: Optional[str] = field(
301
+ default=None,
302
+ metadata={
303
+ "help": (
304
+ "The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for"
305
+ " multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to"
306
+ " be the target language token.(Usually it is the target language token)"
307
+ )
308
+ },
309
+ )
310
+
311
+ def __post_init__(self):
312
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
313
+ raise ValueError("Need either a dataset name or a training/validation file.")
314
+ # accepting both json and jsonl file extensions, as
315
+ # many jsonlines files actually have a .json extension
316
+ else:
317
+ if self.train_file is not None:
318
+ extension = self.train_file.split(".")[-1]
319
+ assert extension in ["csv", "json", "jsonl", "parquet"], "`train_file` should be a csv or a json file."
320
+ if self.validation_file is not None:
321
+ extension = self.validation_file.split(".")[-1]
322
+ assert extension in ["csv", "json", "jsonl",
323
+ "parquet"], "`validation_file` should be a csv or a json file."
324
+ if self.test_file is not None:
325
+ extension = self.test_file.split(".")[-1]
326
+ assert extension == "json", "`test_file` should be a json file."
327
+
328
+ def main():
329
+ # 1. Parse input arguments
330
+ # See all possible arguments in src/transformers/training_args.py
331
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
332
+ init_dist(launcher='slurm', backend='nccl', port=29598)
333
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
334
+ # If we pass only one argument to the script, and it's the path to a json file,
335
+ # let's parse it to get our arguments.
336
+ model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
337
+ else:
338
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
339
+
340
+ # 2. Setup logging
341
+ logging.basicConfig(
342
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
343
+ datefmt="%m/%d/%Y %H:%M:%S",
344
+ handlers=[logging.StreamHandler(sys.stdout)],
345
+ )
346
+
347
+ if training_args.should_log:
348
+ # The default of training_args.log_level is passive, so we set log level at info here to have that default.
349
+ transformers.utils.logging.set_verbosity_info()
350
+
351
+ log_level = training_args.get_process_log_level()
352
+ logger.setLevel(log_level)
353
+ set_verbosity(log_level)
354
+ enable_default_handler()
355
+ enable_explicit_format()
356
+
357
+ # Log on each process the small summary:
358
+ logger.warning(
359
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
360
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
361
+ )
362
+ logger.info(f"Training/evaluation parameters {training_args}")
363
+
364
+ # 3. Detecting last checkpoint and eventually continue from last checkpoint.
365
+ last_checkpoint = None
366
+ if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
367
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
368
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
369
+ raise ValueError(
370
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
371
+ "Use --overwrite_output_dir to overcome."
372
+ )
373
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
374
+ logger.info(
375
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
376
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
377
+ )
378
+
379
+ # Set seed before initializing model.
380
+ set_seed(training_args.seed)
381
+
382
+ # 4. Get the datasets
383
+ # you can either provide your own JSON training and evaluation files (see below)
384
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
385
+ # (the dataset will be downloaded automatically from the datasets Hub).
386
+ #
387
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
388
+ # download the dataset.
389
+
390
+ if data_args.dataset_name is not None:
391
+ # Downloading and loading a dataset from the hub.
392
+ ds = load_dataset(
393
+ data_args.dataset_name,
394
+ data_args.dataset_config_name,
395
+ data_dir=data_args.data_dir,
396
+ cache_dir=model_args.cache_dir,
397
+ use_auth_token=True if model_args.use_auth_token else None,
398
+ )
399
+ else:
400
+ data_files = {}
401
+ if data_args.train_file is not None:
402
+ data_files["train"] = data_args.train_file
403
+ extension = data_args.train_file.split(".")[-1]
404
+ if data_args.validation_file is not None:
405
+ data_files["validation"] = data_args.validation_file
406
+ extension = data_args.validation_file.split(".")[-1]
407
+ if data_args.test_file is not None:
408
+ data_files["test"] = data_args.test_file
409
+ extension = data_args.test_file.split(".")[-1]
410
+
411
+ # ds = load_dataset(
412
+ # "json" if extension == "jsonl" else extension,
413
+ # data_files=data_files,
414
+ # split="train"
415
+ # )
416
+ ds = json.load(open(data_args.train_file, "r"))
417
+
418
+ # 5. Load pretrained model, tokenizer, and image processor
419
+ #
420
+ # Distributed training: The .from_pretrained methods guarantee that only one local process can concurrently
421
+ # download model & vocab.
422
+ tokenizer = LlamaTokenizer.from_pretrained(
423
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
424
+ cache_dir=model_args.cache_dir,
425
+ use_fast=model_args.use_fast_tokenizer,
426
+ legacy=True,
427
+ )
428
+ # add special token
429
+ tokenizer.pad_token_id = 0
430
+ if tokenizer.unk_token is None:
431
+ tokenizer.add_special_tokens({"unk_token": DEFAULT_UNK_TOKEN})
432
+
433
+ tokens_list = [
434
+ DEFAULT_IMG_START_TOKEN, DEFAULT_IMG_END_TOKEN,
435
+ DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN
436
+ ]
437
+ tokenizer.add_tokens(tokens_list, special_tokens=True)
438
+
439
+ model = HuskyForConditionalGeneration.from_pretrained(
440
+ model_args.model_name_or_path, ignore_mismatched_sizes=True
441
+ )
442
+ embedding_size = model.language_model.get_input_embeddings().weight.shape[0]
443
+
444
+ # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
445
+ # on a small vocab and want a smaller embedding size, remove this test.
446
+
447
+ if len(tokenizer) > embedding_size:
448
+ model.resize_token_embeddings(len(tokenizer))
449
+ model.language_model.resize_token_embeddings(len(tokenizer))
450
+ model.config.text_config.vocab_size = len(tokenizer)
451
+
452
+ model.config.use_cache = False
453
+
454
+ def _freeze_params(module):
455
+ for param in module.parameters():
456
+ param.requires_grad = False
457
+
458
+ if model_args.freeze_model:
459
+ _freeze_params(model)
460
+ # only update language projection
461
+ model.language_projection.weight.requires_grad = True
462
+
463
+ if model_args.freeze_vision_model:
464
+ model.vision_model = model.vision_model.eval()
465
+ _freeze_params(model.vision_model)
466
+
467
+ if model_args.freeze_vision_adapter:
468
+ _freeze_params(model.vision_adapter)
469
+
470
+ if model_args.freeze_qformer:
471
+ model.qformer = model.qformer.eval()
472
+ _freeze_params(model.qformer)
473
+ model.query_tokens.requires_grad = False
474
+
475
+ if model_args.freeze_text_model:
476
+ _freeze_params(model.language_model)
477
+
478
+ if model_args.use_lora:
479
+ training_args.ddp_find_unused_parameters = False
480
+ _freeze_params(model)
481
+ lora_config = LoraConfig(
482
+ r=16,
483
+ target_modules=["q_proj", "v_proj"],
484
+ lora_alpha=32,
485
+ lora_dropout=0.05,
486
+ bias="none",
487
+ task_type="CAUSAL_LM",
488
+ )
489
+ model.language_model = get_peft_model(model.language_model, lora_config)
490
+ model.language_model.print_trainable_parameters()
491
+
492
+ if model_args.un_freeze_video_embedding:
493
+ _freeze_params(model)
494
+ model.vision_model.video_embeddings.patch_embedding.weight.requires_grad = True
495
+ model.vision_model.video_embeddings.class_embedding.requires_grad = True
496
+ model.vision_model.video_embeddings.position_embedding.requires_grad = True
497
+
498
+ if model_args.un_freeze_llm_head:
499
+ model.language_model.lm_head.weight.requires_grad = True
500
+
501
+ # set seed for torch dataloaders
502
+ set_seed(training_args.seed)
503
+
504
+ # 7. Preprocessing the datasets.
505
+ # We need to tokenize input captions and transform the images.
506
+
507
+ # set padding.
508
+ padding = "max_length" if data_args.pad_to_max_length else False
509
+
510
+ def husky_processor(examples):
511
+ processor = partial(
512
+ process_func,
513
+ tokenizer=tokenizer,
514
+ max_seq_length=data_args.max_seq_length,
515
+ )
516
+ model_inputs = processor(examples)
517
+ return model_inputs
518
+
519
+ # Data collator
520
+ label_pad_token_id = IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
521
+ if data_args.pad_to_max_length:
522
+ data_collator = default_data_collator
523
+ else:
524
+ data_collator = DataCollatorForSeq2Seq(
525
+ tokenizer,
526
+ model=model,
527
+ label_pad_token_id=label_pad_token_id,
528
+ pad_to_multiple_of=8 if training_args.fp16 else None,
529
+ )
530
+
531
+ concat_dataset = []
532
+ weights = []
533
+ batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
534
+ for data in ds:
535
+ data_name = data['data_name']
536
+ data_file = data["text_file"]
537
+ extension = data_file.split(".")[-1]
538
+ extension = "json" if extension == "jsonl" else extension
539
+ logger.info(f"Loading dataset: {data_name}")
540
+
541
+ raw_dataset = load_dataset(extension, data_files=data_file, num_proc=cpu_count(), split="train")
542
+ raw_dataset = raw_dataset.shuffle(seed=0)
543
+ max_train_sample = min(len(raw_dataset), batch_size * (len(raw_dataset) // batch_size))
544
+ raw_dataset = raw_dataset.select(range(max_train_sample))
545
+
546
+ media_type = data["data_type"]
547
+ input_size = data_args.video_size if media_type == "video" else data_args.input_size
548
+
549
+ temp = BaseDataset(
550
+ raw_dataset,
551
+ processor=husky_processor,
552
+ image_path=data["image_path"],
553
+ input_size=input_size,
554
+ num_segments=8,
555
+ norm_type="openai",
556
+ media_type=media_type
557
+ )
558
+
559
+ concat_dataset.append(temp)
560
+ weights.append(1 / len(temp))
561
+ logger.info(f"All datasets have been loaded!")
562
+
563
+ if len(concat_dataset) > 1:
564
+ train_dataset = WeightedConcatDataset(datasets=concat_dataset, weights=weights, batch_size=batch_size)
565
+ else:
566
+ train_dataset = concat_dataset[0]
567
+
568
+
569
+ # 8. Initialize our Trainer
570
+ trainer = Trainer(
571
+ model=model,
572
+ args=training_args,
573
+ train_dataset=train_dataset if training_args.do_train else None,
574
+ eval_dataset=None,
575
+ tokenizer=tokenizer,
576
+ data_collator=data_collator,
577
+ )
578
+
579
+ # 9. Training
580
+ if training_args.do_train:
581
+ checkpoint = None
582
+ if training_args.resume_from_checkpoint is not None:
583
+ checkpoint = training_args.resume_from_checkpoint
584
+ elif last_checkpoint is not None:
585
+ checkpoint = last_checkpoint
586
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
587
+ if model_args.use_lora:
588
+ model.language_model.save_pretrained(training_args.output_dir)
589
+ else:
590
+ trainer.save_model() # Saves the tokenizer too for easy upload
591
+
592
+ metrics = train_result.metrics
593
+ max_train_samples = (
594
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
595
+ )
596
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
597
+
598
+ trainer.log_metrics("train", metrics)
599
+ trainer.save_metrics("train", metrics)
600
+ trainer.save_state()
601
+
602
+ if __name__ == "__main__":
603
+ main()
robohusky/utils.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from asyncio import AbstractEventLoop
2
+ import json
3
+ import logging
4
+ import logging.handlers
5
+ import os
6
+ import platform
7
+ import sys
8
+ from typing import AsyncGenerator, Generator
9
+ import warnings
10
+
11
+ import requests
12
+ import torch
13
+
14
+ from husky.constants import LOGDIR
15
+
16
+ handler = None
17
+
18
+
19
+ def build_logger(logger_name, logger_filename):
20
+ global handler
21
+
22
+ formatter = logging.Formatter(
23
+ fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
24
+ datefmt="%Y-%m-%d %H:%M:%S",
25
+ )
26
+
27
+ # Set the format of root handlers
28
+ if not logging.getLogger().handlers:
29
+ if sys.version_info[1] >= 9:
30
+ # This is for windows
31
+ logging.basicConfig(level=logging.INFO, encoding="utf-8")
32
+ else:
33
+ if platform.system() == "Windows":
34
+ warnings.warn(
35
+ "If you are running on Windows, "
36
+ "we recommend you use Python >= 3.9 for UTF-8 encoding."
37
+ )
38
+ logging.basicConfig(level=logging.INFO)
39
+ logging.getLogger().handlers[0].setFormatter(formatter)
40
+
41
+ # Redirect stdout and stderr to loggers
42
+ stdout_logger = logging.getLogger("stdout")
43
+ stdout_logger.setLevel(logging.INFO)
44
+ sl = StreamToLogger(stdout_logger, logging.INFO)
45
+ sys.stdout = sl
46
+
47
+ stderr_logger = logging.getLogger("stderr")
48
+ stderr_logger.setLevel(logging.ERROR)
49
+ sl = StreamToLogger(stderr_logger, logging.ERROR)
50
+ sys.stderr = sl
51
+
52
+ # Get logger
53
+ logger = logging.getLogger(logger_name)
54
+ logger.setLevel(logging.INFO)
55
+
56
+ # Add a file handler for all loggers
57
+ if handler is None:
58
+ os.makedirs(LOGDIR, exist_ok=True)
59
+ filename = os.path.join(LOGDIR, logger_filename)
60
+ handler = logging.handlers.TimedRotatingFileHandler(
61
+ filename, when="D", utc=True, encoding="utf-8"
62
+ )
63
+ handler.setFormatter(formatter)
64
+
65
+ for name, item in logging.root.manager.loggerDict.items():
66
+ if isinstance(item, logging.Logger):
67
+ item.addHandler(handler)
68
+
69
+ return logger
70
+
71
+
72
+ class StreamToLogger(object):
73
+ """
74
+ Fake file-like stream object that redirects writes to a logger instance.
75
+ """
76
+
77
+ def __init__(self, logger, log_level=logging.INFO):
78
+ self.terminal = sys.stdout
79
+ self.logger = logger
80
+ self.log_level = log_level
81
+ self.linebuf = ""
82
+
83
+ def __getattr__(self, attr):
84
+ return getattr(self.terminal, attr)
85
+
86
+ def write(self, buf):
87
+ temp_linebuf = self.linebuf + buf
88
+ self.linebuf = ""
89
+ for line in temp_linebuf.splitlines(True):
90
+ # From the io.TextIOWrapper docs:
91
+ # On output, if newline is None, any '\n' characters written
92
+ # are translated to the system default line separator.
93
+ # By default sys.stdout.write() expects '\n' newlines and then
94
+ # translates them so this is still cross platform.
95
+ if line[-1] == "\n":
96
+ encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
97
+ self.logger.log(self.log_level, encoded_message.rstrip())
98
+ else:
99
+ self.linebuf += line
100
+
101
+ def flush(self):
102
+ if self.linebuf != "":
103
+ encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
104
+ self.logger.log(self.log_level, encoded_message.rstrip())
105
+ self.linebuf = ""
106
+
107
+
108
+ def disable_torch_init():
109
+ """
110
+ Disable the redundant torch default initialization to accelerate model creation.
111
+ """
112
+ import torch
113
+
114
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
115
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
116
+
117
+
118
+ def get_gpu_memory(max_gpus=None):
119
+ """Get available memory for each GPU."""
120
+ gpu_memory = []
121
+ num_gpus = (
122
+ torch.cuda.device_count()
123
+ if max_gpus is None
124
+ else min(max_gpus, torch.cuda.device_count())
125
+ )
126
+
127
+ for gpu_id in range(num_gpus):
128
+ with torch.cuda.device(gpu_id):
129
+ device = torch.cuda.current_device()
130
+ gpu_properties = torch.cuda.get_device_properties(device)
131
+ total_memory = gpu_properties.total_memory / (1024 ** 3)
132
+ allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)
133
+ available_memory = total_memory - allocated_memory
134
+ gpu_memory.append(available_memory)
135
+ return gpu_memory
136
+
137
+
138
+ def violates_moderation(text):
139
+ """
140
+ Check whether the text violates OpenAI moderation API.
141
+ """
142
+ url = "https://api.openai.com/v1/moderations"
143
+ headers = {
144
+ "Content-Type": "application/json",
145
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
146
+ }
147
+ text = text.replace("\n", "")
148
+ data = "{" + '"input": ' + f'"{text}"' + "}"
149
+ data = data.encode("utf-8")
150
+ try:
151
+ ret = requests.post(url, headers=headers, data=data, timeout=5)
152
+ flagged = ret.json()["results"][0]["flagged"]
153
+ except requests.exceptions.RequestException as e:
154
+ flagged = False
155
+ except KeyError as e:
156
+ flagged = False
157
+
158
+ return flagged
159
+
160
+
161
+ # Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings,
162
+ # Use this function to make sure it can be correctly loaded.
163
+ def clean_flant5_ckpt(ckpt_path):
164
+ index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
165
+ index_json = json.load(open(index_file, "r"))
166
+
167
+ weightmap = index_json["weight_map"]
168
+
169
+ share_weight_file = weightmap["shared.weight"]
170
+ share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[
171
+ "shared.weight"
172
+ ]
173
+
174
+ for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]:
175
+ weight_file = weightmap[weight_name]
176
+ weight = torch.load(os.path.join(ckpt_path, weight_file))
177
+ weight[weight_name] = share_weight
178
+ torch.save(weight, os.path.join(ckpt_path, weight_file))
179
+
180
+
181
+ def pretty_print_semaphore(semaphore):
182
+ """Print a semaphore in better format."""
183
+ if semaphore is None:
184
+ return "None"
185
+ return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
186
+
187
+
188
+ """A javascript function to get url parameters for the gradio web server."""
189
+ get_window_url_params_js = """
190
+ function() {
191
+ const params = new URLSearchParams(window.location.search);
192
+ url_params = Object.fromEntries(params);
193
+ console.log("url_params", url_params);
194
+ return url_params;
195
+ }
196
+ """
197
+
198
+
199
+ def iter_over_async(
200
+ async_gen: AsyncGenerator, event_loop: AbstractEventLoop
201
+ ) -> Generator:
202
+ """
203
+ Convert async generator to sync generator
204
+
205
+ :param async_gen: the AsyncGenerator to convert
206
+ :param event_loop: the event loop to run on
207
+ :returns: Sync generator
208
+ """
209
+ ait = async_gen.__aiter__()
210
+
211
+ async def get_next():
212
+ try:
213
+ obj = await ait.__anext__()
214
+ return False, obj
215
+ except StopAsyncIteration:
216
+ return True, None
217
+
218
+ while True:
219
+ done, obj = event_loop.run_until_complete(get_next())
220
+ if done:
221
+ break
222
+ yield obj
223
+
224
+
225
+ def detect_language(text: str) -> str:
226
+ """Detect the langauge of a string."""
227
+ import polyglot # pip3 install polyglot pyicu pycld2
228
+ from polyglot.detect import Detector
229
+ from polyglot.detect.base import logger as polyglot_logger
230
+ import pycld2
231
+
232
+ polyglot_logger.setLevel("ERROR")
233
+
234
+ try:
235
+ lang_code = Detector(text).language.name
236
+ except (pycld2.error, polyglot.detect.base.UnknownLanguage):
237
+ lang_code = "unknown"
238
+ return lang_code
robohusky/video_transformers.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import random
3
+ from PIL import Image, ImageOps
4
+ import numpy as np
5
+ import numbers
6
+ import math
7
+ import torch
8
+
9
+ class GroupRandomCrop(object):
10
+ def __init__(self, size):
11
+ if isinstance(size, numbers.Number):
12
+ self.size = (int(size), int(size))
13
+ else:
14
+ self.size = size
15
+
16
+ def __call__(self, img_group):
17
+
18
+ w, h = img_group[0].size
19
+ th, tw = self.size
20
+
21
+ out_images = list()
22
+
23
+ x1 = random.randint(0, w - tw)
24
+ y1 = random.randint(0, h - th)
25
+
26
+ for img in img_group:
27
+ assert (img.size[0] == w and img.size[1] == h)
28
+ if w == tw and h == th:
29
+ out_images.append(img)
30
+ else:
31
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
32
+
33
+ return out_images
34
+
35
+ class MultiGroupRandomCrop(object):
36
+ def __init__(self, size, groups=1):
37
+ if isinstance(size, numbers.Number):
38
+ self.size = (int(size), int(size))
39
+ else:
40
+ self.size = size
41
+ self.groups = groups
42
+
43
+ def __call__(self, img_group):
44
+
45
+ w, h = img_group[0].size
46
+ th, tw = self.size
47
+
48
+ out_images = list()
49
+
50
+ for i in range(self.groups):
51
+ x1 = random.randint(0, w - tw)
52
+ y1 = random.randint(0, h - th)
53
+
54
+ for img in img_group:
55
+ assert (img.size[0] == w and img.size[1] == h)
56
+ if w == tw and h == th:
57
+ out_images.append(img)
58
+ else:
59
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
60
+
61
+ return out_images
62
+
63
+ class GroupCenterCrop(object):
64
+ def __init__(self, size):
65
+ self.worker = torchvision.transforms.CenterCrop(size)
66
+
67
+ def __call__(self, img_group):
68
+ return [self.worker(img) for img in img_group]
69
+
70
+ class GroupRandomHorizontalFlip(object):
71
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
72
+ """
73
+
74
+ def __init__(self, is_flow=False):
75
+ self.is_flow = is_flow
76
+
77
+ def __call__(self, img_group, is_flow=False):
78
+ v = random.random()
79
+ if v < 0.5:
80
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
81
+ if self.is_flow:
82
+ for i in range(0, len(ret), 2):
83
+ # invert flow pixel values when flipping
84
+ ret[i] = ImageOps.invert(ret[i])
85
+ return ret
86
+ else:
87
+ return img_group
88
+
89
+ class GroupNormalize(object):
90
+ def __init__(self, mean, std):
91
+ self.mean = mean
92
+ self.std = std
93
+
94
+ def __call__(self, tensor):
95
+ rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
96
+ rep_std = self.std * (tensor.size()[0] // len(self.std))
97
+
98
+ # TODO: make efficient
99
+ for t, m, s in zip(tensor, rep_mean, rep_std):
100
+ t.sub_(m).div_(s)
101
+
102
+ return tensor
103
+
104
+ class GroupScale(object):
105
+ """ Rescales the input PIL.Image to the given 'size'.
106
+ 'size' will be the size of the smaller edge.
107
+ For example, if height > width, then image will be
108
+ rescaled to (size * height / width, size)
109
+ size: size of the smaller edge
110
+ interpolation: Default: PIL.Image.BILINEAR
111
+ """
112
+
113
+ def __init__(self, size, interpolation=Image.BILINEAR):
114
+ self.worker = torchvision.transforms.Resize(size, interpolation)
115
+
116
+ def __call__(self, img_group):
117
+ return [self.worker(img) for img in img_group]
118
+
119
+ class GroupOverSample(object):
120
+ def __init__(self, crop_size, scale_size=None, flip=True):
121
+ self.crop_size = crop_size if not isinstance(
122
+ crop_size, int) else (crop_size, crop_size)
123
+
124
+ if scale_size is not None:
125
+ self.scale_worker = GroupScale(scale_size)
126
+ else:
127
+ self.scale_worker = None
128
+ self.flip = flip
129
+
130
+ def __call__(self, img_group):
131
+
132
+ if self.scale_worker is not None:
133
+ img_group = self.scale_worker(img_group)
134
+
135
+ image_w, image_h = img_group[0].size
136
+ crop_w, crop_h = self.crop_size
137
+
138
+ offsets = GroupMultiScaleCrop.fill_fix_offset(
139
+ False, image_w, image_h, crop_w, crop_h)
140
+ oversample_group = list()
141
+ for o_w, o_h in offsets:
142
+ normal_group = list()
143
+ flip_group = list()
144
+ for i, img in enumerate(img_group):
145
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
146
+ normal_group.append(crop)
147
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
148
+
149
+ if img.mode == 'L' and i % 2 == 0:
150
+ flip_group.append(ImageOps.invert(flip_crop))
151
+ else:
152
+ flip_group.append(flip_crop)
153
+
154
+ oversample_group.extend(normal_group)
155
+ if self.flip:
156
+ oversample_group.extend(flip_group)
157
+ return oversample_group
158
+
159
+ class GroupFullResSample(object):
160
+ def __init__(self, crop_size, scale_size=None, flip=True):
161
+ self.crop_size = crop_size if not isinstance(
162
+ crop_size, int) else (crop_size, crop_size)
163
+
164
+ if scale_size is not None:
165
+ self.scale_worker = GroupScale(scale_size)
166
+ else:
167
+ self.scale_worker = None
168
+ self.flip = flip
169
+
170
+ def __call__(self, img_group):
171
+
172
+ if self.scale_worker is not None:
173
+ img_group = self.scale_worker(img_group)
174
+
175
+ image_w, image_h = img_group[0].size
176
+ crop_w, crop_h = self.crop_size
177
+
178
+ w_step = (image_w - crop_w) // 4
179
+ h_step = (image_h - crop_h) // 4
180
+
181
+ offsets = list()
182
+ offsets.append((0 * w_step, 2 * h_step)) # left
183
+ offsets.append((4 * w_step, 2 * h_step)) # right
184
+ offsets.append((2 * w_step, 2 * h_step)) # center
185
+
186
+ oversample_group = list()
187
+ for o_w, o_h in offsets:
188
+ normal_group = list()
189
+ flip_group = list()
190
+ for i, img in enumerate(img_group):
191
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
192
+ normal_group.append(crop)
193
+ if self.flip:
194
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
195
+
196
+ if img.mode == 'L' and i % 2 == 0:
197
+ flip_group.append(ImageOps.invert(flip_crop))
198
+ else:
199
+ flip_group.append(flip_crop)
200
+
201
+ oversample_group.extend(normal_group)
202
+ oversample_group.extend(flip_group)
203
+ return oversample_group
204
+
205
+ class GroupMultiScaleCrop(object):
206
+
207
+ def __init__(self, input_size, scales=None, max_distort=1,
208
+ fix_crop=True, more_fix_crop=True):
209
+ self.scales = scales if scales is not None else [1, .875, .75, .66]
210
+ self.max_distort = max_distort
211
+ self.fix_crop = fix_crop
212
+ self.more_fix_crop = more_fix_crop
213
+ self.input_size = input_size if not isinstance(input_size, int) else [
214
+ input_size, input_size]
215
+ self.interpolation = Image.BILINEAR
216
+
217
+ def __call__(self, img_group):
218
+
219
+ im_size = img_group[0].size
220
+
221
+ crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
222
+ crop_img_group = [
223
+ img.crop(
224
+ (offset_w,
225
+ offset_h,
226
+ offset_w +
227
+ crop_w,
228
+ offset_h +
229
+ crop_h)) for img in img_group]
230
+ ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
231
+ for img in crop_img_group]
232
+ return ret_img_group
233
+
234
+ def _sample_crop_size(self, im_size):
235
+ image_w, image_h = im_size[0], im_size[1]
236
+
237
+ # find a crop size
238
+ base_size = min(image_w, image_h)
239
+ crop_sizes = [int(base_size * x) for x in self.scales]
240
+ crop_h = [
241
+ self.input_size[1] if abs(
242
+ x - self.input_size[1]) < 3 else x for x in crop_sizes]
243
+ crop_w = [
244
+ self.input_size[0] if abs(
245
+ x - self.input_size[0]) < 3 else x for x in crop_sizes]
246
+
247
+ pairs = []
248
+ for i, h in enumerate(crop_h):
249
+ for j, w in enumerate(crop_w):
250
+ if abs(i - j) <= self.max_distort:
251
+ pairs.append((w, h))
252
+
253
+ crop_pair = random.choice(pairs)
254
+ if not self.fix_crop:
255
+ w_offset = random.randint(0, image_w - crop_pair[0])
256
+ h_offset = random.randint(0, image_h - crop_pair[1])
257
+ else:
258
+ w_offset, h_offset = self._sample_fix_offset(
259
+ image_w, image_h, crop_pair[0], crop_pair[1])
260
+
261
+ return crop_pair[0], crop_pair[1], w_offset, h_offset
262
+
263
+ def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
264
+ offsets = self.fill_fix_offset(
265
+ self.more_fix_crop, image_w, image_h, crop_w, crop_h)
266
+ return random.choice(offsets)
267
+
268
+ @staticmethod
269
+ def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
270
+ w_step = (image_w - crop_w) // 4
271
+ h_step = (image_h - crop_h) // 4
272
+
273
+ ret = list()
274
+ ret.append((0, 0)) # upper left
275
+ ret.append((4 * w_step, 0)) # upper right
276
+ ret.append((0, 4 * h_step)) # lower left
277
+ ret.append((4 * w_step, 4 * h_step)) # lower right
278
+ ret.append((2 * w_step, 2 * h_step)) # center
279
+
280
+ if more_fix_crop:
281
+ ret.append((0, 2 * h_step)) # center left
282
+ ret.append((4 * w_step, 2 * h_step)) # center right
283
+ ret.append((2 * w_step, 4 * h_step)) # lower center
284
+ ret.append((2 * w_step, 0 * h_step)) # upper center
285
+
286
+ ret.append((1 * w_step, 1 * h_step)) # upper left quarter
287
+ ret.append((3 * w_step, 1 * h_step)) # upper right quarter
288
+ ret.append((1 * w_step, 3 * h_step)) # lower left quarter
289
+ ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
290
+
291
+ return ret
292
+
293
+ class GroupRandomSizedCrop(object):
294
+ """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
295
+ and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
296
+ This is popularly used to train the Inception networks
297
+ size: size of the smaller edge
298
+ interpolation: Default: PIL.Image.BILINEAR
299
+ """
300
+
301
+ def __init__(self, size, interpolation=Image.BILINEAR):
302
+ self.size = size
303
+ self.interpolation = interpolation
304
+
305
+ def __call__(self, img_group):
306
+ for attempt in range(10):
307
+ area = img_group[0].size[0] * img_group[0].size[1]
308
+ target_area = random.uniform(0.08, 1.0) * area
309
+ aspect_ratio = random.uniform(3. / 4, 4. / 3)
310
+
311
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
312
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
313
+
314
+ if random.random() < 0.5:
315
+ w, h = h, w
316
+
317
+ if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
318
+ x1 = random.randint(0, img_group[0].size[0] - w)
319
+ y1 = random.randint(0, img_group[0].size[1] - h)
320
+ found = True
321
+ break
322
+ else:
323
+ found = False
324
+ x1 = 0
325
+ y1 = 0
326
+
327
+ if found:
328
+ out_group = list()
329
+ for img in img_group:
330
+ img = img.crop((x1, y1, x1 + w, y1 + h))
331
+ assert (img.size == (w, h))
332
+ out_group.append(
333
+ img.resize(
334
+ (self.size, self.size), self.interpolation))
335
+ return out_group
336
+ else:
337
+ # Fallback
338
+ scale = GroupScale(self.size, interpolation=self.interpolation)
339
+ crop = GroupRandomCrop(self.size)
340
+ return crop(scale(img_group))
341
+
342
+ class ConvertDataFormat(object):
343
+ def __init__(self, model_type):
344
+ self.model_type = model_type
345
+
346
+ def __call__(self, images):
347
+ if self.model_type == '2D':
348
+ return images
349
+ tc, h, w = images.size()
350
+ t = tc // 3
351
+ images = images.view(t, 3, h, w)
352
+ images = images.permute(1, 0, 2, 3)
353
+ return images
354
+
355
+ class Stack(object):
356
+
357
+ def __init__(self, roll=False):
358
+ self.roll = roll
359
+
360
+ def __call__(self, img_group):
361
+ if img_group[0].mode == 'L':
362
+ return np.concatenate([np.expand_dims(x, 2)
363
+ for x in img_group], axis=2)
364
+ elif img_group[0].mode == 'RGB':
365
+ if self.roll:
366
+ return np.concatenate([np.array(x)[:, :, ::-1]
367
+ for x in img_group], axis=2)
368
+ else:
369
+ # print(np.concatenate(img_group, axis=2).shape)
370
+ # print(img_group[0].shape)
371
+ return np.concatenate(img_group, axis=2)
372
+
373
+ class ToTorchFormatTensor(object):
374
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
375
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
376
+
377
+ def __init__(self, div=True):
378
+ self.div = div
379
+
380
+ def __call__(self, pic):
381
+ if isinstance(pic, np.ndarray):
382
+ # handle numpy array
383
+ img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
384
+ else:
385
+ # handle PIL Image
386
+ img = torch.ByteTensor(
387
+ torch.ByteStorage.from_buffer(
388
+ pic.tobytes()))
389
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
390
+ # put it from HWC to CHW format
391
+ # yikes, this transpose takes 80% of the loading time/CPU
392
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
393
+ return img.float().div(255) if self.div else img.float()
394
+
395
+ class IdentityTransform(object):
396
+
397
+ def __call__(self, data):
398
+ return data
399
+
400
+ def get_index(num_frames, num_segments):
401
+ seg_size = float(num_frames - 1) / num_segments
402
+ start = int(seg_size / 2)
403
+ offsets = np.array([
404
+ start + int(np.round(seg_size * idx)) for idx in range(num_segments)
405
+ ])
406
+ return offsets