Commit
·
cf932d8
1
Parent(s):
9e99e54
Add model code supports
Browse files- .idea/.gitignore +8 -0
- .idea/embodied_explainer.iml +12 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/modules.xml +8 -0
- .idea/vcs.xml +6 -0
- handler.py +52 -0
- inference.py +415 -0
- requirements.txt +13 -0
- robohusky/.DS_Store +0 -0
- robohusky/base_dataset.py +226 -0
- robohusky/base_dataset_uni.py +434 -0
- robohusky/compression.py +230 -0
- robohusky/configuration_husky.py +326 -0
- robohusky/constants.py +47 -0
- robohusky/conversation.py +511 -0
- robohusky/convert_fp16.py +27 -0
- robohusky/convert_husky_fp16.py +28 -0
- robohusky/convert_reward_fp16.py +27 -0
- robohusky/dist_utils.py +100 -0
- robohusky/llama2_flash_attn_monkey_patch.py +232 -0
- robohusky/model/__init__.py +70 -0
- robohusky/model/__pycache__/__init__.cpython-38.pyc +0 -0
- robohusky/model/__pycache__/configuration_husky.cpython-38.pyc +0 -0
- robohusky/model/__pycache__/modeling_husky_embody2.cpython-38.pyc +0 -0
- robohusky/model/compression.py +0 -0
- robohusky/model/configuration_husky.py +331 -0
- robohusky/model/configuration_husky_ori.py +327 -0
- robohusky/model/modeling_husky.py +1820 -0
- robohusky/model/modeling_husky_embody2.py +1962 -0
- robohusky/model/modeling_husky_embody2_ori.py +1821 -0
- robohusky/model/processing_husky.py +178 -0
- robohusky/train/.DS_Store +0 -0
- robohusky/train/llama_flash_attn_monkey_patch.py +232 -0
- robohusky/train/llama_rmsnorm_monkey_patch.py +15 -0
- robohusky/train/train.py +597 -0
- robohusky/train/train_uni.py +603 -0
- robohusky/utils.py +238 -0
- robohusky/video_transformers.py +406 -0
.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 默认忽略的文件
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# 基于编辑器的 HTTP 客户端请求
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
.idea/embodied_explainer.iml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="inheritedJdk" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
<component name="PyDocumentationSettings">
|
| 9 |
+
<option name="format" value="GOOGLE" />
|
| 10 |
+
<option name="myDocStringFormat" value="Google" />
|
| 11 |
+
</component>
|
| 12 |
+
</module>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/embodied_explainer.iml" filepath="$PROJECT_DIR$/.idea/embodied_explainer.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/vcs.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="VcsDirectoryMappings">
|
| 4 |
+
<mapping directory="" vcs="Git" />
|
| 5 |
+
</component>
|
| 6 |
+
</project>
|
handler.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Dict, Any
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from io import BytesIO
|
| 5 |
+
|
| 6 |
+
from inference import Chat # 直接import你放的inference.py里Chat类
|
| 7 |
+
from robohusky.conversation import get_conv_template
|
| 8 |
+
|
| 9 |
+
class EndpointHandler:
|
| 10 |
+
def __init__(self, path: str = "."):
|
| 11 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
+
self.chat = Chat(
|
| 13 |
+
model_path=path,
|
| 14 |
+
device=self.device,
|
| 15 |
+
num_gpus=1,
|
| 16 |
+
max_new_tokens=1024,
|
| 17 |
+
load_8bit=False
|
| 18 |
+
)
|
| 19 |
+
self.vision_feature = None
|
| 20 |
+
self.modal_type = "text"
|
| 21 |
+
self.conv = get_conv_template("husky").copy()
|
| 22 |
+
|
| 23 |
+
def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
| 24 |
+
query = inputs.get("inputs", "")
|
| 25 |
+
self.conv = get_conv_template("husky").copy()
|
| 26 |
+
self.vision_feature = None
|
| 27 |
+
self.modal_type = "text"
|
| 28 |
+
|
| 29 |
+
if "image" in inputs:
|
| 30 |
+
image_bytes = inputs["image"]
|
| 31 |
+
image = Image.open(BytesIO(image_bytes)).convert("RGB")
|
| 32 |
+
image.save("temp.jpg")
|
| 33 |
+
self.vision_feature = self.chat.get_image_embedding("temp.jpg")
|
| 34 |
+
self.modal_type = "image"
|
| 35 |
+
|
| 36 |
+
elif "video" in inputs:
|
| 37 |
+
video_bytes = inputs["video"]
|
| 38 |
+
with open("temp.mp4", "wb") as f:
|
| 39 |
+
f.write(video_bytes)
|
| 40 |
+
self.vision_feature = self.chat.get_video_embedding("temp.mp4")
|
| 41 |
+
self.modal_type = "video"
|
| 42 |
+
|
| 43 |
+
return {"query": query}
|
| 44 |
+
|
| 45 |
+
def __call__(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
| 46 |
+
processed = self.preprocess(inputs)
|
| 47 |
+
query = processed["query"]
|
| 48 |
+
|
| 49 |
+
conversations = self.chat.ask(text=query, conv=self.conv, modal_type=self.modal_type)
|
| 50 |
+
outputs = self.chat.answer(conversations, self.vision_feature, modal_type=self.modal_type)
|
| 51 |
+
self.conv.messages[-1][1] = outputs.strip()
|
| 52 |
+
return {"output": outputs.strip()}
|
inference.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
srun -p INTERN2 --job-name='husky_multi_test' --gres=gpu:1 --cpus-per-task=8 --quotatype="auto" python -u demo/inference_new.py
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import abc
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import requests
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from io import BytesIO
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torchvision.transforms as T
|
| 15 |
+
from peft import PeftModel
|
| 16 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 17 |
+
|
| 18 |
+
from transformers import (
|
| 19 |
+
LlamaTokenizer,
|
| 20 |
+
GenerationConfig,
|
| 21 |
+
StoppingCriteria,
|
| 22 |
+
StoppingCriteriaList,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
|
| 26 |
+
|
| 27 |
+
from robohusky.conversation import (
|
| 28 |
+
conv_templates,
|
| 29 |
+
get_conv_template,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
from robohusky.video_transformers import (
|
| 33 |
+
GroupNormalize,
|
| 34 |
+
GroupScale,
|
| 35 |
+
GroupCenterCrop,
|
| 36 |
+
Stack,
|
| 37 |
+
ToTorchFormatTensor,
|
| 38 |
+
get_index,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
from robohusky.compression import compress_module
|
| 42 |
+
from decord import VideoReader, cpu
|
| 43 |
+
|
| 44 |
+
# import deepspeed
|
| 45 |
+
|
| 46 |
+
IGNORE_INDEX = -100
|
| 47 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
| 48 |
+
DEFAULT_IMG_START_TOKEN = "<img>"
|
| 49 |
+
DEFAULT_IMG_END_TOKEN = "</img>"
|
| 50 |
+
|
| 51 |
+
DEFAULT_VIDEO_START_TOKEN = "<vid>"
|
| 52 |
+
DEFAULT_VIDEO_END_TOKEN = "</vid>"
|
| 53 |
+
|
| 54 |
+
def get_gpu_memory(max_gpus=None):
|
| 55 |
+
gpu_memory = []
|
| 56 |
+
num_gpus = (
|
| 57 |
+
torch.cuda.device_count()
|
| 58 |
+
if max_gpus is None
|
| 59 |
+
else min(max_gpus, torch.cuda.device_count())
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
for gpu_id in range(num_gpus):
|
| 63 |
+
with torch.cuda.device(gpu_id):
|
| 64 |
+
device = torch.cuda.current_device()
|
| 65 |
+
gpu_properties = torch.cuda.get_device_properties(device)
|
| 66 |
+
total_memory = gpu_properties.total_memory / (1024 ** 3)
|
| 67 |
+
allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)
|
| 68 |
+
available_memory = total_memory - allocated_memory
|
| 69 |
+
gpu_memory.append(available_memory)
|
| 70 |
+
return gpu_memory
|
| 71 |
+
|
| 72 |
+
def load_model(
|
| 73 |
+
model_path, device, num_gpus, max_gpu_memory=None, load_8bit=False, lora_weights=None
|
| 74 |
+
):
|
| 75 |
+
if device == "cpu":
|
| 76 |
+
kwargs = {}
|
| 77 |
+
elif device == "cuda":
|
| 78 |
+
kwargs = {"torch_dtype": torch.float16}
|
| 79 |
+
if num_gpus == "auto":
|
| 80 |
+
kwargs["device_map"] = "auto"
|
| 81 |
+
else:
|
| 82 |
+
num_gpus = int(num_gpus)
|
| 83 |
+
if num_gpus != 1:
|
| 84 |
+
kwargs["device_map"] = "auto"
|
| 85 |
+
if max_gpu_memory is None:
|
| 86 |
+
kwargs[
|
| 87 |
+
"device_map"
|
| 88 |
+
] = "sequential" # This is important for not the same VRAM sizes
|
| 89 |
+
available_gpu_memory = get_gpu_memory(num_gpus)
|
| 90 |
+
kwargs["max_memory"] = {
|
| 91 |
+
i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
|
| 92 |
+
for i in range(num_gpus)
|
| 93 |
+
}
|
| 94 |
+
else:
|
| 95 |
+
kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError(f"Invalid device: {device}")
|
| 98 |
+
|
| 99 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
| 100 |
+
model_path, use_fast=False)
|
| 101 |
+
|
| 102 |
+
if lora_weights is None:
|
| 103 |
+
model = HuskyForConditionalGeneration.from_pretrained(
|
| 104 |
+
model_path, low_cpu_mem_usage=True, **kwargs
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
kwargs["device_map"] = "auto"
|
| 108 |
+
model = HuskyForConditionalGeneration.from_pretrained(
|
| 109 |
+
model_path, low_cpu_mem_usage=True, **kwargs
|
| 110 |
+
)
|
| 111 |
+
model.language_model = PeftModel.from_pretrained(
|
| 112 |
+
model.language_model,
|
| 113 |
+
lora_weights,
|
| 114 |
+
**kwargs
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if load_8bit:
|
| 118 |
+
compress_module(model, device)
|
| 119 |
+
|
| 120 |
+
if (device == "cuda" and num_gpus == 1) or device == "mps":
|
| 121 |
+
model.to(device)
|
| 122 |
+
|
| 123 |
+
model = model.eval()
|
| 124 |
+
return model, tokenizer
|
| 125 |
+
|
| 126 |
+
def load_image(image_file, input_size=224):
|
| 127 |
+
if image_file.startswith('http') or image_file.startswith('https'):
|
| 128 |
+
response = requests.get(image_file)
|
| 129 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
| 130 |
+
else:
|
| 131 |
+
image = Image.open(image_file).convert('RGB')
|
| 132 |
+
|
| 133 |
+
crop_pct = 224 / 256
|
| 134 |
+
size = int(input_size / crop_pct)
|
| 135 |
+
transform = T.Compose([
|
| 136 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 137 |
+
T.Resize(size, interpolation=InterpolationMode.BICUBIC),
|
| 138 |
+
T.CenterCrop(input_size),
|
| 139 |
+
T.ToTensor(),
|
| 140 |
+
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
|
| 141 |
+
])
|
| 142 |
+
image = transform(image)
|
| 143 |
+
return image
|
| 144 |
+
|
| 145 |
+
def load_video(video_path, num_segments=8):
|
| 146 |
+
vr = VideoReader(video_path, ctx=cpu(0))
|
| 147 |
+
num_frames = len(vr)
|
| 148 |
+
frame_indices = get_index(num_frames, num_segments)
|
| 149 |
+
|
| 150 |
+
# transform
|
| 151 |
+
crop_size = 224
|
| 152 |
+
scale_size = 224
|
| 153 |
+
input_mean = [0.48145466, 0.4578275, 0.40821073]
|
| 154 |
+
input_std = [0.26862954, 0.26130258, 0.27577711]
|
| 155 |
+
|
| 156 |
+
transform = T.Compose([
|
| 157 |
+
GroupScale(int(scale_size), interpolation=InterpolationMode.BICUBIC),
|
| 158 |
+
GroupCenterCrop(crop_size),
|
| 159 |
+
Stack(),
|
| 160 |
+
ToTorchFormatTensor(),
|
| 161 |
+
GroupNormalize(input_mean, input_std)
|
| 162 |
+
])
|
| 163 |
+
|
| 164 |
+
images_group = list()
|
| 165 |
+
for frame_index in frame_indices:
|
| 166 |
+
img = Image.fromarray(vr[frame_index].asnumpy())
|
| 167 |
+
images_group.append(img)
|
| 168 |
+
video = transform(images_group)
|
| 169 |
+
return video
|
| 170 |
+
|
| 171 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
| 172 |
+
|
| 173 |
+
def __init__(self, stops, encounters=1):
|
| 174 |
+
super().__init__()
|
| 175 |
+
self.stops = stops
|
| 176 |
+
|
| 177 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
|
| 178 |
+
for stop in self.stops:
|
| 179 |
+
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
| 180 |
+
return True
|
| 181 |
+
|
| 182 |
+
return False
|
| 183 |
+
|
| 184 |
+
@torch.inference_mode()
|
| 185 |
+
def generate_stream(
|
| 186 |
+
model, tokenizer, image_processor, params, device
|
| 187 |
+
):
|
| 188 |
+
prompt = params["prompt"]
|
| 189 |
+
images = params.get("images", None)
|
| 190 |
+
videos = params.get("videos", None)
|
| 191 |
+
temperature = float(params.get("temperature", 0.7))
|
| 192 |
+
max_new_tokens = int(params.get("max_new_tokens", 1024))
|
| 193 |
+
|
| 194 |
+
num_queries = model.config.num_query_tokens
|
| 195 |
+
|
| 196 |
+
stop_words = ["Human: ", "Assistant: ", "###", "\n\n"]
|
| 197 |
+
stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
|
| 198 |
+
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
| 199 |
+
|
| 200 |
+
generation_config = GenerationConfig(
|
| 201 |
+
bos_token_id=1,
|
| 202 |
+
do_sample=True,
|
| 203 |
+
temperature=temperature,
|
| 204 |
+
max_new_tokens=max_new_tokens,
|
| 205 |
+
stopping_criteria=stopping_criteria
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
pixel_values = None
|
| 209 |
+
if images is not None:
|
| 210 |
+
pixel_values = load_image(images).to(device) # only support one image
|
| 211 |
+
image_query = DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN
|
| 212 |
+
prompt = prompt.replace("<image>", image_query)
|
| 213 |
+
|
| 214 |
+
elif videos is not None:
|
| 215 |
+
pixel_values = load_video(videos).to(device)
|
| 216 |
+
video_query = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN
|
| 217 |
+
prompt = prompt.replace("<video>", video_query)
|
| 218 |
+
|
| 219 |
+
model_inputs = tokenizer([prompt], return_tensors="pt")
|
| 220 |
+
model_inputs.pop("token_type_ids", None)
|
| 221 |
+
|
| 222 |
+
if pixel_values is not None:
|
| 223 |
+
model_inputs["pixel_values"] = pixel_values
|
| 224 |
+
|
| 225 |
+
generation_output = model.generate(
|
| 226 |
+
**model_inputs,
|
| 227 |
+
generation_config=generation_config,
|
| 228 |
+
return_dict_in_generate=True,
|
| 229 |
+
output_scores=True
|
| 230 |
+
)
|
| 231 |
+
else:
|
| 232 |
+
generation_output = model.language_model.generate(
|
| 233 |
+
**model_inputs,
|
| 234 |
+
generation_config=generation_config,
|
| 235 |
+
return_dict_in_generate=True,
|
| 236 |
+
output_scores=True
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
preds = generation_output.sequences
|
| 240 |
+
outputs = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
| 241 |
+
return outputs
|
| 242 |
+
|
| 243 |
+
class Chat:
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
model_path,
|
| 247 |
+
device,
|
| 248 |
+
num_gpus=1,
|
| 249 |
+
load_8bit=False,
|
| 250 |
+
temperature=0.7,
|
| 251 |
+
max_new_tokens=512,
|
| 252 |
+
lora_path=None,
|
| 253 |
+
):
|
| 254 |
+
model, tokenizer = load_model(
|
| 255 |
+
model_path, device, num_gpus, load_8bit=load_8bit, lora_weights=lora_path
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
self.model = model
|
| 259 |
+
# self.model.language_model = deepspeed.init_inference(
|
| 260 |
+
# self.model.language_model, mp_size=1, dtype=torch.float16, checkpoint=None, replace_with_kernel_inject=True)
|
| 261 |
+
self.tokenizer = tokenizer
|
| 262 |
+
num_queries = model.config.num_query_tokens
|
| 263 |
+
|
| 264 |
+
self.device = device
|
| 265 |
+
self.dtype = model.dtype
|
| 266 |
+
|
| 267 |
+
stop_words = ["Human: ", "Assistant: ", "###", "\n\n"]
|
| 268 |
+
stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
|
| 269 |
+
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
| 270 |
+
|
| 271 |
+
self.conv = get_conv_template("husky")
|
| 272 |
+
|
| 273 |
+
self.image_query = DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN
|
| 274 |
+
self.video_query = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN
|
| 275 |
+
|
| 276 |
+
self.generation_config = GenerationConfig(
|
| 277 |
+
bos_token_id=1,
|
| 278 |
+
do_sample=True,
|
| 279 |
+
top_k=20,
|
| 280 |
+
top_p=0.9,
|
| 281 |
+
temperature=temperature,
|
| 282 |
+
max_new_tokens=max_new_tokens,
|
| 283 |
+
stopping_criteria=stopping_criteria
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
def ask(self, text, conv, modal_type="image"):
|
| 287 |
+
assert modal_type in ["text", "image", "video"]
|
| 288 |
+
conversations = []
|
| 289 |
+
|
| 290 |
+
if len(conv.messages) > 0 or modal_type == "text":
|
| 291 |
+
conv.append_message(conv.roles[0], text)
|
| 292 |
+
elif modal_type == "image":
|
| 293 |
+
conv.append_message(conv.roles[0], self.image_query + "\n" + text)
|
| 294 |
+
else:
|
| 295 |
+
conv.append_message(conv.roles[0], self.video_query + "\n" + text)
|
| 296 |
+
|
| 297 |
+
conv.append_message(conv.roles[1], None)
|
| 298 |
+
conversations.append(conv.get_prompt())
|
| 299 |
+
return conversations
|
| 300 |
+
|
| 301 |
+
@torch.no_grad()
|
| 302 |
+
def get_image_embedding(self, image_file):
|
| 303 |
+
pixel_values = load_image(image_file)
|
| 304 |
+
pixel_values = pixel_values.unsqueeze(0).to(self.device, dtype=self.dtype)
|
| 305 |
+
language_model_inputs = self.model.extract_feature(pixel_values)
|
| 306 |
+
return language_model_inputs
|
| 307 |
+
|
| 308 |
+
@torch.no_grad()
|
| 309 |
+
def get_video_embedding(self, video_file):
|
| 310 |
+
pixel_values = load_video(video_file)
|
| 311 |
+
TC, H, W = pixel_values.shape
|
| 312 |
+
pixel_values = pixel_values.reshape(TC // 3, 3, H, W).transpose(0, 1) # [C, T, H, W]
|
| 313 |
+
pixel_values = pixel_values.unsqueeze(0).to(self.device, dtype=self.dtype)
|
| 314 |
+
assert len(pixel_values.shape) == 5
|
| 315 |
+
language_model_inputs = self.model.extract_feature(pixel_values)
|
| 316 |
+
return language_model_inputs
|
| 317 |
+
|
| 318 |
+
@torch.no_grad()
|
| 319 |
+
def answer(self, conversations, language_model_inputs, modal_type="image"):
|
| 320 |
+
model_inputs = self.tokenizer(
|
| 321 |
+
conversations,
|
| 322 |
+
return_tensors="pt",
|
| 323 |
+
)
|
| 324 |
+
model_inputs.pop("token_type_ids", None)
|
| 325 |
+
|
| 326 |
+
input_ids = model_inputs["input_ids"].to(self.device)
|
| 327 |
+
attention_mask = model_inputs["attention_mask"].to(self.device)
|
| 328 |
+
|
| 329 |
+
if modal_type == "text":
|
| 330 |
+
generation_output = self.model.language_model.generate(
|
| 331 |
+
input_ids=input_ids,
|
| 332 |
+
attention_mask=attention_mask,
|
| 333 |
+
generation_config=self.generation_config,
|
| 334 |
+
return_dict_in_generate=True,
|
| 335 |
+
output_scores=True
|
| 336 |
+
)
|
| 337 |
+
else:
|
| 338 |
+
pixel_values = model_inputs.pop("pixel_values", None)
|
| 339 |
+
if pixel_values is not None:
|
| 340 |
+
pixel_values = pixel_values.to(self.device)
|
| 341 |
+
|
| 342 |
+
generation_output = self.model.generate(
|
| 343 |
+
pixel_values=pixel_values,
|
| 344 |
+
input_ids=input_ids,
|
| 345 |
+
attention_mask=attention_mask,
|
| 346 |
+
language_model_inputs=language_model_inputs,
|
| 347 |
+
generation_config=self.generation_config,
|
| 348 |
+
return_dict_in_generate=True,
|
| 349 |
+
output_scores=True
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
preds = generation_output.sequences
|
| 353 |
+
outputs = self.tokenizer.batch_decode(preds, skip_special_tokens=True)[0]
|
| 354 |
+
|
| 355 |
+
if modal_type == "text":
|
| 356 |
+
skip_echo_len = len(conversations[0]) - conversations[0].count("</s>") * 3
|
| 357 |
+
outputs = outputs[skip_echo_len:].strip()
|
| 358 |
+
|
| 359 |
+
return outputs
|
| 360 |
+
|
| 361 |
+
if __name__ == '__main__':
|
| 362 |
+
# model_path = "/mnt/petrelfs/zhangqinglong/Documents/Husky/work_dirs/husky_v3/EmbodiedGPT/pretrain_0727"
|
| 363 |
+
model_path = "/mnt/petrelfs/share_data/gvembodied/workdirs/align_new_myyf"
|
| 364 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 365 |
+
chat = Chat(model_path, device=device, num_gpus=1, max_new_tokens=1024, load_8bit=False)
|
| 366 |
+
|
| 367 |
+
vision_feature = None
|
| 368 |
+
image_state = False
|
| 369 |
+
video_state = False
|
| 370 |
+
|
| 371 |
+
while True:
|
| 372 |
+
query = input("\n")
|
| 373 |
+
if query.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
|
| 374 |
+
if os.path.exists(query):
|
| 375 |
+
print("received.")
|
| 376 |
+
vision_feature = chat.get_image_embedding(query)
|
| 377 |
+
chat.conv = get_conv_template("husky").copy()
|
| 378 |
+
image_state = True
|
| 379 |
+
continue
|
| 380 |
+
if query.lower().endswith(('.mp4', '.mkv', '.avi', '.wmv', '.iso', ".webm")):
|
| 381 |
+
if os.path.exists(query):
|
| 382 |
+
print("received.")
|
| 383 |
+
vision_feature = chat.get_video_embedding(query)
|
| 384 |
+
chat.conv = get_conv_template("husky").copy()
|
| 385 |
+
video_state = True
|
| 386 |
+
continue
|
| 387 |
+
|
| 388 |
+
if query == "stop":
|
| 389 |
+
break
|
| 390 |
+
if query == "clear" or query == "" or query == "\n":
|
| 391 |
+
chat.conv = get_conv_template("husky").copy()
|
| 392 |
+
image_state = False
|
| 393 |
+
video_state = False
|
| 394 |
+
os.system("clear")
|
| 395 |
+
print("欢迎使用 husky-13b-zh 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
|
| 396 |
+
continue
|
| 397 |
+
|
| 398 |
+
if image_state:
|
| 399 |
+
modal_type = "image"
|
| 400 |
+
elif video_state:
|
| 401 |
+
modal_type = "video"
|
| 402 |
+
else:
|
| 403 |
+
modal_type = "text"
|
| 404 |
+
|
| 405 |
+
# image_test = "assets/husky.jpg"
|
| 406 |
+
# image_test = "assets/yoga.mp4"
|
| 407 |
+
# video_test = "assets/pretty_girl.mp4"
|
| 408 |
+
# video_test = "assets/stock-footage-billiards-concentrated-young-woman-playing-in-club.webm"
|
| 409 |
+
# video_test = "assets/stock-footage-kherson-ukraine-may-open-free-rock-music-festival-crowd-partying-at-a-rock-concert.webm"
|
| 410 |
+
conversations = chat.ask(text=query, conv=chat.conv, modal_type=modal_type)
|
| 411 |
+
outputs = chat.answer(conversations, vision_feature, modal_type=modal_type)
|
| 412 |
+
# NOTE: strip is important to align with the training data.
|
| 413 |
+
chat.conv.messages[-1][1] = outputs.strip()
|
| 414 |
+
|
| 415 |
+
print(f"Husky: \n{outputs}")
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.0.1
|
| 2 |
+
torchvision==0.15.2
|
| 3 |
+
torchaudio==2.0.2
|
| 4 |
+
transformers==4.34.1
|
| 5 |
+
decord
|
| 6 |
+
peft
|
| 7 |
+
huggingface_hub
|
| 8 |
+
Pillow
|
| 9 |
+
einops
|
| 10 |
+
scipy
|
| 11 |
+
numpy
|
| 12 |
+
tqdm
|
| 13 |
+
flash-attn
|
robohusky/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
robohusky/base_dataset.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
from typing import Dict, Optional, Sequence
|
| 5 |
+
from PIL import PngImagePlugin, Image, ImageFile
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import Dataset
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 11 |
+
|
| 12 |
+
from robohusky.train.tcsloader import TCSLoader
|
| 13 |
+
from robohusky.conversation import get_conv_template
|
| 14 |
+
|
| 15 |
+
IGNORE_INDEX = -100
|
| 16 |
+
|
| 17 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 18 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 19 |
+
MaximumDecompressedSize = 1024
|
| 20 |
+
MegaByte = 2 ** 20
|
| 21 |
+
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
|
| 22 |
+
|
| 23 |
+
DEFAULT_IMG_START_TOKEN = "<img>"
|
| 24 |
+
DEFAULT_IMG_END_TOKEN = "</img>"
|
| 25 |
+
|
| 26 |
+
DEFAULT_VIDEO_START_TOKEN = "<vid>"
|
| 27 |
+
DEFAULT_VIDEO_END_TOKEN = "</vid>"
|
| 28 |
+
|
| 29 |
+
def is_image(image_file):
|
| 30 |
+
if image_file.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
|
| 31 |
+
return True
|
| 32 |
+
else:
|
| 33 |
+
return False
|
| 34 |
+
|
| 35 |
+
def is_video(image_file):
|
| 36 |
+
if image_file.lower().endswith(('.mp4', '.mkv', '.avi', '.wmv', '.iso', ".webm")):
|
| 37 |
+
return True
|
| 38 |
+
else:
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
def build_transform(input_size):
|
| 42 |
+
transform = T.Compose([
|
| 43 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 44 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
| 45 |
+
T.ToTensor(),
|
| 46 |
+
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
|
| 47 |
+
])
|
| 48 |
+
return transform
|
| 49 |
+
|
| 50 |
+
def format_inputs(sources):
|
| 51 |
+
# Apply prompt templates
|
| 52 |
+
conv = get_conv_template("husky").copy()
|
| 53 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 54 |
+
conversations = []
|
| 55 |
+
|
| 56 |
+
for i, source in enumerate(sources):
|
| 57 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
| 58 |
+
# Skip the first one if it is not from human
|
| 59 |
+
source = source[1:]
|
| 60 |
+
|
| 61 |
+
conv.messages = []
|
| 62 |
+
for j, sentence in enumerate(source):
|
| 63 |
+
role = roles[sentence["from"]]
|
| 64 |
+
assert role == conv.roles[j % 2], f"{i}"
|
| 65 |
+
# vision is only supported for the human input
|
| 66 |
+
if role == conv.roles[0]:
|
| 67 |
+
value = sentence["value"]
|
| 68 |
+
if "<image>" in value:
|
| 69 |
+
if value.endswith("\n<image>"):
|
| 70 |
+
value = "<image>\n" + value.replace("\n<image>", "")
|
| 71 |
+
image_query = DEFAULT_IMG_START_TOKEN + DEFAULT_IMG_END_TOKEN
|
| 72 |
+
sentence["value"] = value.replace("<image>", image_query)
|
| 73 |
+
|
| 74 |
+
elif "<video>" in value:
|
| 75 |
+
if value.endswith("\n<video>"):
|
| 76 |
+
value = "<video>\n" + value.replace("\n<video>", "")
|
| 77 |
+
video_query = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_END_TOKEN
|
| 78 |
+
sentence["value"] = value.replace("<video>", video_query)
|
| 79 |
+
|
| 80 |
+
conv.append_message(role, sentence["value"])
|
| 81 |
+
conversations.append(conv.get_prompt())
|
| 82 |
+
|
| 83 |
+
return conversations, conv
|
| 84 |
+
|
| 85 |
+
def process_func(examples, tokenizer, max_seq_length):
|
| 86 |
+
conversations, conv = format_inputs(examples['conversations'])
|
| 87 |
+
model_inputs = tokenizer(
|
| 88 |
+
conversations,
|
| 89 |
+
max_length=max_seq_length,
|
| 90 |
+
padding="max_length",
|
| 91 |
+
truncation=True,
|
| 92 |
+
return_tensors="pt",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
model_inputs.pop("token_type_ids", None)
|
| 96 |
+
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
| 97 |
+
# padding in the loss.
|
| 98 |
+
targets = model_inputs["input_ids"].clone()
|
| 99 |
+
|
| 100 |
+
# Mask targets
|
| 101 |
+
sep = conv.sep + conv.roles[1] + ": "
|
| 102 |
+
for conversation, target in zip(conversations, targets):
|
| 103 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
| 104 |
+
|
| 105 |
+
turns = conversation.split(conv.sep2)
|
| 106 |
+
cur_len = 1
|
| 107 |
+
target[:cur_len] = IGNORE_INDEX
|
| 108 |
+
for i, turn in enumerate(turns):
|
| 109 |
+
if turn == "":
|
| 110 |
+
break
|
| 111 |
+
turn_len = len(tokenizer(turn).input_ids)
|
| 112 |
+
|
| 113 |
+
parts = turn.split(sep)
|
| 114 |
+
if len(parts) != 2:
|
| 115 |
+
break
|
| 116 |
+
parts[0] += sep
|
| 117 |
+
|
| 118 |
+
# "-2" is hardcoded for the Llama tokenizer to make the offset correct.
|
| 119 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
| 120 |
+
|
| 121 |
+
if i != 0 and not tokenizer.legacy:
|
| 122 |
+
# The legacy and non-legacy modes handle special tokens differently
|
| 123 |
+
instruction_len -= 1
|
| 124 |
+
|
| 125 |
+
# Ignore the user instructions
|
| 126 |
+
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
|
| 127 |
+
cur_len += turn_len
|
| 128 |
+
|
| 129 |
+
if i != 0 and not tokenizer.legacy:
|
| 130 |
+
# The legacy and non-legacy modes handle special tokens differently
|
| 131 |
+
cur_len -= 1
|
| 132 |
+
|
| 133 |
+
target[cur_len:] = IGNORE_INDEX
|
| 134 |
+
|
| 135 |
+
if cur_len < tokenizer.model_max_length:
|
| 136 |
+
if cur_len != total_len:
|
| 137 |
+
target[:] = IGNORE_INDEX
|
| 138 |
+
|
| 139 |
+
model_inputs["labels"] = targets
|
| 140 |
+
return model_inputs
|
| 141 |
+
|
| 142 |
+
class BaseDataset(Dataset):
|
| 143 |
+
def __init__(self, dataset, processor, image_path="", input_size=224):
|
| 144 |
+
super(BaseDataset, self).__init__()
|
| 145 |
+
self.dataset = dataset
|
| 146 |
+
self.image_path = image_path
|
| 147 |
+
|
| 148 |
+
self.transform = build_transform(input_size)
|
| 149 |
+
self.husky_processor = processor
|
| 150 |
+
|
| 151 |
+
self.cached_data_dict = {}
|
| 152 |
+
|
| 153 |
+
def __len__(self):
|
| 154 |
+
return len(self.dataset)
|
| 155 |
+
|
| 156 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 157 |
+
if i in self.cached_data_dict:
|
| 158 |
+
return self.cached_data_dict[i]
|
| 159 |
+
|
| 160 |
+
data = self.dataset[i]
|
| 161 |
+
image_file = data.pop("image", None)
|
| 162 |
+
|
| 163 |
+
if self.image_path != "":
|
| 164 |
+
image_file = os.path.join(self.image_path, image_file)
|
| 165 |
+
if not os.path.exists(image_file):
|
| 166 |
+
return self.__getitem__((i + 1) % len(self.dataset))
|
| 167 |
+
image = Image.open(image_file)
|
| 168 |
+
else:
|
| 169 |
+
image = Image.open(image_file)
|
| 170 |
+
|
| 171 |
+
for k, v in data.items():
|
| 172 |
+
data[k] = [v]
|
| 173 |
+
ret = self.husky_processor(data)
|
| 174 |
+
for k, v in ret.items():
|
| 175 |
+
ret[k] = v[0]
|
| 176 |
+
|
| 177 |
+
pixel_values = self.transform(image)
|
| 178 |
+
ret["pixel_values"] = pixel_values
|
| 179 |
+
|
| 180 |
+
self.cached_data_dict[i] = ret
|
| 181 |
+
return ret
|
| 182 |
+
|
| 183 |
+
class CephDataset(Dataset):
|
| 184 |
+
def __init__(self, dataset, processor, input_size=224):
|
| 185 |
+
super(CephDataset, self).__init__()
|
| 186 |
+
self.dataset = dataset
|
| 187 |
+
|
| 188 |
+
self.transform = build_transform(input_size)
|
| 189 |
+
self.husky_processor = processor
|
| 190 |
+
|
| 191 |
+
conf_path = "./petrelf.conf"
|
| 192 |
+
self.conf_path = os.path.abspath(conf_path)
|
| 193 |
+
|
| 194 |
+
self.initialized = False
|
| 195 |
+
self._init_memcached()
|
| 196 |
+
|
| 197 |
+
def _init_memcached(self):
|
| 198 |
+
if not self.initialized:
|
| 199 |
+
assert self.conf_path is not None
|
| 200 |
+
self.mt_loader = TCSLoader(self.conf_path)
|
| 201 |
+
self.initialized = True
|
| 202 |
+
|
| 203 |
+
def __len__(self):
|
| 204 |
+
return len(self.dataset)
|
| 205 |
+
|
| 206 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 207 |
+
data = self.dataset[i]
|
| 208 |
+
image_file = data.pop("image", None)
|
| 209 |
+
|
| 210 |
+
try:
|
| 211 |
+
image = self.mt_loader(image_file).convert('RGB')
|
| 212 |
+
except (AttributeError, OSError):
|
| 213 |
+
with open("error.txt", 'a') as f:
|
| 214 |
+
f.write(image_file + '\n')
|
| 215 |
+
i = random.randint(0, len(self.dataset))
|
| 216 |
+
return self.__getitem__(i % len(self.dataset))
|
| 217 |
+
|
| 218 |
+
for k, v in data.items():
|
| 219 |
+
data[k] = [v]
|
| 220 |
+
|
| 221 |
+
ret = self.husky_processor(data)
|
| 222 |
+
for k, v in ret.items():
|
| 223 |
+
ret[k] = v[0]
|
| 224 |
+
pixel_values = self.transform(image)
|
| 225 |
+
ret["pixel_values"] = pixel_values
|
| 226 |
+
return ret
|
robohusky/base_dataset_uni.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
from typing import Dict, Optional, Sequence, Iterator, List, Iterable, Union
|
| 5 |
+
from PIL import PngImagePlugin, Image, ImageFile, ImageOps
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from torch.utils.data import (
|
| 11 |
+
Dataset,
|
| 12 |
+
ConcatDataset,
|
| 13 |
+
Sampler,
|
| 14 |
+
WeightedRandomSampler
|
| 15 |
+
)
|
| 16 |
+
import torchvision.transforms as T
|
| 17 |
+
from torchvision.transforms.functional import InterpolationMode
|
| 18 |
+
|
| 19 |
+
from robohusky.train.tcsloader import TCSLoader
|
| 20 |
+
|
| 21 |
+
from decord import VideoReader, cpu
|
| 22 |
+
from robohusky.video_transformers import (
|
| 23 |
+
GroupNormalize,
|
| 24 |
+
GroupScale,
|
| 25 |
+
GroupCenterCrop,
|
| 26 |
+
Stack,
|
| 27 |
+
ToTorchFormatTensor,
|
| 28 |
+
get_index,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
from robohusky.conversation import get_conv_template
|
| 32 |
+
|
| 33 |
+
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
|
| 34 |
+
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
|
| 35 |
+
IMAGENET_STANDARD_MEAN = [0.5, 0.5, 0.5]
|
| 36 |
+
IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5]
|
| 37 |
+
OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
| 38 |
+
OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|
| 39 |
+
|
| 40 |
+
IGNORE_INDEX = -100
|
| 41 |
+
|
| 42 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 43 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 44 |
+
MaximumDecompressedSize = 1024
|
| 45 |
+
MegaByte = 2 ** 20
|
| 46 |
+
PngImagePlugin.MAX_TEXT_CHUNK = MaximumDecompressedSize * MegaByte
|
| 47 |
+
|
| 48 |
+
DEFAULT_IMG_START_TOKEN = "<img>"
|
| 49 |
+
DEFAULT_IMG_END_TOKEN = "</img>"
|
| 50 |
+
|
| 51 |
+
DEFAULT_VIDEO_START_TOKEN = "<vid>"
|
| 52 |
+
DEFAULT_VIDEO_END_TOKEN = "</vid>"
|
| 53 |
+
|
| 54 |
+
DEFAULT_EMBED_TOKEN = "<quad>"
|
| 55 |
+
|
| 56 |
+
conf_path = "/your path to/petrelf.conf"
|
| 57 |
+
|
| 58 |
+
def is_image(image_file):
|
| 59 |
+
if image_file.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
|
| 60 |
+
return True
|
| 61 |
+
else:
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
def is_video(image_file):
|
| 65 |
+
if image_file.lower().endswith(('.mp4', '.mkv', '.avi', '.wmv', '.iso', ".webm")):
|
| 66 |
+
return True
|
| 67 |
+
else:
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
def is_numpy(image_file):
|
| 71 |
+
if image_file.endswith(".npy"):
|
| 72 |
+
return True
|
| 73 |
+
else:
|
| 74 |
+
return False
|
| 75 |
+
|
| 76 |
+
def get_media_type(image_file):
|
| 77 |
+
if is_image(image_file):
|
| 78 |
+
return "image"
|
| 79 |
+
elif is_video(image_file):
|
| 80 |
+
return "video"
|
| 81 |
+
elif is_numpy(image_file):
|
| 82 |
+
return "numpy"
|
| 83 |
+
else:
|
| 84 |
+
return "text"
|
| 85 |
+
|
| 86 |
+
def build_transform(input_size, norm_type="openai", media_type="image"):
|
| 87 |
+
if norm_type == "openai":
|
| 88 |
+
mean = OPENAI_CLIP_MEAN
|
| 89 |
+
std = OPENAI_CLIP_STD
|
| 90 |
+
elif norm_type == "imagenet":
|
| 91 |
+
mean = IMAGENET_DEFAULT_MEAN
|
| 92 |
+
std = IMAGENET_DEFAULT_STD
|
| 93 |
+
else:
|
| 94 |
+
mean = IMAGENET_DEFAULT_MEAN
|
| 95 |
+
std = IMAGENET_DEFAULT_STD
|
| 96 |
+
|
| 97 |
+
if media_type == "image":
|
| 98 |
+
transform = T.Compose([
|
| 99 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 100 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
| 101 |
+
T.ToTensor(),
|
| 102 |
+
T.Normalize(mean=mean, std=std)
|
| 103 |
+
])
|
| 104 |
+
elif media_type == "video":
|
| 105 |
+
transform = T.Compose([
|
| 106 |
+
GroupScale(int(input_size), interpolation=InterpolationMode.BICUBIC),
|
| 107 |
+
GroupCenterCrop(input_size),
|
| 108 |
+
Stack(),
|
| 109 |
+
ToTorchFormatTensor(),
|
| 110 |
+
GroupNormalize(mean=mean, std=std)
|
| 111 |
+
])
|
| 112 |
+
else:
|
| 113 |
+
transform = None
|
| 114 |
+
return transform
|
| 115 |
+
|
| 116 |
+
def check_format(data):
|
| 117 |
+
if not ('id' in data and 'image' in data and 'conversations' in data and len(data['conversations']) % 2 == 0):
|
| 118 |
+
print(f"Lake field: {data}")
|
| 119 |
+
return False
|
| 120 |
+
for i, message in enumerate(data['conversations']):
|
| 121 |
+
if i == 0:
|
| 122 |
+
if not (message['value'].startswith("<image>\n") or message['value'].endswith("\n<image>")):
|
| 123 |
+
print(f"No <image>: {data}")
|
| 124 |
+
return False
|
| 125 |
+
if i % 2 == 0:
|
| 126 |
+
if not (message['from'] == 'human'):
|
| 127 |
+
print(f"Not from human: {data}")
|
| 128 |
+
return False
|
| 129 |
+
else:
|
| 130 |
+
if not (message['from'] == 'gpt'):
|
| 131 |
+
print(f"Not from gpt: {data}")
|
| 132 |
+
return False
|
| 133 |
+
if message['value'] is None or (len(message['value']) == 0):
|
| 134 |
+
print(f"No Message: {data}")
|
| 135 |
+
return False
|
| 136 |
+
return True
|
| 137 |
+
|
| 138 |
+
def format_inputs(sources, conv_tempt="husky", num_query_tokens=256):
|
| 139 |
+
# Apply prompt templates
|
| 140 |
+
conv = get_conv_template(conv_tempt).copy()
|
| 141 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 142 |
+
conversations = []
|
| 143 |
+
|
| 144 |
+
for i, source in enumerate(sources):
|
| 145 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
| 146 |
+
# Skip the first one if it is not from human
|
| 147 |
+
source = source[1:]
|
| 148 |
+
|
| 149 |
+
conv.messages = []
|
| 150 |
+
for j, sentence in enumerate(source):
|
| 151 |
+
role = roles[sentence["from"]]
|
| 152 |
+
assert role == conv.roles[j % 2], f"{i}"
|
| 153 |
+
# vision is only supported for the human input
|
| 154 |
+
if role == conv.roles[0]:
|
| 155 |
+
value = sentence["value"]
|
| 156 |
+
if "<image>" in value:
|
| 157 |
+
if value.endswith("\n<image>"):
|
| 158 |
+
value = "<image>\n" + value.replace("\n<image>", "")
|
| 159 |
+
|
| 160 |
+
image_query = DEFAULT_IMG_START_TOKEN + num_query_tokens * DEFAULT_EMBED_TOKEN + DEFAULT_IMG_END_TOKEN
|
| 161 |
+
sentence["value"] = value.replace("<image>", image_query)
|
| 162 |
+
|
| 163 |
+
elif "<video>" in value:
|
| 164 |
+
if value.endswith("\n<video>"):
|
| 165 |
+
value = "<video>\n" + value.replace("\n<video>", "")
|
| 166 |
+
|
| 167 |
+
video_query = DEFAULT_VIDEO_START_TOKEN + num_query_tokens * DEFAULT_EMBED_TOKEN + DEFAULT_VIDEO_END_TOKEN
|
| 168 |
+
sentence["value"] = value.replace("<video>", video_query)
|
| 169 |
+
|
| 170 |
+
conv.append_message(role, sentence["value"])
|
| 171 |
+
conversations.append(conv.get_prompt())
|
| 172 |
+
|
| 173 |
+
return conversations, conv
|
| 174 |
+
|
| 175 |
+
def process_func(examples, tokenizer, max_seq_length=-1, conv_tempt="husky", num_query_tokens=256):
|
| 176 |
+
conversations, conv = format_inputs(examples['conversations'], conv_tempt, num_query_tokens)
|
| 177 |
+
if max_seq_length < 0:
|
| 178 |
+
model_inputs = tokenizer(
|
| 179 |
+
conversations,
|
| 180 |
+
return_tensors="pt",
|
| 181 |
+
max_length=tokenizer.model_max_length,
|
| 182 |
+
truncation=True,
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
model_inputs = tokenizer(
|
| 186 |
+
conversations,
|
| 187 |
+
max_length=max_seq_length,
|
| 188 |
+
padding="max_length",
|
| 189 |
+
truncation=True,
|
| 190 |
+
return_tensors="pt",
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
model_inputs.pop("token_type_ids", None)
|
| 194 |
+
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
|
| 195 |
+
# padding in the loss.
|
| 196 |
+
targets = model_inputs["input_ids"].clone()
|
| 197 |
+
|
| 198 |
+
# Mask targets
|
| 199 |
+
sep = conv.sep + conv.roles[1] + ": "
|
| 200 |
+
for conversation, target in zip(conversations, targets):
|
| 201 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
| 202 |
+
|
| 203 |
+
turns = conversation.split(conv.sep2)
|
| 204 |
+
cur_len = 1
|
| 205 |
+
target[:cur_len] = IGNORE_INDEX
|
| 206 |
+
for i, turn in enumerate(turns):
|
| 207 |
+
if turn == "":
|
| 208 |
+
break
|
| 209 |
+
turn_len = len(tokenizer(turn).input_ids)
|
| 210 |
+
|
| 211 |
+
parts = turn.split(sep)
|
| 212 |
+
if len(parts) != 2:
|
| 213 |
+
break
|
| 214 |
+
parts[0] += sep
|
| 215 |
+
|
| 216 |
+
# "-2" is hardcoded for the Llama tokenizer to make the offset correct.
|
| 217 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
| 218 |
+
|
| 219 |
+
if i != 0 and not tokenizer.legacy:
|
| 220 |
+
# The legacy and non-legacy modes handle special tokens differently
|
| 221 |
+
instruction_len -= 1
|
| 222 |
+
|
| 223 |
+
# Ignore the user instructions
|
| 224 |
+
target[cur_len: cur_len + instruction_len] = IGNORE_INDEX
|
| 225 |
+
cur_len += turn_len
|
| 226 |
+
|
| 227 |
+
if i != 0 and not tokenizer.legacy:
|
| 228 |
+
# The legacy and non-legacy modes handle special tokens differently
|
| 229 |
+
cur_len -= 1
|
| 230 |
+
|
| 231 |
+
target[cur_len:] = IGNORE_INDEX
|
| 232 |
+
|
| 233 |
+
if cur_len < tokenizer.model_max_length:
|
| 234 |
+
if cur_len != total_len:
|
| 235 |
+
target[:] = IGNORE_INDEX
|
| 236 |
+
|
| 237 |
+
model_inputs["labels"] = targets
|
| 238 |
+
return model_inputs
|
| 239 |
+
|
| 240 |
+
class BaseDataset(Dataset):
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
dataset,
|
| 244 |
+
processor,
|
| 245 |
+
image_path="",
|
| 246 |
+
input_size=224,
|
| 247 |
+
num_segments=8,
|
| 248 |
+
norm_type="openai",
|
| 249 |
+
media_type="image"
|
| 250 |
+
):
|
| 251 |
+
super(BaseDataset, self).__init__()
|
| 252 |
+
self.dataset = dataset
|
| 253 |
+
self.image_path = image_path
|
| 254 |
+
self.input_size = input_size
|
| 255 |
+
self.num_segments = num_segments
|
| 256 |
+
|
| 257 |
+
self.media_type = media_type
|
| 258 |
+
self.transform = build_transform(input_size, norm_type, media_type)
|
| 259 |
+
self.husky_processor = processor
|
| 260 |
+
self.tcs_loader = TCSLoader(os.path.abspath(conf_path), media_type=media_type)
|
| 261 |
+
|
| 262 |
+
self.cached_data_dict = {}
|
| 263 |
+
|
| 264 |
+
def __len__(self):
|
| 265 |
+
return len(self.dataset)
|
| 266 |
+
|
| 267 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 268 |
+
if i in self.cached_data_dict:
|
| 269 |
+
return self.cached_data_dict[i]
|
| 270 |
+
|
| 271 |
+
data = self.dataset[i]
|
| 272 |
+
image_file = data["image"] if "image" in data else data["video"]
|
| 273 |
+
|
| 274 |
+
if self.media_type == "llm" or image_file == "":
|
| 275 |
+
# Pseudo pixel_values
|
| 276 |
+
# pixel_values = torch.zeros(size=(3, self.input_size, self.input_size))
|
| 277 |
+
pixel_values = None
|
| 278 |
+
else:
|
| 279 |
+
if self.image_path != "":
|
| 280 |
+
image_file = os.path.join(self.image_path, image_file)
|
| 281 |
+
if "s3://" not in image_file and not os.path.exists(image_file):
|
| 282 |
+
i = random.randint(0, len(self.dataset))
|
| 283 |
+
return self.__getitem__(i % len(self.dataset))
|
| 284 |
+
|
| 285 |
+
try:
|
| 286 |
+
if self.media_type == "image":
|
| 287 |
+
# load from ceph
|
| 288 |
+
if "s3://" in image_file:
|
| 289 |
+
image = self.tcs_loader(image_file)
|
| 290 |
+
else:
|
| 291 |
+
image = Image.open(image_file).convert('RGB')
|
| 292 |
+
|
| 293 |
+
# process image with extreme aspect ratios
|
| 294 |
+
height, width = image.size
|
| 295 |
+
if height / width >= 1.8:
|
| 296 |
+
delta = height - width
|
| 297 |
+
padding = (0, delta // 2, 0, delta - delta // 2)
|
| 298 |
+
image = ImageOps.expand(image, padding)
|
| 299 |
+
elif height / width <= 0.56:
|
| 300 |
+
delta = width - height
|
| 301 |
+
padding = (delta // 2, 0, delta - delta // 2, 0)
|
| 302 |
+
image = ImageOps.expand(image, padding)
|
| 303 |
+
pixel_values = self.transform(image)
|
| 304 |
+
elif self.media_type == "video":
|
| 305 |
+
if "s3://" in image_file:
|
| 306 |
+
vr = self.tcs_loader(image_file)
|
| 307 |
+
else:
|
| 308 |
+
vr = VideoReader(image_file, ctx=cpu(0))
|
| 309 |
+
|
| 310 |
+
num_frames = len(vr)
|
| 311 |
+
frame_indices = get_index(num_frames, self.num_segments)
|
| 312 |
+
images_group = list()
|
| 313 |
+
for frame_index in frame_indices:
|
| 314 |
+
img = Image.fromarray(vr[frame_index].asnumpy())
|
| 315 |
+
images_group.append(img)
|
| 316 |
+
pixel_values = self.transform(images_group)
|
| 317 |
+
TC, H, W = pixel_values.shape
|
| 318 |
+
pixel_values = pixel_values.reshape(TC // 3, 3, H, W).transpose(0, 1) # [C, T, H, W]
|
| 319 |
+
else:
|
| 320 |
+
# load numpy
|
| 321 |
+
if "s3://" in image_file:
|
| 322 |
+
pixel_values = self.tcs_loader(image_file)
|
| 323 |
+
else:
|
| 324 |
+
pixel_values = np.load(image_file)
|
| 325 |
+
pixel_values = torch.tensor(pixel_values).transpose(0, 1)
|
| 326 |
+
except (AttributeError, OSError):
|
| 327 |
+
with open("error.txt", 'a') as f:
|
| 328 |
+
f.write(image_file + '\n')
|
| 329 |
+
i = random.randint(0, len(self.dataset))
|
| 330 |
+
return self.__getitem__(i % len(self.dataset))
|
| 331 |
+
|
| 332 |
+
for k, v in data.items():
|
| 333 |
+
data[k] = [v]
|
| 334 |
+
ret = self.husky_processor(data)
|
| 335 |
+
for k, v in ret.items():
|
| 336 |
+
ret[k] = v[0]
|
| 337 |
+
|
| 338 |
+
if pixel_values is not None:
|
| 339 |
+
ret["pixel_values"] = pixel_values
|
| 340 |
+
|
| 341 |
+
self.cached_data_dict[i] = ret
|
| 342 |
+
return ret
|
| 343 |
+
|
| 344 |
+
class WeightedConcatDataset(ConcatDataset):
|
| 345 |
+
def __init__(
|
| 346 |
+
self,
|
| 347 |
+
datasets: List[Dataset],
|
| 348 |
+
weights: Sequence[float] = None,
|
| 349 |
+
replacement: bool = True,
|
| 350 |
+
batch_size: int = -1,
|
| 351 |
+
generator=None
|
| 352 |
+
) -> None:
|
| 353 |
+
super().__init__(datasets)
|
| 354 |
+
if weights is None:
|
| 355 |
+
weights = [1.0] * len(self.datasets)
|
| 356 |
+
weights_tensor = torch.as_tensor(weights, dtype=torch.double)
|
| 357 |
+
if len(weights_tensor.shape) != 1:
|
| 358 |
+
raise ValueError("weights should be a 1d sequence but given "
|
| 359 |
+
"weights have shape {}".format(tuple(weights_tensor.shape)))
|
| 360 |
+
self.weights = weights_tensor
|
| 361 |
+
self.batch_size = batch_size
|
| 362 |
+
|
| 363 |
+
self.replacement = replacement
|
| 364 |
+
self.generator = generator
|
| 365 |
+
|
| 366 |
+
if self.batch_size <= 0:
|
| 367 |
+
self.num_samples = sum([len(d) for d in datasets])
|
| 368 |
+
self.sampler = WeightedRandomSampler(
|
| 369 |
+
weights=self.weights,
|
| 370 |
+
num_samples=self.num_samples,
|
| 371 |
+
replacement=self.replacement
|
| 372 |
+
)
|
| 373 |
+
else:
|
| 374 |
+
self.task_batches = [len(d) // batch_size for d in datasets]
|
| 375 |
+
self.num_samples = sum(self.task_batches) * batch_size
|
| 376 |
+
self.sampler = WeightedBatchSampler(
|
| 377 |
+
weights=self.weights,
|
| 378 |
+
num_samples=self.num_samples,
|
| 379 |
+
batch_size=self.batch_size,
|
| 380 |
+
replacement=self.replacement
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
def __iter__(self) -> Iterator[int]:
|
| 384 |
+
return iter(self.sampler)
|
| 385 |
+
|
| 386 |
+
def __len__(self) -> int:
|
| 387 |
+
return self.num_samples
|
| 388 |
+
|
| 389 |
+
class WeightedBatchSampler(Sampler[int]):
|
| 390 |
+
weights: torch.Tensor
|
| 391 |
+
num_samples: int
|
| 392 |
+
batch_size: int
|
| 393 |
+
replacement: bool
|
| 394 |
+
|
| 395 |
+
def __init__(
|
| 396 |
+
self,
|
| 397 |
+
weights: Sequence[float],
|
| 398 |
+
num_samples: int,
|
| 399 |
+
batch_size: int,
|
| 400 |
+
replacement: bool = True,
|
| 401 |
+
generator=None
|
| 402 |
+
) -> None:
|
| 403 |
+
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
|
| 404 |
+
batch_size <= 0:
|
| 405 |
+
raise ValueError("batch_size should be a positive integer value, "
|
| 406 |
+
"but got batch_size={}".format(batch_size))
|
| 407 |
+
if not isinstance(num_samples, int) or isinstance(num_samples, bool) or \
|
| 408 |
+
num_samples <= 0:
|
| 409 |
+
raise ValueError("num_samples should be a positive integer "
|
| 410 |
+
"value, but got num_samples={}".format(num_samples))
|
| 411 |
+
if not isinstance(replacement, bool):
|
| 412 |
+
raise ValueError("replacement should be a boolean value, but got "
|
| 413 |
+
"replacement={}".format(replacement))
|
| 414 |
+
|
| 415 |
+
weights_tensor = torch.as_tensor(weights, dtype=torch.double)
|
| 416 |
+
if len(weights_tensor.shape) != 1:
|
| 417 |
+
raise ValueError("weights should be a 1d sequence but given "
|
| 418 |
+
"weights have shape {}".format(tuple(weights_tensor.shape)))
|
| 419 |
+
|
| 420 |
+
self.weights = weights_tensor
|
| 421 |
+
self.num_samples = num_samples
|
| 422 |
+
self.batch_size = batch_size
|
| 423 |
+
self.num_batches = num_samples // batch_size
|
| 424 |
+
self.replacement = replacement
|
| 425 |
+
self.generator = generator
|
| 426 |
+
|
| 427 |
+
def __iter__(self) -> Iterator[int]:
|
| 428 |
+
rand_tensor = torch.multinomial(self.weights, self.num_batches, self.replacement, generator=self.generator)
|
| 429 |
+
rand_tensor = rand_tensor.repeat_interleave(self.batch_size)
|
| 430 |
+
|
| 431 |
+
yield from iter(rand_tensor.tolist())
|
| 432 |
+
|
| 433 |
+
def __len__(self) -> int:
|
| 434 |
+
return self.num_samples
|
robohusky/compression.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
import gc
|
| 3 |
+
import glob
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from accelerate import init_empty_weights
|
| 7 |
+
from accelerate.utils import set_module_tensor_to_device
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
from torch.nn import functional as F
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclasses.dataclass
|
| 17 |
+
class CompressionConfig:
|
| 18 |
+
"""Group-wise quantization."""
|
| 19 |
+
|
| 20 |
+
num_bits: int
|
| 21 |
+
group_size: int
|
| 22 |
+
group_dim: int
|
| 23 |
+
symmetric: bool
|
| 24 |
+
enabled: bool = True
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
default_compression_config = CompressionConfig(
|
| 28 |
+
num_bits=8, group_size=256, group_dim=1, symmetric=True, enabled=True
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class CLinear(nn.Module):
|
| 33 |
+
"""Compressed Linear Layer."""
|
| 34 |
+
|
| 35 |
+
def __init__(self, weight=None, bias=None, device=None):
|
| 36 |
+
super().__init__()
|
| 37 |
+
if weight is None:
|
| 38 |
+
self.weight = None
|
| 39 |
+
elif isinstance(weight, Tensor):
|
| 40 |
+
self.weight = compress(weight.data.to(device), default_compression_config)
|
| 41 |
+
else:
|
| 42 |
+
self.weight = weight
|
| 43 |
+
self.bias = bias
|
| 44 |
+
|
| 45 |
+
def forward(self, input: Tensor) -> Tensor:
|
| 46 |
+
weight = decompress(self.weight, default_compression_config)
|
| 47 |
+
if self.bias is None:
|
| 48 |
+
return F.linear(input.to(weight.dtype), weight)
|
| 49 |
+
return F.linear(input.to(weight.dtype), weight, self.bias.to(weight.dtype))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def compress_module(module, target_device):
|
| 53 |
+
for attr_str in dir(module):
|
| 54 |
+
target_attr = getattr(module, attr_str)
|
| 55 |
+
if type(target_attr) == torch.nn.Linear:
|
| 56 |
+
setattr(
|
| 57 |
+
module,
|
| 58 |
+
attr_str,
|
| 59 |
+
CLinear(target_attr.weight, target_attr.bias, target_device),
|
| 60 |
+
)
|
| 61 |
+
for name, child in module.named_children():
|
| 62 |
+
compress_module(child, target_device)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def get_compressed_list(module, prefix=""):
|
| 66 |
+
compressed_list = []
|
| 67 |
+
for attr_str in dir(module):
|
| 68 |
+
target_attr = getattr(module, attr_str)
|
| 69 |
+
if type(target_attr) == torch.nn.Linear:
|
| 70 |
+
full_name = (
|
| 71 |
+
f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight"
|
| 72 |
+
)
|
| 73 |
+
compressed_list.append(full_name)
|
| 74 |
+
for name, child in module.named_children():
|
| 75 |
+
child_prefix = f"{prefix}.{name}" if prefix else name
|
| 76 |
+
for each in get_compressed_list(child, child_prefix):
|
| 77 |
+
compressed_list.append(each)
|
| 78 |
+
return compressed_list
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def apply_compressed_weight(module, compressed_state_dict, target_device, prefix=""):
|
| 82 |
+
for attr_str in dir(module):
|
| 83 |
+
target_attr = getattr(module, attr_str)
|
| 84 |
+
if type(target_attr) == torch.nn.Linear:
|
| 85 |
+
full_name = (
|
| 86 |
+
f"{prefix}.{attr_str}.weight" if prefix else f"{attr_str}.weight"
|
| 87 |
+
)
|
| 88 |
+
setattr(
|
| 89 |
+
module,
|
| 90 |
+
attr_str,
|
| 91 |
+
CLinear(
|
| 92 |
+
compressed_state_dict[full_name], target_attr.bias, target_device
|
| 93 |
+
),
|
| 94 |
+
)
|
| 95 |
+
for name, child in module.named_children():
|
| 96 |
+
child_prefix = f"{prefix}.{name}" if prefix else name
|
| 97 |
+
apply_compressed_weight(
|
| 98 |
+
child, compressed_state_dict, target_device, child_prefix
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def load_compress_model(model_path, device, torch_dtype, use_fast=False):
|
| 103 |
+
# partially load model
|
| 104 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=use_fast)
|
| 105 |
+
base_pattern = os.path.join(model_path, "pytorch_model*.bin")
|
| 106 |
+
files = glob.glob(base_pattern)
|
| 107 |
+
|
| 108 |
+
with init_empty_weights():
|
| 109 |
+
config = AutoConfig.from_pretrained(
|
| 110 |
+
model_path, low_cpu_mem_usage=True, torch_dtype=torch_dtype
|
| 111 |
+
)
|
| 112 |
+
model = AutoModelForCausalLM.from_config(config)
|
| 113 |
+
linear_weights = get_compressed_list(model)
|
| 114 |
+
|
| 115 |
+
compressed_state_dict = {}
|
| 116 |
+
|
| 117 |
+
for filename in tqdm(files):
|
| 118 |
+
tmp_state_dict = torch.load(filename)
|
| 119 |
+
for name in tmp_state_dict:
|
| 120 |
+
if name in linear_weights:
|
| 121 |
+
tensor = tmp_state_dict[name].to(device).data.to(torch_dtype)
|
| 122 |
+
compressed_state_dict[name] = compress(
|
| 123 |
+
tensor, default_compression_config
|
| 124 |
+
)
|
| 125 |
+
else:
|
| 126 |
+
compressed_state_dict[name] = tmp_state_dict[name].to(device)
|
| 127 |
+
tmp_state_dict[name] = None
|
| 128 |
+
tensor = None
|
| 129 |
+
gc.collect()
|
| 130 |
+
torch.cuda.empty_cache()
|
| 131 |
+
|
| 132 |
+
for name in model.state_dict():
|
| 133 |
+
if name not in linear_weights:
|
| 134 |
+
set_module_tensor_to_device(
|
| 135 |
+
model, name, device, value=compressed_state_dict[name]
|
| 136 |
+
)
|
| 137 |
+
apply_compressed_weight(model, compressed_state_dict, device)
|
| 138 |
+
|
| 139 |
+
model.to(device)
|
| 140 |
+
|
| 141 |
+
return model, tokenizer
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def compress(tensor, config):
|
| 145 |
+
"""Simulate group-wise quantization."""
|
| 146 |
+
if not config.enabled:
|
| 147 |
+
return tensor
|
| 148 |
+
|
| 149 |
+
group_size, num_bits, group_dim, symmetric = (
|
| 150 |
+
config.group_size,
|
| 151 |
+
config.num_bits,
|
| 152 |
+
config.group_dim,
|
| 153 |
+
config.symmetric,
|
| 154 |
+
)
|
| 155 |
+
assert num_bits <= 8
|
| 156 |
+
|
| 157 |
+
original_shape = tensor.shape
|
| 158 |
+
num_groups = (original_shape[group_dim] + group_size - 1) // group_size
|
| 159 |
+
new_shape = (
|
| 160 |
+
original_shape[:group_dim]
|
| 161 |
+
+ (num_groups, group_size)
|
| 162 |
+
+ original_shape[group_dim + 1 :]
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Pad
|
| 166 |
+
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
|
| 167 |
+
if pad_len != 0:
|
| 168 |
+
pad_shape = (
|
| 169 |
+
original_shape[:group_dim] + (pad_len,) + original_shape[group_dim + 1 :]
|
| 170 |
+
)
|
| 171 |
+
tensor = torch.cat(
|
| 172 |
+
[tensor, torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)],
|
| 173 |
+
dim=group_dim,
|
| 174 |
+
)
|
| 175 |
+
data = tensor.view(new_shape)
|
| 176 |
+
|
| 177 |
+
# Quantize
|
| 178 |
+
if symmetric:
|
| 179 |
+
B = 2 ** (num_bits - 1) - 1
|
| 180 |
+
scale = B / torch.max(data.abs(), dim=group_dim + 1, keepdim=True)[0]
|
| 181 |
+
data = data * scale
|
| 182 |
+
data = data.clamp_(-B, B).round_().to(torch.int8)
|
| 183 |
+
return data, scale, original_shape
|
| 184 |
+
else:
|
| 185 |
+
B = 2**num_bits - 1
|
| 186 |
+
mn = torch.min(data, dim=group_dim + 1, keepdim=True)[0]
|
| 187 |
+
mx = torch.max(data, dim=group_dim + 1, keepdim=True)[0]
|
| 188 |
+
|
| 189 |
+
scale = B / (mx - mn)
|
| 190 |
+
data = data - mn
|
| 191 |
+
data.mul_(scale)
|
| 192 |
+
|
| 193 |
+
data = data.clamp_(0, B).round_().to(torch.uint8)
|
| 194 |
+
return data, mn, scale, original_shape
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def decompress(packed_data, config):
|
| 198 |
+
"""Simulate group-wise dequantization."""
|
| 199 |
+
if not config.enabled:
|
| 200 |
+
return packed_data
|
| 201 |
+
|
| 202 |
+
group_size, num_bits, group_dim, symmetric = (
|
| 203 |
+
config.group_size,
|
| 204 |
+
config.num_bits,
|
| 205 |
+
config.group_dim,
|
| 206 |
+
config.symmetric,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# Dequantize
|
| 210 |
+
if symmetric:
|
| 211 |
+
data, scale, original_shape = packed_data
|
| 212 |
+
data = data / scale
|
| 213 |
+
else:
|
| 214 |
+
data, mn, scale, original_shape = packed_data
|
| 215 |
+
data = data / scale
|
| 216 |
+
data.add_(mn)
|
| 217 |
+
|
| 218 |
+
# Unpad
|
| 219 |
+
pad_len = (group_size - original_shape[group_dim] % group_size) % group_size
|
| 220 |
+
if pad_len:
|
| 221 |
+
padded_original_shape = (
|
| 222 |
+
original_shape[:group_dim]
|
| 223 |
+
+ (original_shape[group_dim] + pad_len,)
|
| 224 |
+
+ original_shape[group_dim + 1 :]
|
| 225 |
+
)
|
| 226 |
+
data = data.reshape(padded_original_shape)
|
| 227 |
+
indices = [slice(0, x) for x in original_shape]
|
| 228 |
+
return data[indices].contiguous()
|
| 229 |
+
else:
|
| 230 |
+
return data.view(original_shape)
|
robohusky/configuration_husky.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
""" Husky model configuration"""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
import os
|
| 19 |
+
from typing import Union
|
| 20 |
+
|
| 21 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 22 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 23 |
+
from transformers.utils import logging
|
| 24 |
+
from transformers.models.auto import CONFIG_MAPPING
|
| 25 |
+
|
| 26 |
+
logger = logging.get_logger(__name__)
|
| 27 |
+
|
| 28 |
+
HUSKY_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
| 29 |
+
"wofmanaf/husky-7b": "https://huggingface.co/wofmanaf/husky-7b/resolve/main/config.json",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
class HuskyVisionConfig(PretrainedConfig):
|
| 33 |
+
r"""
|
| 34 |
+
This is the configuration class to store the configuration of a [`HuskyVisionModel`]. It is used to
|
| 35 |
+
instantiate a Husky vision encoder according to the specified arguments, defining the model architecture.
|
| 36 |
+
Instantiating a configuration defaults will yield a similar configuration to that of the Husky architecture.
|
| 37 |
+
|
| 38 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 39 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
hidden_size (`int`, *optional*, defaults to 1408):
|
| 43 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 44 |
+
intermediate_size (`int`, *optional*, defaults to 6144):
|
| 45 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 46 |
+
num_hidden_layers (`int`, *optional*, defaults to 39):
|
| 47 |
+
Number of hidden layers in the Transformer encoder.
|
| 48 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 49 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 50 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 51 |
+
The size (resolution) of each image.
|
| 52 |
+
patch_size (`int`, *optional*, defaults to 14):
|
| 53 |
+
The size (resolution) of each patch.
|
| 54 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 55 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 56 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults
|
| 57 |
+
to 1e-5): The epsilon used by the layer normalization layers.
|
| 58 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 59 |
+
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
| 60 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 61 |
+
The dropout ratio for the attention probabilities.
|
| 62 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 63 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 64 |
+
initializer_factor (`float``, *optional*, defaults to 1):
|
| 65 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
| 66 |
+
testing).
|
| 67 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 68 |
+
Whether to add a bias to the queries and values in the self-attention layers.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
model_type = "husky_vision_model"
|
| 72 |
+
|
| 73 |
+
def __init__(
|
| 74 |
+
self,
|
| 75 |
+
hidden_size=1408,
|
| 76 |
+
intermediate_size=6144,
|
| 77 |
+
projection_dim=512,
|
| 78 |
+
num_hidden_layers=39,
|
| 79 |
+
num_attention_heads=16,
|
| 80 |
+
num_channels=3,
|
| 81 |
+
image_size=224,
|
| 82 |
+
patch_size=14,
|
| 83 |
+
hidden_act="gelu",
|
| 84 |
+
layer_norm_eps=0.00001,
|
| 85 |
+
dropout=0.0,
|
| 86 |
+
attention_dropout=0.0,
|
| 87 |
+
initializer_range=1e-10,
|
| 88 |
+
initializer_factor=1.0,
|
| 89 |
+
qkv_bias=True,
|
| 90 |
+
**kwargs,
|
| 91 |
+
):
|
| 92 |
+
super().__init__(**kwargs)
|
| 93 |
+
|
| 94 |
+
self.hidden_size = hidden_size
|
| 95 |
+
self.intermediate_size = intermediate_size
|
| 96 |
+
self.projection_dim = projection_dim
|
| 97 |
+
self.dropout = dropout
|
| 98 |
+
self.num_hidden_layers = num_hidden_layers
|
| 99 |
+
self.num_attention_heads = num_attention_heads
|
| 100 |
+
self.num_channels = num_channels
|
| 101 |
+
self.patch_size = patch_size
|
| 102 |
+
self.image_size = image_size
|
| 103 |
+
self.initializer_range = initializer_range
|
| 104 |
+
self.initializer_factor = initializer_factor
|
| 105 |
+
self.attention_dropout = attention_dropout
|
| 106 |
+
self.layer_norm_eps = layer_norm_eps
|
| 107 |
+
self.hidden_act = hidden_act
|
| 108 |
+
self.qkv_bias = qkv_bias
|
| 109 |
+
|
| 110 |
+
@classmethod
|
| 111 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 112 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 113 |
+
|
| 114 |
+
# get the vision config dict if we are loading from HuskyConfig
|
| 115 |
+
if config_dict.get("model_type") == "husky":
|
| 116 |
+
config_dict = config_dict["vision_config"]
|
| 117 |
+
|
| 118 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 119 |
+
logger.warning(
|
| 120 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 121 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 125 |
+
|
| 126 |
+
class HuskyQFormerConfig(PretrainedConfig):
|
| 127 |
+
r"""
|
| 128 |
+
This is the configuration class to store the configuration of a [`HuskyQFormerModel`]. It is used to
|
| 129 |
+
instantiate a Husky Querying Transformer (Q-Former) model according to the specified arguments, defining the
|
| 130 |
+
model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
| 131 |
+
the Husky [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5)
|
| 132 |
+
architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
|
| 133 |
+
Read the documentation from [`PretrainedConfig`] for more information.
|
| 134 |
+
|
| 135 |
+
Note that [`HuskyQFormerModel`] is very similar to [`BertLMHeadModel`] with interleaved cross-attention.
|
| 136 |
+
|
| 137 |
+
Args:
|
| 138 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
| 139 |
+
Vocabulary size of the Q-Former model. Defines the number of different tokens that can be represented by
|
| 140 |
+
the `inputs_ids` passed when calling the model.
|
| 141 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 142 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 143 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 144 |
+
Number of hidden layers in the Transformer encoder.
|
| 145 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 146 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 147 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 148 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
| 149 |
+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
| 150 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 151 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 152 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 153 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 154 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 155 |
+
The dropout ratio for the attention probabilities.
|
| 156 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 157 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 158 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 159 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 160 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 161 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 162 |
+
The epsilon used by the layer normalization layers.
|
| 163 |
+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
| 164 |
+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
| 165 |
+
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
| 166 |
+
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
|
| 167 |
+
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
| 168 |
+
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
|
| 169 |
+
classifier_dropout (`float`, *optional*):
|
| 170 |
+
The dropout ratio for the classification head.
|
| 171 |
+
cross_attention_frequency (`int`, *optional*, defaults to 2):
|
| 172 |
+
The frequency of adding cross-attention to the Transformer layers.
|
| 173 |
+
encoder_hidden_size (`int`, *optional*, defaults to 1408):
|
| 174 |
+
The hidden size of the hidden states for cross-attention.
|
| 175 |
+
"""
|
| 176 |
+
model_type = "husky_qformer"
|
| 177 |
+
|
| 178 |
+
def __init__(
|
| 179 |
+
self,
|
| 180 |
+
vocab_size=30522,
|
| 181 |
+
hidden_size=768,
|
| 182 |
+
num_hidden_layers=12,
|
| 183 |
+
num_attention_heads=12,
|
| 184 |
+
intermediate_size=3072,
|
| 185 |
+
hidden_act="gelu",
|
| 186 |
+
hidden_dropout_prob=0.1,
|
| 187 |
+
attention_probs_dropout_prob=0.1,
|
| 188 |
+
max_position_embeddings=512,
|
| 189 |
+
initializer_range=0.02,
|
| 190 |
+
layer_norm_eps=1e-12,
|
| 191 |
+
pad_token_id=0,
|
| 192 |
+
position_embedding_type="absolute",
|
| 193 |
+
classifier_dropout=None,
|
| 194 |
+
cross_attention_frequency=2,
|
| 195 |
+
encoder_hidden_size=1408,
|
| 196 |
+
**kwargs,
|
| 197 |
+
):
|
| 198 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 199 |
+
|
| 200 |
+
self.vocab_size = vocab_size
|
| 201 |
+
self.hidden_size = hidden_size
|
| 202 |
+
self.num_hidden_layers = num_hidden_layers
|
| 203 |
+
self.num_attention_heads = num_attention_heads
|
| 204 |
+
self.hidden_act = hidden_act
|
| 205 |
+
self.intermediate_size = intermediate_size
|
| 206 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 207 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 208 |
+
self.max_position_embeddings = max_position_embeddings
|
| 209 |
+
self.initializer_range = initializer_range
|
| 210 |
+
self.layer_norm_eps = layer_norm_eps
|
| 211 |
+
self.position_embedding_type = position_embedding_type
|
| 212 |
+
self.classifier_dropout = classifier_dropout
|
| 213 |
+
self.cross_attention_frequency = cross_attention_frequency
|
| 214 |
+
self.encoder_hidden_size = encoder_hidden_size
|
| 215 |
+
|
| 216 |
+
@classmethod
|
| 217 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 218 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 219 |
+
# get the qformer config dict if we are loading from HuskyConfig
|
| 220 |
+
if config_dict.get("model_type") == "husky":
|
| 221 |
+
config_dict = config_dict["qformer_config"]
|
| 222 |
+
|
| 223 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 224 |
+
logger.warning(
|
| 225 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 226 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 230 |
+
|
| 231 |
+
class HuskyConfig(PretrainedConfig):
|
| 232 |
+
r"""
|
| 233 |
+
[`HuskyConfig`] is the configuration class to store the configuration of a
|
| 234 |
+
[`HuskyForConditionalGeneration`]. It is used to instantiate a Husky model according to the specified
|
| 235 |
+
arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
|
| 236 |
+
the defaults will yield a similar configuration to that of the Husky
|
| 237 |
+
[Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
|
| 238 |
+
|
| 239 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 240 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
vision_config (`dict`, *optional*):
|
| 244 |
+
Dictionary of configuration options used to initialize [`HuskyVisionConfig`].
|
| 245 |
+
qformer_config (`dict`, *optional*):
|
| 246 |
+
Dictionary of configuration options used to initialize [`HuskyQFormerConfig`].
|
| 247 |
+
text_config (`dict`, *optional*):
|
| 248 |
+
Dictionary of configuration options used to initialize any [`PretrainedConfig`].
|
| 249 |
+
num_query_tokens (`int`, *optional*, defaults to 32):
|
| 250 |
+
The number of query tokens passed through the Transformer.
|
| 251 |
+
|
| 252 |
+
kwargs (*optional*):
|
| 253 |
+
Dictionary of keyword arguments.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
model_type = "husky"
|
| 257 |
+
is_composition = True
|
| 258 |
+
|
| 259 |
+
def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
|
| 260 |
+
super().__init__(**kwargs)
|
| 261 |
+
|
| 262 |
+
if vision_config is None:
|
| 263 |
+
vision_config = {}
|
| 264 |
+
logger.info("vision_config is None. initializing the HuskyVisionConfig with default values.")
|
| 265 |
+
|
| 266 |
+
if qformer_config is None:
|
| 267 |
+
qformer_config = {}
|
| 268 |
+
logger.info("qformer_config is None. Initializing the HuskyQFormerConfig with default values.")
|
| 269 |
+
|
| 270 |
+
if text_config is None:
|
| 271 |
+
text_config = {}
|
| 272 |
+
logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
|
| 273 |
+
|
| 274 |
+
self.vision_config = HuskyVisionConfig(**vision_config)
|
| 275 |
+
self.qformer_config = HuskyQFormerConfig(**qformer_config)
|
| 276 |
+
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
|
| 277 |
+
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
|
| 278 |
+
|
| 279 |
+
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
| 280 |
+
self.is_encoder_decoder = self.text_config.is_encoder_decoder
|
| 281 |
+
|
| 282 |
+
self.num_query_tokens = num_query_tokens
|
| 283 |
+
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
| 284 |
+
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 285 |
+
self.initializer_factor = 1.0
|
| 286 |
+
self.initializer_range = 0.02
|
| 287 |
+
|
| 288 |
+
@classmethod
|
| 289 |
+
def from_vision_qformer_text_configs(
|
| 290 |
+
cls,
|
| 291 |
+
vision_config: HuskyVisionConfig,
|
| 292 |
+
qformer_config: HuskyQFormerConfig,
|
| 293 |
+
text_config: PretrainedConfig,
|
| 294 |
+
**kwargs,
|
| 295 |
+
):
|
| 296 |
+
r"""
|
| 297 |
+
Instantiate a [`HuskyConfig`] (or a derived class) from a Husky vision model, Q-Former and
|
| 298 |
+
language model configurations.
|
| 299 |
+
|
| 300 |
+
Returns:
|
| 301 |
+
[`HuskyConfig`]: An instance of a configuration object
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
return cls(
|
| 305 |
+
vision_config=vision_config.to_dict(),
|
| 306 |
+
qformer_config=qformer_config.to_dict(),
|
| 307 |
+
text_config=text_config.to_dict(),
|
| 308 |
+
**kwargs,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def to_dict(self):
|
| 312 |
+
"""
|
| 313 |
+
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
| 317 |
+
"""
|
| 318 |
+
output = copy.deepcopy(self.__dict__)
|
| 319 |
+
output["vision_config"] = self.vision_config.to_dict()
|
| 320 |
+
output["qformer_config"] = self.qformer_config.to_dict()
|
| 321 |
+
output["text_config"] = self.text_config.to_dict()
|
| 322 |
+
output["model_type"] = self.__class__.model_type
|
| 323 |
+
return output
|
| 324 |
+
|
| 325 |
+
if __name__ == '__main__':
|
| 326 |
+
config = HuskyConfig.from_pretrain
|
robohusky/constants.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import IntEnum
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# For the gradio web server
|
| 5 |
+
SERVER_ERROR_MSG = (
|
| 6 |
+
"**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
| 7 |
+
)
|
| 8 |
+
MODERATION_MSG = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE FIX YOUR INPUT AND TRY AGAIN."
|
| 9 |
+
CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION."
|
| 10 |
+
INPUT_CHAR_LEN_LIMIT = 2560
|
| 11 |
+
CONVERSATION_LEN_LIMIT = 50
|
| 12 |
+
LOGDIR = "."
|
| 13 |
+
|
| 14 |
+
# For the controller and workers(could be overwritten through ENV variables.)
|
| 15 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = int(
|
| 16 |
+
os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90)
|
| 17 |
+
)
|
| 18 |
+
WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 30))
|
| 19 |
+
WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100))
|
| 20 |
+
WORKER_API_EMBEDDING_BATCH_SIZE = int(os.getenv("WORKER_API_EMBEDDING_BATCH_SIZE", 4))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ErrorCode(IntEnum):
|
| 24 |
+
"""
|
| 25 |
+
https://platform.openai.com/docs/guides/error-codes/api-errors
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
VALIDATION_TYPE_ERROR = 40001
|
| 29 |
+
|
| 30 |
+
INVALID_AUTH_KEY = 40101
|
| 31 |
+
INCORRECT_AUTH_KEY = 40102
|
| 32 |
+
NO_PERMISSION = 40103
|
| 33 |
+
|
| 34 |
+
INVALID_MODEL = 40301
|
| 35 |
+
PARAM_OUT_OF_RANGE = 40302
|
| 36 |
+
CONTEXT_OVERFLOW = 40303
|
| 37 |
+
|
| 38 |
+
RATE_LIMIT = 42901
|
| 39 |
+
QUOTA_EXCEEDED = 42902
|
| 40 |
+
ENGINE_OVERLOADED = 42903
|
| 41 |
+
|
| 42 |
+
INTERNAL_ERROR = 50001
|
| 43 |
+
CUDA_OUT_OF_MEMORY = 50002
|
| 44 |
+
GRADIO_REQUEST_ERROR = 50003
|
| 45 |
+
GRADIO_STREAM_UNKNOWN_ERROR = 50004
|
| 46 |
+
CONTROLLER_NO_WORKER = 50005
|
| 47 |
+
CONTROLLER_WORKER_TIMEOUT = 50006
|
robohusky/conversation.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Conversation prompt templates.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import dataclasses
|
| 6 |
+
from enum import auto, Enum
|
| 7 |
+
from typing import List, Any, Dict
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class SeparatorStyle(Enum):
|
| 11 |
+
"""Separator styles."""
|
| 12 |
+
|
| 13 |
+
ADD_COLON_SINGLE = auto()
|
| 14 |
+
ADD_COLON_TWO = auto()
|
| 15 |
+
ADD_COLON_SPACE_SINGLE = auto()
|
| 16 |
+
NO_COLON_SINGLE = auto()
|
| 17 |
+
ADD_NEW_LINE_SINGLE = auto()
|
| 18 |
+
DOLLY = auto()
|
| 19 |
+
RWKV = auto()
|
| 20 |
+
PHOENIX = auto()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclasses.dataclass
|
| 24 |
+
class Conversation:
|
| 25 |
+
"""A class that keeps all conversation history."""
|
| 26 |
+
|
| 27 |
+
# The name of this template
|
| 28 |
+
name: str
|
| 29 |
+
# The system prompt
|
| 30 |
+
system: str
|
| 31 |
+
# Two roles
|
| 32 |
+
roles: List[str]
|
| 33 |
+
# All messages. Each item is (role, message).
|
| 34 |
+
messages: List[List[str]]
|
| 35 |
+
# The number of few shot examples
|
| 36 |
+
offset: int
|
| 37 |
+
# Separators
|
| 38 |
+
sep_style: SeparatorStyle
|
| 39 |
+
sep: str
|
| 40 |
+
sep2: str = None
|
| 41 |
+
# Stop criteria (the default one is EOS token)
|
| 42 |
+
stop_str: str = None
|
| 43 |
+
# Stops generation if meeting any token in this list
|
| 44 |
+
stop_token_ids: List[int] = None
|
| 45 |
+
|
| 46 |
+
def get_prompt(self) -> str:
|
| 47 |
+
"""Get the prompt for generation."""
|
| 48 |
+
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
| 49 |
+
ret = self.system + self.sep
|
| 50 |
+
for role, message in self.messages:
|
| 51 |
+
if message:
|
| 52 |
+
ret += role + ": " + message + self.sep
|
| 53 |
+
else:
|
| 54 |
+
ret += role + ":"
|
| 55 |
+
return ret
|
| 56 |
+
elif self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
| 57 |
+
seps = [self.sep, self.sep2]
|
| 58 |
+
ret = self.system + seps[0]
|
| 59 |
+
for i, (role, message) in enumerate(self.messages):
|
| 60 |
+
if message:
|
| 61 |
+
ret += role + ": " + message + seps[i % 2]
|
| 62 |
+
else:
|
| 63 |
+
ret += role + ":"
|
| 64 |
+
return ret
|
| 65 |
+
elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
|
| 66 |
+
ret = self.system + self.sep
|
| 67 |
+
for role, message in self.messages:
|
| 68 |
+
if message:
|
| 69 |
+
ret += role + ": " + message + self.sep
|
| 70 |
+
else:
|
| 71 |
+
ret += role + ": " # must be end with a space
|
| 72 |
+
return ret
|
| 73 |
+
elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
| 74 |
+
ret = self.system
|
| 75 |
+
for role, message in self.messages:
|
| 76 |
+
if message:
|
| 77 |
+
ret += role + message + self.sep
|
| 78 |
+
else:
|
| 79 |
+
ret += role
|
| 80 |
+
return ret
|
| 81 |
+
elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
|
| 82 |
+
ret = self.system + self.sep
|
| 83 |
+
for role, message in self.messages:
|
| 84 |
+
if message:
|
| 85 |
+
ret += role + "\n" + message + self.sep
|
| 86 |
+
else:
|
| 87 |
+
ret += role + "\n"
|
| 88 |
+
return ret
|
| 89 |
+
elif self.sep_style == SeparatorStyle.DOLLY:
|
| 90 |
+
seps = [self.sep, self.sep2]
|
| 91 |
+
ret = self.system
|
| 92 |
+
for i, (role, message) in enumerate(self.messages):
|
| 93 |
+
if message:
|
| 94 |
+
ret += role + ":\n" + message + seps[i % 2]
|
| 95 |
+
if i % 2 == 1:
|
| 96 |
+
ret += "\n\n"
|
| 97 |
+
else:
|
| 98 |
+
ret += role + ":\n"
|
| 99 |
+
return ret
|
| 100 |
+
elif self.sep_style == SeparatorStyle.RWKV:
|
| 101 |
+
ret = self.system
|
| 102 |
+
for i, (role, message) in enumerate(self.messages):
|
| 103 |
+
if message:
|
| 104 |
+
ret += (
|
| 105 |
+
role
|
| 106 |
+
+ ": "
|
| 107 |
+
+ message.replace("\r\n", "\n").replace("\n\n", "\n")
|
| 108 |
+
)
|
| 109 |
+
ret += "\n\n"
|
| 110 |
+
else:
|
| 111 |
+
ret += role + ":"
|
| 112 |
+
return ret
|
| 113 |
+
elif self.sep_style == SeparatorStyle.PHOENIX:
|
| 114 |
+
ret = self.system
|
| 115 |
+
for role, message in self.messages:
|
| 116 |
+
if message:
|
| 117 |
+
ret += role + ": " + "<s>" + message + "</s>"
|
| 118 |
+
else:
|
| 119 |
+
ret += role + ": " + "<s>"
|
| 120 |
+
return ret
|
| 121 |
+
else:
|
| 122 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
| 123 |
+
|
| 124 |
+
def append_message(self, role: str, message: str):
|
| 125 |
+
"""Append a new message."""
|
| 126 |
+
self.messages.append([role, message])
|
| 127 |
+
|
| 128 |
+
def update_last_message(self, message: str):
|
| 129 |
+
"""Update the last output.
|
| 130 |
+
|
| 131 |
+
The last message is typically set to be None when constructing the prompt,
|
| 132 |
+
so we need to update it in-place after getting the response from a model.
|
| 133 |
+
"""
|
| 134 |
+
self.messages[-1][1] = message
|
| 135 |
+
|
| 136 |
+
def to_gradio_chatbot(self):
|
| 137 |
+
"""Convert the conversation to gradio chatbot format"""
|
| 138 |
+
ret = []
|
| 139 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
| 140 |
+
if i % 2 == 0:
|
| 141 |
+
ret.append([msg, None])
|
| 142 |
+
else:
|
| 143 |
+
ret[-1][-1] = msg
|
| 144 |
+
return ret
|
| 145 |
+
|
| 146 |
+
def to_openai_api_messages(self):
|
| 147 |
+
"""Convert the conversation to OpenAI chat completion format."""
|
| 148 |
+
ret = [{"role": "system", "content": self.system}]
|
| 149 |
+
|
| 150 |
+
for i, (_, msg) in enumerate(self.messages[self.offset:]):
|
| 151 |
+
if i % 2 == 0:
|
| 152 |
+
ret.append({"role": "user", "content": msg})
|
| 153 |
+
else:
|
| 154 |
+
if msg is not None:
|
| 155 |
+
ret.append({"role": "assistant", "content": msg})
|
| 156 |
+
return ret
|
| 157 |
+
|
| 158 |
+
def copy(self):
|
| 159 |
+
return Conversation(
|
| 160 |
+
name=self.name,
|
| 161 |
+
system=self.system,
|
| 162 |
+
roles=self.roles,
|
| 163 |
+
messages=[[x, y] for x, y in self.messages],
|
| 164 |
+
offset=self.offset,
|
| 165 |
+
sep_style=self.sep_style,
|
| 166 |
+
sep=self.sep,
|
| 167 |
+
sep2=self.sep2,
|
| 168 |
+
stop_str=self.stop_str,
|
| 169 |
+
stop_token_ids=self.stop_token_ids,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
def dict(self):
|
| 173 |
+
return {
|
| 174 |
+
"name": self.name,
|
| 175 |
+
"system": self.system,
|
| 176 |
+
"roles": self.roles,
|
| 177 |
+
"messages": self.messages,
|
| 178 |
+
"offset": self.offset,
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
# A global registry for all conversation templates
|
| 183 |
+
conv_templates: Dict[str, Conversation] = {}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def register_conv_template(template: Conversation, override: bool = False):
|
| 187 |
+
"""Register a new conversation template."""
|
| 188 |
+
if not override:
|
| 189 |
+
assert template.name not in conv_templates, f"{template.name} has been registered."
|
| 190 |
+
conv_templates[template.name] = template
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def get_conv_template(name: str) -> Conversation:
|
| 194 |
+
"""Get a conversation template."""
|
| 195 |
+
return conv_templates[name].copy()
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# A template with one conversation example
|
| 199 |
+
register_conv_template(
|
| 200 |
+
Conversation(
|
| 201 |
+
name="one_shot",
|
| 202 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
| 203 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
| 204 |
+
roles=("Human", "Assistant"),
|
| 205 |
+
messages=(
|
| 206 |
+
(
|
| 207 |
+
"Human",
|
| 208 |
+
"Got any creative ideas for a 10 year old’s birthday?",
|
| 209 |
+
),
|
| 210 |
+
(
|
| 211 |
+
"Assistant",
|
| 212 |
+
"""Of course! Here are some creative ideas for a 10-year-old's birthday party:
|
| 213 |
+
1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises.
|
| 214 |
+
2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions.
|
| 215 |
+
3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars.
|
| 216 |
+
4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors.
|
| 217 |
+
5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants.
|
| 218 |
+
6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen.
|
| 219 |
+
7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges.
|
| 220 |
+
8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors.
|
| 221 |
+
Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!""",
|
| 222 |
+
),
|
| 223 |
+
),
|
| 224 |
+
offset=2,
|
| 225 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
| 226 |
+
sep="\n### ",
|
| 227 |
+
stop_str="###",
|
| 228 |
+
)
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# Vicuna v1.1 template
|
| 232 |
+
register_conv_template(
|
| 233 |
+
Conversation(
|
| 234 |
+
name="vicuna_v1.1",
|
| 235 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
| 236 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
| 237 |
+
roles=("USER", "ASSISTANT"),
|
| 238 |
+
messages=(),
|
| 239 |
+
offset=0,
|
| 240 |
+
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
| 241 |
+
sep=" ",
|
| 242 |
+
sep2="</s>",
|
| 243 |
+
)
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Husky template
|
| 247 |
+
register_conv_template(
|
| 248 |
+
Conversation(
|
| 249 |
+
name="husky",
|
| 250 |
+
system="",
|
| 251 |
+
roles=("Human", "Assistant"),
|
| 252 |
+
messages=(),
|
| 253 |
+
offset=0,
|
| 254 |
+
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
| 255 |
+
sep=" ",
|
| 256 |
+
sep2="</s>",
|
| 257 |
+
)
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
# Koala default template
|
| 261 |
+
register_conv_template(
|
| 262 |
+
Conversation(
|
| 263 |
+
name="koala_v1",
|
| 264 |
+
system="BEGINNING OF CONVERSATION:",
|
| 265 |
+
roles=("USER", "GPT"),
|
| 266 |
+
messages=(),
|
| 267 |
+
offset=0,
|
| 268 |
+
sep_style=SeparatorStyle.ADD_COLON_TWO,
|
| 269 |
+
sep=" ",
|
| 270 |
+
sep2="</s>",
|
| 271 |
+
)
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Alpaca default template
|
| 275 |
+
register_conv_template(
|
| 276 |
+
Conversation(
|
| 277 |
+
name="alpaca",
|
| 278 |
+
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.",
|
| 279 |
+
roles=("### Instruction:", "### Response:"),
|
| 280 |
+
messages=(),
|
| 281 |
+
offset=0,
|
| 282 |
+
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
| 283 |
+
sep="\n\n",
|
| 284 |
+
)
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Dolly V2 default template
|
| 288 |
+
register_conv_template(
|
| 289 |
+
Conversation(
|
| 290 |
+
name="dolly_v2",
|
| 291 |
+
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
|
| 292 |
+
roles=("### Instruction", "### Response"),
|
| 293 |
+
messages=(),
|
| 294 |
+
offset=0,
|
| 295 |
+
sep_style=SeparatorStyle.DOLLY,
|
| 296 |
+
sep="\n\n",
|
| 297 |
+
sep2="### End",
|
| 298 |
+
)
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# OpenAssistant Pythia default template
|
| 302 |
+
register_conv_template(
|
| 303 |
+
Conversation(
|
| 304 |
+
name="oasst_pythia",
|
| 305 |
+
system="",
|
| 306 |
+
roles=("<|prompter|>", "<|assistant|>"),
|
| 307 |
+
messages=(),
|
| 308 |
+
offset=0,
|
| 309 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
| 310 |
+
sep="<|endoftext|>",
|
| 311 |
+
)
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# StableLM Alpha default template
|
| 315 |
+
register_conv_template(
|
| 316 |
+
Conversation(
|
| 317 |
+
name="stablelm",
|
| 318 |
+
system="""<|SYSTEM|># StableLM Tuned (Alpha version)
|
| 319 |
+
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
| 320 |
+
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
| 321 |
+
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
| 322 |
+
- StableLM will refuse to participate in anything that could harm a human.
|
| 323 |
+
""",
|
| 324 |
+
roles=("<|USER|>", "<|ASSISTANT|>"),
|
| 325 |
+
messages=(),
|
| 326 |
+
offset=0,
|
| 327 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
| 328 |
+
sep="",
|
| 329 |
+
stop_token_ids=[50278, 50279, 50277, 1, 0],
|
| 330 |
+
)
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# Baize default template
|
| 334 |
+
register_conv_template(
|
| 335 |
+
Conversation(
|
| 336 |
+
name="baize",
|
| 337 |
+
system="The following is a conversation between a human and an AI assistant named Baize (named after a mythical creature in Chinese folklore). Baize is an open-source AI assistant developed by UCSD and Sun Yat-Sen University. The human and the AI assistant take turns chatting. Human statements start with [|Human|] and AI assistant statements start with [|AI|]. The AI assistant always provides responses in as much detail as possible, and in Markdown format. The AI assistant always declines to engage with topics, questions and instructions related to unethical, controversial, or sensitive issues. Complete the transcript in exactly that format.\n",
|
| 338 |
+
roles=("[|Human|]", "[|AI|]"),
|
| 339 |
+
messages=(
|
| 340 |
+
("[|Human|]", "Hello!"),
|
| 341 |
+
("[|AI|]", "Hi!"),
|
| 342 |
+
),
|
| 343 |
+
offset=2,
|
| 344 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
| 345 |
+
sep="\n",
|
| 346 |
+
stop_str="[|Human|]",
|
| 347 |
+
)
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# RWKV-4-Raven default template
|
| 351 |
+
register_conv_template(
|
| 352 |
+
Conversation(
|
| 353 |
+
name="rwkv",
|
| 354 |
+
system="",
|
| 355 |
+
roles=("Bob", "Alice"),
|
| 356 |
+
messages=(
|
| 357 |
+
("Bob", "hi"),
|
| 358 |
+
(
|
| 359 |
+
"Alice",
|
| 360 |
+
"Hi. I am your assistant and I will provide expert full response in full details. Please feel free to ask any question and I will always answer it.",
|
| 361 |
+
),
|
| 362 |
+
),
|
| 363 |
+
offset=2,
|
| 364 |
+
sep_style=SeparatorStyle.RWKV,
|
| 365 |
+
sep="",
|
| 366 |
+
stop_str="\n\n",
|
| 367 |
+
)
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# Buddy default template
|
| 371 |
+
register_conv_template(
|
| 372 |
+
Conversation(
|
| 373 |
+
name="openbuddy",
|
| 374 |
+
system="""Consider a conversation between User (a human) and Assistant (named Buddy).
|
| 375 |
+
Buddy is an INTP-T, a friendly, intelligent and multilingual AI assistant, by OpenBuddy team. GitHub: https://github.com/OpenBuddy/OpenBuddy
|
| 376 |
+
Buddy cannot access the Internet.
|
| 377 |
+
Buddy can fluently speak the user's language (e.g. English, Chinese).
|
| 378 |
+
Buddy can generate poems, stories, code, essays, songs, parodies, and more.
|
| 379 |
+
Buddy possesses vast knowledge about the world, history, and culture.
|
| 380 |
+
Buddy's responses are always safe, creative, high-quality, human-like, and interesting.
|
| 381 |
+
Buddy strictly refuses to discuss political, NSFW, or other unsafe topics.
|
| 382 |
+
|
| 383 |
+
User: Hi.
|
| 384 |
+
Assistant: Hi, I'm Buddy, your AI assistant. How can I help you today?""",
|
| 385 |
+
roles=("User", "Assistant"),
|
| 386 |
+
messages=(),
|
| 387 |
+
offset=0,
|
| 388 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
| 389 |
+
sep="\n",
|
| 390 |
+
)
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# Phoenix default template
|
| 394 |
+
register_conv_template(
|
| 395 |
+
Conversation(
|
| 396 |
+
name="phoenix",
|
| 397 |
+
system="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n",
|
| 398 |
+
roles=("Human", "Assistant"),
|
| 399 |
+
messages=(),
|
| 400 |
+
offset=0,
|
| 401 |
+
sep_style=SeparatorStyle.PHOENIX,
|
| 402 |
+
sep="</s>",
|
| 403 |
+
)
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# ChatGPT default template
|
| 407 |
+
register_conv_template(
|
| 408 |
+
Conversation(
|
| 409 |
+
name="chatgpt",
|
| 410 |
+
system="You are a helpful assistant.",
|
| 411 |
+
roles=("user", "assistant"),
|
| 412 |
+
messages=(),
|
| 413 |
+
offset=0,
|
| 414 |
+
sep_style=None,
|
| 415 |
+
sep=None,
|
| 416 |
+
)
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# Claude default template
|
| 420 |
+
register_conv_template(
|
| 421 |
+
Conversation(
|
| 422 |
+
name="claude",
|
| 423 |
+
system="",
|
| 424 |
+
roles=("Human", "Assistant"),
|
| 425 |
+
messages=(),
|
| 426 |
+
offset=0,
|
| 427 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
| 428 |
+
sep="\n\n",
|
| 429 |
+
)
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
# MPT default template
|
| 433 |
+
register_conv_template(
|
| 434 |
+
Conversation(
|
| 435 |
+
name="mpt",
|
| 436 |
+
system="""<|im_start|>system
|
| 437 |
+
- You are a helpful assistant chatbot trained by MosaicML.
|
| 438 |
+
- You answer questions.
|
| 439 |
+
- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
| 440 |
+
- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.
|
| 441 |
+
""",
|
| 442 |
+
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
| 443 |
+
messages=(),
|
| 444 |
+
offset=0,
|
| 445 |
+
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
| 446 |
+
sep="<|im_end|>",
|
| 447 |
+
stop_token_ids=[50278, 0],
|
| 448 |
+
)
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# Bard default template
|
| 452 |
+
register_conv_template(
|
| 453 |
+
Conversation(
|
| 454 |
+
name="bard",
|
| 455 |
+
system="",
|
| 456 |
+
roles=("0", "1"),
|
| 457 |
+
messages=(),
|
| 458 |
+
offset=0,
|
| 459 |
+
sep_style=None,
|
| 460 |
+
sep=None,
|
| 461 |
+
)
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
# BiLLa default template
|
| 465 |
+
register_conv_template(
|
| 466 |
+
Conversation(
|
| 467 |
+
name="billa",
|
| 468 |
+
system="",
|
| 469 |
+
roles=("Human", "Assistant"),
|
| 470 |
+
messages=(),
|
| 471 |
+
offset=0,
|
| 472 |
+
sep_style=SeparatorStyle.ADD_COLON_SPACE_SINGLE,
|
| 473 |
+
sep="\n",
|
| 474 |
+
stop_str="Human:",
|
| 475 |
+
)
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# RedPajama INCITE default template
|
| 479 |
+
register_conv_template(
|
| 480 |
+
Conversation(
|
| 481 |
+
name="redpajama-incite",
|
| 482 |
+
system="",
|
| 483 |
+
roles=("<human>", "<bot>"),
|
| 484 |
+
messages=(),
|
| 485 |
+
offset=0,
|
| 486 |
+
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
|
| 487 |
+
sep="\n",
|
| 488 |
+
stop_str="<human>",
|
| 489 |
+
)
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
# h2oGPT default template
|
| 493 |
+
register_conv_template(
|
| 494 |
+
Conversation(
|
| 495 |
+
name="h2ogpt",
|
| 496 |
+
system="",
|
| 497 |
+
roles=("<|prompt|>", "<|answer|>"),
|
| 498 |
+
messages=(),
|
| 499 |
+
offset=0,
|
| 500 |
+
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
| 501 |
+
sep="</s>",
|
| 502 |
+
)
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
if __name__ == "__main__":
|
| 506 |
+
conv = get_conv_template("husky")
|
| 507 |
+
conv.append_message(conv.roles[0], "Hello!")
|
| 508 |
+
conv.append_message(conv.roles[1], "Hi!")
|
| 509 |
+
conv.append_message(conv.roles[0], "How are you?")
|
| 510 |
+
conv.append_message(conv.roles[1], None)
|
| 511 |
+
print(conv.get_prompt())
|
robohusky/convert_fp16.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Usage:
|
| 3 |
+
srun -p INTERN2 --job-name='convert_2_fp16' --gres=gpu:0 --cpus-per-task=8 --quotatype="auto" python -u husky/convert_fp16.py --in-checkpoint work_dirs/llm/husky-13b/zh_bell/checkpoint-9500 --out-checkpoint work_dirs/llm/husky-13b/zh_bell/
|
| 4 |
+
"""
|
| 5 |
+
import argparse
|
| 6 |
+
import os.path
|
| 7 |
+
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
def convert_fp16(in_checkpoint, out_checkpoint):
|
| 12 |
+
tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False)
|
| 13 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 14 |
+
in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=False
|
| 15 |
+
)
|
| 16 |
+
if not os.path.exists(out_checkpoint):
|
| 17 |
+
os.mkdir(out_checkpoint)
|
| 18 |
+
model.save_pretrained(out_checkpoint)
|
| 19 |
+
tokenizer.save_pretrained(out_checkpoint)
|
| 20 |
+
|
| 21 |
+
if __name__ == "__main__":
|
| 22 |
+
parser = argparse.ArgumentParser()
|
| 23 |
+
parser.add_argument("--in-checkpoint", type=str, help="Path to the model")
|
| 24 |
+
parser.add_argument("--out-checkpoint", type=str, help="Path to the output model")
|
| 25 |
+
args = parser.parse_args()
|
| 26 |
+
|
| 27 |
+
convert_fp16(args.in_checkpoint, args.out_checkpoint)
|
robohusky/convert_husky_fp16.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Usage:
|
| 3 |
+
srun -p INTERN2 --job-name='convert_2_fp16' --gres=gpu:0 --cpus-per-task=8 --quotatype="auto" python -u husky/convert_husky_fp16.py --in-checkpoint work_dirs/husky_v3/multi_align/checkpoint-48000 --out-checkpoint work_dirs/husky_v3/multi_align_fp16
|
| 4 |
+
"""
|
| 5 |
+
import argparse
|
| 6 |
+
import os.path
|
| 7 |
+
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
from husky.model.modeling_husky_multi import HuskyForConditionalGeneration
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
def convert_fp16(in_checkpoint, out_checkpoint):
|
| 13 |
+
tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False)
|
| 14 |
+
model = HuskyForConditionalGeneration.from_pretrained(
|
| 15 |
+
in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=False
|
| 16 |
+
)
|
| 17 |
+
if not os.path.exists(out_checkpoint):
|
| 18 |
+
os.mkdir(out_checkpoint)
|
| 19 |
+
model.save_pretrained(out_checkpoint)
|
| 20 |
+
tokenizer.save_pretrained(out_checkpoint)
|
| 21 |
+
|
| 22 |
+
if __name__ == "__main__":
|
| 23 |
+
parser = argparse.ArgumentParser()
|
| 24 |
+
parser.add_argument("--in-checkpoint", type=str, help="Path to the model")
|
| 25 |
+
parser.add_argument("--out-checkpoint", type=str, help="Path to the output model")
|
| 26 |
+
args = parser.parse_args()
|
| 27 |
+
|
| 28 |
+
convert_fp16(args.in_checkpoint, args.out_checkpoint)
|
robohusky/convert_reward_fp16.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Usage:
|
| 3 |
+
srun -p INTERN2 --job-name='convert_2_fp16' --gres=gpu:0 --cpus-per-task=8 --quotatype="auto" python -u husky/convert_reward_fp16.py --in-checkpoint work_dirs/llm/Ziya-LLaMA-7B-Reward --out-checkpoint work_dirs/llm/reward_model
|
| 4 |
+
"""
|
| 5 |
+
import argparse
|
| 6 |
+
import os.path
|
| 7 |
+
|
| 8 |
+
from transformers import LlamaTokenizer, AutoModelForSequenceClassification
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
def convert_fp16(in_checkpoint, out_checkpoint):
|
| 12 |
+
tokenizer = LlamaTokenizer.from_pretrained(in_checkpoint, use_fast=False)
|
| 13 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 14 |
+
in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=False
|
| 15 |
+
)
|
| 16 |
+
if not os.path.exists(out_checkpoint):
|
| 17 |
+
os.mkdir(out_checkpoint)
|
| 18 |
+
model.save_pretrained(out_checkpoint)
|
| 19 |
+
tokenizer.save_pretrained(out_checkpoint)
|
| 20 |
+
|
| 21 |
+
if __name__ == "__main__":
|
| 22 |
+
parser = argparse.ArgumentParser()
|
| 23 |
+
parser.add_argument("--in-checkpoint", type=str, help="Path to the model")
|
| 24 |
+
parser.add_argument("--out-checkpoint", type=str, help="Path to the output model")
|
| 25 |
+
args = parser.parse_args()
|
| 26 |
+
|
| 27 |
+
convert_fp16(args.in_checkpoint, args.out_checkpoint)
|
robohusky/dist_utils.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import os
|
| 3 |
+
import socket
|
| 4 |
+
import subprocess
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.multiprocessing as mp
|
| 9 |
+
from torch import distributed as dist
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def _find_free_port():
|
| 13 |
+
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
|
| 14 |
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 15 |
+
# Binding to port 0 will cause the OS to find an available port for us
|
| 16 |
+
sock.bind(('', 0))
|
| 17 |
+
port = sock.getsockname()[1]
|
| 18 |
+
sock.close()
|
| 19 |
+
# NOTE: there is still a chance the port could be taken by other processes.
|
| 20 |
+
return port
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _is_free_port(port):
|
| 24 |
+
ips = socket.gethostbyname_ex(socket.gethostname())[-1]
|
| 25 |
+
ips.append('localhost')
|
| 26 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 27 |
+
return all(s.connect_ex((ip, port)) != 0 for ip in ips)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def init_dist(launcher, backend='nccl', **kwargs):
|
| 31 |
+
if mp.get_start_method(allow_none=True) is None:
|
| 32 |
+
mp.set_start_method('spawn')
|
| 33 |
+
if launcher == 'pytorch':
|
| 34 |
+
_init_dist_pytorch(backend, **kwargs)
|
| 35 |
+
elif launcher == 'mpi':
|
| 36 |
+
_init_dist_mpi(backend, **kwargs)
|
| 37 |
+
elif launcher == 'slurm':
|
| 38 |
+
_init_dist_slurm(backend, **kwargs)
|
| 39 |
+
else:
|
| 40 |
+
raise ValueError(f'Invalid launcher type: {launcher}')
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _init_dist_pytorch(backend, **kwargs):
|
| 44 |
+
# TODO: use local_rank instead of rank % num_gpus
|
| 45 |
+
rank = int(os.environ['RANK'])
|
| 46 |
+
num_gpus = torch.cuda.device_count()
|
| 47 |
+
torch.cuda.set_device(rank % num_gpus)
|
| 48 |
+
dist.init_process_group(backend=backend, **kwargs)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _init_dist_mpi(backend, **kwargs):
|
| 52 |
+
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
|
| 53 |
+
torch.cuda.set_device(local_rank)
|
| 54 |
+
if 'MASTER_PORT' not in os.environ:
|
| 55 |
+
# 29500 is torch.distributed default port
|
| 56 |
+
os.environ['MASTER_PORT'] = '29500'
|
| 57 |
+
if 'MASTER_ADDR' not in os.environ:
|
| 58 |
+
raise KeyError('The environment variable MASTER_ADDR is not set')
|
| 59 |
+
os.environ['WORLD_SIZE'] = os.environ['OMPI_COMM_WORLD_SIZE']
|
| 60 |
+
os.environ['RANK'] = os.environ['OMPI_COMM_WORLD_RANK']
|
| 61 |
+
dist.init_process_group(backend=backend, **kwargs)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _init_dist_slurm(backend, port=None):
|
| 65 |
+
"""Initialize slurm distributed training environment.
|
| 66 |
+
|
| 67 |
+
If argument ``port`` is not specified, then the master port will be system
|
| 68 |
+
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
| 69 |
+
environment variable, then a default port ``29500`` will be used.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
backend (str): Backend of torch.distributed.
|
| 73 |
+
port (int, optional): Master port. Defaults to None.
|
| 74 |
+
"""
|
| 75 |
+
proc_id = int(os.environ['SLURM_PROCID'])
|
| 76 |
+
ntasks = int(os.environ['SLURM_NTASKS'])
|
| 77 |
+
node_list = os.environ['SLURM_NODELIST']
|
| 78 |
+
num_gpus = torch.cuda.device_count()
|
| 79 |
+
torch.cuda.set_device(proc_id % num_gpus)
|
| 80 |
+
addr = subprocess.getoutput(
|
| 81 |
+
f'scontrol show hostname {node_list} | head -n1')
|
| 82 |
+
# specify master port
|
| 83 |
+
if port is not None:
|
| 84 |
+
os.environ['MASTER_PORT'] = str(port)
|
| 85 |
+
elif 'MASTER_PORT' in os.environ:
|
| 86 |
+
pass # use MASTER_PORT in the environment variable
|
| 87 |
+
else:
|
| 88 |
+
# if torch.distributed default port(29500) is available
|
| 89 |
+
# then use it, else find a free port
|
| 90 |
+
if _is_free_port(29500):
|
| 91 |
+
os.environ['MASTER_PORT'] = '29500'
|
| 92 |
+
else:
|
| 93 |
+
os.environ['MASTER_PORT'] = str(_find_free_port())
|
| 94 |
+
# use MASTER_ADDR in the environment variable if it already exists
|
| 95 |
+
if 'MASTER_ADDR' not in os.environ:
|
| 96 |
+
os.environ['MASTER_ADDR'] = addr
|
| 97 |
+
os.environ['WORLD_SIZE'] = str(ntasks)
|
| 98 |
+
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
|
| 99 |
+
os.environ['RANK'] = str(proc_id)
|
| 100 |
+
dist.init_process_group(backend=backend)
|
robohusky/llama2_flash_attn_monkey_patch.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from flash_attn import __version__ as flash_attn_version
|
| 6 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 7 |
+
from flash_attn.flash_attn_interface import (
|
| 8 |
+
flash_attn_func,
|
| 9 |
+
flash_attn_varlen_kvpacked_func,
|
| 10 |
+
)
|
| 11 |
+
from transformers.models.llama.modeling_llama import (
|
| 12 |
+
LlamaAttention,
|
| 13 |
+
LlamaModel,
|
| 14 |
+
rotate_half,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
|
| 18 |
+
gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
|
| 19 |
+
gather_indices = gather_indices.repeat(
|
| 20 |
+
1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
|
| 21 |
+
)
|
| 22 |
+
bsz = gather_indices.shape[0]
|
| 23 |
+
cos, sin = (
|
| 24 |
+
torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
|
| 25 |
+
for x in cos_sin
|
| 26 |
+
)
|
| 27 |
+
q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
|
| 28 |
+
return q, k
|
| 29 |
+
|
| 30 |
+
def forward(
|
| 31 |
+
self,
|
| 32 |
+
hidden_states: torch.Tensor,
|
| 33 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 34 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 35 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 36 |
+
output_attentions: bool = False,
|
| 37 |
+
use_cache: bool = False,
|
| 38 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 39 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 40 |
+
if output_attentions:
|
| 41 |
+
warnings.warn(
|
| 42 |
+
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
bsz, q_len, _ = hidden_states.size()
|
| 46 |
+
kv_heads = getattr(self, "num_key_value_heads", self.num_heads)
|
| 47 |
+
|
| 48 |
+
q, k, v = (
|
| 49 |
+
op(hidden_states).view(bsz, q_len, nh, self.head_dim)
|
| 50 |
+
for op, nh in (
|
| 51 |
+
(self.q_proj, self.num_heads),
|
| 52 |
+
(self.k_proj, kv_heads),
|
| 53 |
+
(self.v_proj, kv_heads),
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
# shape: (b, s, num_heads, head_dim)
|
| 57 |
+
|
| 58 |
+
kv_seq_len = k.shape[1]
|
| 59 |
+
past_kv_len = 0
|
| 60 |
+
if past_key_value is not None:
|
| 61 |
+
past_kv_len = past_key_value[0].shape[2]
|
| 62 |
+
kv_seq_len += past_kv_len
|
| 63 |
+
|
| 64 |
+
cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
|
| 65 |
+
q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)
|
| 66 |
+
|
| 67 |
+
if past_key_value is not None:
|
| 68 |
+
assert (
|
| 69 |
+
flash_attn_version >= "2.1.0"
|
| 70 |
+
), "past_key_value support requires flash-attn >= 2.1.0"
|
| 71 |
+
# reuse k, v
|
| 72 |
+
k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
|
| 73 |
+
v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
|
| 74 |
+
|
| 75 |
+
past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
|
| 76 |
+
|
| 77 |
+
if attention_mask is None:
|
| 78 |
+
output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
|
| 79 |
+
bsz, q_len, -1
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
|
| 83 |
+
# We can skip concat and call unpad twice but seems better to call unpad only once.
|
| 84 |
+
kv, _, cu_k_lens, max_k = unpad_input(
|
| 85 |
+
torch.stack((k, v), dim=2), attention_mask
|
| 86 |
+
)
|
| 87 |
+
output_unpad = flash_attn_varlen_kvpacked_func(
|
| 88 |
+
q,
|
| 89 |
+
kv,
|
| 90 |
+
cu_q_lens,
|
| 91 |
+
cu_k_lens,
|
| 92 |
+
max_s,
|
| 93 |
+
max_k,
|
| 94 |
+
0.0,
|
| 95 |
+
softmax_scale=None,
|
| 96 |
+
causal=True,
|
| 97 |
+
)
|
| 98 |
+
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
|
| 99 |
+
output = pad_input(output_unpad, indices, bsz, q_len)
|
| 100 |
+
|
| 101 |
+
return self.o_proj(output), None, past_key_value
|
| 102 |
+
|
| 103 |
+
# Disable the transformation of the attention mask in LlamaModel as flash attention
|
| 104 |
+
# takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
|
| 105 |
+
def _prepare_decoder_attention_mask(
|
| 106 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
| 107 |
+
):
|
| 108 |
+
# [bsz, seq_len]
|
| 109 |
+
if past_key_values_length > 0 and attention_mask is not None:
|
| 110 |
+
attention_mask = torch.cat(
|
| 111 |
+
(
|
| 112 |
+
torch.full(
|
| 113 |
+
(input_shape[0], past_key_values_length),
|
| 114 |
+
True,
|
| 115 |
+
dtype=attention_mask.dtype,
|
| 116 |
+
device=attention_mask.device,
|
| 117 |
+
),
|
| 118 |
+
attention_mask,
|
| 119 |
+
),
|
| 120 |
+
dim=-1,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if attention_mask is not None and torch.all(attention_mask):
|
| 124 |
+
return None # This uses the faster call when training with full samples
|
| 125 |
+
|
| 126 |
+
return attention_mask
|
| 127 |
+
|
| 128 |
+
def replace_llama_attn_with_flash_attn():
|
| 129 |
+
cuda_major, cuda_minor = torch.cuda.get_device_capability()
|
| 130 |
+
if cuda_major < 8:
|
| 131 |
+
warnings.warn(
|
| 132 |
+
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
|
| 133 |
+
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
|
| 137 |
+
LlamaAttention.forward = forward
|
| 138 |
+
|
| 139 |
+
def test():
|
| 140 |
+
from robohusky.train.llama_flash_attn_monkey_patch import forward as fastchat_forward
|
| 141 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
| 142 |
+
|
| 143 |
+
config = LlamaConfig(
|
| 144 |
+
hidden_size=1024,
|
| 145 |
+
intermediate_size=128,
|
| 146 |
+
num_hidden_layers=1,
|
| 147 |
+
num_attention_heads=8,
|
| 148 |
+
max_position_embeddings=16,
|
| 149 |
+
)
|
| 150 |
+
device = torch.device("cuda")
|
| 151 |
+
model = LlamaModel(config)
|
| 152 |
+
attn = LlamaAttention(config).to(device).half()
|
| 153 |
+
bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings
|
| 154 |
+
position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view(
|
| 155 |
+
-1, seqlen
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
|
| 159 |
+
for i in range(4):
|
| 160 |
+
hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
|
| 161 |
+
if i:
|
| 162 |
+
mask[0, -i:] = False
|
| 163 |
+
mask[1, :i] = False
|
| 164 |
+
|
| 165 |
+
lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0)
|
| 166 |
+
ref, _, _ = attn.forward(
|
| 167 |
+
hidden, attention_mask=lmask, position_ids=position_ids
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
fast, _, _ = fastchat_forward(
|
| 171 |
+
attn, hidden, attention_mask=mask, position_ids=position_ids
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
lmask = _prepare_decoder_attention_mask(
|
| 175 |
+
model, mask, hidden.shape[:2], hidden, 0
|
| 176 |
+
)
|
| 177 |
+
test, _, _ = forward(
|
| 178 |
+
attn, hidden, attention_mask=lmask, position_ids=position_ids
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
print(f"Mean(abs(ref)) = {torch.mean(torch.abs(ref))}")
|
| 182 |
+
print(f"Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}")
|
| 183 |
+
print(f"Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}")
|
| 184 |
+
print(f"Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}")
|
| 185 |
+
print(f"allclose(fast, test) = {torch.allclose(fast, test)}")
|
| 186 |
+
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
# Also check that past_kv is handled properly
|
| 189 |
+
hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
|
| 190 |
+
part_len = seqlen // 4
|
| 191 |
+
assert part_len * 4 == seqlen
|
| 192 |
+
mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
|
| 193 |
+
mask[0, -2:] = False
|
| 194 |
+
lmask = _prepare_decoder_attention_mask(
|
| 195 |
+
model, mask, hidden.shape[:2], hidden, 0
|
| 196 |
+
)
|
| 197 |
+
oneshot, _, _ = forward(
|
| 198 |
+
attn, hidden, attention_mask=lmask, position_ids=position_ids
|
| 199 |
+
)
|
| 200 |
+
parts = []
|
| 201 |
+
past_kv, past_kv_len = None, 0
|
| 202 |
+
for i in range(4):
|
| 203 |
+
start = part_len * i
|
| 204 |
+
end = start + part_len
|
| 205 |
+
hidden_part = hidden[:, start:end, ...]
|
| 206 |
+
lmask = _prepare_decoder_attention_mask(
|
| 207 |
+
model,
|
| 208 |
+
mask[:, start:end],
|
| 209 |
+
hidden_part.shape[:2],
|
| 210 |
+
hidden_part,
|
| 211 |
+
past_kv_len,
|
| 212 |
+
)
|
| 213 |
+
part, _, past_kv = forward(
|
| 214 |
+
attn,
|
| 215 |
+
hidden_part.clone(),
|
| 216 |
+
attention_mask=lmask,
|
| 217 |
+
position_ids=position_ids[:, start:end],
|
| 218 |
+
past_key_value=past_kv,
|
| 219 |
+
use_cache=True,
|
| 220 |
+
)
|
| 221 |
+
parts.append(part)
|
| 222 |
+
past_kv_len = past_kv[0].shape[2]
|
| 223 |
+
|
| 224 |
+
print(
|
| 225 |
+
f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}"
|
| 226 |
+
)
|
| 227 |
+
print(
|
| 228 |
+
f"allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if __name__ == "__main__":
|
| 232 |
+
test()
|
robohusky/model/__init__.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import TYPE_CHECKING
|
| 15 |
+
|
| 16 |
+
from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
|
| 17 |
+
|
| 18 |
+
_import_structure = {
|
| 19 |
+
"configuration_husky": [
|
| 20 |
+
"HUSKY_PRETRAINED_CONFIG_ARCHIVE_MAP",
|
| 21 |
+
"HuskyConfig",
|
| 22 |
+
"HuskyQFormerConfig",
|
| 23 |
+
"HuskyVisionConfig",
|
| 24 |
+
],
|
| 25 |
+
"processing_husky": ["HuskyProcessor"],
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
if not is_torch_available():
|
| 30 |
+
raise OptionalDependencyNotAvailable()
|
| 31 |
+
except OptionalDependencyNotAvailable:
|
| 32 |
+
pass
|
| 33 |
+
else:
|
| 34 |
+
_import_structure["modeling_husky"] = [
|
| 35 |
+
"HUSKY_PRETRAINED_MODEL_ARCHIVE_LIST",
|
| 36 |
+
"HuskyModel",
|
| 37 |
+
"HuskyQFormerModel",
|
| 38 |
+
"HuskyPreTrainedModel",
|
| 39 |
+
"HuskyForConditionalGeneration",
|
| 40 |
+
"HuskyVisionModel",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
if TYPE_CHECKING:
|
| 44 |
+
from .configuration_husky import (
|
| 45 |
+
HUSKY_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
| 46 |
+
HuskyConfig,
|
| 47 |
+
HuskyVisionConfig,
|
| 48 |
+
HuskyQFormerConfig
|
| 49 |
+
)
|
| 50 |
+
from .processing_husky import HuskyProcessor
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
if not is_torch_available():
|
| 54 |
+
raise OptionalDependencyNotAvailable()
|
| 55 |
+
except OptionalDependencyNotAvailable:
|
| 56 |
+
pass
|
| 57 |
+
else:
|
| 58 |
+
from .modeling_husky import (
|
| 59 |
+
HUSKY_PRETRAINED_MODEL_ARCHIVE_LIST,
|
| 60 |
+
HuskyForConditionalGeneration,
|
| 61 |
+
HuskyModel,
|
| 62 |
+
HuskyPreTrainedModel,
|
| 63 |
+
HuskyQFormerModel,
|
| 64 |
+
HuskyVisionModel,
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
else:
|
| 68 |
+
import sys
|
| 69 |
+
|
| 70 |
+
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
robohusky/model/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (1.14 kB). View file
|
|
|
robohusky/model/__pycache__/configuration_husky.cpython-38.pyc
ADDED
|
Binary file (13.4 kB). View file
|
|
|
robohusky/model/__pycache__/modeling_husky_embody2.cpython-38.pyc
ADDED
|
Binary file (54.6 kB). View file
|
|
|
robohusky/model/compression.py
ADDED
|
File without changes
|
robohusky/model/configuration_husky.py
ADDED
|
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
""" Husky model configuration"""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
import os
|
| 19 |
+
from typing import Union
|
| 20 |
+
|
| 21 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 22 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 23 |
+
from transformers.utils import logging
|
| 24 |
+
|
| 25 |
+
from transformers.models.auto import CONFIG_MAPPING
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
HUSKY_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
| 30 |
+
"wofmanaf/husky-7b": "https://huggingface.co/wofmanaf/husky-7b/resolve/main/config.json",
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
class HuskyVisionConfig(PretrainedConfig):
|
| 34 |
+
r"""
|
| 35 |
+
This is the configuration class to store the configuration of a [`HuskyVisionModel`]. It is used to
|
| 36 |
+
instantiate a Husky vision encoder according to the specified arguments, defining the model architecture.
|
| 37 |
+
Instantiating a configuration defaults will yield a similar configuration to that of the Husky architecture.
|
| 38 |
+
|
| 39 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 40 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
hidden_size (`int`, *optional*, defaults to 1408):
|
| 44 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 45 |
+
intermediate_size (`int`, *optional*, defaults to 6144):
|
| 46 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 47 |
+
num_hidden_layers (`int`, *optional*, defaults to 39):
|
| 48 |
+
Number of hidden layers in the Transformer encoder.
|
| 49 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 50 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 51 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 52 |
+
The size (resolution) of each image.
|
| 53 |
+
patch_size (`int`, *optional*, defaults to 14):
|
| 54 |
+
The size (resolution) of each patch.
|
| 55 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 56 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 57 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults
|
| 58 |
+
to 1e-5): The epsilon used by the layer normalization layers.
|
| 59 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 60 |
+
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
| 61 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 62 |
+
The dropout ratio for the attention probabilities.
|
| 63 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 64 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 65 |
+
initializer_factor (`float``, *optional*, defaults to 1):
|
| 66 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
| 67 |
+
testing).
|
| 68 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 69 |
+
Whether to add a bias to the queries and values in the self-attention layers.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
model_type = "husky_vision_model"
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
hidden_size=1408,
|
| 77 |
+
intermediate_size=6144,
|
| 78 |
+
projection_dim=512,
|
| 79 |
+
num_hidden_layers=39,
|
| 80 |
+
num_attention_heads=16,
|
| 81 |
+
num_channels=3,
|
| 82 |
+
image_size=224,
|
| 83 |
+
patch_size=14,
|
| 84 |
+
hidden_act="gelu",
|
| 85 |
+
layer_norm_eps=0.00001,
|
| 86 |
+
dropout=0.0,
|
| 87 |
+
attention_dropout=0.0,
|
| 88 |
+
initializer_range=1e-10,
|
| 89 |
+
initializer_factor=1.0,
|
| 90 |
+
qkv_bias=True,
|
| 91 |
+
_flash_attn_2_enabled=True,
|
| 92 |
+
**kwargs,
|
| 93 |
+
):
|
| 94 |
+
super().__init__(**kwargs)
|
| 95 |
+
|
| 96 |
+
self.hidden_size = hidden_size
|
| 97 |
+
self.intermediate_size = intermediate_size
|
| 98 |
+
self.projection_dim = projection_dim
|
| 99 |
+
self.dropout = dropout
|
| 100 |
+
self.num_hidden_layers = num_hidden_layers
|
| 101 |
+
self.num_attention_heads = num_attention_heads
|
| 102 |
+
self.num_channels = num_channels
|
| 103 |
+
self.patch_size = patch_size
|
| 104 |
+
self.image_size = image_size
|
| 105 |
+
self.initializer_range = initializer_range
|
| 106 |
+
self.initializer_factor = initializer_factor
|
| 107 |
+
self.attention_dropout = attention_dropout
|
| 108 |
+
self.layer_norm_eps = layer_norm_eps
|
| 109 |
+
self.hidden_act = hidden_act
|
| 110 |
+
self.qkv_bias = qkv_bias
|
| 111 |
+
self._flash_attn_2_enabled = _flash_attn_2_enabled
|
| 112 |
+
|
| 113 |
+
@classmethod
|
| 114 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 115 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 116 |
+
|
| 117 |
+
# get the vision config dict if we are loading from HuskyConfig
|
| 118 |
+
if config_dict.get("model_type") == "husky":
|
| 119 |
+
config_dict = config_dict["vision_config"]
|
| 120 |
+
|
| 121 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 122 |
+
logger.warning(
|
| 123 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 124 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 128 |
+
|
| 129 |
+
class HuskyQFormerConfig(PretrainedConfig):
|
| 130 |
+
r"""
|
| 131 |
+
This is the configuration class to store the configuration of a [`HuskyQFormerModel`]. It is used to
|
| 132 |
+
instantiate a Husky Querying Transformer (Q-Former) model according to the specified arguments, defining the
|
| 133 |
+
model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
| 134 |
+
the Husky [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5)
|
| 135 |
+
architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
|
| 136 |
+
Read the documentation from [`PretrainedConfig`] for more information.
|
| 137 |
+
|
| 138 |
+
Note that [`HuskyQFormerModel`] is very similar to [`BertLMHeadModel`] with interleaved cross-attention.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
| 142 |
+
Vocabulary size of the Q-Former model. Defines the number of different tokens that can be represented by
|
| 143 |
+
the `inputs_ids` passed when calling the model.
|
| 144 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 145 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 146 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 147 |
+
Number of hidden layers in the Transformer encoder.
|
| 148 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 149 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 150 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 151 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
| 152 |
+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
| 153 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 154 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 155 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 156 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 157 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 158 |
+
The dropout ratio for the attention probabilities.
|
| 159 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 160 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 161 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 162 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 163 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 164 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 165 |
+
The epsilon used by the layer normalization layers.
|
| 166 |
+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
| 167 |
+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
| 168 |
+
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
| 169 |
+
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
|
| 170 |
+
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
| 171 |
+
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
|
| 172 |
+
classifier_dropout (`float`, *optional*):
|
| 173 |
+
The dropout ratio for the classification head.
|
| 174 |
+
cross_attention_frequency (`int`, *optional*, defaults to 2):
|
| 175 |
+
The frequency of adding cross-attention to the Transformer layers.
|
| 176 |
+
encoder_hidden_size (`int`, *optional*, defaults to 1408):
|
| 177 |
+
The hidden size of the hidden states for cross-attention.
|
| 178 |
+
"""
|
| 179 |
+
model_type = "husky_qformer"
|
| 180 |
+
|
| 181 |
+
def __init__(
|
| 182 |
+
self,
|
| 183 |
+
vocab_size=30522,
|
| 184 |
+
hidden_size=768,
|
| 185 |
+
num_hidden_layers=12,
|
| 186 |
+
num_attention_heads=12,
|
| 187 |
+
intermediate_size=3072,
|
| 188 |
+
hidden_act="gelu",
|
| 189 |
+
hidden_dropout_prob=0.1,
|
| 190 |
+
attention_probs_dropout_prob=0.1,
|
| 191 |
+
max_position_embeddings=512,
|
| 192 |
+
initializer_range=0.02,
|
| 193 |
+
layer_norm_eps=1e-12,
|
| 194 |
+
pad_token_id=0,
|
| 195 |
+
position_embedding_type="absolute",
|
| 196 |
+
classifier_dropout=None,
|
| 197 |
+
cross_attention_frequency=2,
|
| 198 |
+
encoder_hidden_size=1408,
|
| 199 |
+
_flash_attn_2_enabled=True,
|
| 200 |
+
**kwargs,
|
| 201 |
+
):
|
| 202 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 203 |
+
|
| 204 |
+
self.vocab_size = vocab_size
|
| 205 |
+
self.hidden_size = hidden_size
|
| 206 |
+
self.num_hidden_layers = num_hidden_layers
|
| 207 |
+
self.num_attention_heads = num_attention_heads
|
| 208 |
+
self.hidden_act = hidden_act
|
| 209 |
+
self.intermediate_size = intermediate_size
|
| 210 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 211 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 212 |
+
self.max_position_embeddings = max_position_embeddings
|
| 213 |
+
self.initializer_range = initializer_range
|
| 214 |
+
self.layer_norm_eps = layer_norm_eps
|
| 215 |
+
self.position_embedding_type = position_embedding_type
|
| 216 |
+
self.classifier_dropout = classifier_dropout
|
| 217 |
+
self.cross_attention_frequency = cross_attention_frequency
|
| 218 |
+
self.encoder_hidden_size = encoder_hidden_size
|
| 219 |
+
self._flash_attn_2_enabled = _flash_attn_2_enabled
|
| 220 |
+
|
| 221 |
+
@classmethod
|
| 222 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 223 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 224 |
+
# get the qformer config dict if we are loading from HuskyConfig
|
| 225 |
+
if config_dict.get("model_type") == "husky":
|
| 226 |
+
config_dict = config_dict["qformer_config"]
|
| 227 |
+
|
| 228 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 229 |
+
logger.warning(
|
| 230 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 231 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 235 |
+
|
| 236 |
+
class HuskyConfig(PretrainedConfig):
|
| 237 |
+
r"""
|
| 238 |
+
[`HuskyConfig`] is the configuration class to store the configuration of a
|
| 239 |
+
[`HuskyForConditionalGeneration`]. It is used to instantiate a Husky model according to the specified
|
| 240 |
+
arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
|
| 241 |
+
the defaults will yield a similar configuration to that of the Husky
|
| 242 |
+
[Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
|
| 243 |
+
|
| 244 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 245 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
vision_config (`dict`, *optional*):
|
| 249 |
+
Dictionary of configuration options used to initialize [`HuskyVisionConfig`].
|
| 250 |
+
qformer_config (`dict`, *optional*):
|
| 251 |
+
Dictionary of configuration options used to initialize [`HuskyQFormerConfig`].
|
| 252 |
+
text_config (`dict`, *optional*):
|
| 253 |
+
Dictionary of configuration options used to initialize any [`PretrainedConfig`].
|
| 254 |
+
num_query_tokens (`int`, *optional*, defaults to 32):
|
| 255 |
+
The number of query tokens passed through the Transformer.
|
| 256 |
+
|
| 257 |
+
kwargs (*optional*):
|
| 258 |
+
Dictionary of keyword arguments.
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
model_type = "husky"
|
| 262 |
+
is_composition = True
|
| 263 |
+
|
| 264 |
+
def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
|
| 265 |
+
super().__init__(**kwargs)
|
| 266 |
+
|
| 267 |
+
if vision_config is None:
|
| 268 |
+
vision_config = {}
|
| 269 |
+
logger.info("vision_config is None. initializing the HuskyVisionConfig with default values.")
|
| 270 |
+
|
| 271 |
+
if qformer_config is None:
|
| 272 |
+
qformer_config = {}
|
| 273 |
+
logger.info("qformer_config is None. Initializing the HuskyQFormerConfig with default values.")
|
| 274 |
+
|
| 275 |
+
if text_config is None:
|
| 276 |
+
text_config = {}
|
| 277 |
+
logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
|
| 278 |
+
|
| 279 |
+
self.vision_config = HuskyVisionConfig(**vision_config)
|
| 280 |
+
self.qformer_config = HuskyQFormerConfig(**qformer_config)
|
| 281 |
+
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
|
| 282 |
+
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
|
| 283 |
+
|
| 284 |
+
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
| 285 |
+
self.is_encoder_decoder = self.text_config.is_encoder_decoder
|
| 286 |
+
|
| 287 |
+
self.num_query_tokens = num_query_tokens
|
| 288 |
+
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
| 289 |
+
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 290 |
+
self.initializer_factor = 1.0
|
| 291 |
+
self.initializer_range = 0.02
|
| 292 |
+
|
| 293 |
+
@classmethod
|
| 294 |
+
def from_vision_qformer_text_configs(
|
| 295 |
+
cls,
|
| 296 |
+
vision_config: HuskyVisionConfig,
|
| 297 |
+
qformer_config: HuskyQFormerConfig,
|
| 298 |
+
text_config: PretrainedConfig,
|
| 299 |
+
**kwargs,
|
| 300 |
+
):
|
| 301 |
+
r"""
|
| 302 |
+
Instantiate a [`HuskyConfig`] (or a derived class) from a Husky vision model, Q-Former and
|
| 303 |
+
language model configurations.
|
| 304 |
+
|
| 305 |
+
Returns:
|
| 306 |
+
[`HuskyConfig`]: An instance of a configuration object
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
return cls(
|
| 310 |
+
vision_config=vision_config.to_dict(),
|
| 311 |
+
qformer_config=qformer_config.to_dict(),
|
| 312 |
+
text_config=text_config.to_dict(),
|
| 313 |
+
**kwargs,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
def to_dict(self):
|
| 317 |
+
"""
|
| 318 |
+
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
| 322 |
+
"""
|
| 323 |
+
output = copy.deepcopy(self.__dict__)
|
| 324 |
+
output["vision_config"] = self.vision_config.to_dict()
|
| 325 |
+
output["qformer_config"] = self.qformer_config.to_dict()
|
| 326 |
+
output["text_config"] = self.text_config.to_dict()
|
| 327 |
+
output["model_type"] = self.__class__.model_type
|
| 328 |
+
return output
|
| 329 |
+
|
| 330 |
+
if __name__ == '__main__':
|
| 331 |
+
config = HuskyConfig.from_pretrain
|
robohusky/model/configuration_husky_ori.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
""" Husky model configuration"""
|
| 16 |
+
|
| 17 |
+
import copy
|
| 18 |
+
import os
|
| 19 |
+
from typing import Union
|
| 20 |
+
|
| 21 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 22 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 23 |
+
from transformers.utils import logging
|
| 24 |
+
|
| 25 |
+
from transformers.models.auto import CONFIG_MAPPING
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
HUSKY_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
| 30 |
+
"wofmanaf/husky-7b": "https://huggingface.co/wofmanaf/husky-7b/resolve/main/config.json",
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
class HuskyVisionConfig(PretrainedConfig):
|
| 34 |
+
r"""
|
| 35 |
+
This is the configuration class to store the configuration of a [`HuskyVisionModel`]. It is used to
|
| 36 |
+
instantiate a Husky vision encoder according to the specified arguments, defining the model architecture.
|
| 37 |
+
Instantiating a configuration defaults will yield a similar configuration to that of the Husky architecture.
|
| 38 |
+
|
| 39 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 40 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
hidden_size (`int`, *optional*, defaults to 1408):
|
| 44 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 45 |
+
intermediate_size (`int`, *optional*, defaults to 6144):
|
| 46 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 47 |
+
num_hidden_layers (`int`, *optional*, defaults to 39):
|
| 48 |
+
Number of hidden layers in the Transformer encoder.
|
| 49 |
+
num_attention_heads (`int`, *optional*, defaults to 16):
|
| 50 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 51 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 52 |
+
The size (resolution) of each image.
|
| 53 |
+
patch_size (`int`, *optional*, defaults to 14):
|
| 54 |
+
The size (resolution) of each patch.
|
| 55 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
| 56 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 57 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported. layer_norm_eps (`float`, *optional*, defaults
|
| 58 |
+
to 1e-5): The epsilon used by the layer normalization layers.
|
| 59 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 60 |
+
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
| 61 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 62 |
+
The dropout ratio for the attention probabilities.
|
| 63 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 64 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 65 |
+
initializer_factor (`float``, *optional*, defaults to 1):
|
| 66 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
| 67 |
+
testing).
|
| 68 |
+
qkv_bias (`bool`, *optional*, defaults to `True`):
|
| 69 |
+
Whether to add a bias to the queries and values in the self-attention layers.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
model_type = "husky_vision_model"
|
| 73 |
+
|
| 74 |
+
def __init__(
|
| 75 |
+
self,
|
| 76 |
+
hidden_size=1408,
|
| 77 |
+
intermediate_size=6144,
|
| 78 |
+
projection_dim=512,
|
| 79 |
+
num_hidden_layers=39,
|
| 80 |
+
num_attention_heads=16,
|
| 81 |
+
num_channels=3,
|
| 82 |
+
image_size=224,
|
| 83 |
+
patch_size=14,
|
| 84 |
+
hidden_act="gelu",
|
| 85 |
+
layer_norm_eps=0.00001,
|
| 86 |
+
dropout=0.0,
|
| 87 |
+
attention_dropout=0.0,
|
| 88 |
+
initializer_range=1e-10,
|
| 89 |
+
initializer_factor=1.0,
|
| 90 |
+
qkv_bias=True,
|
| 91 |
+
**kwargs,
|
| 92 |
+
):
|
| 93 |
+
super().__init__(**kwargs)
|
| 94 |
+
|
| 95 |
+
self.hidden_size = hidden_size
|
| 96 |
+
self.intermediate_size = intermediate_size
|
| 97 |
+
self.projection_dim = projection_dim
|
| 98 |
+
self.dropout = dropout
|
| 99 |
+
self.num_hidden_layers = num_hidden_layers
|
| 100 |
+
self.num_attention_heads = num_attention_heads
|
| 101 |
+
self.num_channels = num_channels
|
| 102 |
+
self.patch_size = patch_size
|
| 103 |
+
self.image_size = image_size
|
| 104 |
+
self.initializer_range = initializer_range
|
| 105 |
+
self.initializer_factor = initializer_factor
|
| 106 |
+
self.attention_dropout = attention_dropout
|
| 107 |
+
self.layer_norm_eps = layer_norm_eps
|
| 108 |
+
self.hidden_act = hidden_act
|
| 109 |
+
self.qkv_bias = qkv_bias
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 113 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 114 |
+
|
| 115 |
+
# get the vision config dict if we are loading from HuskyConfig
|
| 116 |
+
if config_dict.get("model_type") == "husky":
|
| 117 |
+
config_dict = config_dict["vision_config"]
|
| 118 |
+
|
| 119 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 120 |
+
logger.warning(
|
| 121 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 122 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 126 |
+
|
| 127 |
+
class HuskyQFormerConfig(PretrainedConfig):
|
| 128 |
+
r"""
|
| 129 |
+
This is the configuration class to store the configuration of a [`HuskyQFormerModel`]. It is used to
|
| 130 |
+
instantiate a Husky Querying Transformer (Q-Former) model according to the specified arguments, defining the
|
| 131 |
+
model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
| 132 |
+
the Husky [Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5)
|
| 133 |
+
architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs.
|
| 134 |
+
Read the documentation from [`PretrainedConfig`] for more information.
|
| 135 |
+
|
| 136 |
+
Note that [`HuskyQFormerModel`] is very similar to [`BertLMHeadModel`] with interleaved cross-attention.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
| 140 |
+
Vocabulary size of the Q-Former model. Defines the number of different tokens that can be represented by
|
| 141 |
+
the `inputs_ids` passed when calling the model.
|
| 142 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 143 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 144 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 145 |
+
Number of hidden layers in the Transformer encoder.
|
| 146 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 147 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 148 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 149 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
| 150 |
+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
| 151 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 152 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
| 153 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 154 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
| 155 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
| 156 |
+
The dropout ratio for the attention probabilities.
|
| 157 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
| 158 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
| 159 |
+
just in case (e.g., 512 or 1024 or 2048).
|
| 160 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 161 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 162 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
| 163 |
+
The epsilon used by the layer normalization layers.
|
| 164 |
+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
| 165 |
+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
| 166 |
+
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
| 167 |
+
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
|
| 168 |
+
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
| 169 |
+
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
|
| 170 |
+
classifier_dropout (`float`, *optional*):
|
| 171 |
+
The dropout ratio for the classification head.
|
| 172 |
+
cross_attention_frequency (`int`, *optional*, defaults to 2):
|
| 173 |
+
The frequency of adding cross-attention to the Transformer layers.
|
| 174 |
+
encoder_hidden_size (`int`, *optional*, defaults to 1408):
|
| 175 |
+
The hidden size of the hidden states for cross-attention.
|
| 176 |
+
"""
|
| 177 |
+
model_type = "husky_qformer"
|
| 178 |
+
|
| 179 |
+
def __init__(
|
| 180 |
+
self,
|
| 181 |
+
vocab_size=30522,
|
| 182 |
+
hidden_size=768,
|
| 183 |
+
num_hidden_layers=12,
|
| 184 |
+
num_attention_heads=12,
|
| 185 |
+
intermediate_size=3072,
|
| 186 |
+
hidden_act="gelu",
|
| 187 |
+
hidden_dropout_prob=0.1,
|
| 188 |
+
attention_probs_dropout_prob=0.1,
|
| 189 |
+
max_position_embeddings=512,
|
| 190 |
+
initializer_range=0.02,
|
| 191 |
+
layer_norm_eps=1e-12,
|
| 192 |
+
pad_token_id=0,
|
| 193 |
+
position_embedding_type="absolute",
|
| 194 |
+
classifier_dropout=None,
|
| 195 |
+
cross_attention_frequency=2,
|
| 196 |
+
encoder_hidden_size=1408,
|
| 197 |
+
**kwargs,
|
| 198 |
+
):
|
| 199 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
| 200 |
+
|
| 201 |
+
self.vocab_size = vocab_size
|
| 202 |
+
self.hidden_size = hidden_size
|
| 203 |
+
self.num_hidden_layers = num_hidden_layers
|
| 204 |
+
self.num_attention_heads = num_attention_heads
|
| 205 |
+
self.hidden_act = hidden_act
|
| 206 |
+
self.intermediate_size = intermediate_size
|
| 207 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 208 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 209 |
+
self.max_position_embeddings = max_position_embeddings
|
| 210 |
+
self.initializer_range = initializer_range
|
| 211 |
+
self.layer_norm_eps = layer_norm_eps
|
| 212 |
+
self.position_embedding_type = position_embedding_type
|
| 213 |
+
self.classifier_dropout = classifier_dropout
|
| 214 |
+
self.cross_attention_frequency = cross_attention_frequency
|
| 215 |
+
self.encoder_hidden_size = encoder_hidden_size
|
| 216 |
+
|
| 217 |
+
@classmethod
|
| 218 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 219 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 220 |
+
# get the qformer config dict if we are loading from HuskyConfig
|
| 221 |
+
if config_dict.get("model_type") == "husky":
|
| 222 |
+
config_dict = config_dict["qformer_config"]
|
| 223 |
+
|
| 224 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 225 |
+
logger.warning(
|
| 226 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 227 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 231 |
+
|
| 232 |
+
class HuskyConfig(PretrainedConfig):
|
| 233 |
+
r"""
|
| 234 |
+
[`HuskyConfig`] is the configuration class to store the configuration of a
|
| 235 |
+
[`HuskyForConditionalGeneration`]. It is used to instantiate a Husky model according to the specified
|
| 236 |
+
arguments, defining the vision model, Q-Former model and language model configs. Instantiating a configuration with
|
| 237 |
+
the defaults will yield a similar configuration to that of the Husky
|
| 238 |
+
[Salesforce/instruct-blip-flan-t5](https://huggingface.co/Salesforce/instruct-blip-flan-t5) architecture.
|
| 239 |
+
|
| 240 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 241 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
vision_config (`dict`, *optional*):
|
| 245 |
+
Dictionary of configuration options used to initialize [`HuskyVisionConfig`].
|
| 246 |
+
qformer_config (`dict`, *optional*):
|
| 247 |
+
Dictionary of configuration options used to initialize [`HuskyQFormerConfig`].
|
| 248 |
+
text_config (`dict`, *optional*):
|
| 249 |
+
Dictionary of configuration options used to initialize any [`PretrainedConfig`].
|
| 250 |
+
num_query_tokens (`int`, *optional*, defaults to 32):
|
| 251 |
+
The number of query tokens passed through the Transformer.
|
| 252 |
+
|
| 253 |
+
kwargs (*optional*):
|
| 254 |
+
Dictionary of keyword arguments.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
model_type = "husky"
|
| 258 |
+
is_composition = True
|
| 259 |
+
|
| 260 |
+
def __init__(self, vision_config=None, qformer_config=None, text_config=None, num_query_tokens=32, **kwargs):
|
| 261 |
+
super().__init__(**kwargs)
|
| 262 |
+
|
| 263 |
+
if vision_config is None:
|
| 264 |
+
vision_config = {}
|
| 265 |
+
logger.info("vision_config is None. initializing the HuskyVisionConfig with default values.")
|
| 266 |
+
|
| 267 |
+
if qformer_config is None:
|
| 268 |
+
qformer_config = {}
|
| 269 |
+
logger.info("qformer_config is None. Initializing the HuskyQFormerConfig with default values.")
|
| 270 |
+
|
| 271 |
+
if text_config is None:
|
| 272 |
+
text_config = {}
|
| 273 |
+
logger.info("text_config is None. Initializing the text config with default values (`OPTConfig`).")
|
| 274 |
+
|
| 275 |
+
self.vision_config = HuskyVisionConfig(**vision_config)
|
| 276 |
+
self.qformer_config = HuskyQFormerConfig(**qformer_config)
|
| 277 |
+
text_model_type = text_config["model_type"] if "model_type" in text_config else "opt"
|
| 278 |
+
self.text_config = CONFIG_MAPPING[text_model_type](**text_config)
|
| 279 |
+
|
| 280 |
+
self.tie_word_embeddings = self.text_config.tie_word_embeddings
|
| 281 |
+
self.is_encoder_decoder = self.text_config.is_encoder_decoder
|
| 282 |
+
|
| 283 |
+
self.num_query_tokens = num_query_tokens
|
| 284 |
+
self.qformer_config.encoder_hidden_size = self.vision_config.hidden_size
|
| 285 |
+
self.use_decoder_only_language_model = self.text_config.model_type in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 286 |
+
self.initializer_factor = 1.0
|
| 287 |
+
self.initializer_range = 0.02
|
| 288 |
+
|
| 289 |
+
@classmethod
|
| 290 |
+
def from_vision_qformer_text_configs(
|
| 291 |
+
cls,
|
| 292 |
+
vision_config: HuskyVisionConfig,
|
| 293 |
+
qformer_config: HuskyQFormerConfig,
|
| 294 |
+
text_config: PretrainedConfig,
|
| 295 |
+
**kwargs,
|
| 296 |
+
):
|
| 297 |
+
r"""
|
| 298 |
+
Instantiate a [`HuskyConfig`] (or a derived class) from a Husky vision model, Q-Former and
|
| 299 |
+
language model configurations.
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
[`HuskyConfig`]: An instance of a configuration object
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
return cls(
|
| 306 |
+
vision_config=vision_config.to_dict(),
|
| 307 |
+
qformer_config=qformer_config.to_dict(),
|
| 308 |
+
text_config=text_config.to_dict(),
|
| 309 |
+
**kwargs,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
def to_dict(self):
|
| 313 |
+
"""
|
| 314 |
+
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
| 318 |
+
"""
|
| 319 |
+
output = copy.deepcopy(self.__dict__)
|
| 320 |
+
output["vision_config"] = self.vision_config.to_dict()
|
| 321 |
+
output["qformer_config"] = self.qformer_config.to_dict()
|
| 322 |
+
output["text_config"] = self.text_config.to_dict()
|
| 323 |
+
output["model_type"] = self.__class__.model_type
|
| 324 |
+
return output
|
| 325 |
+
|
| 326 |
+
if __name__ == '__main__':
|
| 327 |
+
config = HuskyConfig.from_pretrain
|
robohusky/model/modeling_husky.py
ADDED
|
@@ -0,0 +1,1820 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
""" PyTorch Husky model."""
|
| 16 |
+
|
| 17 |
+
import contextlib
|
| 18 |
+
import math
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Any, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.utils.checkpoint
|
| 24 |
+
from torch import nn
|
| 25 |
+
from torch.nn import CrossEntropyLoss
|
| 26 |
+
|
| 27 |
+
from transformers.activations import ACT2FN
|
| 28 |
+
from transformers.modeling_outputs import (
|
| 29 |
+
BaseModelOutput,
|
| 30 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 31 |
+
BaseModelOutputWithPooling,
|
| 32 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 33 |
+
)
|
| 34 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 35 |
+
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
| 36 |
+
from transformers.utils import (
|
| 37 |
+
ModelOutput,
|
| 38 |
+
add_start_docstrings,
|
| 39 |
+
add_start_docstrings_to_model_forward,
|
| 40 |
+
logging,
|
| 41 |
+
replace_return_docstrings,
|
| 42 |
+
)
|
| 43 |
+
from transformers import AutoModelForCausalLM, GenerationConfig
|
| 44 |
+
|
| 45 |
+
from .configuration_husky import HuskyConfig, HuskyQFormerConfig, HuskyVisionConfig
|
| 46 |
+
|
| 47 |
+
logger = logging.get_logger(__name__)
|
| 48 |
+
|
| 49 |
+
_CHECKPOINT_FOR_DOC = "wofmanaf/husky-7b"
|
| 50 |
+
|
| 51 |
+
HUSKY_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 52 |
+
"wofmanaf/husky-7b",
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class HuskyForConditionalGenerationModelOutput(ModelOutput):
|
| 57 |
+
"""
|
| 58 |
+
Class defining the outputs of [`HuskyForConditionalGeneration`].
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
| 62 |
+
Language modeling loss from the language model.
|
| 63 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 64 |
+
Prediction scores of the language modeling head of the language model.
|
| 65 |
+
vision_outputs (`BaseModelOutputWithPooling`):
|
| 66 |
+
Outputs of the vision encoder.
|
| 67 |
+
qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
|
| 68 |
+
Outputs of the Q-Former (Querying Transformer).
|
| 69 |
+
language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
|
| 70 |
+
Outputs of the language model.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
loss: Optional[Tuple[torch.FloatTensor]] = None
|
| 74 |
+
logits: Optional[Tuple[torch.FloatTensor]] = None
|
| 75 |
+
vision_outputs: Optional[torch.FloatTensor] = None
|
| 76 |
+
qformer_outputs: Optional[Tuple[torch.FloatTensor]] = None
|
| 77 |
+
language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None
|
| 78 |
+
|
| 79 |
+
def to_tuple(self) -> Tuple[Any]:
|
| 80 |
+
return tuple(
|
| 81 |
+
self[k]
|
| 82 |
+
if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
|
| 83 |
+
else getattr(self, k).to_tuple()
|
| 84 |
+
for k in self.keys()
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Husky
|
| 88 |
+
class HuskyVisionEmbeddings(nn.Module):
|
| 89 |
+
def __init__(self, config: HuskyVisionConfig):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.config = config
|
| 92 |
+
self.embed_dim = config.hidden_size
|
| 93 |
+
self.image_size = config.image_size
|
| 94 |
+
self.patch_size = config.patch_size
|
| 95 |
+
|
| 96 |
+
self.class_embedding = nn.Parameter(
|
| 97 |
+
torch.randn(1, 1, self.embed_dim),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
self.patch_embedding = nn.Conv2d(
|
| 101 |
+
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 105 |
+
self.num_positions = self.num_patches + 1
|
| 106 |
+
|
| 107 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
| 108 |
+
|
| 109 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 110 |
+
batch_size = pixel_values.shape[0]
|
| 111 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 112 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
| 113 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 114 |
+
|
| 115 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
| 116 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 117 |
+
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
|
| 118 |
+
return embeddings
|
| 119 |
+
|
| 120 |
+
class HuskyVideoEmbeddings(nn.Module):
|
| 121 |
+
def __init__(self, config: HuskyVisionConfig):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.config = config
|
| 124 |
+
self.embed_dim = config.hidden_size
|
| 125 |
+
self.image_size = config.image_size
|
| 126 |
+
self.patch_size = config.patch_size
|
| 127 |
+
self.num_frames = getattr(self.config, "num_frames", 8)
|
| 128 |
+
self.frame_stride = getattr(self.config, "frame_stride", 2)
|
| 129 |
+
|
| 130 |
+
self.class_embedding = nn.Parameter(
|
| 131 |
+
torch.randn(1, 1, self.embed_dim),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
self.patch_embedding = nn.Conv3d(
|
| 135 |
+
in_channels=3, out_channels=self.embed_dim,
|
| 136 |
+
kernel_size=(self.frame_stride, self.patch_size, self.patch_size),
|
| 137 |
+
stride=(self.frame_stride, self.patch_size, self.patch_size)
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
self.num_patches = int(self.num_frames // self.frame_stride) * (self.image_size // self.patch_size) ** 2
|
| 141 |
+
self.num_positions = self.num_patches + 1
|
| 142 |
+
|
| 143 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
| 144 |
+
|
| 145 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 146 |
+
batch_size = pixel_values.shape[0]
|
| 147 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 148 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
| 149 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 150 |
+
|
| 151 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
| 152 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 153 |
+
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
|
| 154 |
+
return embeddings
|
| 155 |
+
|
| 156 |
+
class HuskyAttention(nn.Module):
|
| 157 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 158 |
+
|
| 159 |
+
def __init__(self, config):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.config = config
|
| 162 |
+
self.embed_dim = config.hidden_size
|
| 163 |
+
self.num_heads = config.num_attention_heads
|
| 164 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 165 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 166 |
+
raise ValueError(
|
| 167 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 168 |
+
f" {self.num_heads})."
|
| 169 |
+
)
|
| 170 |
+
self.scale = self.head_dim ** -0.5
|
| 171 |
+
self.dropout = nn.Dropout(config.attention_dropout)
|
| 172 |
+
|
| 173 |
+
# small tweak here compared to CLIP, no bias here
|
| 174 |
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
|
| 175 |
+
|
| 176 |
+
if config.qkv_bias:
|
| 177 |
+
q_bias = nn.Parameter(torch.zeros(self.embed_dim))
|
| 178 |
+
v_bias = nn.Parameter(torch.zeros(self.embed_dim))
|
| 179 |
+
else:
|
| 180 |
+
q_bias = None
|
| 181 |
+
v_bias = None
|
| 182 |
+
|
| 183 |
+
if q_bias is not None:
|
| 184 |
+
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
|
| 185 |
+
self.qkv.bias = nn.Parameter(qkv_bias)
|
| 186 |
+
|
| 187 |
+
self.projection = nn.Linear(self.embed_dim, self.embed_dim)
|
| 188 |
+
|
| 189 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 190 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 191 |
+
|
| 192 |
+
def forward(
|
| 193 |
+
self,
|
| 194 |
+
hidden_states: torch.Tensor,
|
| 195 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 196 |
+
output_attentions: Optional[bool] = False,
|
| 197 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 198 |
+
"""Input shape: Batch x Time x Channel"""
|
| 199 |
+
|
| 200 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
| 201 |
+
|
| 202 |
+
mixed_qkv = self.qkv(hidden_states)
|
| 203 |
+
|
| 204 |
+
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
|
| 205 |
+
2, 0, 3, 1, 4
|
| 206 |
+
)
|
| 207 |
+
query_states, key_states, value_states = (
|
| 208 |
+
mixed_qkv[0],
|
| 209 |
+
mixed_qkv[1],
|
| 210 |
+
mixed_qkv[2],
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 214 |
+
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
| 215 |
+
|
| 216 |
+
attention_scores = attention_scores * self.scale
|
| 217 |
+
|
| 218 |
+
# Normalize the attention scores to probabilities.
|
| 219 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 220 |
+
|
| 221 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 222 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 223 |
+
attention_probs = self.dropout(attention_probs)
|
| 224 |
+
|
| 225 |
+
# Mask heads if we want to
|
| 226 |
+
if head_mask is not None:
|
| 227 |
+
attention_probs = attention_probs * head_mask
|
| 228 |
+
|
| 229 |
+
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
|
| 230 |
+
|
| 231 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
|
| 232 |
+
context_layer = context_layer.reshape(new_context_layer_shape)
|
| 233 |
+
|
| 234 |
+
output = self.projection(context_layer)
|
| 235 |
+
|
| 236 |
+
outputs = (output, attention_probs) if output_attentions else (output, None)
|
| 237 |
+
|
| 238 |
+
return outputs
|
| 239 |
+
|
| 240 |
+
# Copied from transformers.models.blip.modeling_blip.BlipMLP
|
| 241 |
+
class HuskyMLP(nn.Module):
|
| 242 |
+
def __init__(self, config):
|
| 243 |
+
super().__init__()
|
| 244 |
+
self.config = config
|
| 245 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 246 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 247 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 248 |
+
|
| 249 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 250 |
+
hidden_states = self.fc1(hidden_states)
|
| 251 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 252 |
+
hidden_states = self.fc2(hidden_states)
|
| 253 |
+
return hidden_states
|
| 254 |
+
|
| 255 |
+
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Husky
|
| 256 |
+
class HuskyEncoderLayer(nn.Module):
|
| 257 |
+
def __init__(self, config: HuskyConfig):
|
| 258 |
+
super().__init__()
|
| 259 |
+
self.embed_dim = config.hidden_size
|
| 260 |
+
self.self_attn = HuskyAttention(config)
|
| 261 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 262 |
+
self.mlp = HuskyMLP(config)
|
| 263 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 264 |
+
|
| 265 |
+
def forward(
|
| 266 |
+
self,
|
| 267 |
+
hidden_states: torch.Tensor,
|
| 268 |
+
attention_mask: torch.Tensor,
|
| 269 |
+
output_attentions: Optional[bool] = False,
|
| 270 |
+
) -> Tuple[torch.FloatTensor]:
|
| 271 |
+
"""
|
| 272 |
+
Args:
|
| 273 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 274 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 275 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 276 |
+
`(config.encoder_attention_heads,)`.
|
| 277 |
+
output_attentions (`bool`, *optional*):
|
| 278 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 279 |
+
returned tensors for more detail.
|
| 280 |
+
"""
|
| 281 |
+
residual = hidden_states
|
| 282 |
+
|
| 283 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 284 |
+
hidden_states, attn_weights = self.self_attn(
|
| 285 |
+
hidden_states=hidden_states,
|
| 286 |
+
head_mask=attention_mask,
|
| 287 |
+
output_attentions=output_attentions,
|
| 288 |
+
)
|
| 289 |
+
hidden_states = hidden_states + residual
|
| 290 |
+
residual = hidden_states
|
| 291 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 292 |
+
hidden_states = self.mlp(hidden_states)
|
| 293 |
+
|
| 294 |
+
hidden_states = hidden_states + residual
|
| 295 |
+
|
| 296 |
+
outputs = (hidden_states,)
|
| 297 |
+
|
| 298 |
+
if output_attentions:
|
| 299 |
+
outputs += (attn_weights,)
|
| 300 |
+
|
| 301 |
+
return outputs
|
| 302 |
+
|
| 303 |
+
class HuskyPreTrainedModel(PreTrainedModel):
|
| 304 |
+
"""
|
| 305 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 306 |
+
models.
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
config_class = HuskyConfig
|
| 310 |
+
base_model_prefix = "husky"
|
| 311 |
+
supports_gradient_checkpointing = True
|
| 312 |
+
_keys_to_ignore_on_load_missing = [
|
| 313 |
+
r"position_ids",
|
| 314 |
+
r"language_model.encoder.embed_tokens.weight",
|
| 315 |
+
r"language_model.decoder.embed_tokens.weight",
|
| 316 |
+
r"language_model.lm_head.weight",
|
| 317 |
+
]
|
| 318 |
+
_no_split_modules = ["HuskyAttention", "LlamaDecoderLayer", "LlamaForCausalLM"]
|
| 319 |
+
_skip_keys_device_placement = "past_key_values"
|
| 320 |
+
_keep_in_fp32_modules = ["wo"]
|
| 321 |
+
|
| 322 |
+
def _init_weights(self, module):
|
| 323 |
+
"""Initialize the weights"""
|
| 324 |
+
factor = self.config.initializer_range
|
| 325 |
+
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
|
| 326 |
+
module.weight.data.normal_(mean=0.0, std=factor)
|
| 327 |
+
if hasattr(module, "bias") and module.bias is not None:
|
| 328 |
+
module.bias.data.zero_()
|
| 329 |
+
|
| 330 |
+
if isinstance(module, HuskyVisionEmbeddings):
|
| 331 |
+
if hasattr(self.config, "vision_config"):
|
| 332 |
+
factor = self.config.vision_config.initializer_range
|
| 333 |
+
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
| 334 |
+
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
| 335 |
+
|
| 336 |
+
elif isinstance(module, nn.LayerNorm):
|
| 337 |
+
module.bias.data.zero_()
|
| 338 |
+
module.weight.data.fill_(1.0)
|
| 339 |
+
elif isinstance(module, nn.Linear) and module.bias is not None:
|
| 340 |
+
module.bias.data.zero_()
|
| 341 |
+
|
| 342 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 343 |
+
if isinstance(module, HuskyEncoder):
|
| 344 |
+
module.gradient_checkpointing = value
|
| 345 |
+
|
| 346 |
+
Husky_START_DOCSTRING = r"""
|
| 347 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 348 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 349 |
+
etc.)
|
| 350 |
+
|
| 351 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 352 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 353 |
+
and behavior.
|
| 354 |
+
|
| 355 |
+
Parameters:
|
| 356 |
+
config ([`HuskyConfig`]): Model configuration class with all the parameters of the model.
|
| 357 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 358 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
Husky_VISION_INPUTS_DOCSTRING = r"""
|
| 362 |
+
Args:
|
| 363 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 364 |
+
Pixel values. Pixel values can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for
|
| 365 |
+
details.
|
| 366 |
+
output_attentions (`bool`, *optional*):
|
| 367 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 368 |
+
tensors for more detail.
|
| 369 |
+
output_hidden_states (`bool`, *optional*):
|
| 370 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 371 |
+
more detail.
|
| 372 |
+
return_dict (`bool`, *optional*):
|
| 373 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 374 |
+
"""
|
| 375 |
+
|
| 376 |
+
Husky_TEXT_INPUTS_DOCSTRING = r"""
|
| 377 |
+
Args:
|
| 378 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 379 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 380 |
+
it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 381 |
+
[`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
|
| 382 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 383 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 384 |
+
- 1 for tokens that are **not masked**,
|
| 385 |
+
- 0 for tokens that are **masked**.
|
| 386 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 387 |
+
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 388 |
+
Indices of decoder input sequence tokens in the vocabulary.
|
| 389 |
+
|
| 390 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 391 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 392 |
+
|
| 393 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
| 394 |
+
|
| 395 |
+
T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
|
| 396 |
+
is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
|
| 397 |
+
|
| 398 |
+
To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
|
| 399 |
+
Training](./t5#training).
|
| 400 |
+
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 401 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
| 402 |
+
be used by default.
|
| 403 |
+
output_attentions (`bool`, *optional*):
|
| 404 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 405 |
+
tensors for more detail.
|
| 406 |
+
output_hidden_states (`bool`, *optional*):
|
| 407 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 408 |
+
more detail.
|
| 409 |
+
return_dict (`bool`, *optional*):
|
| 410 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 411 |
+
"""
|
| 412 |
+
|
| 413 |
+
Husky_INPUTS_DOCSTRING = r"""
|
| 414 |
+
Args:
|
| 415 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 416 |
+
Pixel values. Pixel values can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for
|
| 417 |
+
details.
|
| 418 |
+
|
| 419 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 420 |
+
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
|
| 421 |
+
provided to serve as text prompt, which the language model can continue.
|
| 422 |
+
|
| 423 |
+
Indices can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for details.
|
| 424 |
+
|
| 425 |
+
[What are input IDs?](../glossary#input-ids)
|
| 426 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 427 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 428 |
+
|
| 429 |
+
- 1 for tokens that are **not masked**,
|
| 430 |
+
- 0 for tokens that are **masked**.
|
| 431 |
+
|
| 432 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 433 |
+
|
| 434 |
+
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 435 |
+
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
|
| 436 |
+
encoder-decoder language model (like T5) is used.
|
| 437 |
+
|
| 438 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 439 |
+
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
|
| 440 |
+
|
| 441 |
+
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 442 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
| 443 |
+
be used by default.
|
| 444 |
+
|
| 445 |
+
Only relevant in case an encoder-decoder language model (like T5) is used.
|
| 446 |
+
|
| 447 |
+
output_attentions (`bool`, *optional*):
|
| 448 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 449 |
+
tensors for more detail.
|
| 450 |
+
output_hidden_states (`bool`, *optional*):
|
| 451 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 452 |
+
more detail.
|
| 453 |
+
return_dict (`bool`, *optional*):
|
| 454 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Husky
|
| 458 |
+
class HuskyEncoder(nn.Module):
|
| 459 |
+
"""
|
| 460 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 461 |
+
[`HuskyEncoderLayer`].
|
| 462 |
+
|
| 463 |
+
Args:
|
| 464 |
+
config (`HuskyConfig`):
|
| 465 |
+
The corresponding vision configuration for the `HuskyEncoder`.
|
| 466 |
+
"""
|
| 467 |
+
|
| 468 |
+
def __init__(self, config: HuskyConfig):
|
| 469 |
+
super().__init__()
|
| 470 |
+
self.config = config
|
| 471 |
+
self.layers = nn.ModuleList([HuskyEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 472 |
+
self.gradient_checkpointing = False
|
| 473 |
+
|
| 474 |
+
def forward(
|
| 475 |
+
self,
|
| 476 |
+
inputs_embeds,
|
| 477 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 478 |
+
output_attentions: Optional[bool] = None,
|
| 479 |
+
output_hidden_states: Optional[bool] = None,
|
| 480 |
+
return_dict: Optional[bool] = None,
|
| 481 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 482 |
+
r"""
|
| 483 |
+
Args:
|
| 484 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 485 |
+
Embedded representation of the inputs. Should be float, not int tokens.
|
| 486 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 487 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 488 |
+
|
| 489 |
+
- 1 for tokens that are **not masked**,
|
| 490 |
+
- 0 for tokens that are **masked**.
|
| 491 |
+
|
| 492 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 493 |
+
output_attentions (`bool`, *optional*):
|
| 494 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 495 |
+
returned tensors for more detail.
|
| 496 |
+
output_hidden_states (`bool`, *optional*):
|
| 497 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 498 |
+
for more detail.
|
| 499 |
+
return_dict (`bool`, *optional*):
|
| 500 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 501 |
+
"""
|
| 502 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 503 |
+
output_hidden_states = (
|
| 504 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 505 |
+
)
|
| 506 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 507 |
+
|
| 508 |
+
encoder_states = () if output_hidden_states else None
|
| 509 |
+
all_attentions = () if output_attentions else None
|
| 510 |
+
|
| 511 |
+
hidden_states = inputs_embeds
|
| 512 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 513 |
+
if output_hidden_states:
|
| 514 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 515 |
+
if self.gradient_checkpointing and self.training:
|
| 516 |
+
|
| 517 |
+
def create_custom_forward(module):
|
| 518 |
+
def custom_forward(*inputs):
|
| 519 |
+
return module(*inputs, output_attentions)
|
| 520 |
+
|
| 521 |
+
return custom_forward
|
| 522 |
+
|
| 523 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 524 |
+
create_custom_forward(encoder_layer),
|
| 525 |
+
hidden_states,
|
| 526 |
+
attention_mask,
|
| 527 |
+
)
|
| 528 |
+
else:
|
| 529 |
+
layer_outputs = encoder_layer(
|
| 530 |
+
hidden_states,
|
| 531 |
+
attention_mask,
|
| 532 |
+
output_attentions=output_attentions,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
hidden_states = layer_outputs[0]
|
| 536 |
+
|
| 537 |
+
if output_attentions:
|
| 538 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 539 |
+
|
| 540 |
+
if output_hidden_states:
|
| 541 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 542 |
+
|
| 543 |
+
if not return_dict:
|
| 544 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
| 545 |
+
return BaseModelOutput(
|
| 546 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Husky, BLIP->Husky
|
| 550 |
+
class HuskyVisionModel(HuskyPreTrainedModel):
|
| 551 |
+
main_input_name = "pixel_values"
|
| 552 |
+
config_class = HuskyVisionConfig
|
| 553 |
+
|
| 554 |
+
def __init__(self, config: HuskyVisionConfig):
|
| 555 |
+
super().__init__(config)
|
| 556 |
+
self.config = config
|
| 557 |
+
embed_dim = config.hidden_size
|
| 558 |
+
|
| 559 |
+
self.embeddings = HuskyVisionEmbeddings(config)
|
| 560 |
+
self.video_embeddings = HuskyVideoEmbeddings(config)
|
| 561 |
+
|
| 562 |
+
self.encoder = HuskyEncoder(config)
|
| 563 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 564 |
+
|
| 565 |
+
self.post_init()
|
| 566 |
+
|
| 567 |
+
@add_start_docstrings_to_model_forward(Husky_VISION_INPUTS_DOCSTRING)
|
| 568 |
+
# @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=HuskyVisionConfig)
|
| 569 |
+
def forward(
|
| 570 |
+
self,
|
| 571 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 572 |
+
output_attentions: Optional[bool] = None,
|
| 573 |
+
output_hidden_states: Optional[bool] = None,
|
| 574 |
+
return_dict: Optional[bool] = None,
|
| 575 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 576 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 577 |
+
output_hidden_states = (
|
| 578 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 579 |
+
)
|
| 580 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 581 |
+
|
| 582 |
+
if pixel_values is None:
|
| 583 |
+
raise ValueError("You have to specify pixel_values")
|
| 584 |
+
|
| 585 |
+
if len(pixel_values.shape) == 4:
|
| 586 |
+
hidden_states = self.embeddings(pixel_values)
|
| 587 |
+
elif len(pixel_values.shape) == 5:
|
| 588 |
+
hidden_states = self.video_embeddings(pixel_values)
|
| 589 |
+
else:
|
| 590 |
+
raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
|
| 591 |
+
|
| 592 |
+
encoder_outputs = self.encoder(
|
| 593 |
+
inputs_embeds=hidden_states,
|
| 594 |
+
output_attentions=output_attentions,
|
| 595 |
+
output_hidden_states=output_hidden_states,
|
| 596 |
+
return_dict=return_dict,
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
last_hidden_state = encoder_outputs[0]
|
| 600 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
| 601 |
+
|
| 602 |
+
pooled_output = last_hidden_state[:, 0, :]
|
| 603 |
+
pooled_output = self.post_layernorm(pooled_output)
|
| 604 |
+
|
| 605 |
+
if not return_dict:
|
| 606 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 607 |
+
|
| 608 |
+
return BaseModelOutputWithPooling(
|
| 609 |
+
last_hidden_state=last_hidden_state,
|
| 610 |
+
pooler_output=pooled_output,
|
| 611 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 612 |
+
attentions=encoder_outputs.attentions,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
def get_input_embeddings(self):
|
| 616 |
+
return self.embeddings
|
| 617 |
+
|
| 618 |
+
def get_video_embeddings(self):
|
| 619 |
+
return self.video_embeddings
|
| 620 |
+
|
| 621 |
+
class HuskyQFormerMultiHeadAttention(nn.Module):
|
| 622 |
+
def __init__(self, config, is_cross_attention=False):
|
| 623 |
+
super().__init__()
|
| 624 |
+
self.config = config
|
| 625 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 626 |
+
raise ValueError(
|
| 627 |
+
"The hidden size (%d) is not a multiple of the number of attention heads (%d)"
|
| 628 |
+
% (config.hidden_size, config.num_attention_heads)
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
self.num_attention_heads = config.num_attention_heads
|
| 632 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 633 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 634 |
+
|
| 635 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 636 |
+
if is_cross_attention:
|
| 637 |
+
self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
|
| 638 |
+
self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
|
| 639 |
+
else:
|
| 640 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 641 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 642 |
+
|
| 643 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 644 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 645 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 646 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 647 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 648 |
+
self.save_attention = False
|
| 649 |
+
|
| 650 |
+
def save_attn_gradients(self, attn_gradients):
|
| 651 |
+
self.attn_gradients = attn_gradients
|
| 652 |
+
|
| 653 |
+
def get_attn_gradients(self):
|
| 654 |
+
return self.attn_gradients
|
| 655 |
+
|
| 656 |
+
def save_attention_map(self, attention_map):
|
| 657 |
+
self.attention_map = attention_map
|
| 658 |
+
|
| 659 |
+
def get_attention_map(self):
|
| 660 |
+
return self.attention_map
|
| 661 |
+
|
| 662 |
+
def transpose_for_scores(self, x):
|
| 663 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 664 |
+
x = x.view(*new_x_shape)
|
| 665 |
+
return x.permute(0, 2, 1, 3)
|
| 666 |
+
|
| 667 |
+
def forward(
|
| 668 |
+
self,
|
| 669 |
+
hidden_states,
|
| 670 |
+
attention_mask=None,
|
| 671 |
+
head_mask=None,
|
| 672 |
+
encoder_hidden_states=None,
|
| 673 |
+
encoder_attention_mask=None,
|
| 674 |
+
past_key_value=None,
|
| 675 |
+
output_attentions=False,
|
| 676 |
+
):
|
| 677 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 678 |
+
# and values come from an encoder; the attention mask needs to be
|
| 679 |
+
# such that the encoder's padding tokens are not attended to.
|
| 680 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 681 |
+
|
| 682 |
+
if is_cross_attention:
|
| 683 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 684 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 685 |
+
attention_mask = encoder_attention_mask
|
| 686 |
+
elif past_key_value is not None:
|
| 687 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 688 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 689 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 690 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 691 |
+
else:
|
| 692 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 693 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 694 |
+
|
| 695 |
+
mixed_query_layer = self.query(hidden_states)
|
| 696 |
+
|
| 697 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 698 |
+
|
| 699 |
+
past_key_value = (key_layer, value_layer)
|
| 700 |
+
|
| 701 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 702 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 703 |
+
|
| 704 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 705 |
+
seq_length = hidden_states.size()[1]
|
| 706 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 707 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 708 |
+
distance = position_ids_l - position_ids_r
|
| 709 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 710 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 711 |
+
|
| 712 |
+
if self.position_embedding_type == "relative_key":
|
| 713 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 714 |
+
attention_scores = attention_scores + relative_position_scores
|
| 715 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 716 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 717 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 718 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 719 |
+
|
| 720 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 721 |
+
|
| 722 |
+
if attention_mask is not None:
|
| 723 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 724 |
+
attention_scores = attention_scores + attention_mask
|
| 725 |
+
|
| 726 |
+
# Normalize the attention scores to probabilities.
|
| 727 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 728 |
+
|
| 729 |
+
if is_cross_attention and self.save_attention:
|
| 730 |
+
self.save_attention_map(attention_probs)
|
| 731 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
| 732 |
+
|
| 733 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 734 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 735 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
| 736 |
+
|
| 737 |
+
# Mask heads if we want to
|
| 738 |
+
if head_mask is not None:
|
| 739 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
| 740 |
+
|
| 741 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
| 742 |
+
|
| 743 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 744 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 745 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 746 |
+
|
| 747 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 748 |
+
|
| 749 |
+
outputs = outputs + (past_key_value,)
|
| 750 |
+
return outputs
|
| 751 |
+
|
| 752 |
+
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->HuskyQFormer
|
| 753 |
+
class HuskyQFormerSelfOutput(nn.Module):
|
| 754 |
+
def __init__(self, config):
|
| 755 |
+
super().__init__()
|
| 756 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 757 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 758 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 759 |
+
|
| 760 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 761 |
+
hidden_states = self.dense(hidden_states)
|
| 762 |
+
hidden_states = self.dropout(hidden_states)
|
| 763 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 764 |
+
return hidden_states
|
| 765 |
+
|
| 766 |
+
class HuskyQFormerAttention(nn.Module):
|
| 767 |
+
def __init__(self, config, is_cross_attention=False):
|
| 768 |
+
super().__init__()
|
| 769 |
+
self.attention = HuskyQFormerMultiHeadAttention(config, is_cross_attention)
|
| 770 |
+
self.output = HuskyQFormerSelfOutput(config)
|
| 771 |
+
self.pruned_heads = set()
|
| 772 |
+
|
| 773 |
+
def prune_heads(self, heads):
|
| 774 |
+
if len(heads) == 0:
|
| 775 |
+
return
|
| 776 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 777 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
| 778 |
+
)
|
| 779 |
+
|
| 780 |
+
# Prune linear layers
|
| 781 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
| 782 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
| 783 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
| 784 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 785 |
+
|
| 786 |
+
# Update hyper params and store pruned heads
|
| 787 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
| 788 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
| 789 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 790 |
+
|
| 791 |
+
def forward(
|
| 792 |
+
self,
|
| 793 |
+
hidden_states: torch.Tensor,
|
| 794 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 795 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 796 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 797 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 798 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 799 |
+
output_attentions: Optional[bool] = False,
|
| 800 |
+
) -> Tuple[torch.Tensor]:
|
| 801 |
+
self_outputs = self.attention(
|
| 802 |
+
hidden_states,
|
| 803 |
+
attention_mask,
|
| 804 |
+
head_mask,
|
| 805 |
+
encoder_hidden_states,
|
| 806 |
+
encoder_attention_mask,
|
| 807 |
+
past_key_value,
|
| 808 |
+
output_attentions,
|
| 809 |
+
)
|
| 810 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 811 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 812 |
+
return outputs
|
| 813 |
+
|
| 814 |
+
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->HuskyQFormer
|
| 815 |
+
class HuskyQFormerIntermediate(nn.Module):
|
| 816 |
+
def __init__(self, config):
|
| 817 |
+
super().__init__()
|
| 818 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 819 |
+
if isinstance(config.hidden_act, str):
|
| 820 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 821 |
+
else:
|
| 822 |
+
self.intermediate_act_fn = config.hidden_act
|
| 823 |
+
|
| 824 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 825 |
+
hidden_states = self.dense(hidden_states)
|
| 826 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 827 |
+
return hidden_states
|
| 828 |
+
|
| 829 |
+
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->HuskyQFormer
|
| 830 |
+
class HuskyQFormerOutput(nn.Module):
|
| 831 |
+
def __init__(self, config):
|
| 832 |
+
super().__init__()
|
| 833 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 834 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 835 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 836 |
+
|
| 837 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 838 |
+
hidden_states = self.dense(hidden_states)
|
| 839 |
+
hidden_states = self.dropout(hidden_states)
|
| 840 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 841 |
+
return hidden_states
|
| 842 |
+
|
| 843 |
+
class HuskyQFormerLayer(nn.Module):
|
| 844 |
+
def __init__(self, config, layer_idx):
|
| 845 |
+
super().__init__()
|
| 846 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 847 |
+
self.seq_len_dim = 1
|
| 848 |
+
self.attention = HuskyQFormerAttention(config)
|
| 849 |
+
|
| 850 |
+
self.layer_idx = layer_idx
|
| 851 |
+
|
| 852 |
+
if layer_idx % config.cross_attention_frequency == 0:
|
| 853 |
+
self.crossattention = HuskyQFormerAttention(config, is_cross_attention=True)
|
| 854 |
+
self.has_cross_attention = True
|
| 855 |
+
else:
|
| 856 |
+
self.has_cross_attention = False
|
| 857 |
+
|
| 858 |
+
self.intermediate_query = HuskyQFormerIntermediate(config)
|
| 859 |
+
self.output_query = HuskyQFormerOutput(config)
|
| 860 |
+
|
| 861 |
+
def forward(
|
| 862 |
+
self,
|
| 863 |
+
hidden_states,
|
| 864 |
+
attention_mask=None,
|
| 865 |
+
head_mask=None,
|
| 866 |
+
encoder_hidden_states=None,
|
| 867 |
+
encoder_attention_mask=None,
|
| 868 |
+
past_key_value=None,
|
| 869 |
+
output_attentions=False,
|
| 870 |
+
query_length=0,
|
| 871 |
+
):
|
| 872 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 873 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 874 |
+
self_attention_outputs = self.attention(
|
| 875 |
+
hidden_states,
|
| 876 |
+
attention_mask,
|
| 877 |
+
head_mask,
|
| 878 |
+
output_attentions=output_attentions,
|
| 879 |
+
past_key_value=self_attn_past_key_value,
|
| 880 |
+
)
|
| 881 |
+
attention_output = self_attention_outputs[0]
|
| 882 |
+
outputs = self_attention_outputs[1:-1]
|
| 883 |
+
|
| 884 |
+
present_key_value = self_attention_outputs[-1]
|
| 885 |
+
|
| 886 |
+
if query_length > 0:
|
| 887 |
+
query_attention_output = attention_output[:, :query_length, :]
|
| 888 |
+
|
| 889 |
+
if self.has_cross_attention:
|
| 890 |
+
if encoder_hidden_states is None:
|
| 891 |
+
raise ValueError("encoder_hidden_states must be given for cross-attention layers")
|
| 892 |
+
cross_attention_outputs = self.crossattention(
|
| 893 |
+
query_attention_output,
|
| 894 |
+
attention_mask,
|
| 895 |
+
head_mask,
|
| 896 |
+
encoder_hidden_states,
|
| 897 |
+
encoder_attention_mask,
|
| 898 |
+
output_attentions=output_attentions,
|
| 899 |
+
)
|
| 900 |
+
query_attention_output = cross_attention_outputs[0]
|
| 901 |
+
# add cross attentions if we output attention weights
|
| 902 |
+
outputs = outputs + cross_attention_outputs[1:-1]
|
| 903 |
+
|
| 904 |
+
layer_output = apply_chunking_to_forward(
|
| 905 |
+
self.feed_forward_chunk_query,
|
| 906 |
+
self.chunk_size_feed_forward,
|
| 907 |
+
self.seq_len_dim,
|
| 908 |
+
query_attention_output,
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
if attention_output.shape[1] > query_length:
|
| 912 |
+
layer_output_text = apply_chunking_to_forward(
|
| 913 |
+
self.feed_forward_chunk,
|
| 914 |
+
self.chunk_size_feed_forward,
|
| 915 |
+
self.seq_len_dim,
|
| 916 |
+
attention_output[:, query_length:, :],
|
| 917 |
+
)
|
| 918 |
+
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
| 919 |
+
else:
|
| 920 |
+
layer_output = apply_chunking_to_forward(
|
| 921 |
+
self.feed_forward_chunk,
|
| 922 |
+
self.chunk_size_feed_forward,
|
| 923 |
+
self.seq_len_dim,
|
| 924 |
+
attention_output,
|
| 925 |
+
)
|
| 926 |
+
outputs = (layer_output,) + outputs
|
| 927 |
+
|
| 928 |
+
outputs = outputs + (present_key_value,)
|
| 929 |
+
|
| 930 |
+
return outputs
|
| 931 |
+
|
| 932 |
+
def feed_forward_chunk(self, attention_output):
|
| 933 |
+
intermediate_output = self.intermediate(attention_output)
|
| 934 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 935 |
+
return layer_output
|
| 936 |
+
|
| 937 |
+
def feed_forward_chunk_query(self, attention_output):
|
| 938 |
+
intermediate_output = self.intermediate_query(attention_output)
|
| 939 |
+
layer_output = self.output_query(intermediate_output, attention_output)
|
| 940 |
+
return layer_output
|
| 941 |
+
|
| 942 |
+
class HuskyQFormerEncoder(nn.Module):
|
| 943 |
+
def __init__(self, config):
|
| 944 |
+
super().__init__()
|
| 945 |
+
self.config = config
|
| 946 |
+
self.layer = nn.ModuleList(
|
| 947 |
+
[HuskyQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 948 |
+
)
|
| 949 |
+
self.gradient_checkpointing = False
|
| 950 |
+
|
| 951 |
+
def forward(
|
| 952 |
+
self,
|
| 953 |
+
hidden_states,
|
| 954 |
+
attention_mask=None,
|
| 955 |
+
head_mask=None,
|
| 956 |
+
encoder_hidden_states=None,
|
| 957 |
+
encoder_attention_mask=None,
|
| 958 |
+
past_key_values=None,
|
| 959 |
+
use_cache=None,
|
| 960 |
+
output_attentions=False,
|
| 961 |
+
output_hidden_states=False,
|
| 962 |
+
return_dict=True,
|
| 963 |
+
query_length=0,
|
| 964 |
+
):
|
| 965 |
+
all_hidden_states = () if output_hidden_states else None
|
| 966 |
+
all_self_attentions = () if output_attentions else None
|
| 967 |
+
all_cross_attentions = () if output_attentions else None
|
| 968 |
+
|
| 969 |
+
next_decoder_cache = () if use_cache else None
|
| 970 |
+
|
| 971 |
+
for i in range(self.config.num_hidden_layers):
|
| 972 |
+
layer_module = self.layer[i]
|
| 973 |
+
if output_hidden_states:
|
| 974 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 975 |
+
|
| 976 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 977 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 978 |
+
|
| 979 |
+
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
| 980 |
+
if use_cache:
|
| 981 |
+
logger.warn(
|
| 982 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 983 |
+
)
|
| 984 |
+
use_cache = False
|
| 985 |
+
|
| 986 |
+
def create_custom_forward(module):
|
| 987 |
+
def custom_forward(*inputs):
|
| 988 |
+
return module(*inputs, past_key_value, output_attentions, query_length)
|
| 989 |
+
|
| 990 |
+
return custom_forward
|
| 991 |
+
|
| 992 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 993 |
+
create_custom_forward(layer_module),
|
| 994 |
+
hidden_states,
|
| 995 |
+
attention_mask,
|
| 996 |
+
layer_head_mask,
|
| 997 |
+
encoder_hidden_states,
|
| 998 |
+
encoder_attention_mask,
|
| 999 |
+
)
|
| 1000 |
+
else:
|
| 1001 |
+
layer_outputs = layer_module(
|
| 1002 |
+
hidden_states,
|
| 1003 |
+
attention_mask,
|
| 1004 |
+
layer_head_mask,
|
| 1005 |
+
encoder_hidden_states,
|
| 1006 |
+
encoder_attention_mask,
|
| 1007 |
+
past_key_value,
|
| 1008 |
+
output_attentions,
|
| 1009 |
+
query_length,
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
hidden_states = layer_outputs[0]
|
| 1013 |
+
if use_cache:
|
| 1014 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 1015 |
+
if output_attentions:
|
| 1016 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 1017 |
+
if layer_module.has_cross_attention:
|
| 1018 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
| 1019 |
+
|
| 1020 |
+
if output_hidden_states:
|
| 1021 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1022 |
+
|
| 1023 |
+
if not return_dict:
|
| 1024 |
+
return tuple(
|
| 1025 |
+
v
|
| 1026 |
+
for v in [
|
| 1027 |
+
hidden_states,
|
| 1028 |
+
next_decoder_cache,
|
| 1029 |
+
all_hidden_states,
|
| 1030 |
+
all_self_attentions,
|
| 1031 |
+
all_cross_attentions,
|
| 1032 |
+
]
|
| 1033 |
+
if v is not None
|
| 1034 |
+
)
|
| 1035 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 1036 |
+
last_hidden_state=hidden_states,
|
| 1037 |
+
past_key_values=next_decoder_cache,
|
| 1038 |
+
hidden_states=all_hidden_states,
|
| 1039 |
+
attentions=all_self_attentions,
|
| 1040 |
+
cross_attentions=all_cross_attentions,
|
| 1041 |
+
)
|
| 1042 |
+
|
| 1043 |
+
class HuskyQFormerModel(HuskyPreTrainedModel):
|
| 1044 |
+
"""
|
| 1045 |
+
Querying Transformer (Q-Former), used in Husky.
|
| 1046 |
+
"""
|
| 1047 |
+
|
| 1048 |
+
def __init__(self, config: HuskyQFormerConfig):
|
| 1049 |
+
super().__init__(config)
|
| 1050 |
+
self.config = config
|
| 1051 |
+
|
| 1052 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 1053 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1054 |
+
|
| 1055 |
+
self.encoder = HuskyQFormerEncoder(config)
|
| 1056 |
+
|
| 1057 |
+
self.post_init()
|
| 1058 |
+
|
| 1059 |
+
def get_input_embeddings(self):
|
| 1060 |
+
return self.embeddings.word_embeddings
|
| 1061 |
+
|
| 1062 |
+
def set_input_embeddings(self, value):
|
| 1063 |
+
self.embeddings.word_embeddings = value
|
| 1064 |
+
|
| 1065 |
+
def _prune_heads(self, heads_to_prune):
|
| 1066 |
+
"""
|
| 1067 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 1068 |
+
class PreTrainedModel
|
| 1069 |
+
"""
|
| 1070 |
+
for layer, heads in heads_to_prune.items():
|
| 1071 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 1072 |
+
|
| 1073 |
+
def get_extended_attention_mask(
|
| 1074 |
+
self,
|
| 1075 |
+
attention_mask: torch.Tensor,
|
| 1076 |
+
input_shape: Tuple[int],
|
| 1077 |
+
device: torch.device,
|
| 1078 |
+
has_query: bool = False,
|
| 1079 |
+
) -> torch.Tensor:
|
| 1080 |
+
"""
|
| 1081 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
| 1082 |
+
|
| 1083 |
+
Arguments:
|
| 1084 |
+
attention_mask (`torch.Tensor`):
|
| 1085 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
| 1086 |
+
input_shape (`Tuple[int]`):
|
| 1087 |
+
The shape of the input to the model.
|
| 1088 |
+
device (`torch.device`):
|
| 1089 |
+
The device of the input to the model.
|
| 1090 |
+
|
| 1091 |
+
Returns:
|
| 1092 |
+
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
|
| 1093 |
+
"""
|
| 1094 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 1095 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 1096 |
+
if attention_mask.dim() == 3:
|
| 1097 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 1098 |
+
elif attention_mask.dim() == 2:
|
| 1099 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 1100 |
+
# - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 1101 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 1102 |
+
else:
|
| 1103 |
+
raise ValueError(
|
| 1104 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
| 1105 |
+
input_shape, attention_mask.shape
|
| 1106 |
+
)
|
| 1107 |
+
)
|
| 1108 |
+
|
| 1109 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 1110 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 1111 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 1112 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 1113 |
+
# effectively the same as removing these entirely.
|
| 1114 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 1115 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 1116 |
+
return extended_attention_mask
|
| 1117 |
+
|
| 1118 |
+
def forward(
|
| 1119 |
+
self,
|
| 1120 |
+
query_embeds,
|
| 1121 |
+
attention_mask=None,
|
| 1122 |
+
head_mask=None,
|
| 1123 |
+
encoder_hidden_states=None,
|
| 1124 |
+
encoder_attention_mask=None,
|
| 1125 |
+
past_key_values=None,
|
| 1126 |
+
use_cache=None,
|
| 1127 |
+
output_attentions=None,
|
| 1128 |
+
output_hidden_states=None,
|
| 1129 |
+
return_dict=None,
|
| 1130 |
+
):
|
| 1131 |
+
r"""
|
| 1132 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 1133 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 1134 |
+
the model is configured as a decoder.
|
| 1135 |
+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
|
| 1136 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 1137 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
| 1138 |
+
- 1 for tokens that are **not masked**,
|
| 1139 |
+
- 0 for tokens that are **masked**.
|
| 1140 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
|
| 1141 |
+
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
|
| 1142 |
+
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
|
| 1143 |
+
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
|
| 1144 |
+
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
|
| 1145 |
+
`(batch_size, sequence_length)`.
|
| 1146 |
+
use_cache (`bool`, `optional`):
|
| 1147 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 1148 |
+
`past_key_values`).
|
| 1149 |
+
"""
|
| 1150 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1151 |
+
output_hidden_states = (
|
| 1152 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1153 |
+
)
|
| 1154 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1155 |
+
|
| 1156 |
+
# past_key_values_length
|
| 1157 |
+
past_key_values_length = (
|
| 1158 |
+
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
|
| 1159 |
+
)
|
| 1160 |
+
|
| 1161 |
+
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
| 1162 |
+
|
| 1163 |
+
embedding_output = self.layernorm(query_embeds)
|
| 1164 |
+
embedding_output = self.dropout(embedding_output)
|
| 1165 |
+
|
| 1166 |
+
input_shape = embedding_output.size()[:-1]
|
| 1167 |
+
batch_size, seq_length = input_shape
|
| 1168 |
+
device = embedding_output.device
|
| 1169 |
+
|
| 1170 |
+
if attention_mask is None:
|
| 1171 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
| 1172 |
+
|
| 1173 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 1174 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 1175 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
| 1176 |
+
|
| 1177 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 1178 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 1179 |
+
if encoder_hidden_states is not None:
|
| 1180 |
+
if type(encoder_hidden_states) == list:
|
| 1181 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
| 1182 |
+
else:
|
| 1183 |
+
(
|
| 1184 |
+
encoder_batch_size,
|
| 1185 |
+
encoder_sequence_length,
|
| 1186 |
+
_,
|
| 1187 |
+
) = encoder_hidden_states.size()
|
| 1188 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 1189 |
+
|
| 1190 |
+
if type(encoder_attention_mask) == list:
|
| 1191 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
| 1192 |
+
elif encoder_attention_mask is None:
|
| 1193 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 1194 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 1195 |
+
else:
|
| 1196 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 1197 |
+
else:
|
| 1198 |
+
encoder_extended_attention_mask = None
|
| 1199 |
+
|
| 1200 |
+
# Prepare head mask if needed
|
| 1201 |
+
# 1.0 in head_mask indicate we keep the head
|
| 1202 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 1203 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 1204 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 1205 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 1206 |
+
|
| 1207 |
+
encoder_outputs = self.encoder(
|
| 1208 |
+
embedding_output,
|
| 1209 |
+
attention_mask=extended_attention_mask,
|
| 1210 |
+
head_mask=head_mask,
|
| 1211 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1212 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 1213 |
+
past_key_values=past_key_values,
|
| 1214 |
+
use_cache=use_cache,
|
| 1215 |
+
output_attentions=output_attentions,
|
| 1216 |
+
output_hidden_states=output_hidden_states,
|
| 1217 |
+
return_dict=return_dict,
|
| 1218 |
+
query_length=query_length,
|
| 1219 |
+
)
|
| 1220 |
+
sequence_output = encoder_outputs[0]
|
| 1221 |
+
pooled_output = sequence_output[:, 0, :]
|
| 1222 |
+
|
| 1223 |
+
if not return_dict:
|
| 1224 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 1225 |
+
|
| 1226 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 1227 |
+
last_hidden_state=sequence_output,
|
| 1228 |
+
pooler_output=pooled_output,
|
| 1229 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 1230 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1231 |
+
attentions=encoder_outputs.attentions,
|
| 1232 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
class AdapterMLP(nn.Module):
|
| 1236 |
+
def __init__(self, config):
|
| 1237 |
+
super().__init__()
|
| 1238 |
+
self.config = config
|
| 1239 |
+
self.activation_fn = ACT2FN["silu"]
|
| 1240 |
+
hidden_size = config.vision_config.hidden_size
|
| 1241 |
+
intermediate_size = hidden_size // 4
|
| 1242 |
+
output_size = config.qformer_config.hidden_size
|
| 1243 |
+
|
| 1244 |
+
self.fc1 = nn.Linear(hidden_size, intermediate_size)
|
| 1245 |
+
self.fc2 = nn.Linear(intermediate_size, output_size)
|
| 1246 |
+
|
| 1247 |
+
# nn.init.trunc_normal_(self.fc1.weight, std=0.02)
|
| 1248 |
+
# nn.init.trunc_normal_(self.fc2.weight, std=0.02)
|
| 1249 |
+
# nn.init.constant_(self.fc1.bias, 0)
|
| 1250 |
+
# nn.init.constant_(self.fc2.bias, 0)
|
| 1251 |
+
|
| 1252 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 1253 |
+
hidden_states = self.fc1(hidden_states)
|
| 1254 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 1255 |
+
hidden_states = self.fc2(hidden_states)
|
| 1256 |
+
return hidden_states
|
| 1257 |
+
|
| 1258 |
+
@add_start_docstrings(
|
| 1259 |
+
"""
|
| 1260 |
+
Husky Model for generating text and image features. The model consists of a vision encoder, Querying Transformer
|
| 1261 |
+
(Q-Former) and a language model.
|
| 1262 |
+
""",
|
| 1263 |
+
Husky_START_DOCSTRING,
|
| 1264 |
+
)
|
| 1265 |
+
class HuskyModel(HuskyPreTrainedModel):
|
| 1266 |
+
config_class = HuskyConfig
|
| 1267 |
+
main_input_name = "pixel_values"
|
| 1268 |
+
|
| 1269 |
+
def __init__(self, config: HuskyConfig):
|
| 1270 |
+
super().__init__(config)
|
| 1271 |
+
|
| 1272 |
+
self.vision_model = HuskyVisionModel(config.vision_config)
|
| 1273 |
+
|
| 1274 |
+
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
| 1275 |
+
self.qformer = HuskyQFormerModel(config.qformer_config)
|
| 1276 |
+
|
| 1277 |
+
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
| 1278 |
+
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
|
| 1279 |
+
|
| 1280 |
+
self.config.hidden_size = config.text_config.hidden_size
|
| 1281 |
+
self.num_queries = config.num_query_tokens
|
| 1282 |
+
self.offset = 5
|
| 1283 |
+
|
| 1284 |
+
# Initialize weights and apply final processing
|
| 1285 |
+
self.post_init()
|
| 1286 |
+
|
| 1287 |
+
def get_input_embeddings(self):
|
| 1288 |
+
return self.language_model.get_input_embeddings()
|
| 1289 |
+
|
| 1290 |
+
def set_input_embeddings(self, value):
|
| 1291 |
+
self.language_model.set_input_embeddings(value)
|
| 1292 |
+
|
| 1293 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1294 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
| 1295 |
+
|
| 1296 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 1297 |
+
return self.language_model.get_output_embeddings()
|
| 1298 |
+
|
| 1299 |
+
def get_encoder(self):
|
| 1300 |
+
return self.language_model.get_encoder()
|
| 1301 |
+
|
| 1302 |
+
def get_decoder(self):
|
| 1303 |
+
return self.language_model.get_decoder()
|
| 1304 |
+
|
| 1305 |
+
def _tie_weights(self):
|
| 1306 |
+
if not self.config.use_decoder_only_language_model:
|
| 1307 |
+
self.language_model.encoder.embed_tokens = self.language_model.shared
|
| 1308 |
+
self.language_model.decoder.embed_tokens = self.language_model.shared
|
| 1309 |
+
|
| 1310 |
+
@add_start_docstrings_to_model_forward(Husky_TEXT_INPUTS_DOCSTRING)
|
| 1311 |
+
def get_text_features(
|
| 1312 |
+
self,
|
| 1313 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1314 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1315 |
+
output_attentions: Optional[bool] = None,
|
| 1316 |
+
output_hidden_states: Optional[bool] = None,
|
| 1317 |
+
return_dict: Optional[bool] = None,
|
| 1318 |
+
):
|
| 1319 |
+
r"""
|
| 1320 |
+
Returns:
|
| 1321 |
+
text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`):
|
| 1322 |
+
The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that
|
| 1323 |
+
contains the language model logits, the past key values and the hidden states if
|
| 1324 |
+
`output_hidden_states=True`.
|
| 1325 |
+
```"""
|
| 1326 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1327 |
+
output_hidden_states = (
|
| 1328 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1329 |
+
)
|
| 1330 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1331 |
+
|
| 1332 |
+
text_outputs = self.language_model(
|
| 1333 |
+
input_ids=input_ids,
|
| 1334 |
+
attention_mask=attention_mask,
|
| 1335 |
+
output_attentions=output_attentions,
|
| 1336 |
+
output_hidden_states=output_hidden_states,
|
| 1337 |
+
return_dict=return_dict,
|
| 1338 |
+
)
|
| 1339 |
+
|
| 1340 |
+
return text_outputs
|
| 1341 |
+
|
| 1342 |
+
@add_start_docstrings_to_model_forward(Husky_VISION_INPUTS_DOCSTRING)
|
| 1343 |
+
def get_image_features(
|
| 1344 |
+
self,
|
| 1345 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1346 |
+
output_attentions: Optional[bool] = None,
|
| 1347 |
+
output_hidden_states: Optional[bool] = None,
|
| 1348 |
+
return_dict: Optional[bool] = None,
|
| 1349 |
+
):
|
| 1350 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1351 |
+
output_hidden_states = (
|
| 1352 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1353 |
+
)
|
| 1354 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1355 |
+
|
| 1356 |
+
vision_outputs = self.vision_model(
|
| 1357 |
+
pixel_values=pixel_values,
|
| 1358 |
+
output_attentions=output_attentions,
|
| 1359 |
+
output_hidden_states=output_hidden_states,
|
| 1360 |
+
return_dict=return_dict,
|
| 1361 |
+
)
|
| 1362 |
+
|
| 1363 |
+
return vision_outputs
|
| 1364 |
+
|
| 1365 |
+
@add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
|
| 1366 |
+
def get_qformer_features(
|
| 1367 |
+
self,
|
| 1368 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1369 |
+
output_attentions: Optional[bool] = None,
|
| 1370 |
+
output_hidden_states: Optional[bool] = None,
|
| 1371 |
+
return_dict: Optional[bool] = None,
|
| 1372 |
+
):
|
| 1373 |
+
r"""
|
| 1374 |
+
Returns:
|
| 1375 |
+
vision_outputs (`BaseModelOutputWithPooling` or tuple of `torch.FloatTensor`):
|
| 1376 |
+
The vision model outputs. If `return_dict=True`, the output is a [`BaseModelOutputWithPooling`] that
|
| 1377 |
+
contains the image features, the pooled image features and the hidden states if
|
| 1378 |
+
`output_hidden_states=True`.
|
| 1379 |
+
"""
|
| 1380 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1381 |
+
output_hidden_states = (
|
| 1382 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1383 |
+
)
|
| 1384 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1385 |
+
|
| 1386 |
+
vision_outputs = self.vision_model(
|
| 1387 |
+
pixel_values=pixel_values,
|
| 1388 |
+
output_attentions=output_attentions,
|
| 1389 |
+
output_hidden_states=output_hidden_states,
|
| 1390 |
+
return_dict=return_dict,
|
| 1391 |
+
)
|
| 1392 |
+
|
| 1393 |
+
image_embeds = vision_outputs[0]
|
| 1394 |
+
|
| 1395 |
+
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
| 1396 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1397 |
+
|
| 1398 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1399 |
+
query_outputs = self.qformer(
|
| 1400 |
+
query_embeds=query_tokens,
|
| 1401 |
+
encoder_hidden_states=image_embeds,
|
| 1402 |
+
encoder_attention_mask=image_attention_mask,
|
| 1403 |
+
output_attentions=output_attentions,
|
| 1404 |
+
output_hidden_states=output_hidden_states,
|
| 1405 |
+
return_dict=return_dict,
|
| 1406 |
+
)
|
| 1407 |
+
|
| 1408 |
+
return query_outputs
|
| 1409 |
+
|
| 1410 |
+
@add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
|
| 1411 |
+
# @replace_return_docstrings(output_type=HuskyForConditionalGenerationModelOutput, config_class=HuskyVisionConfig)
|
| 1412 |
+
def forward(
|
| 1413 |
+
self,
|
| 1414 |
+
pixel_values: torch.FloatTensor,
|
| 1415 |
+
input_ids: torch.FloatTensor,
|
| 1416 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1417 |
+
output_attentions: Optional[bool] = None,
|
| 1418 |
+
output_hidden_states: Optional[bool] = None,
|
| 1419 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1420 |
+
return_dict: Optional[bool] = None,
|
| 1421 |
+
) -> Union[Tuple, HuskyForConditionalGenerationModelOutput]:
|
| 1422 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1423 |
+
|
| 1424 |
+
# step 1: forward the images through the vision encoder,
|
| 1425 |
+
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
| 1426 |
+
vision_outputs = self.vision_model(
|
| 1427 |
+
pixel_values=pixel_values,
|
| 1428 |
+
output_attentions=output_attentions,
|
| 1429 |
+
output_hidden_states=output_hidden_states,
|
| 1430 |
+
return_dict=return_dict,
|
| 1431 |
+
)
|
| 1432 |
+
image_embeds = vision_outputs[0]
|
| 1433 |
+
|
| 1434 |
+
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
| 1435 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1436 |
+
|
| 1437 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1438 |
+
query_outputs = self.qformer(
|
| 1439 |
+
query_embeds=query_tokens,
|
| 1440 |
+
encoder_hidden_states=image_embeds,
|
| 1441 |
+
encoder_attention_mask=image_attention_mask,
|
| 1442 |
+
output_attentions=output_attentions,
|
| 1443 |
+
output_hidden_states=output_hidden_states,
|
| 1444 |
+
return_dict=return_dict,
|
| 1445 |
+
)
|
| 1446 |
+
query_output = query_outputs[0]
|
| 1447 |
+
|
| 1448 |
+
# step 3: use the language model, conditioned on the query outputs and the prompt
|
| 1449 |
+
language_model_inputs = self.language_projection(query_output)
|
| 1450 |
+
assert language_model_inputs.shape[1] == self.num_queries
|
| 1451 |
+
|
| 1452 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 1453 |
+
# Human: <img><IMAGE></img>. Give the describe Assistant:
|
| 1454 |
+
# position of <image>: [offset: offset+num_queries]
|
| 1455 |
+
|
| 1456 |
+
inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs
|
| 1457 |
+
if attention_mask is None:
|
| 1458 |
+
attention_mask = torch.ones_like(
|
| 1459 |
+
input_ids, dtype=torch.long, device=language_model_inputs.device)
|
| 1460 |
+
|
| 1461 |
+
outputs = self.language_model(
|
| 1462 |
+
inputs_embeds=inputs_embeds,
|
| 1463 |
+
attention_mask=attention_mask,
|
| 1464 |
+
output_attentions=output_attentions,
|
| 1465 |
+
output_hidden_states=output_hidden_states,
|
| 1466 |
+
return_dict=return_dict,
|
| 1467 |
+
)
|
| 1468 |
+
logits = outputs.logits if return_dict else outputs[0]
|
| 1469 |
+
loss = None
|
| 1470 |
+
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
| 1471 |
+
if labels is not None:
|
| 1472 |
+
labels = labels.to(logits.device)
|
| 1473 |
+
logits = logits[:, -labels.size(1):, :]
|
| 1474 |
+
# Shift so that tokens < n predict n
|
| 1475 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1476 |
+
shift_labels = labels[..., 1:].contiguous().to(logits.device)
|
| 1477 |
+
|
| 1478 |
+
# Flatten the tokens
|
| 1479 |
+
loss_fct = CrossEntropyLoss(reduction="mean")
|
| 1480 |
+
|
| 1481 |
+
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
|
| 1482 |
+
|
| 1483 |
+
if not return_dict:
|
| 1484 |
+
output = (logits, vision_outputs, query_outputs, outputs)
|
| 1485 |
+
return ((loss,) + output) if loss is not None else output
|
| 1486 |
+
|
| 1487 |
+
return HuskyForConditionalGenerationModelOutput(
|
| 1488 |
+
loss=loss,
|
| 1489 |
+
logits=logits,
|
| 1490 |
+
vision_outputs=vision_outputs,
|
| 1491 |
+
qformer_outputs=query_outputs,
|
| 1492 |
+
language_model_outputs=outputs,
|
| 1493 |
+
)
|
| 1494 |
+
|
| 1495 |
+
@add_start_docstrings(
|
| 1496 |
+
"""
|
| 1497 |
+
Husky Model for generating text given an image and an optional text prompt. The model consists of a vision
|
| 1498 |
+
encoder, Querying Transformer (Q-Former) and a language model.
|
| 1499 |
+
|
| 1500 |
+
One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
|
| 1501 |
+
the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
|
| 1502 |
+
""",
|
| 1503 |
+
Husky_START_DOCSTRING,
|
| 1504 |
+
)
|
| 1505 |
+
class HuskyForConditionalGeneration(HuskyPreTrainedModel):
|
| 1506 |
+
config_class = HuskyConfig
|
| 1507 |
+
main_input_name = "pixel_values"
|
| 1508 |
+
|
| 1509 |
+
def __init__(self, config: HuskyConfig):
|
| 1510 |
+
super().__init__(config)
|
| 1511 |
+
|
| 1512 |
+
self.vision_model = HuskyVisionModel(config.vision_config)
|
| 1513 |
+
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
| 1514 |
+
self.qformer = HuskyQFormerModel(config.qformer_config)
|
| 1515 |
+
|
| 1516 |
+
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
| 1517 |
+
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
|
| 1518 |
+
|
| 1519 |
+
self.config.hidden_size = config.text_config.hidden_size
|
| 1520 |
+
self.num_queries = config.num_query_tokens
|
| 1521 |
+
self.offset = 5
|
| 1522 |
+
|
| 1523 |
+
self.vision_adapter = AdapterMLP(config)
|
| 1524 |
+
self.layer_norms = nn.ModuleList()
|
| 1525 |
+
for i in range(4):
|
| 1526 |
+
self.layer_norms.append(
|
| 1527 |
+
nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
|
| 1528 |
+
)
|
| 1529 |
+
|
| 1530 |
+
# Initialize weights and apply final processing
|
| 1531 |
+
self.post_init()
|
| 1532 |
+
|
| 1533 |
+
def get_input_embeddings(self):
|
| 1534 |
+
return self.language_model.get_input_embeddings()
|
| 1535 |
+
|
| 1536 |
+
def set_input_embeddings(self, value):
|
| 1537 |
+
self.language_model.set_input_embeddings(value)
|
| 1538 |
+
|
| 1539 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1540 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
| 1541 |
+
|
| 1542 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 1543 |
+
return self.language_model.get_output_embeddings()
|
| 1544 |
+
|
| 1545 |
+
def get_encoder(self):
|
| 1546 |
+
return self.language_model.get_encoder()
|
| 1547 |
+
|
| 1548 |
+
def get_decoder(self):
|
| 1549 |
+
return self.language_model.get_decoder()
|
| 1550 |
+
|
| 1551 |
+
def extract_feature(
|
| 1552 |
+
self,
|
| 1553 |
+
pixel_values: torch.FloatTensor,
|
| 1554 |
+
):
|
| 1555 |
+
vision_outputs = self.vision_model(
|
| 1556 |
+
pixel_values=pixel_values,
|
| 1557 |
+
output_hidden_states=True,
|
| 1558 |
+
)
|
| 1559 |
+
image_embeds = vision_outputs[0]
|
| 1560 |
+
|
| 1561 |
+
depth = len(vision_outputs[2])
|
| 1562 |
+
indices = range(depth // 4 - 1, depth, depth // 4)
|
| 1563 |
+
pooled_outputs = []
|
| 1564 |
+
for idx, layer_norm in zip(indices, self.layer_norms):
|
| 1565 |
+
pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
|
| 1566 |
+
pool_output = layer_norm(pool_output)
|
| 1567 |
+
pooled_outputs.append(pool_output)
|
| 1568 |
+
|
| 1569 |
+
pooled_outputs = torch.cat(pooled_outputs, dim=1)
|
| 1570 |
+
pooled_outputs = self.vision_adapter(pooled_outputs)
|
| 1571 |
+
|
| 1572 |
+
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
| 1573 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1574 |
+
|
| 1575 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1576 |
+
query_outputs = self.qformer(
|
| 1577 |
+
query_embeds=query_tokens,
|
| 1578 |
+
encoder_hidden_states=image_embeds,
|
| 1579 |
+
encoder_attention_mask=image_attention_mask
|
| 1580 |
+
)
|
| 1581 |
+
query_output = query_outputs[0]
|
| 1582 |
+
query_output = torch.cat([query_output, pooled_outputs], dim=1)
|
| 1583 |
+
language_model_inputs = self.language_projection(query_output)
|
| 1584 |
+
|
| 1585 |
+
return language_model_inputs
|
| 1586 |
+
|
| 1587 |
+
def _tie_weights(self):
|
| 1588 |
+
if not self.config.use_decoder_only_language_model:
|
| 1589 |
+
self.language_model.encoder.embed_tokens = self.language_model.shared
|
| 1590 |
+
self.language_model.decoder.embed_tokens = self.language_model.shared
|
| 1591 |
+
|
| 1592 |
+
def _preprocess_accelerate(self):
|
| 1593 |
+
r"""
|
| 1594 |
+
Some pre-processing hacks to make the model `accelerate` compatible. Check
|
| 1595 |
+
https://github.com/huggingface/transformers/pull/21707 for more details.
|
| 1596 |
+
"""
|
| 1597 |
+
hf_device_map = self.hf_device_map
|
| 1598 |
+
|
| 1599 |
+
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
|
| 1600 |
+
# warn users about unexpected behavior when using multi-GPU + Husky + `accelerate`.
|
| 1601 |
+
logger.warning(
|
| 1602 |
+
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
|
| 1603 |
+
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
|
| 1604 |
+
" Please pass a `device_map` that contains `language_model` to remove this warning."
|
| 1605 |
+
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
|
| 1606 |
+
" more details on creating a `device_map` for large models.",
|
| 1607 |
+
)
|
| 1608 |
+
|
| 1609 |
+
if hasattr(self.language_model, "_hf_hook"):
|
| 1610 |
+
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
| 1611 |
+
|
| 1612 |
+
@add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
|
| 1613 |
+
# @replace_return_docstrings(output_type=HuskyForConditionalGenerationModelOutput, config_class=HuskyVisionConfig)
|
| 1614 |
+
def forward(
|
| 1615 |
+
self,
|
| 1616 |
+
pixel_values: torch.FloatTensor,
|
| 1617 |
+
input_ids: torch.FloatTensor,
|
| 1618 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1619 |
+
output_attentions: Optional[bool] = None,
|
| 1620 |
+
output_hidden_states: Optional[bool] = None,
|
| 1621 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1622 |
+
return_dict: Optional[bool] = None,
|
| 1623 |
+
) -> Union[Tuple, HuskyForConditionalGenerationModelOutput]:
|
| 1624 |
+
|
| 1625 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1626 |
+
|
| 1627 |
+
# step 1: forward the images through the vision encoder,
|
| 1628 |
+
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
| 1629 |
+
batch_size = input_ids.shape[0]
|
| 1630 |
+
vision_outputs = self.vision_model(
|
| 1631 |
+
pixel_values=pixel_values,
|
| 1632 |
+
output_attentions=output_attentions,
|
| 1633 |
+
output_hidden_states=True,
|
| 1634 |
+
return_dict=return_dict,
|
| 1635 |
+
)
|
| 1636 |
+
image_embeds = vision_outputs[0]
|
| 1637 |
+
|
| 1638 |
+
depth = len(vision_outputs[2])
|
| 1639 |
+
indices = range(depth // 4 - 1, depth, depth // 4)
|
| 1640 |
+
pooled_outputs = []
|
| 1641 |
+
for idx, layer_norm in zip(indices, self.layer_norms):
|
| 1642 |
+
pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
|
| 1643 |
+
pool_output = layer_norm(pool_output)
|
| 1644 |
+
pooled_outputs.append(pool_output)
|
| 1645 |
+
|
| 1646 |
+
pooled_outputs = torch.cat(pooled_outputs, dim=1)
|
| 1647 |
+
pooled_outputs = self.vision_adapter(pooled_outputs)
|
| 1648 |
+
|
| 1649 |
+
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
| 1650 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1651 |
+
|
| 1652 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1653 |
+
query_outputs = self.qformer(
|
| 1654 |
+
query_embeds=query_tokens,
|
| 1655 |
+
encoder_hidden_states=image_embeds,
|
| 1656 |
+
encoder_attention_mask=image_attention_mask,
|
| 1657 |
+
output_attentions=output_attentions,
|
| 1658 |
+
output_hidden_states=output_hidden_states,
|
| 1659 |
+
return_dict=return_dict,
|
| 1660 |
+
)
|
| 1661 |
+
query_output = query_outputs[0]
|
| 1662 |
+
query_output = torch.cat([query_output, pooled_outputs], dim=1)
|
| 1663 |
+
|
| 1664 |
+
# step 3: use the language model, conditioned on the query outputs and the prompt
|
| 1665 |
+
language_model_inputs = self.language_projection(query_output)
|
| 1666 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 1667 |
+
# Human: <img></img>. Give the describe Assistant:
|
| 1668 |
+
# position of <image>: [offset: offset+num_queries]
|
| 1669 |
+
|
| 1670 |
+
# inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs
|
| 1671 |
+
prefix_embeds = inputs_embeds[:, :self.offset, :]
|
| 1672 |
+
postfix_embeds = inputs_embeds[:, self.offset:, :]
|
| 1673 |
+
inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)
|
| 1674 |
+
|
| 1675 |
+
if attention_mask is None:
|
| 1676 |
+
attention_mask = torch.ones_like(
|
| 1677 |
+
inputs_embeds, dtype=torch.long, device=language_model_inputs.device)
|
| 1678 |
+
else:
|
| 1679 |
+
prefix_mask = attention_mask[:, :self.offset]
|
| 1680 |
+
postfix_mask = attention_mask[:, self.offset:]
|
| 1681 |
+
vision_mask = torch.ones(size=(batch_size, self.num_queries + 4), dtype=torch.long,
|
| 1682 |
+
device=attention_mask.device)
|
| 1683 |
+
attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)
|
| 1684 |
+
|
| 1685 |
+
outputs = self.language_model(
|
| 1686 |
+
inputs_embeds=inputs_embeds,
|
| 1687 |
+
attention_mask=attention_mask,
|
| 1688 |
+
output_attentions=output_attentions,
|
| 1689 |
+
output_hidden_states=output_hidden_states,
|
| 1690 |
+
return_dict=return_dict,
|
| 1691 |
+
)
|
| 1692 |
+
logits = outputs.logits if return_dict else outputs[0]
|
| 1693 |
+
loss = None
|
| 1694 |
+
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
| 1695 |
+
if labels is not None:
|
| 1696 |
+
labels = labels.to(logits.device)
|
| 1697 |
+
logits = logits[:, -labels.size(1):, :]
|
| 1698 |
+
# Shift so that tokens < n predict n
|
| 1699 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1700 |
+
shift_labels = labels[..., 1:].contiguous().to(logits.device)
|
| 1701 |
+
|
| 1702 |
+
# Flatten the tokens
|
| 1703 |
+
loss_fct = CrossEntropyLoss(reduction="mean")
|
| 1704 |
+
|
| 1705 |
+
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
|
| 1706 |
+
|
| 1707 |
+
if not return_dict:
|
| 1708 |
+
output = (logits, vision_outputs, query_outputs, outputs)
|
| 1709 |
+
return ((loss,) + output) if loss is not None else output
|
| 1710 |
+
|
| 1711 |
+
return HuskyForConditionalGenerationModelOutput(
|
| 1712 |
+
loss=loss,
|
| 1713 |
+
logits=logits,
|
| 1714 |
+
vision_outputs=vision_outputs,
|
| 1715 |
+
qformer_outputs=query_outputs,
|
| 1716 |
+
language_model_outputs=outputs,
|
| 1717 |
+
)
|
| 1718 |
+
|
| 1719 |
+
@torch.no_grad()
|
| 1720 |
+
def generate(
|
| 1721 |
+
self,
|
| 1722 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1723 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1724 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1725 |
+
language_model_inputs: Optional[torch.FloatTensor] = None,
|
| 1726 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 1727 |
+
**generate_kwargs,
|
| 1728 |
+
) -> torch.LongTensor:
|
| 1729 |
+
"""
|
| 1730 |
+
Overrides `generate` function to be able to use the model as a conditional generator.
|
| 1731 |
+
|
| 1732 |
+
Args:
|
| 1733 |
+
pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
|
| 1734 |
+
Input images to be processed.
|
| 1735 |
+
input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
| 1736 |
+
The sequence used as a prompt for the generation.
|
| 1737 |
+
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
| 1738 |
+
Mask to avoid performing attention on padding token indices
|
| 1739 |
+
language_model_inputs (`torch.LongTensor` of shape (batch_size, sequence_length, num_channel), *optional*):
|
| 1740 |
+
The sequence used as the input for the generation
|
| 1741 |
+
language_model_inputs (`torch.LongTensor` of shape (batch_size, sequence_length, num_channel), *optional*):
|
| 1742 |
+
The sequence used as the input for the generation
|
| 1743 |
+
generation_config (`~generation.GenerationConfig`, *optional*):
|
| 1744 |
+
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
| 1745 |
+
passed to generate matching the attributes of `generation_config` will override them. If
|
| 1746 |
+
`generation_config` is not provided, the default will be used, which had the following loading
|
| 1747 |
+
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
| 1748 |
+
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
| 1749 |
+
default values, whose documentation should be checked to parameterize generation.
|
| 1750 |
+
|
| 1751 |
+
Returns:
|
| 1752 |
+
captions (list): A list of strings of length batch_size * num_captions.
|
| 1753 |
+
"""
|
| 1754 |
+
if hasattr(self, "hf_device_map"):
|
| 1755 |
+
# preprocess for `accelerate`
|
| 1756 |
+
self._preprocess_accelerate()
|
| 1757 |
+
|
| 1758 |
+
if language_model_inputs is None:
|
| 1759 |
+
vision_outputs = self.vision_model(
|
| 1760 |
+
pixel_values=pixel_values,
|
| 1761 |
+
output_hidden_states=True,
|
| 1762 |
+
)
|
| 1763 |
+
image_embeds = vision_outputs[0]
|
| 1764 |
+
|
| 1765 |
+
depth = len(vision_outputs[2])
|
| 1766 |
+
indices = range(depth // 4 - 1, depth, depth // 4)
|
| 1767 |
+
pooled_outputs = []
|
| 1768 |
+
for idx, layer_norm in zip(indices, self.layer_norms):
|
| 1769 |
+
pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
|
| 1770 |
+
pool_output = layer_norm(pool_output)
|
| 1771 |
+
pooled_outputs.append(pool_output)
|
| 1772 |
+
|
| 1773 |
+
pooled_outputs = torch.cat(pooled_outputs, dim=1)
|
| 1774 |
+
pooled_outputs = self.vision_adapter(pooled_outputs)
|
| 1775 |
+
|
| 1776 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1777 |
+
|
| 1778 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1779 |
+
query_outputs = self.qformer(
|
| 1780 |
+
query_embeds=query_tokens,
|
| 1781 |
+
encoder_hidden_states=image_embeds,
|
| 1782 |
+
encoder_attention_mask=image_attention_mask,
|
| 1783 |
+
)
|
| 1784 |
+
query_output = query_outputs[0]
|
| 1785 |
+
query_output = torch.cat([query_output, pooled_outputs], dim=1)
|
| 1786 |
+
|
| 1787 |
+
language_model_inputs = self.language_projection(query_output)
|
| 1788 |
+
|
| 1789 |
+
batch_size = language_model_inputs.shape[0]
|
| 1790 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 1791 |
+
|
| 1792 |
+
prefix_embeds = inputs_embeds[:, :self.offset, :]
|
| 1793 |
+
postfix_embeds = inputs_embeds[:, self.offset:, :]
|
| 1794 |
+
inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)
|
| 1795 |
+
|
| 1796 |
+
if input_ids is None:
|
| 1797 |
+
input_ids = (
|
| 1798 |
+
torch.LongTensor([[self.config.text_config.bos_token_id]])
|
| 1799 |
+
.repeat(batch_size, 1)
|
| 1800 |
+
.to(inputs_embeds.device)
|
| 1801 |
+
)
|
| 1802 |
+
|
| 1803 |
+
if attention_mask is None:
|
| 1804 |
+
attention_mask = torch.ones_like(
|
| 1805 |
+
input_ids, dtype=torch.long, device=language_model_inputs.device)
|
| 1806 |
+
else:
|
| 1807 |
+
prefix_mask = attention_mask[:, :self.offset]
|
| 1808 |
+
postfix_mask = attention_mask[:, self.offset:]
|
| 1809 |
+
vision_mask = torch.ones(size=(batch_size, self.num_queries + 4), dtype=torch.long,
|
| 1810 |
+
device=attention_mask.device)
|
| 1811 |
+
attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)
|
| 1812 |
+
|
| 1813 |
+
outputs = self.language_model.generate(
|
| 1814 |
+
inputs_embeds=inputs_embeds,
|
| 1815 |
+
attention_mask=attention_mask,
|
| 1816 |
+
generation_config=generation_config,
|
| 1817 |
+
**generate_kwargs,
|
| 1818 |
+
)
|
| 1819 |
+
|
| 1820 |
+
return outputs
|
robohusky/model/modeling_husky_embody2.py
ADDED
|
@@ -0,0 +1,1962 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
""" PyTorch Husky model."""
|
| 16 |
+
|
| 17 |
+
import contextlib
|
| 18 |
+
import math
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Any, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.utils.checkpoint
|
| 24 |
+
from torch import nn
|
| 25 |
+
from torch.nn import CrossEntropyLoss
|
| 26 |
+
|
| 27 |
+
from transformers.activations import ACT2FN
|
| 28 |
+
from transformers.modeling_outputs import (
|
| 29 |
+
BaseModelOutput,
|
| 30 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 31 |
+
BaseModelOutputWithPooling,
|
| 32 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 33 |
+
)
|
| 34 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 35 |
+
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
| 36 |
+
from transformers.utils import (
|
| 37 |
+
ModelOutput,
|
| 38 |
+
add_start_docstrings,
|
| 39 |
+
add_start_docstrings_to_model_forward,
|
| 40 |
+
logging,
|
| 41 |
+
replace_return_docstrings,
|
| 42 |
+
is_flash_attn_available
|
| 43 |
+
)
|
| 44 |
+
from transformers import AutoModelForCausalLM, GenerationConfig
|
| 45 |
+
|
| 46 |
+
from .configuration_husky import HuskyConfig, HuskyQFormerConfig, HuskyVisionConfig
|
| 47 |
+
|
| 48 |
+
if is_flash_attn_available():
|
| 49 |
+
from flash_attn import flash_attn_func
|
| 50 |
+
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
from apex.normalization import FusedLayerNorm as LayerNorm
|
| 54 |
+
except ImportError:
|
| 55 |
+
from torch.nn import LayerNorm as LayerNorm
|
| 56 |
+
|
| 57 |
+
logger = logging.get_logger(__name__)
|
| 58 |
+
|
| 59 |
+
_CHECKPOINT_FOR_DOC = "wofmanaf/husky-7b"
|
| 60 |
+
|
| 61 |
+
HUSKY_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 62 |
+
"wofmanaf/husky-7b",
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
@dataclass
|
| 66 |
+
class HuskyForConditionalGenerationModelOutput(ModelOutput):
|
| 67 |
+
"""
|
| 68 |
+
Class defining the outputs of [`HuskyForConditionalGeneration`].
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
| 72 |
+
Language modeling loss from the language model.
|
| 73 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 74 |
+
Prediction scores of the language modeling head of the language model.
|
| 75 |
+
vision_outputs (`BaseModelOutputWithPooling`):
|
| 76 |
+
Outputs of the vision encoder.
|
| 77 |
+
qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
|
| 78 |
+
Outputs of the Q-Former (Querying Transformer).
|
| 79 |
+
language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
|
| 80 |
+
Outputs of the language model.
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
loss: Optional[Tuple[torch.FloatTensor]] = None
|
| 84 |
+
logits: Optional[Tuple[torch.FloatTensor]] = None
|
| 85 |
+
vision_outputs: Optional[torch.FloatTensor] = None
|
| 86 |
+
qformer_outputs: Optional[Tuple[torch.FloatTensor]] = None
|
| 87 |
+
language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None
|
| 88 |
+
|
| 89 |
+
def to_tuple(self) -> Tuple[Any]:
|
| 90 |
+
return tuple(
|
| 91 |
+
self[k]
|
| 92 |
+
if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
|
| 93 |
+
else getattr(self, k).to_tuple()
|
| 94 |
+
for k in self.keys()
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
# Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Husky
|
| 98 |
+
class HuskyVisionEmbeddings(nn.Module):
|
| 99 |
+
def __init__(self, config: HuskyVisionConfig):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.config = config
|
| 102 |
+
self.embed_dim = config.hidden_size
|
| 103 |
+
self.image_size = config.image_size
|
| 104 |
+
self.patch_size = config.patch_size
|
| 105 |
+
|
| 106 |
+
self.class_embedding = nn.Parameter(
|
| 107 |
+
torch.randn(1, 1, self.embed_dim),
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
self.patch_embedding = nn.Conv2d(
|
| 111 |
+
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 115 |
+
self.num_positions = self.num_patches + 1
|
| 116 |
+
|
| 117 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
| 118 |
+
|
| 119 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 120 |
+
batch_size = pixel_values.shape[0]
|
| 121 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 122 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
| 123 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 124 |
+
|
| 125 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
| 126 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 127 |
+
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
|
| 128 |
+
return embeddings
|
| 129 |
+
|
| 130 |
+
class HuskyVideoEmbeddings(nn.Module):
|
| 131 |
+
def __init__(self, config: HuskyVisionConfig):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.config = config
|
| 134 |
+
self.embed_dim = config.hidden_size
|
| 135 |
+
self.image_size = config.image_size
|
| 136 |
+
self.patch_size = config.patch_size
|
| 137 |
+
self.num_frames = getattr(self.config, "num_frames", 8)
|
| 138 |
+
self.frame_stride = getattr(self.config, "frame_stride", 2)
|
| 139 |
+
|
| 140 |
+
self.class_embedding = nn.Parameter(
|
| 141 |
+
torch.randn(1, 1, self.embed_dim),
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.patch_embedding = nn.Conv3d(
|
| 145 |
+
in_channels=3, out_channels=self.embed_dim,
|
| 146 |
+
kernel_size=(self.frame_stride, self.patch_size, self.patch_size),
|
| 147 |
+
stride=(self.frame_stride, self.patch_size, self.patch_size)
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
self.num_patches = int(self.num_frames // self.frame_stride) * (self.image_size // self.patch_size) ** 2
|
| 151 |
+
self.num_positions = self.num_patches + 1
|
| 152 |
+
|
| 153 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
| 154 |
+
|
| 155 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 156 |
+
batch_size = pixel_values.shape[0]
|
| 157 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 158 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
| 159 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 160 |
+
|
| 161 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
| 162 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 163 |
+
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
|
| 164 |
+
return embeddings
|
| 165 |
+
|
| 166 |
+
class HuskyAttention(nn.Module):
|
| 167 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 168 |
+
|
| 169 |
+
def __init__(self, config):
|
| 170 |
+
super().__init__()
|
| 171 |
+
self.config = config
|
| 172 |
+
self.embed_dim = config.hidden_size
|
| 173 |
+
self.num_heads = config.num_attention_heads
|
| 174 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 175 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 176 |
+
raise ValueError(
|
| 177 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 178 |
+
f" {self.num_heads})."
|
| 179 |
+
)
|
| 180 |
+
self.scale = self.head_dim ** -0.5
|
| 181 |
+
self.dropout = nn.Dropout(config.attention_dropout)
|
| 182 |
+
|
| 183 |
+
# small tweak here compared to CLIP, no bias here
|
| 184 |
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
|
| 185 |
+
|
| 186 |
+
if config.qkv_bias:
|
| 187 |
+
q_bias = nn.Parameter(torch.zeros(self.embed_dim))
|
| 188 |
+
v_bias = nn.Parameter(torch.zeros(self.embed_dim))
|
| 189 |
+
else:
|
| 190 |
+
q_bias = None
|
| 191 |
+
v_bias = None
|
| 192 |
+
|
| 193 |
+
if q_bias is not None:
|
| 194 |
+
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
|
| 195 |
+
self.qkv.bias = nn.Parameter(qkv_bias)
|
| 196 |
+
|
| 197 |
+
self.projection = nn.Linear(self.embed_dim, self.embed_dim)
|
| 198 |
+
|
| 199 |
+
def _shape(self, tensor: torch.Tensor, bsz: int, seq_len: int):
|
| 200 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 201 |
+
|
| 202 |
+
def forward(
|
| 203 |
+
self,
|
| 204 |
+
hidden_states: torch.Tensor,
|
| 205 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 206 |
+
output_attentions: Optional[bool] = False,
|
| 207 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 208 |
+
"""Input shape: Batch x Time x Channel"""
|
| 209 |
+
|
| 210 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
| 211 |
+
|
| 212 |
+
mixed_qkv = self.qkv(hidden_states)
|
| 213 |
+
|
| 214 |
+
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
|
| 215 |
+
2, 0, 3, 1, 4
|
| 216 |
+
)
|
| 217 |
+
query_states, key_states, value_states = (
|
| 218 |
+
mixed_qkv[0],
|
| 219 |
+
mixed_qkv[1],
|
| 220 |
+
mixed_qkv[2],
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 224 |
+
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
| 225 |
+
|
| 226 |
+
attention_scores = attention_scores * self.scale
|
| 227 |
+
|
| 228 |
+
# Normalize the attention scores to probabilities.
|
| 229 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 230 |
+
|
| 231 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 232 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 233 |
+
attention_probs = self.dropout(attention_probs)
|
| 234 |
+
|
| 235 |
+
# Mask heads if we want to
|
| 236 |
+
if head_mask is not None:
|
| 237 |
+
attention_probs = attention_probs * head_mask
|
| 238 |
+
|
| 239 |
+
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
|
| 240 |
+
|
| 241 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
|
| 242 |
+
context_layer = context_layer.reshape(new_context_layer_shape)
|
| 243 |
+
|
| 244 |
+
output = self.projection(context_layer)
|
| 245 |
+
|
| 246 |
+
outputs = (output, attention_probs) if output_attentions else (output, None)
|
| 247 |
+
|
| 248 |
+
return outputs
|
| 249 |
+
|
| 250 |
+
class HuskyFlashAttention2(HuskyAttention):
|
| 251 |
+
"""
|
| 252 |
+
Husky flash attention module. This module inherits from `HuskyAttention` as the weights of the module stays
|
| 253 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 254 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 255 |
+
"""
|
| 256 |
+
|
| 257 |
+
def forward(
|
| 258 |
+
self,
|
| 259 |
+
hidden_states: torch.Tensor,
|
| 260 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 261 |
+
output_attentions: Optional[bool] = False,
|
| 262 |
+
) -> tuple[Any, None]:
|
| 263 |
+
# HuskyFlashAttention2 does not support output_attentions
|
| 264 |
+
assert output_attentions is False
|
| 265 |
+
|
| 266 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
| 267 |
+
|
| 268 |
+
mixed_qkv = self.qkv(hidden_states)
|
| 269 |
+
|
| 270 |
+
# Flash attention requires the input to have the shape batch_size x seq_len x num_heads x head_dim
|
| 271 |
+
# therefore we just need to keep the original shape
|
| 272 |
+
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
|
| 273 |
+
2, 0, 1, 3, 4
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
query_states, key_states, value_states = (
|
| 277 |
+
mixed_qkv[0],
|
| 278 |
+
mixed_qkv[1],
|
| 279 |
+
mixed_qkv[2],
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
input_dtype = query_states.dtype
|
| 283 |
+
if input_dtype == torch.float32:
|
| 284 |
+
# Handle the case where the model is quantized
|
| 285 |
+
if hasattr(self.config, "_pre_quantization_dtype"):
|
| 286 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 287 |
+
else:
|
| 288 |
+
target_dtype = self.qkv.weight.dtype
|
| 289 |
+
|
| 290 |
+
query_states = query_states.to(target_dtype)
|
| 291 |
+
key_states = key_states.to(target_dtype)
|
| 292 |
+
value_states = value_states.to(target_dtype)
|
| 293 |
+
|
| 294 |
+
attn_output = flash_attn_func(
|
| 295 |
+
query_states, key_states, value_states
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim).contiguous()
|
| 299 |
+
output = self.projection(attn_output)
|
| 300 |
+
|
| 301 |
+
outputs = (output, None)
|
| 302 |
+
return outputs
|
| 303 |
+
|
| 304 |
+
class HuskyMLP(nn.Module):
|
| 305 |
+
def __init__(self, config):
|
| 306 |
+
super().__init__()
|
| 307 |
+
self.config = config
|
| 308 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 309 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 310 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 311 |
+
|
| 312 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 313 |
+
hidden_states = self.fc1(hidden_states)
|
| 314 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 315 |
+
hidden_states = self.fc2(hidden_states)
|
| 316 |
+
return hidden_states
|
| 317 |
+
|
| 318 |
+
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Husky
|
| 319 |
+
class HuskyEncoderLayer(nn.Module):
|
| 320 |
+
def __init__(self, config: HuskyConfig):
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.embed_dim = config.hidden_size
|
| 323 |
+
self.self_attn = (
|
| 324 |
+
HuskyAttention(config=config)
|
| 325 |
+
if not getattr(config, "_flash_attn_2_enabled", False)
|
| 326 |
+
else HuskyFlashAttention2(config=config)
|
| 327 |
+
)
|
| 328 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 329 |
+
self.mlp = HuskyMLP(config)
|
| 330 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 331 |
+
|
| 332 |
+
def forward(
|
| 333 |
+
self,
|
| 334 |
+
hidden_states: torch.Tensor,
|
| 335 |
+
attention_mask: torch.Tensor,
|
| 336 |
+
output_attentions: Optional[bool] = False,
|
| 337 |
+
) -> Tuple[torch.FloatTensor]:
|
| 338 |
+
"""
|
| 339 |
+
Args:
|
| 340 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 341 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 342 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 343 |
+
`(config.encoder_attention_heads,)`.
|
| 344 |
+
output_attentions (`bool`, *optional*):
|
| 345 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 346 |
+
returned tensors for more detail.
|
| 347 |
+
"""
|
| 348 |
+
residual = hidden_states
|
| 349 |
+
|
| 350 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 351 |
+
hidden_states, attn_weights = self.self_attn(
|
| 352 |
+
hidden_states=hidden_states,
|
| 353 |
+
head_mask=attention_mask,
|
| 354 |
+
output_attentions=output_attentions,
|
| 355 |
+
)
|
| 356 |
+
hidden_states = hidden_states + residual
|
| 357 |
+
residual = hidden_states
|
| 358 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 359 |
+
hidden_states = self.mlp(hidden_states)
|
| 360 |
+
|
| 361 |
+
hidden_states = hidden_states + residual
|
| 362 |
+
|
| 363 |
+
outputs = (hidden_states,)
|
| 364 |
+
|
| 365 |
+
if output_attentions:
|
| 366 |
+
outputs += (attn_weights,)
|
| 367 |
+
|
| 368 |
+
return outputs
|
| 369 |
+
|
| 370 |
+
class HuskyPreTrainedModel(PreTrainedModel):
|
| 371 |
+
"""
|
| 372 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 373 |
+
models.
|
| 374 |
+
"""
|
| 375 |
+
|
| 376 |
+
config_class = HuskyConfig
|
| 377 |
+
base_model_prefix = "husky"
|
| 378 |
+
supports_gradient_checkpointing = True
|
| 379 |
+
_keys_to_ignore_on_load_missing = [
|
| 380 |
+
r"position_ids",
|
| 381 |
+
r"language_model.encoder.embed_tokens.weight",
|
| 382 |
+
r"language_model.decoder.embed_tokens.weight",
|
| 383 |
+
r"language_model.lm_head.weight",
|
| 384 |
+
]
|
| 385 |
+
_no_split_modules = [
|
| 386 |
+
"HuskyAttention",
|
| 387 |
+
"HuskyFlashAttention2",
|
| 388 |
+
"LlamaDecoderLayer",
|
| 389 |
+
]
|
| 390 |
+
_skip_keys_device_placement = "past_key_values"
|
| 391 |
+
_supports_flash_attn_2 = True
|
| 392 |
+
_keep_in_fp32_modules = ["wo"]
|
| 393 |
+
|
| 394 |
+
def _init_weights(self, module):
|
| 395 |
+
"""Initialize the weights"""
|
| 396 |
+
factor = self.config.initializer_range
|
| 397 |
+
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
|
| 398 |
+
module.weight.data.normal_(mean=0.0, std=factor)
|
| 399 |
+
if hasattr(module, "bias") and module.bias is not None:
|
| 400 |
+
module.bias.data.zero_()
|
| 401 |
+
|
| 402 |
+
if isinstance(module, HuskyVisionEmbeddings):
|
| 403 |
+
if hasattr(self.config, "vision_config"):
|
| 404 |
+
factor = self.config.vision_config.initializer_range
|
| 405 |
+
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
| 406 |
+
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
| 407 |
+
|
| 408 |
+
elif isinstance(module, nn.LayerNorm):
|
| 409 |
+
module.bias.data.zero_()
|
| 410 |
+
module.weight.data.fill_(1.0)
|
| 411 |
+
elif isinstance(module, nn.Linear) and module.bias is not None:
|
| 412 |
+
module.bias.data.zero_()
|
| 413 |
+
|
| 414 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 415 |
+
if isinstance(module, HuskyEncoder):
|
| 416 |
+
module.gradient_checkpointing = value
|
| 417 |
+
|
| 418 |
+
Husky_START_DOCSTRING = r"""
|
| 419 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 420 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 421 |
+
etc.)
|
| 422 |
+
|
| 423 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 424 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 425 |
+
and behavior.
|
| 426 |
+
|
| 427 |
+
Parameters:
|
| 428 |
+
config ([`HuskyConfig`]): Model configuration class with all the parameters of the model.
|
| 429 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 430 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 431 |
+
"""
|
| 432 |
+
|
| 433 |
+
Husky_VISION_INPUTS_DOCSTRING = r"""
|
| 434 |
+
Args:
|
| 435 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 436 |
+
Pixel values. Pixel values can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for
|
| 437 |
+
details.
|
| 438 |
+
output_attentions (`bool`, *optional*):
|
| 439 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 440 |
+
tensors for more detail.
|
| 441 |
+
output_hidden_states (`bool`, *optional*):
|
| 442 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 443 |
+
more detail.
|
| 444 |
+
return_dict (`bool`, *optional*):
|
| 445 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 446 |
+
"""
|
| 447 |
+
|
| 448 |
+
Husky_TEXT_INPUTS_DOCSTRING = r"""
|
| 449 |
+
Args:
|
| 450 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 451 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 452 |
+
it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 453 |
+
[`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
|
| 454 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 455 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 456 |
+
- 1 for tokens that are **not masked**,
|
| 457 |
+
- 0 for tokens that are **masked**.
|
| 458 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 459 |
+
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 460 |
+
Indices of decoder input sequence tokens in the vocabulary.
|
| 461 |
+
|
| 462 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 463 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 464 |
+
|
| 465 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
| 466 |
+
|
| 467 |
+
T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
|
| 468 |
+
is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
|
| 469 |
+
|
| 470 |
+
To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
|
| 471 |
+
Training](./t5#training).
|
| 472 |
+
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 473 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
| 474 |
+
be used by default.
|
| 475 |
+
output_attentions (`bool`, *optional*):
|
| 476 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 477 |
+
tensors for more detail.
|
| 478 |
+
output_hidden_states (`bool`, *optional*):
|
| 479 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 480 |
+
more detail.
|
| 481 |
+
return_dict (`bool`, *optional*):
|
| 482 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 483 |
+
"""
|
| 484 |
+
|
| 485 |
+
Husky_INPUTS_DOCSTRING = r"""
|
| 486 |
+
Args:
|
| 487 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 488 |
+
Pixel values. Pixel values can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for
|
| 489 |
+
details.
|
| 490 |
+
|
| 491 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 492 |
+
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
|
| 493 |
+
provided to serve as text prompt, which the language model can continue.
|
| 494 |
+
|
| 495 |
+
Indices can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for details.
|
| 496 |
+
|
| 497 |
+
[What are input IDs?](../glossary#input-ids)
|
| 498 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 499 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 500 |
+
|
| 501 |
+
- 1 for tokens that are **not masked**,
|
| 502 |
+
- 0 for tokens that are **masked**.
|
| 503 |
+
|
| 504 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 505 |
+
|
| 506 |
+
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 507 |
+
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
|
| 508 |
+
encoder-decoder language model (like T5) is used.
|
| 509 |
+
|
| 510 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 511 |
+
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
|
| 512 |
+
|
| 513 |
+
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 514 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
| 515 |
+
be used by default.
|
| 516 |
+
|
| 517 |
+
Only relevant in case an encoder-decoder language model (like T5) is used.
|
| 518 |
+
|
| 519 |
+
output_attentions (`bool`, *optional*):
|
| 520 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 521 |
+
tensors for more detail.
|
| 522 |
+
output_hidden_states (`bool`, *optional*):
|
| 523 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 524 |
+
more detail.
|
| 525 |
+
return_dict (`bool`, *optional*):
|
| 526 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 527 |
+
"""
|
| 528 |
+
|
| 529 |
+
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Husky
|
| 530 |
+
class HuskyEncoder(nn.Module):
|
| 531 |
+
"""
|
| 532 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 533 |
+
[`HuskyEncoderLayer`].
|
| 534 |
+
|
| 535 |
+
Args:
|
| 536 |
+
config (`HuskyConfig`):
|
| 537 |
+
The corresponding vision configuration for the `HuskyEncoder`.
|
| 538 |
+
"""
|
| 539 |
+
|
| 540 |
+
def __init__(self, config: HuskyConfig):
|
| 541 |
+
super().__init__()
|
| 542 |
+
self.config = config
|
| 543 |
+
self.layers = nn.ModuleList([HuskyEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 544 |
+
self.gradient_checkpointing = False
|
| 545 |
+
|
| 546 |
+
def forward(
|
| 547 |
+
self,
|
| 548 |
+
inputs_embeds,
|
| 549 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 550 |
+
output_attentions: Optional[bool] = None,
|
| 551 |
+
output_hidden_states: Optional[bool] = None,
|
| 552 |
+
return_dict: Optional[bool] = None,
|
| 553 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 554 |
+
r"""
|
| 555 |
+
Args:
|
| 556 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 557 |
+
Embedded representation of the inputs. Should be float, not int tokens.
|
| 558 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 559 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 560 |
+
|
| 561 |
+
- 1 for tokens that are **not masked**,
|
| 562 |
+
- 0 for tokens that are **masked**.
|
| 563 |
+
|
| 564 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 565 |
+
output_attentions (`bool`, *optional*):
|
| 566 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 567 |
+
returned tensors for more detail.
|
| 568 |
+
output_hidden_states (`bool`, *optional*):
|
| 569 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 570 |
+
for more detail.
|
| 571 |
+
return_dict (`bool`, *optional*):
|
| 572 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 573 |
+
"""
|
| 574 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 575 |
+
output_hidden_states = (
|
| 576 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 577 |
+
)
|
| 578 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 579 |
+
|
| 580 |
+
encoder_states = () if output_hidden_states else None
|
| 581 |
+
all_attentions = () if output_attentions else None
|
| 582 |
+
|
| 583 |
+
hidden_states = inputs_embeds
|
| 584 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 585 |
+
if output_hidden_states:
|
| 586 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 587 |
+
if self.gradient_checkpointing and self.training:
|
| 588 |
+
|
| 589 |
+
def create_custom_forward(module):
|
| 590 |
+
def custom_forward(*inputs):
|
| 591 |
+
return module(*inputs, output_attentions)
|
| 592 |
+
|
| 593 |
+
return custom_forward
|
| 594 |
+
|
| 595 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 596 |
+
create_custom_forward(encoder_layer),
|
| 597 |
+
hidden_states,
|
| 598 |
+
attention_mask,
|
| 599 |
+
)
|
| 600 |
+
else:
|
| 601 |
+
layer_outputs = encoder_layer(
|
| 602 |
+
hidden_states,
|
| 603 |
+
attention_mask,
|
| 604 |
+
output_attentions=output_attentions,
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
hidden_states = layer_outputs[0]
|
| 608 |
+
|
| 609 |
+
if output_attentions:
|
| 610 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 611 |
+
|
| 612 |
+
if output_hidden_states:
|
| 613 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 614 |
+
|
| 615 |
+
if not return_dict:
|
| 616 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
| 617 |
+
return BaseModelOutput(
|
| 618 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
# Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Husky, BLIP->Husky
|
| 622 |
+
class HuskyVisionModel(HuskyPreTrainedModel):
|
| 623 |
+
main_input_name = "pixel_values"
|
| 624 |
+
config_class = HuskyVisionConfig
|
| 625 |
+
|
| 626 |
+
def __init__(self, config: HuskyVisionConfig):
|
| 627 |
+
super().__init__(config)
|
| 628 |
+
self.config = config
|
| 629 |
+
embed_dim = config.hidden_size
|
| 630 |
+
|
| 631 |
+
self.embeddings = HuskyVisionEmbeddings(config)
|
| 632 |
+
self.video_embeddings = HuskyVideoEmbeddings(config)
|
| 633 |
+
|
| 634 |
+
self.encoder = HuskyEncoder(config)
|
| 635 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 636 |
+
|
| 637 |
+
self.post_init()
|
| 638 |
+
|
| 639 |
+
@add_start_docstrings_to_model_forward(Husky_VISION_INPUTS_DOCSTRING)
|
| 640 |
+
# @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=HuskyVisionConfig)
|
| 641 |
+
def forward(
|
| 642 |
+
self,
|
| 643 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 644 |
+
output_attentions: Optional[bool] = None,
|
| 645 |
+
output_hidden_states: Optional[bool] = None,
|
| 646 |
+
return_dict: Optional[bool] = None,
|
| 647 |
+
pixel_embeds: Optional[torch.FloatTensor] = None,
|
| 648 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 649 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 650 |
+
output_hidden_states = (
|
| 651 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 652 |
+
)
|
| 653 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 654 |
+
|
| 655 |
+
if pixel_values is None and pixel_embeds is None:
|
| 656 |
+
raise ValueError("You have to specify pixel_values or pixel_embeds")
|
| 657 |
+
|
| 658 |
+
if pixel_embeds is not None:
|
| 659 |
+
hidden_states = pixel_embeds
|
| 660 |
+
else:
|
| 661 |
+
if len(pixel_values.shape) == 4:
|
| 662 |
+
hidden_states = self.embeddings(pixel_values)
|
| 663 |
+
elif len(pixel_values.shape) == 5:
|
| 664 |
+
hidden_states = self.video_embeddings(pixel_values)
|
| 665 |
+
else:
|
| 666 |
+
raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
|
| 667 |
+
|
| 668 |
+
encoder_outputs = self.encoder(
|
| 669 |
+
inputs_embeds=hidden_states,
|
| 670 |
+
output_attentions=output_attentions,
|
| 671 |
+
output_hidden_states=output_hidden_states,
|
| 672 |
+
return_dict=return_dict,
|
| 673 |
+
)
|
| 674 |
+
|
| 675 |
+
last_hidden_state = encoder_outputs[0]
|
| 676 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
| 677 |
+
|
| 678 |
+
pooled_output = last_hidden_state[:, 0, :]
|
| 679 |
+
pooled_output = self.post_layernorm(pooled_output)
|
| 680 |
+
|
| 681 |
+
if not return_dict:
|
| 682 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 683 |
+
|
| 684 |
+
return BaseModelOutputWithPooling(
|
| 685 |
+
last_hidden_state=last_hidden_state,
|
| 686 |
+
pooler_output=pooled_output,
|
| 687 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 688 |
+
attentions=encoder_outputs.attentions,
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
def get_input_embeddings(self):
|
| 692 |
+
return self.embeddings
|
| 693 |
+
|
| 694 |
+
def get_video_embeddings(self):
|
| 695 |
+
return self.video_embeddings
|
| 696 |
+
|
| 697 |
+
class HuskyQFormerMultiHeadAttention(nn.Module):
|
| 698 |
+
def __init__(self, config, is_cross_attention=False):
|
| 699 |
+
super().__init__()
|
| 700 |
+
self.config = config
|
| 701 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 702 |
+
raise ValueError(
|
| 703 |
+
"The hidden size (%d) is not a multiple of the number of attention heads (%d)"
|
| 704 |
+
% (config.hidden_size, config.num_attention_heads)
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
self.num_attention_heads = config.num_attention_heads
|
| 708 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 709 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 710 |
+
|
| 711 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 712 |
+
if is_cross_attention:
|
| 713 |
+
self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
|
| 714 |
+
self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
|
| 715 |
+
else:
|
| 716 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 717 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 718 |
+
|
| 719 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 720 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 721 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 722 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 723 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 724 |
+
self.save_attention = False
|
| 725 |
+
|
| 726 |
+
def save_attn_gradients(self, attn_gradients):
|
| 727 |
+
self.attn_gradients = attn_gradients
|
| 728 |
+
|
| 729 |
+
def get_attn_gradients(self):
|
| 730 |
+
return self.attn_gradients
|
| 731 |
+
|
| 732 |
+
def save_attention_map(self, attention_map):
|
| 733 |
+
self.attention_map = attention_map
|
| 734 |
+
|
| 735 |
+
def get_attention_map(self):
|
| 736 |
+
return self.attention_map
|
| 737 |
+
|
| 738 |
+
def transpose_for_scores(self, x):
|
| 739 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 740 |
+
x = x.view(*new_x_shape)
|
| 741 |
+
return x.permute(0, 2, 1, 3)
|
| 742 |
+
|
| 743 |
+
def forward(
|
| 744 |
+
self,
|
| 745 |
+
hidden_states,
|
| 746 |
+
attention_mask=None,
|
| 747 |
+
head_mask=None,
|
| 748 |
+
encoder_hidden_states=None,
|
| 749 |
+
encoder_attention_mask=None,
|
| 750 |
+
past_key_value=None,
|
| 751 |
+
output_attentions=False,
|
| 752 |
+
):
|
| 753 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 754 |
+
# and values come from an encoder; the attention mask needs to be
|
| 755 |
+
# such that the encoder's padding tokens are not attended to.
|
| 756 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 757 |
+
|
| 758 |
+
if is_cross_attention:
|
| 759 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 760 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 761 |
+
attention_mask = encoder_attention_mask
|
| 762 |
+
elif past_key_value is not None:
|
| 763 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 764 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 765 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 766 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 767 |
+
else:
|
| 768 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 769 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 770 |
+
|
| 771 |
+
mixed_query_layer = self.query(hidden_states)
|
| 772 |
+
|
| 773 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 774 |
+
|
| 775 |
+
past_key_value = (key_layer, value_layer)
|
| 776 |
+
|
| 777 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 778 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 779 |
+
|
| 780 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 781 |
+
seq_length = hidden_states.size()[1]
|
| 782 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 783 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 784 |
+
distance = position_ids_l - position_ids_r
|
| 785 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 786 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 787 |
+
|
| 788 |
+
if self.position_embedding_type == "relative_key":
|
| 789 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 790 |
+
attention_scores = attention_scores + relative_position_scores
|
| 791 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 792 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 793 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 794 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 795 |
+
|
| 796 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 797 |
+
|
| 798 |
+
if attention_mask is not None:
|
| 799 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 800 |
+
attention_scores = attention_scores + attention_mask
|
| 801 |
+
|
| 802 |
+
# Normalize the attention scores to probabilities.
|
| 803 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 804 |
+
|
| 805 |
+
if is_cross_attention and self.save_attention:
|
| 806 |
+
self.save_attention_map(attention_probs)
|
| 807 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
| 808 |
+
|
| 809 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 810 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 811 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
| 812 |
+
|
| 813 |
+
# Mask heads if we want to
|
| 814 |
+
if head_mask is not None:
|
| 815 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
| 816 |
+
|
| 817 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
| 818 |
+
|
| 819 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 820 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 821 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 822 |
+
|
| 823 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 824 |
+
|
| 825 |
+
outputs = outputs + (past_key_value,)
|
| 826 |
+
return outputs
|
| 827 |
+
|
| 828 |
+
class HuskyQFormerFlashAttention2(HuskyQFormerMultiHeadAttention):
|
| 829 |
+
def forward(
|
| 830 |
+
self,
|
| 831 |
+
hidden_states,
|
| 832 |
+
attention_mask=None,
|
| 833 |
+
head_mask=None,
|
| 834 |
+
encoder_hidden_states=None,
|
| 835 |
+
encoder_attention_mask=None,
|
| 836 |
+
past_key_value=None,
|
| 837 |
+
output_attentions=False,
|
| 838 |
+
):
|
| 839 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 840 |
+
# and values come from an encoder; the attention mask needs to be
|
| 841 |
+
# such that the encoder's padding tokens are not attended to.
|
| 842 |
+
|
| 843 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
| 844 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 845 |
+
|
| 846 |
+
if is_cross_attention:
|
| 847 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 848 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 849 |
+
attention_mask = encoder_attention_mask
|
| 850 |
+
elif past_key_value is not None:
|
| 851 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 852 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 853 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 854 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 855 |
+
else:
|
| 856 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 857 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 858 |
+
|
| 859 |
+
mixed_query_layer = self.query(hidden_states)
|
| 860 |
+
|
| 861 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 862 |
+
|
| 863 |
+
past_key_value = (key_layer, value_layer)
|
| 864 |
+
|
| 865 |
+
# original key shape: [batch_size, num_heads, seq_len, head_dim]
|
| 866 |
+
# flash_attn size: [batch_size, seq_len, num_heads, head_dim]
|
| 867 |
+
|
| 868 |
+
query_layer = query_layer.transpose(1, 2)
|
| 869 |
+
key_layer = key_layer.transpose(1, 2)
|
| 870 |
+
value_layer = value_layer.transpose(1, 2)
|
| 871 |
+
|
| 872 |
+
dropout_rate = self.dropout if self.training else 0
|
| 873 |
+
input_dtype = query_layer.dtype
|
| 874 |
+
if input_dtype == torch.float32:
|
| 875 |
+
if torch.is_autocast_enabled():
|
| 876 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 877 |
+
else:
|
| 878 |
+
target_dtype = self.query.weight.dtype
|
| 879 |
+
|
| 880 |
+
query_layer = query_layer.to(target_dtype)
|
| 881 |
+
key_layer = key_layer.to(target_dtype)
|
| 882 |
+
value_layer = value_layer.to(target_dtype)
|
| 883 |
+
|
| 884 |
+
attn_output = flash_attn_func(
|
| 885 |
+
query_layer, key_layer, value_layer, causal=False
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
context_layer = attn_output.reshape(bsz, tgt_len, self.embed_size).contiguous()
|
| 889 |
+
outputs = (context_layer,)
|
| 890 |
+
|
| 891 |
+
outputs = outputs + (past_key_value,)
|
| 892 |
+
return outputs
|
| 893 |
+
|
| 894 |
+
class HuskyQFormerSelfOutput(nn.Module):
|
| 895 |
+
def __init__(self, config):
|
| 896 |
+
super().__init__()
|
| 897 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 898 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 899 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 900 |
+
|
| 901 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 902 |
+
hidden_states = self.dense(hidden_states)
|
| 903 |
+
hidden_states = self.dropout(hidden_states)
|
| 904 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 905 |
+
return hidden_states
|
| 906 |
+
|
| 907 |
+
class HuskyQFormerAttention(nn.Module):
|
| 908 |
+
def __init__(self, config, is_cross_attention=False):
|
| 909 |
+
super().__init__()
|
| 910 |
+
self.attention = (
|
| 911 |
+
HuskyQFormerMultiHeadAttention(config, is_cross_attention)
|
| 912 |
+
if not getattr(config, "_flash_attn_2_enabled", False)
|
| 913 |
+
else HuskyQFormerFlashAttention2(config, is_cross_attention)
|
| 914 |
+
)
|
| 915 |
+
self.output = HuskyQFormerSelfOutput(config)
|
| 916 |
+
self.pruned_heads = set()
|
| 917 |
+
|
| 918 |
+
def prune_heads(self, heads):
|
| 919 |
+
if len(heads) == 0:
|
| 920 |
+
return
|
| 921 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 922 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
| 923 |
+
)
|
| 924 |
+
|
| 925 |
+
# Prune linear layers
|
| 926 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
| 927 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
| 928 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
| 929 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 930 |
+
|
| 931 |
+
# Update hyper params and store pruned heads
|
| 932 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
| 933 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
| 934 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 935 |
+
|
| 936 |
+
def forward(
|
| 937 |
+
self,
|
| 938 |
+
hidden_states: torch.Tensor,
|
| 939 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 940 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 941 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 942 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 943 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 944 |
+
output_attentions: Optional[bool] = False,
|
| 945 |
+
) -> Tuple[torch.Tensor]:
|
| 946 |
+
self_outputs = self.attention(
|
| 947 |
+
hidden_states,
|
| 948 |
+
attention_mask,
|
| 949 |
+
head_mask,
|
| 950 |
+
encoder_hidden_states,
|
| 951 |
+
encoder_attention_mask,
|
| 952 |
+
past_key_value,
|
| 953 |
+
output_attentions,
|
| 954 |
+
)
|
| 955 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 956 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 957 |
+
return outputs
|
| 958 |
+
|
| 959 |
+
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->HuskyQFormer
|
| 960 |
+
class HuskyQFormerIntermediate(nn.Module):
|
| 961 |
+
def __init__(self, config):
|
| 962 |
+
super().__init__()
|
| 963 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 964 |
+
if isinstance(config.hidden_act, str):
|
| 965 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 966 |
+
else:
|
| 967 |
+
self.intermediate_act_fn = config.hidden_act
|
| 968 |
+
|
| 969 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 970 |
+
hidden_states = self.dense(hidden_states)
|
| 971 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 972 |
+
return hidden_states
|
| 973 |
+
|
| 974 |
+
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->HuskyQFormer
|
| 975 |
+
class HuskyQFormerOutput(nn.Module):
|
| 976 |
+
def __init__(self, config):
|
| 977 |
+
super().__init__()
|
| 978 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 979 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 980 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 981 |
+
|
| 982 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 983 |
+
hidden_states = self.dense(hidden_states)
|
| 984 |
+
hidden_states = self.dropout(hidden_states)
|
| 985 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 986 |
+
return hidden_states
|
| 987 |
+
|
| 988 |
+
class HuskyQFormerLayer(nn.Module):
|
| 989 |
+
def __init__(self, config, layer_idx):
|
| 990 |
+
super().__init__()
|
| 991 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 992 |
+
self.seq_len_dim = 1
|
| 993 |
+
self.attention = HuskyQFormerAttention(config)
|
| 994 |
+
|
| 995 |
+
self.layer_idx = layer_idx
|
| 996 |
+
|
| 997 |
+
if layer_idx % config.cross_attention_frequency == 0:
|
| 998 |
+
self.crossattention = HuskyQFormerAttention(config, is_cross_attention=True)
|
| 999 |
+
self.has_cross_attention = True
|
| 1000 |
+
else:
|
| 1001 |
+
self.has_cross_attention = False
|
| 1002 |
+
|
| 1003 |
+
self.intermediate_query = HuskyQFormerIntermediate(config)
|
| 1004 |
+
self.output_query = HuskyQFormerOutput(config)
|
| 1005 |
+
|
| 1006 |
+
def forward(
|
| 1007 |
+
self,
|
| 1008 |
+
hidden_states,
|
| 1009 |
+
attention_mask=None,
|
| 1010 |
+
head_mask=None,
|
| 1011 |
+
encoder_hidden_states=None,
|
| 1012 |
+
encoder_attention_mask=None,
|
| 1013 |
+
past_key_value=None,
|
| 1014 |
+
output_attentions=False,
|
| 1015 |
+
query_length=0,
|
| 1016 |
+
):
|
| 1017 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 1018 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 1019 |
+
self_attention_outputs = self.attention(
|
| 1020 |
+
hidden_states,
|
| 1021 |
+
attention_mask,
|
| 1022 |
+
head_mask,
|
| 1023 |
+
output_attentions=output_attentions,
|
| 1024 |
+
past_key_value=self_attn_past_key_value,
|
| 1025 |
+
)
|
| 1026 |
+
attention_output = self_attention_outputs[0]
|
| 1027 |
+
outputs = self_attention_outputs[1:-1]
|
| 1028 |
+
|
| 1029 |
+
present_key_value = self_attention_outputs[-1]
|
| 1030 |
+
|
| 1031 |
+
if query_length > 0:
|
| 1032 |
+
query_attention_output = attention_output[:, :query_length, :]
|
| 1033 |
+
|
| 1034 |
+
if self.has_cross_attention:
|
| 1035 |
+
if encoder_hidden_states is None:
|
| 1036 |
+
raise ValueError("encoder_hidden_states must be given for cross-attention layers")
|
| 1037 |
+
cross_attention_outputs = self.crossattention(
|
| 1038 |
+
query_attention_output,
|
| 1039 |
+
attention_mask,
|
| 1040 |
+
head_mask,
|
| 1041 |
+
encoder_hidden_states,
|
| 1042 |
+
encoder_attention_mask,
|
| 1043 |
+
output_attentions=output_attentions,
|
| 1044 |
+
)
|
| 1045 |
+
query_attention_output = cross_attention_outputs[0]
|
| 1046 |
+
# add cross attentions if we output attention weights
|
| 1047 |
+
outputs = outputs + cross_attention_outputs[1:-1]
|
| 1048 |
+
|
| 1049 |
+
layer_output = apply_chunking_to_forward(
|
| 1050 |
+
self.feed_forward_chunk_query,
|
| 1051 |
+
self.chunk_size_feed_forward,
|
| 1052 |
+
self.seq_len_dim,
|
| 1053 |
+
query_attention_output,
|
| 1054 |
+
)
|
| 1055 |
+
|
| 1056 |
+
if attention_output.shape[1] > query_length:
|
| 1057 |
+
layer_output_text = apply_chunking_to_forward(
|
| 1058 |
+
self.feed_forward_chunk,
|
| 1059 |
+
self.chunk_size_feed_forward,
|
| 1060 |
+
self.seq_len_dim,
|
| 1061 |
+
attention_output[:, query_length:, :],
|
| 1062 |
+
)
|
| 1063 |
+
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
| 1064 |
+
else:
|
| 1065 |
+
layer_output = apply_chunking_to_forward(
|
| 1066 |
+
self.feed_forward_chunk,
|
| 1067 |
+
self.chunk_size_feed_forward,
|
| 1068 |
+
self.seq_len_dim,
|
| 1069 |
+
attention_output,
|
| 1070 |
+
)
|
| 1071 |
+
outputs = (layer_output,) + outputs
|
| 1072 |
+
|
| 1073 |
+
outputs = outputs + (present_key_value,)
|
| 1074 |
+
|
| 1075 |
+
return outputs
|
| 1076 |
+
|
| 1077 |
+
def feed_forward_chunk(self, attention_output):
|
| 1078 |
+
intermediate_output = self.intermediate(attention_output)
|
| 1079 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 1080 |
+
return layer_output
|
| 1081 |
+
|
| 1082 |
+
def feed_forward_chunk_query(self, attention_output):
|
| 1083 |
+
intermediate_output = self.intermediate_query(attention_output)
|
| 1084 |
+
layer_output = self.output_query(intermediate_output, attention_output)
|
| 1085 |
+
return layer_output
|
| 1086 |
+
|
| 1087 |
+
class HuskyQFormerEncoder(nn.Module):
|
| 1088 |
+
def __init__(self, config):
|
| 1089 |
+
super().__init__()
|
| 1090 |
+
self.config = config
|
| 1091 |
+
self.layer = nn.ModuleList(
|
| 1092 |
+
[HuskyQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 1093 |
+
)
|
| 1094 |
+
self.gradient_checkpointing = False
|
| 1095 |
+
|
| 1096 |
+
def forward(
|
| 1097 |
+
self,
|
| 1098 |
+
hidden_states,
|
| 1099 |
+
attention_mask=None,
|
| 1100 |
+
head_mask=None,
|
| 1101 |
+
encoder_hidden_states=None,
|
| 1102 |
+
encoder_attention_mask=None,
|
| 1103 |
+
past_key_values=None,
|
| 1104 |
+
use_cache=None,
|
| 1105 |
+
output_attentions=False,
|
| 1106 |
+
output_hidden_states=False,
|
| 1107 |
+
return_dict=True,
|
| 1108 |
+
query_length=0,
|
| 1109 |
+
):
|
| 1110 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1111 |
+
all_self_attentions = () if output_attentions else None
|
| 1112 |
+
all_cross_attentions = () if output_attentions else None
|
| 1113 |
+
|
| 1114 |
+
next_decoder_cache = () if use_cache else None
|
| 1115 |
+
|
| 1116 |
+
for i in range(self.config.num_hidden_layers):
|
| 1117 |
+
layer_module = self.layer[i]
|
| 1118 |
+
if output_hidden_states:
|
| 1119 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1120 |
+
|
| 1121 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 1122 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 1123 |
+
|
| 1124 |
+
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
| 1125 |
+
if use_cache:
|
| 1126 |
+
logger.warn(
|
| 1127 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 1128 |
+
)
|
| 1129 |
+
use_cache = False
|
| 1130 |
+
|
| 1131 |
+
def create_custom_forward(module):
|
| 1132 |
+
def custom_forward(*inputs):
|
| 1133 |
+
return module(*inputs, past_key_value, output_attentions, query_length)
|
| 1134 |
+
|
| 1135 |
+
return custom_forward
|
| 1136 |
+
|
| 1137 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 1138 |
+
create_custom_forward(layer_module),
|
| 1139 |
+
hidden_states,
|
| 1140 |
+
attention_mask,
|
| 1141 |
+
layer_head_mask,
|
| 1142 |
+
encoder_hidden_states,
|
| 1143 |
+
encoder_attention_mask,
|
| 1144 |
+
)
|
| 1145 |
+
else:
|
| 1146 |
+
layer_outputs = layer_module(
|
| 1147 |
+
hidden_states,
|
| 1148 |
+
attention_mask,
|
| 1149 |
+
layer_head_mask,
|
| 1150 |
+
encoder_hidden_states,
|
| 1151 |
+
encoder_attention_mask,
|
| 1152 |
+
past_key_value,
|
| 1153 |
+
output_attentions,
|
| 1154 |
+
query_length,
|
| 1155 |
+
)
|
| 1156 |
+
|
| 1157 |
+
hidden_states = layer_outputs[0]
|
| 1158 |
+
if use_cache:
|
| 1159 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 1160 |
+
if output_attentions:
|
| 1161 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 1162 |
+
if layer_module.has_cross_attention:
|
| 1163 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
| 1164 |
+
|
| 1165 |
+
if output_hidden_states:
|
| 1166 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1167 |
+
|
| 1168 |
+
if not return_dict:
|
| 1169 |
+
return tuple(
|
| 1170 |
+
v
|
| 1171 |
+
for v in [
|
| 1172 |
+
hidden_states,
|
| 1173 |
+
next_decoder_cache,
|
| 1174 |
+
all_hidden_states,
|
| 1175 |
+
all_self_attentions,
|
| 1176 |
+
all_cross_attentions,
|
| 1177 |
+
]
|
| 1178 |
+
if v is not None
|
| 1179 |
+
)
|
| 1180 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 1181 |
+
last_hidden_state=hidden_states,
|
| 1182 |
+
past_key_values=next_decoder_cache,
|
| 1183 |
+
hidden_states=all_hidden_states,
|
| 1184 |
+
attentions=all_self_attentions,
|
| 1185 |
+
cross_attentions=all_cross_attentions,
|
| 1186 |
+
)
|
| 1187 |
+
|
| 1188 |
+
class HuskyQFormerModel(HuskyPreTrainedModel):
|
| 1189 |
+
"""
|
| 1190 |
+
Querying Transformer (Q-Former), used in Husky.
|
| 1191 |
+
"""
|
| 1192 |
+
|
| 1193 |
+
def __init__(self, config: HuskyQFormerConfig):
|
| 1194 |
+
super().__init__(config)
|
| 1195 |
+
self.config = config
|
| 1196 |
+
|
| 1197 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 1198 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1199 |
+
|
| 1200 |
+
self.encoder = HuskyQFormerEncoder(config)
|
| 1201 |
+
|
| 1202 |
+
self.post_init()
|
| 1203 |
+
|
| 1204 |
+
def get_input_embeddings(self):
|
| 1205 |
+
return self.embeddings.word_embeddings
|
| 1206 |
+
|
| 1207 |
+
def set_input_embeddings(self, value):
|
| 1208 |
+
self.embeddings.word_embeddings = value
|
| 1209 |
+
|
| 1210 |
+
def _prune_heads(self, heads_to_prune):
|
| 1211 |
+
"""
|
| 1212 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 1213 |
+
class PreTrainedModel
|
| 1214 |
+
"""
|
| 1215 |
+
for layer, heads in heads_to_prune.items():
|
| 1216 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 1217 |
+
|
| 1218 |
+
def get_extended_attention_mask(
|
| 1219 |
+
self,
|
| 1220 |
+
attention_mask: torch.Tensor,
|
| 1221 |
+
input_shape: Tuple[int],
|
| 1222 |
+
device: torch.device,
|
| 1223 |
+
has_query: bool = False,
|
| 1224 |
+
) -> torch.Tensor:
|
| 1225 |
+
"""
|
| 1226 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
| 1227 |
+
|
| 1228 |
+
Arguments:
|
| 1229 |
+
attention_mask (`torch.Tensor`):
|
| 1230 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
| 1231 |
+
input_shape (`Tuple[int]`):
|
| 1232 |
+
The shape of the input to the model.
|
| 1233 |
+
device (`torch.device`):
|
| 1234 |
+
The device of the input to the model.
|
| 1235 |
+
|
| 1236 |
+
Returns:
|
| 1237 |
+
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
|
| 1238 |
+
"""
|
| 1239 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 1240 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 1241 |
+
if attention_mask.dim() == 3:
|
| 1242 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 1243 |
+
elif attention_mask.dim() == 2:
|
| 1244 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 1245 |
+
# - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 1246 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 1247 |
+
else:
|
| 1248 |
+
raise ValueError(
|
| 1249 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
| 1250 |
+
input_shape, attention_mask.shape
|
| 1251 |
+
)
|
| 1252 |
+
)
|
| 1253 |
+
|
| 1254 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 1255 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 1256 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 1257 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 1258 |
+
# effectively the same as removing these entirely.
|
| 1259 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 1260 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 1261 |
+
return extended_attention_mask
|
| 1262 |
+
|
| 1263 |
+
def forward(
|
| 1264 |
+
self,
|
| 1265 |
+
query_embeds,
|
| 1266 |
+
attention_mask=None,
|
| 1267 |
+
head_mask=None,
|
| 1268 |
+
encoder_hidden_states=None,
|
| 1269 |
+
encoder_attention_mask=None,
|
| 1270 |
+
past_key_values=None,
|
| 1271 |
+
use_cache=None,
|
| 1272 |
+
output_attentions=None,
|
| 1273 |
+
output_hidden_states=None,
|
| 1274 |
+
return_dict=None,
|
| 1275 |
+
):
|
| 1276 |
+
r"""
|
| 1277 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 1278 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 1279 |
+
the model is configured as a decoder.
|
| 1280 |
+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
|
| 1281 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 1282 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
| 1283 |
+
- 1 for tokens that are **not masked**,
|
| 1284 |
+
- 0 for tokens that are **masked**.
|
| 1285 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
|
| 1286 |
+
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
|
| 1287 |
+
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
|
| 1288 |
+
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
|
| 1289 |
+
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
|
| 1290 |
+
`(batch_size, sequence_length)`.
|
| 1291 |
+
use_cache (`bool`, `optional`):
|
| 1292 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 1293 |
+
`past_key_values`).
|
| 1294 |
+
"""
|
| 1295 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1296 |
+
output_hidden_states = (
|
| 1297 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1298 |
+
)
|
| 1299 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1300 |
+
|
| 1301 |
+
# past_key_values_length
|
| 1302 |
+
past_key_values_length = (
|
| 1303 |
+
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
|
| 1304 |
+
)
|
| 1305 |
+
|
| 1306 |
+
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
| 1307 |
+
|
| 1308 |
+
embedding_output = self.layernorm(query_embeds)
|
| 1309 |
+
embedding_output = self.dropout(embedding_output)
|
| 1310 |
+
|
| 1311 |
+
input_shape = embedding_output.size()[:-1]
|
| 1312 |
+
batch_size, seq_length = input_shape
|
| 1313 |
+
device = embedding_output.device
|
| 1314 |
+
|
| 1315 |
+
if attention_mask is None:
|
| 1316 |
+
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
|
| 1317 |
+
|
| 1318 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 1319 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 1320 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
| 1321 |
+
|
| 1322 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 1323 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 1324 |
+
if encoder_hidden_states is not None:
|
| 1325 |
+
if type(encoder_hidden_states) == list:
|
| 1326 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
| 1327 |
+
else:
|
| 1328 |
+
(
|
| 1329 |
+
encoder_batch_size,
|
| 1330 |
+
encoder_sequence_length,
|
| 1331 |
+
_,
|
| 1332 |
+
) = encoder_hidden_states.size()
|
| 1333 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 1334 |
+
|
| 1335 |
+
if type(encoder_attention_mask) == list:
|
| 1336 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
| 1337 |
+
elif encoder_attention_mask is None:
|
| 1338 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 1339 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 1340 |
+
else:
|
| 1341 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 1342 |
+
else:
|
| 1343 |
+
encoder_extended_attention_mask = None
|
| 1344 |
+
|
| 1345 |
+
# Prepare head mask if needed
|
| 1346 |
+
# 1.0 in head_mask indicate we keep the head
|
| 1347 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 1348 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 1349 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 1350 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 1351 |
+
|
| 1352 |
+
encoder_outputs = self.encoder(
|
| 1353 |
+
embedding_output,
|
| 1354 |
+
attention_mask=extended_attention_mask,
|
| 1355 |
+
head_mask=head_mask,
|
| 1356 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1357 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 1358 |
+
past_key_values=past_key_values,
|
| 1359 |
+
use_cache=use_cache,
|
| 1360 |
+
output_attentions=output_attentions,
|
| 1361 |
+
output_hidden_states=output_hidden_states,
|
| 1362 |
+
return_dict=return_dict,
|
| 1363 |
+
query_length=query_length,
|
| 1364 |
+
)
|
| 1365 |
+
sequence_output = encoder_outputs[0]
|
| 1366 |
+
pooled_output = sequence_output[:, 0, :]
|
| 1367 |
+
|
| 1368 |
+
if not return_dict:
|
| 1369 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 1370 |
+
|
| 1371 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 1372 |
+
last_hidden_state=sequence_output,
|
| 1373 |
+
pooler_output=pooled_output,
|
| 1374 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 1375 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1376 |
+
attentions=encoder_outputs.attentions,
|
| 1377 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 1378 |
+
)
|
| 1379 |
+
|
| 1380 |
+
class AdapterMLP(nn.Module):
|
| 1381 |
+
def __init__(self, config):
|
| 1382 |
+
super().__init__()
|
| 1383 |
+
self.config = config
|
| 1384 |
+
self.activation_fn = ACT2FN["silu"]
|
| 1385 |
+
hidden_size = config.vision_config.hidden_size
|
| 1386 |
+
intermediate_size = hidden_size // 4
|
| 1387 |
+
output_size = config.qformer_config.hidden_size
|
| 1388 |
+
|
| 1389 |
+
self.fc1 = nn.Linear(hidden_size, intermediate_size)
|
| 1390 |
+
self.fc2 = nn.Linear(intermediate_size, output_size)
|
| 1391 |
+
self.layernorm = nn.LayerNorm(output_size, eps=config.vision_config.layer_norm_eps)
|
| 1392 |
+
|
| 1393 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 1394 |
+
hidden_states = self.fc1(hidden_states)
|
| 1395 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 1396 |
+
hidden_states = self.fc2(hidden_states)
|
| 1397 |
+
hidden_states = self.layernorm(hidden_states)
|
| 1398 |
+
return hidden_states
|
| 1399 |
+
|
| 1400 |
+
@add_start_docstrings(
|
| 1401 |
+
"""
|
| 1402 |
+
Husky Model for generating text and image features. The model consists of a vision encoder, Querying Transformer
|
| 1403 |
+
(Q-Former) and a language model.
|
| 1404 |
+
""",
|
| 1405 |
+
Husky_START_DOCSTRING,
|
| 1406 |
+
)
|
| 1407 |
+
class HuskyModel(HuskyPreTrainedModel):
|
| 1408 |
+
config_class = HuskyConfig
|
| 1409 |
+
main_input_name = "pixel_values"
|
| 1410 |
+
|
| 1411 |
+
def __init__(self, config: HuskyConfig):
|
| 1412 |
+
super().__init__(config)
|
| 1413 |
+
|
| 1414 |
+
self.vision_model = HuskyVisionModel(config.vision_config)
|
| 1415 |
+
|
| 1416 |
+
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
| 1417 |
+
self.qformer = HuskyQFormerModel(config.qformer_config)
|
| 1418 |
+
|
| 1419 |
+
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
| 1420 |
+
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
|
| 1421 |
+
|
| 1422 |
+
self.config.hidden_size = config.text_config.hidden_size
|
| 1423 |
+
self.num_queries = config.num_query_tokens
|
| 1424 |
+
self.offset = 5
|
| 1425 |
+
|
| 1426 |
+
# Initialize weights and apply final processing
|
| 1427 |
+
self.post_init()
|
| 1428 |
+
|
| 1429 |
+
def get_input_embeddings(self):
|
| 1430 |
+
return self.language_model.get_input_embeddings()
|
| 1431 |
+
|
| 1432 |
+
def set_input_embeddings(self, value):
|
| 1433 |
+
self.language_model.set_input_embeddings(value)
|
| 1434 |
+
|
| 1435 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1436 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
| 1437 |
+
|
| 1438 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 1439 |
+
return self.language_model.get_output_embeddings()
|
| 1440 |
+
|
| 1441 |
+
def get_encoder(self):
|
| 1442 |
+
return self.language_model.get_encoder()
|
| 1443 |
+
|
| 1444 |
+
def get_decoder(self):
|
| 1445 |
+
return self.language_model.get_decoder()
|
| 1446 |
+
|
| 1447 |
+
def _tie_weights(self):
|
| 1448 |
+
if not self.config.use_decoder_only_language_model:
|
| 1449 |
+
self.language_model.encoder.embed_tokens = self.language_model.shared
|
| 1450 |
+
self.language_model.decoder.embed_tokens = self.language_model.shared
|
| 1451 |
+
|
| 1452 |
+
@add_start_docstrings_to_model_forward(Husky_TEXT_INPUTS_DOCSTRING)
|
| 1453 |
+
def get_text_features(
|
| 1454 |
+
self,
|
| 1455 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1456 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1457 |
+
output_attentions: Optional[bool] = None,
|
| 1458 |
+
output_hidden_states: Optional[bool] = None,
|
| 1459 |
+
return_dict: Optional[bool] = None,
|
| 1460 |
+
):
|
| 1461 |
+
r"""
|
| 1462 |
+
Returns:
|
| 1463 |
+
text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`):
|
| 1464 |
+
The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that
|
| 1465 |
+
contains the language model logits, the past key values and the hidden states if
|
| 1466 |
+
`output_hidden_states=True`.
|
| 1467 |
+
```"""
|
| 1468 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1469 |
+
output_hidden_states = (
|
| 1470 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1471 |
+
)
|
| 1472 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1473 |
+
|
| 1474 |
+
text_outputs = self.language_model(
|
| 1475 |
+
input_ids=input_ids,
|
| 1476 |
+
attention_mask=attention_mask,
|
| 1477 |
+
output_attentions=output_attentions,
|
| 1478 |
+
output_hidden_states=output_hidden_states,
|
| 1479 |
+
return_dict=return_dict,
|
| 1480 |
+
)
|
| 1481 |
+
|
| 1482 |
+
return text_outputs
|
| 1483 |
+
|
| 1484 |
+
@add_start_docstrings_to_model_forward(Husky_VISION_INPUTS_DOCSTRING)
|
| 1485 |
+
def get_image_features(
|
| 1486 |
+
self,
|
| 1487 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1488 |
+
output_attentions: Optional[bool] = None,
|
| 1489 |
+
output_hidden_states: Optional[bool] = None,
|
| 1490 |
+
return_dict: Optional[bool] = None,
|
| 1491 |
+
):
|
| 1492 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1493 |
+
output_hidden_states = (
|
| 1494 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1495 |
+
)
|
| 1496 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1497 |
+
|
| 1498 |
+
vision_outputs = self.vision_model(
|
| 1499 |
+
pixel_values=pixel_values,
|
| 1500 |
+
output_attentions=output_attentions,
|
| 1501 |
+
output_hidden_states=output_hidden_states,
|
| 1502 |
+
return_dict=return_dict,
|
| 1503 |
+
)
|
| 1504 |
+
|
| 1505 |
+
return vision_outputs
|
| 1506 |
+
|
| 1507 |
+
@add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
|
| 1508 |
+
def get_qformer_features(
|
| 1509 |
+
self,
|
| 1510 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1511 |
+
output_attentions: Optional[bool] = None,
|
| 1512 |
+
output_hidden_states: Optional[bool] = None,
|
| 1513 |
+
return_dict: Optional[bool] = None,
|
| 1514 |
+
):
|
| 1515 |
+
r"""
|
| 1516 |
+
Returns:
|
| 1517 |
+
vision_outputs (`BaseModelOutputWithPooling` or tuple of `torch.FloatTensor`):
|
| 1518 |
+
The vision model outputs. If `return_dict=True`, the output is a [`BaseModelOutputWithPooling`] that
|
| 1519 |
+
contains the image features, the pooled image features and the hidden states if
|
| 1520 |
+
`output_hidden_states=True`.
|
| 1521 |
+
"""
|
| 1522 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1523 |
+
output_hidden_states = (
|
| 1524 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1525 |
+
)
|
| 1526 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1527 |
+
|
| 1528 |
+
vision_outputs = self.vision_model(
|
| 1529 |
+
pixel_values=pixel_values,
|
| 1530 |
+
output_attentions=output_attentions,
|
| 1531 |
+
output_hidden_states=output_hidden_states,
|
| 1532 |
+
return_dict=return_dict,
|
| 1533 |
+
)
|
| 1534 |
+
|
| 1535 |
+
image_embeds = vision_outputs[0]
|
| 1536 |
+
|
| 1537 |
+
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
| 1538 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1539 |
+
|
| 1540 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1541 |
+
query_outputs = self.qformer(
|
| 1542 |
+
query_embeds=query_tokens,
|
| 1543 |
+
encoder_hidden_states=image_embeds,
|
| 1544 |
+
encoder_attention_mask=image_attention_mask,
|
| 1545 |
+
output_attentions=output_attentions,
|
| 1546 |
+
output_hidden_states=output_hidden_states,
|
| 1547 |
+
return_dict=return_dict,
|
| 1548 |
+
)
|
| 1549 |
+
|
| 1550 |
+
return query_outputs
|
| 1551 |
+
|
| 1552 |
+
@add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
|
| 1553 |
+
# @replace_return_docstrings(output_type=HuskyForConditionalGenerationModelOutput, config_class=HuskyVisionConfig)
|
| 1554 |
+
def forward(
|
| 1555 |
+
self,
|
| 1556 |
+
pixel_values: torch.FloatTensor,
|
| 1557 |
+
input_ids: torch.FloatTensor,
|
| 1558 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1559 |
+
output_attentions: Optional[bool] = None,
|
| 1560 |
+
output_hidden_states: Optional[bool] = None,
|
| 1561 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1562 |
+
return_dict: Optional[bool] = None,
|
| 1563 |
+
) -> Union[Tuple, HuskyForConditionalGenerationModelOutput]:
|
| 1564 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1565 |
+
|
| 1566 |
+
# step 1: forward the images through the vision encoder,
|
| 1567 |
+
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
| 1568 |
+
vision_outputs = self.vision_model(
|
| 1569 |
+
pixel_values=pixel_values,
|
| 1570 |
+
output_attentions=output_attentions,
|
| 1571 |
+
output_hidden_states=output_hidden_states,
|
| 1572 |
+
return_dict=return_dict,
|
| 1573 |
+
)
|
| 1574 |
+
image_embeds = vision_outputs[0]
|
| 1575 |
+
|
| 1576 |
+
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
| 1577 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1578 |
+
|
| 1579 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1580 |
+
query_outputs = self.qformer(
|
| 1581 |
+
query_embeds=query_tokens,
|
| 1582 |
+
encoder_hidden_states=image_embeds,
|
| 1583 |
+
encoder_attention_mask=image_attention_mask,
|
| 1584 |
+
output_attentions=output_attentions,
|
| 1585 |
+
output_hidden_states=output_hidden_states,
|
| 1586 |
+
return_dict=return_dict,
|
| 1587 |
+
)
|
| 1588 |
+
query_output = query_outputs[0]
|
| 1589 |
+
|
| 1590 |
+
# step 3: use the language model, conditioned on the query outputs and the prompt
|
| 1591 |
+
language_model_inputs = self.language_projection(query_output)
|
| 1592 |
+
assert language_model_inputs.shape[1] == self.num_queries
|
| 1593 |
+
|
| 1594 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 1595 |
+
# Human: <img><IMAGE></img>. Give the describe Assistant:
|
| 1596 |
+
# position of <image>: [offset: offset+num_queries]
|
| 1597 |
+
|
| 1598 |
+
inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs
|
| 1599 |
+
if attention_mask is None:
|
| 1600 |
+
attention_mask = torch.ones_like(
|
| 1601 |
+
input_ids, dtype=torch.long, device=language_model_inputs.device)
|
| 1602 |
+
|
| 1603 |
+
outputs = self.language_model(
|
| 1604 |
+
inputs_embeds=inputs_embeds,
|
| 1605 |
+
attention_mask=attention_mask,
|
| 1606 |
+
output_attentions=output_attentions,
|
| 1607 |
+
output_hidden_states=output_hidden_states,
|
| 1608 |
+
return_dict=return_dict,
|
| 1609 |
+
)
|
| 1610 |
+
logits = outputs.logits if return_dict else outputs[0]
|
| 1611 |
+
loss = None
|
| 1612 |
+
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
| 1613 |
+
if labels is not None:
|
| 1614 |
+
labels = labels.to(logits.device)
|
| 1615 |
+
logits = logits[:, -labels.size(1):, :]
|
| 1616 |
+
# Shift so that tokens < n predict n
|
| 1617 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1618 |
+
shift_labels = labels[..., 1:].contiguous().to(logits.device)
|
| 1619 |
+
|
| 1620 |
+
# Flatten the tokens
|
| 1621 |
+
loss_fct = CrossEntropyLoss(reduction="mean")
|
| 1622 |
+
|
| 1623 |
+
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
|
| 1624 |
+
|
| 1625 |
+
if not return_dict:
|
| 1626 |
+
output = (logits, vision_outputs, query_outputs, outputs)
|
| 1627 |
+
return ((loss,) + output) if loss is not None else output
|
| 1628 |
+
|
| 1629 |
+
return HuskyForConditionalGenerationModelOutput(
|
| 1630 |
+
loss=loss,
|
| 1631 |
+
logits=logits,
|
| 1632 |
+
vision_outputs=vision_outputs,
|
| 1633 |
+
qformer_outputs=query_outputs,
|
| 1634 |
+
language_model_outputs=outputs,
|
| 1635 |
+
)
|
| 1636 |
+
|
| 1637 |
+
@add_start_docstrings(
|
| 1638 |
+
"""
|
| 1639 |
+
Husky Model for generating text given an image and an optional text prompt. The model consists of a vision
|
| 1640 |
+
encoder, Querying Transformer (Q-Former) and a language model.
|
| 1641 |
+
|
| 1642 |
+
One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
|
| 1643 |
+
the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
|
| 1644 |
+
""",
|
| 1645 |
+
Husky_START_DOCSTRING,
|
| 1646 |
+
)
|
| 1647 |
+
class HuskyForConditionalGeneration(HuskyPreTrainedModel):
|
| 1648 |
+
config_class = HuskyConfig
|
| 1649 |
+
main_input_name = "pixel_values"
|
| 1650 |
+
|
| 1651 |
+
def __init__(self, config: HuskyConfig):
|
| 1652 |
+
super().__init__(config)
|
| 1653 |
+
|
| 1654 |
+
self.vision_model = HuskyVisionModel(config.vision_config)
|
| 1655 |
+
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
| 1656 |
+
self.qformer = HuskyQFormerModel(config.qformer_config)
|
| 1657 |
+
|
| 1658 |
+
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
| 1659 |
+
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
|
| 1660 |
+
|
| 1661 |
+
self.config.hidden_size = config.text_config.hidden_size
|
| 1662 |
+
self.num_queries = config.num_query_tokens
|
| 1663 |
+
self.offset = 5
|
| 1664 |
+
|
| 1665 |
+
self.vision_adapter = AdapterMLP(config)
|
| 1666 |
+
self.layer_norms = nn.ModuleList()
|
| 1667 |
+
for i in range(4):
|
| 1668 |
+
self.layer_norms.append(
|
| 1669 |
+
nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
|
| 1670 |
+
)
|
| 1671 |
+
|
| 1672 |
+
# Initialize weights and apply final processing
|
| 1673 |
+
self.post_init()
|
| 1674 |
+
|
| 1675 |
+
def get_input_embeddings(self):
|
| 1676 |
+
return self.language_model.get_input_embeddings()
|
| 1677 |
+
|
| 1678 |
+
def set_input_embeddings(self, value):
|
| 1679 |
+
self.language_model.set_input_embeddings(value)
|
| 1680 |
+
|
| 1681 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1682 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
| 1683 |
+
|
| 1684 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 1685 |
+
return self.language_model.get_output_embeddings()
|
| 1686 |
+
|
| 1687 |
+
def get_encoder(self):
|
| 1688 |
+
return self.language_model.get_encoder()
|
| 1689 |
+
|
| 1690 |
+
def get_decoder(self):
|
| 1691 |
+
return self.language_model.get_decoder()
|
| 1692 |
+
|
| 1693 |
+
def extract_feature(
|
| 1694 |
+
self,
|
| 1695 |
+
pixel_values: torch.FloatTensor,
|
| 1696 |
+
):
|
| 1697 |
+
vision_outputs = self.vision_model(
|
| 1698 |
+
pixel_values=pixel_values,
|
| 1699 |
+
output_hidden_states=True,
|
| 1700 |
+
)
|
| 1701 |
+
image_embeds = vision_outputs[0]
|
| 1702 |
+
|
| 1703 |
+
depth = len(vision_outputs[2])
|
| 1704 |
+
indices = range(depth // 4 - 1, depth, depth // 4)
|
| 1705 |
+
pooled_outputs = []
|
| 1706 |
+
for idx, layer_norm in zip(indices, self.layer_norms):
|
| 1707 |
+
pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
|
| 1708 |
+
pool_output = layer_norm(pool_output)
|
| 1709 |
+
pooled_outputs.append(pool_output)
|
| 1710 |
+
|
| 1711 |
+
pooled_outputs = torch.cat(pooled_outputs, dim=1)
|
| 1712 |
+
pooled_outputs = self.vision_adapter(pooled_outputs)
|
| 1713 |
+
|
| 1714 |
+
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
| 1715 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1716 |
+
|
| 1717 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1718 |
+
query_outputs = self.qformer(
|
| 1719 |
+
query_embeds=query_tokens,
|
| 1720 |
+
encoder_hidden_states=image_embeds,
|
| 1721 |
+
encoder_attention_mask=image_attention_mask
|
| 1722 |
+
)
|
| 1723 |
+
query_output = query_outputs[0]
|
| 1724 |
+
# soft_prompting
|
| 1725 |
+
query_output = torch.cat([pooled_outputs, query_output], dim=1)
|
| 1726 |
+
language_model_inputs = self.language_projection(query_output)
|
| 1727 |
+
|
| 1728 |
+
return language_model_inputs
|
| 1729 |
+
|
| 1730 |
+
def _tie_weights(self):
|
| 1731 |
+
if not self.config.use_decoder_only_language_model:
|
| 1732 |
+
self.language_model.encoder.embed_tokens = self.language_model.shared
|
| 1733 |
+
self.language_model.decoder.embed_tokens = self.language_model.shared
|
| 1734 |
+
|
| 1735 |
+
def _preprocess_accelerate(self):
|
| 1736 |
+
r"""
|
| 1737 |
+
Some pre-processing hacks to make the model `accelerate` compatible. Check
|
| 1738 |
+
https://github.com/huggingface/transformers/pull/21707 for more details.
|
| 1739 |
+
"""
|
| 1740 |
+
hf_device_map = self.hf_device_map
|
| 1741 |
+
|
| 1742 |
+
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
|
| 1743 |
+
# warn users about unexpected behavior when using multi-GPU + Husky + `accelerate`.
|
| 1744 |
+
logger.warning(
|
| 1745 |
+
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
|
| 1746 |
+
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
|
| 1747 |
+
" Please pass a `device_map` that contains `language_model` to remove this warning."
|
| 1748 |
+
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
|
| 1749 |
+
" more details on creating a `device_map` for large models.",
|
| 1750 |
+
)
|
| 1751 |
+
|
| 1752 |
+
if hasattr(self.language_model, "_hf_hook"):
|
| 1753 |
+
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
| 1754 |
+
|
| 1755 |
+
@add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
|
| 1756 |
+
# @replace_return_docstrings(output_type=HuskyForConditionalGenerationModelOutput, config_class=HuskyVisionConfig)
|
| 1757 |
+
def forward(
|
| 1758 |
+
self,
|
| 1759 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1760 |
+
input_ids: Optional[torch.FloatTensor] = None,
|
| 1761 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1762 |
+
output_attentions: Optional[bool] = None,
|
| 1763 |
+
output_hidden_states: Optional[bool] = None,
|
| 1764 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1765 |
+
return_dict: Optional[bool] = None,
|
| 1766 |
+
pixel_embeds: Optional[torch.FloatTensor] = None,
|
| 1767 |
+
) -> Union[Tuple, HuskyForConditionalGenerationModelOutput]:
|
| 1768 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1769 |
+
|
| 1770 |
+
# step 1: forward the images through the vision encoder,
|
| 1771 |
+
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
| 1772 |
+
batch_size = input_ids.shape[0]
|
| 1773 |
+
vision_outputs = self.vision_model(
|
| 1774 |
+
pixel_values=pixel_values,
|
| 1775 |
+
output_attentions=output_attentions,
|
| 1776 |
+
output_hidden_states=True,
|
| 1777 |
+
return_dict=return_dict,
|
| 1778 |
+
pixel_embeds=pixel_embeds,
|
| 1779 |
+
)
|
| 1780 |
+
image_embeds = vision_outputs[0]
|
| 1781 |
+
depth = len(vision_outputs[2])
|
| 1782 |
+
indices = range(depth // 4 - 1, depth, depth // 4)
|
| 1783 |
+
pooled_outputs = []
|
| 1784 |
+
for idx, layer_norm in zip(indices, self.layer_norms):
|
| 1785 |
+
pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
|
| 1786 |
+
pool_output = layer_norm(pool_output)
|
| 1787 |
+
pooled_outputs.append(pool_output)
|
| 1788 |
+
|
| 1789 |
+
pooled_outputs = torch.cat(pooled_outputs, dim=1)
|
| 1790 |
+
pooled_outputs = self.vision_adapter(pooled_outputs)
|
| 1791 |
+
|
| 1792 |
+
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
| 1793 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1794 |
+
|
| 1795 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1796 |
+
query_outputs = self.qformer(
|
| 1797 |
+
query_embeds=query_tokens,
|
| 1798 |
+
encoder_hidden_states=image_embeds,
|
| 1799 |
+
encoder_attention_mask=image_attention_mask,
|
| 1800 |
+
output_attentions=output_attentions,
|
| 1801 |
+
output_hidden_states=output_hidden_states,
|
| 1802 |
+
return_dict=return_dict,
|
| 1803 |
+
)
|
| 1804 |
+
query_output = query_outputs[0]
|
| 1805 |
+
query_output = torch.cat([pooled_outputs, query_output], dim=1) # 36 token
|
| 1806 |
+
|
| 1807 |
+
# step 3: use the language model, conditioned on the query outputs and the prompt
|
| 1808 |
+
language_model_inputs = self.language_projection(query_output)
|
| 1809 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 1810 |
+
|
| 1811 |
+
# Human: <img></img>. Give the describe Assistant:
|
| 1812 |
+
# position of <image>: [offset: offset+num_queries]
|
| 1813 |
+
prefix_embeds = inputs_embeds[:, :self.offset, :]
|
| 1814 |
+
postfix_embeds = inputs_embeds[:, self.offset:, :]
|
| 1815 |
+
inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)
|
| 1816 |
+
if attention_mask is None:
|
| 1817 |
+
attention_mask = torch.ones_like(
|
| 1818 |
+
inputs_embeds, dtype=torch.long, device=language_model_inputs.device)
|
| 1819 |
+
else:
|
| 1820 |
+
prefix_mask = attention_mask[:, :self.offset]
|
| 1821 |
+
postfix_mask = attention_mask[:, self.offset:]
|
| 1822 |
+
vision_mask = torch.ones(size=(batch_size, self.num_queries + 4), dtype=torch.long,
|
| 1823 |
+
device=attention_mask.device)
|
| 1824 |
+
attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)
|
| 1825 |
+
|
| 1826 |
+
outputs = self.language_model(
|
| 1827 |
+
inputs_embeds=inputs_embeds,
|
| 1828 |
+
attention_mask=attention_mask,
|
| 1829 |
+
output_attentions=output_attentions,
|
| 1830 |
+
output_hidden_states=output_hidden_states,
|
| 1831 |
+
return_dict=return_dict,
|
| 1832 |
+
)
|
| 1833 |
+
logits = outputs.logits if return_dict else outputs[0]
|
| 1834 |
+
loss = None
|
| 1835 |
+
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
| 1836 |
+
if labels is not None:
|
| 1837 |
+
labels = labels.to(logits.device)
|
| 1838 |
+
logits = logits[:, -labels.size(1):, :]
|
| 1839 |
+
# Shift so that tokens < n predict n
|
| 1840 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1841 |
+
shift_labels = labels[..., 1:].contiguous().to(logits.device)
|
| 1842 |
+
|
| 1843 |
+
# Flatten the tokens
|
| 1844 |
+
loss_fct = CrossEntropyLoss(reduction="mean")
|
| 1845 |
+
|
| 1846 |
+
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
|
| 1847 |
+
|
| 1848 |
+
if not return_dict:
|
| 1849 |
+
output = (logits, vision_outputs, query_outputs, outputs)
|
| 1850 |
+
return ((loss,) + output) if loss is not None else output
|
| 1851 |
+
|
| 1852 |
+
return HuskyForConditionalGenerationModelOutput(
|
| 1853 |
+
loss=loss,
|
| 1854 |
+
logits=logits,
|
| 1855 |
+
vision_outputs=vision_outputs,
|
| 1856 |
+
qformer_outputs=query_outputs,
|
| 1857 |
+
language_model_outputs=outputs,
|
| 1858 |
+
)
|
| 1859 |
+
|
| 1860 |
+
@torch.no_grad()
|
| 1861 |
+
def generate(
|
| 1862 |
+
self,
|
| 1863 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1864 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1865 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1866 |
+
language_model_inputs: Optional[torch.FloatTensor] = None,
|
| 1867 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 1868 |
+
**generate_kwargs,
|
| 1869 |
+
) -> torch.LongTensor:
|
| 1870 |
+
"""
|
| 1871 |
+
Overrides `generate` function to be able to use the model as a conditional generator.
|
| 1872 |
+
|
| 1873 |
+
Args:
|
| 1874 |
+
pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
|
| 1875 |
+
Input images to be processed.
|
| 1876 |
+
input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
| 1877 |
+
The sequence used as a prompt for the generation.
|
| 1878 |
+
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
| 1879 |
+
Mask to avoid performing attention on padding token indices
|
| 1880 |
+
language_model_inputs (`torch.LongTensor` of shape (batch_size, sequence_length, num_channel), *optional*):
|
| 1881 |
+
The sequence used as the input for the generation
|
| 1882 |
+
language_model_inputs (`torch.LongTensor` of shape (batch_size, sequence_length, num_channel), *optional*):
|
| 1883 |
+
The sequence used as the input for the generation
|
| 1884 |
+
generation_config (`~generation.GenerationConfig`, *optional*):
|
| 1885 |
+
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
| 1886 |
+
passed to generate matching the attributes of `generation_config` will override them. If
|
| 1887 |
+
`generation_config` is not provided, the default will be used, which had the following loading
|
| 1888 |
+
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
| 1889 |
+
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
| 1890 |
+
default values, whose documentation should be checked to parameterize generation.
|
| 1891 |
+
|
| 1892 |
+
Returns:
|
| 1893 |
+
captions (list): A list of strings of length batch_size * num_captions.
|
| 1894 |
+
"""
|
| 1895 |
+
|
| 1896 |
+
if hasattr(self, "hf_device_map"):
|
| 1897 |
+
# preprocess for `accelerate`
|
| 1898 |
+
self._preprocess_accelerate()
|
| 1899 |
+
|
| 1900 |
+
if language_model_inputs is None:
|
| 1901 |
+
vision_outputs = self.vision_model(
|
| 1902 |
+
pixel_values=pixel_values,
|
| 1903 |
+
output_hidden_states=True,
|
| 1904 |
+
)
|
| 1905 |
+
image_embeds = vision_outputs[0]
|
| 1906 |
+
|
| 1907 |
+
depth = len(vision_outputs[2])
|
| 1908 |
+
indices = range(depth // 4 - 1, depth, depth // 4)
|
| 1909 |
+
pooled_outputs = []
|
| 1910 |
+
for idx, layer_norm in zip(indices, self.layer_norms):
|
| 1911 |
+
pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
|
| 1912 |
+
pool_output = layer_norm(pool_output)
|
| 1913 |
+
pooled_outputs.append(pool_output)
|
| 1914 |
+
|
| 1915 |
+
pooled_outputs = torch.cat(pooled_outputs, dim=1)
|
| 1916 |
+
pooled_outputs = self.vision_adapter(pooled_outputs)
|
| 1917 |
+
|
| 1918 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1919 |
+
|
| 1920 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1921 |
+
query_outputs = self.qformer(
|
| 1922 |
+
query_embeds=query_tokens,
|
| 1923 |
+
encoder_hidden_states=image_embeds,
|
| 1924 |
+
encoder_attention_mask=image_attention_mask,
|
| 1925 |
+
)
|
| 1926 |
+
query_output = query_outputs[0]
|
| 1927 |
+
query_output = torch.cat([pooled_outputs, query_output], dim=1)
|
| 1928 |
+
|
| 1929 |
+
language_model_inputs = self.language_projection(query_output)
|
| 1930 |
+
|
| 1931 |
+
batch_size = language_model_inputs.shape[0]
|
| 1932 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 1933 |
+
|
| 1934 |
+
prefix_embeds = inputs_embeds[:, :self.offset, :]
|
| 1935 |
+
postfix_embeds = inputs_embeds[:, self.offset:, :]
|
| 1936 |
+
inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)
|
| 1937 |
+
|
| 1938 |
+
if input_ids is None:
|
| 1939 |
+
input_ids = (
|
| 1940 |
+
torch.LongTensor([[self.config.text_config.bos_token_id]])
|
| 1941 |
+
.repeat(batch_size, 1)
|
| 1942 |
+
.to(inputs_embeds.device)
|
| 1943 |
+
)
|
| 1944 |
+
|
| 1945 |
+
if attention_mask is None:
|
| 1946 |
+
attention_mask = torch.ones_like(
|
| 1947 |
+
input_ids, dtype=torch.long, device=language_model_inputs.device)
|
| 1948 |
+
else:
|
| 1949 |
+
prefix_mask = attention_mask[:, :self.offset]
|
| 1950 |
+
postfix_mask = attention_mask[:, self.offset:]
|
| 1951 |
+
vision_mask = torch.ones(size=(batch_size, self.num_queries + 4), dtype=torch.long,
|
| 1952 |
+
device=attention_mask.device)
|
| 1953 |
+
attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)
|
| 1954 |
+
|
| 1955 |
+
outputs = self.language_model.generate(
|
| 1956 |
+
inputs_embeds=inputs_embeds,
|
| 1957 |
+
attention_mask=attention_mask,
|
| 1958 |
+
generation_config=generation_config,
|
| 1959 |
+
**generate_kwargs,
|
| 1960 |
+
)
|
| 1961 |
+
|
| 1962 |
+
return outputs
|
robohusky/model/modeling_husky_embody2_ori.py
ADDED
|
@@ -0,0 +1,1821 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The Salesforce Authors and The HuggingFace Team. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
""" PyTorch Husky model."""
|
| 16 |
+
|
| 17 |
+
import contextlib
|
| 18 |
+
import math
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import Any, Optional, Tuple, Union
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.utils.checkpoint
|
| 24 |
+
from torch import nn
|
| 25 |
+
from torch.nn import CrossEntropyLoss
|
| 26 |
+
|
| 27 |
+
from transformers.activations import ACT2FN
|
| 28 |
+
from transformers.modeling_outputs import (
|
| 29 |
+
BaseModelOutput,
|
| 30 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 31 |
+
BaseModelOutputWithPooling,
|
| 32 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 33 |
+
)
|
| 34 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 35 |
+
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
| 36 |
+
from transformers.utils import (
|
| 37 |
+
ModelOutput,
|
| 38 |
+
add_start_docstrings,
|
| 39 |
+
add_start_docstrings_to_model_forward,
|
| 40 |
+
logging,
|
| 41 |
+
replace_return_docstrings,
|
| 42 |
+
)
|
| 43 |
+
from transformers import AutoModelForCausalLM, GenerationConfig
|
| 44 |
+
|
| 45 |
+
from .configuration_husky import HuskyConfig, HuskyQFormerConfig, HuskyVisionConfig
|
| 46 |
+
|
| 47 |
+
logger = logging.get_logger(__name__)
|
| 48 |
+
|
| 49 |
+
_CHECKPOINT_FOR_DOC = "wofmanaf/husky-7b"
|
| 50 |
+
|
| 51 |
+
HUSKY_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 52 |
+
"wofmanaf/husky-7b",
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class HuskyForConditionalGenerationModelOutput(ModelOutput):
|
| 57 |
+
"""
|
| 58 |
+
Class defining the outputs of [`HuskyForConditionalGeneration`].
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
| 62 |
+
Language modeling loss from the language model.
|
| 63 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 64 |
+
Prediction scores of the language modeling head of the language model.
|
| 65 |
+
vision_outputs (`BaseModelOutputWithPooling`):
|
| 66 |
+
Outputs of the vision encoder.
|
| 67 |
+
qformer_outputs (`BaseModelOutputWithPoolingAndCrossAttentions`):
|
| 68 |
+
Outputs of the Q-Former (Querying Transformer).
|
| 69 |
+
language_model_outputs (`CausalLMOutputWithPast` or `Seq2SeqLMOutput`):
|
| 70 |
+
Outputs of the language model.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
loss: Optional[Tuple[torch.FloatTensor]] = None
|
| 74 |
+
logits: Optional[Tuple[torch.FloatTensor]] = None
|
| 75 |
+
vision_outputs: Optional[torch.FloatTensor] = None
|
| 76 |
+
qformer_outputs: Optional[Tuple[torch.FloatTensor]] = None
|
| 77 |
+
language_model_outputs: Optional[Tuple[torch.FloatTensor]] = None
|
| 78 |
+
|
| 79 |
+
def to_tuple(self) -> Tuple[Any]:
|
| 80 |
+
return tuple(
|
| 81 |
+
self[k]
|
| 82 |
+
if k not in ["vision_outputs", "qformer_outputs", "language_model_outputs"]
|
| 83 |
+
else getattr(self, k).to_tuple()
|
| 84 |
+
for k in self.keys()
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# Copied from transformers.models.blip.modeling_blip.BlipVisionEmbeddings with Blip->Husky
|
| 88 |
+
class HuskyVisionEmbeddings(nn.Module):
|
| 89 |
+
def __init__(self, config: HuskyVisionConfig):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.config = config
|
| 92 |
+
self.embed_dim = config.hidden_size
|
| 93 |
+
self.image_size = config.image_size
|
| 94 |
+
self.patch_size = config.patch_size
|
| 95 |
+
|
| 96 |
+
self.class_embedding = nn.Parameter(
|
| 97 |
+
torch.randn(1, 1, self.embed_dim),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
self.patch_embedding = nn.Conv2d(
|
| 101 |
+
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
| 105 |
+
self.num_positions = self.num_patches + 1
|
| 106 |
+
|
| 107 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
| 108 |
+
|
| 109 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 110 |
+
batch_size = pixel_values.shape[0]
|
| 111 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 112 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
| 113 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 114 |
+
|
| 115 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
| 116 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 117 |
+
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
|
| 118 |
+
return embeddings
|
| 119 |
+
|
| 120 |
+
class HuskyVideoEmbeddings(nn.Module):
|
| 121 |
+
def __init__(self, config: HuskyVisionConfig):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.config = config
|
| 124 |
+
self.embed_dim = config.hidden_size
|
| 125 |
+
self.image_size = config.image_size
|
| 126 |
+
self.patch_size = config.patch_size
|
| 127 |
+
self.num_frames = getattr(self.config, "num_frames", 8)
|
| 128 |
+
self.frame_stride = getattr(self.config, "frame_stride", 2)
|
| 129 |
+
|
| 130 |
+
self.class_embedding = nn.Parameter(
|
| 131 |
+
torch.randn(1, 1, self.embed_dim),
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
self.patch_embedding = nn.Conv3d(
|
| 135 |
+
in_channels=3, out_channels=self.embed_dim,
|
| 136 |
+
kernel_size=(self.frame_stride, self.patch_size, self.patch_size),
|
| 137 |
+
stride=(self.frame_stride, self.patch_size, self.patch_size)
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
self.num_patches = int(self.num_frames // self.frame_stride) * (self.image_size // self.patch_size) ** 2
|
| 141 |
+
self.num_positions = self.num_patches + 1
|
| 142 |
+
|
| 143 |
+
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))
|
| 144 |
+
|
| 145 |
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
| 146 |
+
batch_size = pixel_values.shape[0]
|
| 147 |
+
target_dtype = self.patch_embedding.weight.dtype
|
| 148 |
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
| 149 |
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
| 150 |
+
|
| 151 |
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
| 152 |
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
| 153 |
+
embeddings = embeddings + self.position_embedding[:, : embeddings.size(1), :].to(target_dtype)
|
| 154 |
+
return embeddings
|
| 155 |
+
|
| 156 |
+
class HuskyAttention(nn.Module):
|
| 157 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 158 |
+
|
| 159 |
+
def __init__(self, config):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.config = config
|
| 162 |
+
self.embed_dim = config.hidden_size
|
| 163 |
+
self.num_heads = config.num_attention_heads
|
| 164 |
+
self.head_dim = self.embed_dim // self.num_heads
|
| 165 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
| 166 |
+
raise ValueError(
|
| 167 |
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
| 168 |
+
f" {self.num_heads})."
|
| 169 |
+
)
|
| 170 |
+
self.scale = self.head_dim ** -0.5
|
| 171 |
+
self.dropout = nn.Dropout(config.attention_dropout)
|
| 172 |
+
|
| 173 |
+
# small tweak here compared to CLIP, no bias here
|
| 174 |
+
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=False)
|
| 175 |
+
|
| 176 |
+
if config.qkv_bias:
|
| 177 |
+
q_bias = nn.Parameter(torch.zeros(self.embed_dim))
|
| 178 |
+
v_bias = nn.Parameter(torch.zeros(self.embed_dim))
|
| 179 |
+
else:
|
| 180 |
+
q_bias = None
|
| 181 |
+
v_bias = None
|
| 182 |
+
|
| 183 |
+
if q_bias is not None:
|
| 184 |
+
qkv_bias = torch.cat((q_bias, torch.zeros_like(v_bias, requires_grad=False), v_bias))
|
| 185 |
+
self.qkv.bias = nn.Parameter(qkv_bias)
|
| 186 |
+
|
| 187 |
+
self.projection = nn.Linear(self.embed_dim, self.embed_dim)
|
| 188 |
+
|
| 189 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 190 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 191 |
+
|
| 192 |
+
def forward(
|
| 193 |
+
self,
|
| 194 |
+
hidden_states: torch.Tensor,
|
| 195 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 196 |
+
output_attentions: Optional[bool] = False,
|
| 197 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 198 |
+
"""Input shape: Batch x Time x Channel"""
|
| 199 |
+
|
| 200 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
| 201 |
+
|
| 202 |
+
mixed_qkv = self.qkv(hidden_states)
|
| 203 |
+
|
| 204 |
+
mixed_qkv = mixed_qkv.reshape(bsz, tgt_len, 3, self.num_heads, embed_dim // self.num_heads).permute(
|
| 205 |
+
2, 0, 3, 1, 4
|
| 206 |
+
)
|
| 207 |
+
query_states, key_states, value_states = (
|
| 208 |
+
mixed_qkv[0],
|
| 209 |
+
mixed_qkv[1],
|
| 210 |
+
mixed_qkv[2],
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 214 |
+
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
|
| 215 |
+
|
| 216 |
+
attention_scores = attention_scores * self.scale
|
| 217 |
+
|
| 218 |
+
# Normalize the attention scores to probabilities.
|
| 219 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
| 220 |
+
|
| 221 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 222 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 223 |
+
attention_probs = self.dropout(attention_probs)
|
| 224 |
+
|
| 225 |
+
# Mask heads if we want to
|
| 226 |
+
if head_mask is not None:
|
| 227 |
+
attention_probs = attention_probs * head_mask
|
| 228 |
+
|
| 229 |
+
context_layer = torch.matmul(attention_probs, value_states).permute(0, 2, 1, 3)
|
| 230 |
+
|
| 231 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.embed_dim,)
|
| 232 |
+
context_layer = context_layer.reshape(new_context_layer_shape)
|
| 233 |
+
|
| 234 |
+
output = self.projection(context_layer)
|
| 235 |
+
|
| 236 |
+
outputs = (output, attention_probs) if output_attentions else (output, None)
|
| 237 |
+
|
| 238 |
+
return outputs
|
| 239 |
+
|
| 240 |
+
# Copied from transformers.models.blip.modeling_blip.BlipMLP
|
| 241 |
+
class HuskyMLP(nn.Module):
|
| 242 |
+
def __init__(self, config):
|
| 243 |
+
super().__init__()
|
| 244 |
+
self.config = config
|
| 245 |
+
self.activation_fn = ACT2FN[config.hidden_act]
|
| 246 |
+
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 247 |
+
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 248 |
+
|
| 249 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 250 |
+
hidden_states = self.fc1(hidden_states)
|
| 251 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 252 |
+
hidden_states = self.fc2(hidden_states)
|
| 253 |
+
return hidden_states
|
| 254 |
+
|
| 255 |
+
# Copied from transformers.models.blip.modeling_blip.BlipEncoderLayer with Blip->Husky
|
| 256 |
+
class HuskyEncoderLayer(nn.Module):
|
| 257 |
+
def __init__(self, config: HuskyConfig):
|
| 258 |
+
super().__init__()
|
| 259 |
+
self.embed_dim = config.hidden_size
|
| 260 |
+
self.self_attn = HuskyAttention(config)
|
| 261 |
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 262 |
+
self.mlp = HuskyMLP(config)
|
| 263 |
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
| 264 |
+
|
| 265 |
+
def forward(
|
| 266 |
+
self,
|
| 267 |
+
hidden_states: torch.Tensor,
|
| 268 |
+
attention_mask: torch.Tensor,
|
| 269 |
+
output_attentions: Optional[bool] = False,
|
| 270 |
+
) -> Tuple[torch.FloatTensor]:
|
| 271 |
+
"""
|
| 272 |
+
Args:
|
| 273 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 274 |
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
| 275 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 276 |
+
`(config.encoder_attention_heads,)`.
|
| 277 |
+
output_attentions (`bool`, *optional*):
|
| 278 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 279 |
+
returned tensors for more detail.
|
| 280 |
+
"""
|
| 281 |
+
residual = hidden_states
|
| 282 |
+
|
| 283 |
+
hidden_states = self.layer_norm1(hidden_states)
|
| 284 |
+
hidden_states, attn_weights = self.self_attn(
|
| 285 |
+
hidden_states=hidden_states,
|
| 286 |
+
head_mask=attention_mask,
|
| 287 |
+
output_attentions=output_attentions,
|
| 288 |
+
)
|
| 289 |
+
hidden_states = hidden_states + residual
|
| 290 |
+
residual = hidden_states
|
| 291 |
+
hidden_states = self.layer_norm2(hidden_states)
|
| 292 |
+
hidden_states = self.mlp(hidden_states)
|
| 293 |
+
|
| 294 |
+
hidden_states = hidden_states + residual
|
| 295 |
+
|
| 296 |
+
outputs = (hidden_states,)
|
| 297 |
+
|
| 298 |
+
if output_attentions:
|
| 299 |
+
outputs += (attn_weights,)
|
| 300 |
+
|
| 301 |
+
return outputs
|
| 302 |
+
|
| 303 |
+
class HuskyPreTrainedModel(PreTrainedModel):
|
| 304 |
+
"""
|
| 305 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 306 |
+
models.
|
| 307 |
+
"""
|
| 308 |
+
|
| 309 |
+
config_class = HuskyConfig
|
| 310 |
+
base_model_prefix = "husky"
|
| 311 |
+
supports_gradient_checkpointing = True
|
| 312 |
+
_keys_to_ignore_on_load_missing = [
|
| 313 |
+
r"position_ids",
|
| 314 |
+
r"language_model.encoder.embed_tokens.weight",
|
| 315 |
+
r"language_model.decoder.embed_tokens.weight",
|
| 316 |
+
r"language_model.lm_head.weight",
|
| 317 |
+
]
|
| 318 |
+
_no_split_modules = ["HuskyAttention", "LlamaDecoderLayer", "LlamaForCausalLM"]
|
| 319 |
+
_skip_keys_device_placement = "past_key_values"
|
| 320 |
+
_keep_in_fp32_modules = ["wo"]
|
| 321 |
+
|
| 322 |
+
def _init_weights(self, module):
|
| 323 |
+
"""Initialize the weights"""
|
| 324 |
+
factor = self.config.initializer_range
|
| 325 |
+
if isinstance(module, nn.Conv2d) or isinstance(module, nn.Embedding) or isinstance(module, nn.Linear):
|
| 326 |
+
module.weight.data.normal_(mean=0.0, std=factor)
|
| 327 |
+
if hasattr(module, "bias") and module.bias is not None:
|
| 328 |
+
module.bias.data.zero_()
|
| 329 |
+
|
| 330 |
+
if isinstance(module, HuskyVisionEmbeddings):
|
| 331 |
+
if hasattr(self.config, "vision_config"):
|
| 332 |
+
factor = self.config.vision_config.initializer_range
|
| 333 |
+
nn.init.trunc_normal_(module.position_embedding, mean=0.0, std=factor)
|
| 334 |
+
nn.init.trunc_normal_(module.class_embedding, mean=0.0, std=factor)
|
| 335 |
+
|
| 336 |
+
elif isinstance(module, nn.LayerNorm):
|
| 337 |
+
module.bias.data.zero_()
|
| 338 |
+
module.weight.data.fill_(1.0)
|
| 339 |
+
elif isinstance(module, nn.Linear) and module.bias is not None:
|
| 340 |
+
module.bias.data.zero_()
|
| 341 |
+
|
| 342 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 343 |
+
if isinstance(module, HuskyEncoder):
|
| 344 |
+
module.gradient_checkpointing = value
|
| 345 |
+
|
| 346 |
+
Husky_START_DOCSTRING = r"""
|
| 347 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 348 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 349 |
+
etc.)
|
| 350 |
+
|
| 351 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 352 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 353 |
+
and behavior.
|
| 354 |
+
|
| 355 |
+
Parameters:
|
| 356 |
+
config ([`HuskyConfig`]): Model configuration class with all the parameters of the model.
|
| 357 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 358 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 359 |
+
"""
|
| 360 |
+
|
| 361 |
+
Husky_VISION_INPUTS_DOCSTRING = r"""
|
| 362 |
+
Args:
|
| 363 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 364 |
+
Pixel values. Pixel values can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for
|
| 365 |
+
details.
|
| 366 |
+
output_attentions (`bool`, *optional*):
|
| 367 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 368 |
+
tensors for more detail.
|
| 369 |
+
output_hidden_states (`bool`, *optional*):
|
| 370 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 371 |
+
more detail.
|
| 372 |
+
return_dict (`bool`, *optional*):
|
| 373 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 374 |
+
"""
|
| 375 |
+
|
| 376 |
+
Husky_TEXT_INPUTS_DOCSTRING = r"""
|
| 377 |
+
Args:
|
| 378 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 379 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 380 |
+
it. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 381 |
+
[`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids)
|
| 382 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 383 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 384 |
+
- 1 for tokens that are **not masked**,
|
| 385 |
+
- 0 for tokens that are **masked**.
|
| 386 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 387 |
+
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 388 |
+
Indices of decoder input sequence tokens in the vocabulary.
|
| 389 |
+
|
| 390 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 391 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
| 392 |
+
|
| 393 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
| 394 |
+
|
| 395 |
+
T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
|
| 396 |
+
is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
|
| 397 |
+
|
| 398 |
+
To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
|
| 399 |
+
Training](./t5#training).
|
| 400 |
+
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 401 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
| 402 |
+
be used by default.
|
| 403 |
+
output_attentions (`bool`, *optional*):
|
| 404 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 405 |
+
tensors for more detail.
|
| 406 |
+
output_hidden_states (`bool`, *optional*):
|
| 407 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 408 |
+
more detail.
|
| 409 |
+
return_dict (`bool`, *optional*):
|
| 410 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 411 |
+
"""
|
| 412 |
+
|
| 413 |
+
Husky_INPUTS_DOCSTRING = r"""
|
| 414 |
+
Args:
|
| 415 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 416 |
+
Pixel values. Pixel values can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for
|
| 417 |
+
details.
|
| 418 |
+
|
| 419 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 420 |
+
Indices of input sequence tokens in the vocabulary of the language model. Input tokens can optionally be
|
| 421 |
+
provided to serve as text prompt, which the language model can continue.
|
| 422 |
+
|
| 423 |
+
Indices can be obtained using [`HuskyProcessor`]. See [`HuskyProcessor.__call__`] for details.
|
| 424 |
+
|
| 425 |
+
[What are input IDs?](../glossary#input-ids)
|
| 426 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 427 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 428 |
+
|
| 429 |
+
- 1 for tokens that are **not masked**,
|
| 430 |
+
- 0 for tokens that are **masked**.
|
| 431 |
+
|
| 432 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 433 |
+
|
| 434 |
+
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 435 |
+
Indices of decoder input sequence tokens in the vocabulary of the language model. Only relevant in case an
|
| 436 |
+
encoder-decoder language model (like T5) is used.
|
| 437 |
+
|
| 438 |
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
| 439 |
+
[`PreTrainedTokenizer.__call__`] for details. [What are decoder input IDs?](../glossary#decoder-input-ids)
|
| 440 |
+
|
| 441 |
+
decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
|
| 442 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
| 443 |
+
be used by default.
|
| 444 |
+
|
| 445 |
+
Only relevant in case an encoder-decoder language model (like T5) is used.
|
| 446 |
+
|
| 447 |
+
output_attentions (`bool`, *optional*):
|
| 448 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
| 449 |
+
tensors for more detail.
|
| 450 |
+
output_hidden_states (`bool`, *optional*):
|
| 451 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
| 452 |
+
more detail.
|
| 453 |
+
return_dict (`bool`, *optional*):
|
| 454 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 455 |
+
"""
|
| 456 |
+
|
| 457 |
+
# Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Husky
|
| 458 |
+
class HuskyEncoder(nn.Module):
|
| 459 |
+
"""
|
| 460 |
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
| 461 |
+
[`HuskyEncoderLayer`].
|
| 462 |
+
|
| 463 |
+
Args:
|
| 464 |
+
config (`HuskyConfig`):
|
| 465 |
+
The corresponding vision configuration for the `HuskyEncoder`.
|
| 466 |
+
"""
|
| 467 |
+
|
| 468 |
+
def __init__(self, config: HuskyConfig):
|
| 469 |
+
super().__init__()
|
| 470 |
+
self.config = config
|
| 471 |
+
self.layers = nn.ModuleList([HuskyEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 472 |
+
self.gradient_checkpointing = False
|
| 473 |
+
|
| 474 |
+
def forward(
|
| 475 |
+
self,
|
| 476 |
+
inputs_embeds,
|
| 477 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 478 |
+
output_attentions: Optional[bool] = None,
|
| 479 |
+
output_hidden_states: Optional[bool] = None,
|
| 480 |
+
return_dict: Optional[bool] = None,
|
| 481 |
+
) -> Union[Tuple, BaseModelOutput]:
|
| 482 |
+
r"""
|
| 483 |
+
Args:
|
| 484 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 485 |
+
Embedded representation of the inputs. Should be float, not int tokens.
|
| 486 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 487 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 488 |
+
|
| 489 |
+
- 1 for tokens that are **not masked**,
|
| 490 |
+
- 0 for tokens that are **masked**.
|
| 491 |
+
|
| 492 |
+
[What are attention masks?](../glossary#attention-mask)
|
| 493 |
+
output_attentions (`bool`, *optional*):
|
| 494 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 495 |
+
returned tensors for more detail.
|
| 496 |
+
output_hidden_states (`bool`, *optional*):
|
| 497 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
| 498 |
+
for more detail.
|
| 499 |
+
return_dict (`bool`, *optional*):
|
| 500 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
| 501 |
+
"""
|
| 502 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 503 |
+
output_hidden_states = (
|
| 504 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 505 |
+
)
|
| 506 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 507 |
+
|
| 508 |
+
encoder_states = () if output_hidden_states else None
|
| 509 |
+
all_attentions = () if output_attentions else None
|
| 510 |
+
|
| 511 |
+
hidden_states = inputs_embeds
|
| 512 |
+
for idx, encoder_layer in enumerate(self.layers):
|
| 513 |
+
if output_hidden_states:
|
| 514 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 515 |
+
if self.gradient_checkpointing and self.training:
|
| 516 |
+
|
| 517 |
+
def create_custom_forward(module):
|
| 518 |
+
def custom_forward(*inputs):
|
| 519 |
+
return module(*inputs, output_attentions)
|
| 520 |
+
|
| 521 |
+
return custom_forward
|
| 522 |
+
|
| 523 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 524 |
+
create_custom_forward(encoder_layer),
|
| 525 |
+
hidden_states,
|
| 526 |
+
attention_mask,
|
| 527 |
+
)
|
| 528 |
+
else:
|
| 529 |
+
layer_outputs = encoder_layer(
|
| 530 |
+
hidden_states,
|
| 531 |
+
attention_mask,
|
| 532 |
+
output_attentions=output_attentions,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
hidden_states = layer_outputs[0]
|
| 536 |
+
|
| 537 |
+
if output_attentions:
|
| 538 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
| 539 |
+
|
| 540 |
+
if output_hidden_states:
|
| 541 |
+
encoder_states = encoder_states + (hidden_states,)
|
| 542 |
+
|
| 543 |
+
if not return_dict:
|
| 544 |
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
| 545 |
+
return BaseModelOutput(
|
| 546 |
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
# Copied from transformers.models.blip.modeling_blip.BlipVisionModel with Blip->Husky, BLIP->Husky
|
| 550 |
+
class HuskyVisionModel(HuskyPreTrainedModel):
|
| 551 |
+
main_input_name = "pixel_values"
|
| 552 |
+
config_class = HuskyVisionConfig
|
| 553 |
+
|
| 554 |
+
def __init__(self, config: HuskyVisionConfig):
|
| 555 |
+
super().__init__(config)
|
| 556 |
+
self.config = config
|
| 557 |
+
embed_dim = config.hidden_size
|
| 558 |
+
|
| 559 |
+
self.embeddings = HuskyVisionEmbeddings(config)
|
| 560 |
+
self.video_embeddings = HuskyVideoEmbeddings(config)
|
| 561 |
+
|
| 562 |
+
self.encoder = HuskyEncoder(config)
|
| 563 |
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
| 564 |
+
|
| 565 |
+
self.post_init()
|
| 566 |
+
|
| 567 |
+
@add_start_docstrings_to_model_forward(Husky_VISION_INPUTS_DOCSTRING)
|
| 568 |
+
# @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=HuskyVisionConfig)
|
| 569 |
+
def forward(
|
| 570 |
+
self,
|
| 571 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 572 |
+
output_attentions: Optional[bool] = None,
|
| 573 |
+
output_hidden_states: Optional[bool] = None,
|
| 574 |
+
return_dict: Optional[bool] = None,
|
| 575 |
+
pixel_embeds: Optional[torch.FloatTensor] = None,
|
| 576 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
| 577 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 578 |
+
output_hidden_states = (
|
| 579 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 580 |
+
)
|
| 581 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 582 |
+
|
| 583 |
+
if pixel_values is None and pixel_embeds is None:
|
| 584 |
+
raise ValueError("You have to specify pixel_values or pixel_embeds")
|
| 585 |
+
|
| 586 |
+
if pixel_embeds is not None:
|
| 587 |
+
hidden_states = pixel_embeds
|
| 588 |
+
else:
|
| 589 |
+
if len(pixel_values.shape) == 4:
|
| 590 |
+
hidden_states = self.embeddings(pixel_values)
|
| 591 |
+
elif len(pixel_values.shape) == 5:
|
| 592 |
+
hidden_states = self.video_embeddings(pixel_values)
|
| 593 |
+
else:
|
| 594 |
+
raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
|
| 595 |
+
|
| 596 |
+
encoder_outputs = self.encoder(
|
| 597 |
+
inputs_embeds=hidden_states,
|
| 598 |
+
output_attentions=output_attentions,
|
| 599 |
+
output_hidden_states=output_hidden_states,
|
| 600 |
+
return_dict=return_dict,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
last_hidden_state = encoder_outputs[0]
|
| 604 |
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
| 605 |
+
|
| 606 |
+
pooled_output = last_hidden_state[:, 0, :]
|
| 607 |
+
pooled_output = self.post_layernorm(pooled_output)
|
| 608 |
+
|
| 609 |
+
if not return_dict:
|
| 610 |
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
| 611 |
+
|
| 612 |
+
return BaseModelOutputWithPooling(
|
| 613 |
+
last_hidden_state=last_hidden_state,
|
| 614 |
+
pooler_output=pooled_output,
|
| 615 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 616 |
+
attentions=encoder_outputs.attentions,
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
def get_input_embeddings(self):
|
| 620 |
+
return self.embeddings
|
| 621 |
+
|
| 622 |
+
def get_video_embeddings(self):
|
| 623 |
+
return self.video_embeddings
|
| 624 |
+
|
| 625 |
+
class HuskyQFormerMultiHeadAttention(nn.Module):
|
| 626 |
+
def __init__(self, config, is_cross_attention=False):
|
| 627 |
+
super().__init__()
|
| 628 |
+
self.config = config
|
| 629 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
| 630 |
+
raise ValueError(
|
| 631 |
+
"The hidden size (%d) is not a multiple of the number of attention heads (%d)"
|
| 632 |
+
% (config.hidden_size, config.num_attention_heads)
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
self.num_attention_heads = config.num_attention_heads
|
| 636 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
| 637 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
| 638 |
+
|
| 639 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
| 640 |
+
if is_cross_attention:
|
| 641 |
+
self.key = nn.Linear(config.encoder_hidden_size, self.all_head_size)
|
| 642 |
+
self.value = nn.Linear(config.encoder_hidden_size, self.all_head_size)
|
| 643 |
+
else:
|
| 644 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
| 645 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
| 646 |
+
|
| 647 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 648 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
| 649 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 650 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 651 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
| 652 |
+
self.save_attention = False
|
| 653 |
+
|
| 654 |
+
def save_attn_gradients(self, attn_gradients):
|
| 655 |
+
self.attn_gradients = attn_gradients
|
| 656 |
+
|
| 657 |
+
def get_attn_gradients(self):
|
| 658 |
+
return self.attn_gradients
|
| 659 |
+
|
| 660 |
+
def save_attention_map(self, attention_map):
|
| 661 |
+
self.attention_map = attention_map
|
| 662 |
+
|
| 663 |
+
def get_attention_map(self):
|
| 664 |
+
return self.attention_map
|
| 665 |
+
|
| 666 |
+
def transpose_for_scores(self, x):
|
| 667 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
| 668 |
+
x = x.view(*new_x_shape)
|
| 669 |
+
return x.permute(0, 2, 1, 3)
|
| 670 |
+
|
| 671 |
+
def forward(
|
| 672 |
+
self,
|
| 673 |
+
hidden_states,
|
| 674 |
+
attention_mask=None,
|
| 675 |
+
head_mask=None,
|
| 676 |
+
encoder_hidden_states=None,
|
| 677 |
+
encoder_attention_mask=None,
|
| 678 |
+
past_key_value=None,
|
| 679 |
+
output_attentions=False,
|
| 680 |
+
):
|
| 681 |
+
# If this is instantiated as a cross-attention module, the keys
|
| 682 |
+
# and values come from an encoder; the attention mask needs to be
|
| 683 |
+
# such that the encoder's padding tokens are not attended to.
|
| 684 |
+
is_cross_attention = encoder_hidden_states is not None
|
| 685 |
+
|
| 686 |
+
if is_cross_attention:
|
| 687 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
| 688 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
| 689 |
+
attention_mask = encoder_attention_mask
|
| 690 |
+
elif past_key_value is not None:
|
| 691 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 692 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 693 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
| 694 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
| 695 |
+
else:
|
| 696 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
| 697 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
| 698 |
+
|
| 699 |
+
mixed_query_layer = self.query(hidden_states)
|
| 700 |
+
|
| 701 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
| 702 |
+
|
| 703 |
+
past_key_value = (key_layer, value_layer)
|
| 704 |
+
|
| 705 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
| 706 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 707 |
+
|
| 708 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
| 709 |
+
seq_length = hidden_states.size()[1]
|
| 710 |
+
position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
| 711 |
+
position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
| 712 |
+
distance = position_ids_l - position_ids_r
|
| 713 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
| 714 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
| 715 |
+
|
| 716 |
+
if self.position_embedding_type == "relative_key":
|
| 717 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 718 |
+
attention_scores = attention_scores + relative_position_scores
|
| 719 |
+
elif self.position_embedding_type == "relative_key_query":
|
| 720 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
| 721 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
| 722 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
| 723 |
+
|
| 724 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
| 725 |
+
|
| 726 |
+
if attention_mask is not None:
|
| 727 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
| 728 |
+
attention_scores = attention_scores + attention_mask
|
| 729 |
+
|
| 730 |
+
# Normalize the attention scores to probabilities.
|
| 731 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
| 732 |
+
|
| 733 |
+
if is_cross_attention and self.save_attention:
|
| 734 |
+
self.save_attention_map(attention_probs)
|
| 735 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
| 736 |
+
|
| 737 |
+
# This is actually dropping out entire tokens to attend to, which might
|
| 738 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 739 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
| 740 |
+
|
| 741 |
+
# Mask heads if we want to
|
| 742 |
+
if head_mask is not None:
|
| 743 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
| 744 |
+
|
| 745 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
| 746 |
+
|
| 747 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
| 748 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
| 749 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
| 750 |
+
|
| 751 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
| 752 |
+
|
| 753 |
+
outputs = outputs + (past_key_value,)
|
| 754 |
+
return outputs
|
| 755 |
+
|
| 756 |
+
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->HuskyQFormer
|
| 757 |
+
class HuskyQFormerSelfOutput(nn.Module):
|
| 758 |
+
def __init__(self, config):
|
| 759 |
+
super().__init__()
|
| 760 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 761 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 762 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 763 |
+
|
| 764 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 765 |
+
hidden_states = self.dense(hidden_states)
|
| 766 |
+
hidden_states = self.dropout(hidden_states)
|
| 767 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 768 |
+
return hidden_states
|
| 769 |
+
|
| 770 |
+
class HuskyQFormerAttention(nn.Module):
|
| 771 |
+
def __init__(self, config, is_cross_attention=False):
|
| 772 |
+
super().__init__()
|
| 773 |
+
self.attention = HuskyQFormerMultiHeadAttention(config, is_cross_attention)
|
| 774 |
+
self.output = HuskyQFormerSelfOutput(config)
|
| 775 |
+
self.pruned_heads = set()
|
| 776 |
+
|
| 777 |
+
def prune_heads(self, heads):
|
| 778 |
+
if len(heads) == 0:
|
| 779 |
+
return
|
| 780 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 781 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
# Prune linear layers
|
| 785 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
| 786 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
| 787 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
| 788 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
| 789 |
+
|
| 790 |
+
# Update hyper params and store pruned heads
|
| 791 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
| 792 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
| 793 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 794 |
+
|
| 795 |
+
def forward(
|
| 796 |
+
self,
|
| 797 |
+
hidden_states: torch.Tensor,
|
| 798 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 799 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 800 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 801 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 802 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 803 |
+
output_attentions: Optional[bool] = False,
|
| 804 |
+
) -> Tuple[torch.Tensor]:
|
| 805 |
+
self_outputs = self.attention(
|
| 806 |
+
hidden_states,
|
| 807 |
+
attention_mask,
|
| 808 |
+
head_mask,
|
| 809 |
+
encoder_hidden_states,
|
| 810 |
+
encoder_attention_mask,
|
| 811 |
+
past_key_value,
|
| 812 |
+
output_attentions,
|
| 813 |
+
)
|
| 814 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
| 815 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
| 816 |
+
return outputs
|
| 817 |
+
|
| 818 |
+
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->HuskyQFormer
|
| 819 |
+
class HuskyQFormerIntermediate(nn.Module):
|
| 820 |
+
def __init__(self, config):
|
| 821 |
+
super().__init__()
|
| 822 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
| 823 |
+
if isinstance(config.hidden_act, str):
|
| 824 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
| 825 |
+
else:
|
| 826 |
+
self.intermediate_act_fn = config.hidden_act
|
| 827 |
+
|
| 828 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 829 |
+
hidden_states = self.dense(hidden_states)
|
| 830 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
| 831 |
+
return hidden_states
|
| 832 |
+
|
| 833 |
+
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->HuskyQFormer
|
| 834 |
+
class HuskyQFormerOutput(nn.Module):
|
| 835 |
+
def __init__(self, config):
|
| 836 |
+
super().__init__()
|
| 837 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
| 838 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 839 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 840 |
+
|
| 841 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
| 842 |
+
hidden_states = self.dense(hidden_states)
|
| 843 |
+
hidden_states = self.dropout(hidden_states)
|
| 844 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
| 845 |
+
return hidden_states
|
| 846 |
+
|
| 847 |
+
class HuskyQFormerLayer(nn.Module):
|
| 848 |
+
def __init__(self, config, layer_idx):
|
| 849 |
+
super().__init__()
|
| 850 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
| 851 |
+
self.seq_len_dim = 1
|
| 852 |
+
self.attention = HuskyQFormerAttention(config)
|
| 853 |
+
|
| 854 |
+
self.layer_idx = layer_idx
|
| 855 |
+
|
| 856 |
+
if layer_idx % config.cross_attention_frequency == 0:
|
| 857 |
+
self.crossattention = HuskyQFormerAttention(config, is_cross_attention=True)
|
| 858 |
+
self.has_cross_attention = True
|
| 859 |
+
else:
|
| 860 |
+
self.has_cross_attention = False
|
| 861 |
+
|
| 862 |
+
self.intermediate_query = HuskyQFormerIntermediate(config)
|
| 863 |
+
self.output_query = HuskyQFormerOutput(config)
|
| 864 |
+
|
| 865 |
+
def forward(
|
| 866 |
+
self,
|
| 867 |
+
hidden_states,
|
| 868 |
+
attention_mask=None,
|
| 869 |
+
head_mask=None,
|
| 870 |
+
encoder_hidden_states=None,
|
| 871 |
+
encoder_attention_mask=None,
|
| 872 |
+
past_key_value=None,
|
| 873 |
+
output_attentions=False,
|
| 874 |
+
query_length=0,
|
| 875 |
+
):
|
| 876 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 877 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
| 878 |
+
self_attention_outputs = self.attention(
|
| 879 |
+
hidden_states,
|
| 880 |
+
attention_mask,
|
| 881 |
+
head_mask,
|
| 882 |
+
output_attentions=output_attentions,
|
| 883 |
+
past_key_value=self_attn_past_key_value,
|
| 884 |
+
)
|
| 885 |
+
attention_output = self_attention_outputs[0]
|
| 886 |
+
outputs = self_attention_outputs[1:-1]
|
| 887 |
+
|
| 888 |
+
present_key_value = self_attention_outputs[-1]
|
| 889 |
+
|
| 890 |
+
if query_length > 0:
|
| 891 |
+
query_attention_output = attention_output[:, :query_length, :]
|
| 892 |
+
|
| 893 |
+
if self.has_cross_attention:
|
| 894 |
+
if encoder_hidden_states is None:
|
| 895 |
+
raise ValueError("encoder_hidden_states must be given for cross-attention layers")
|
| 896 |
+
cross_attention_outputs = self.crossattention(
|
| 897 |
+
query_attention_output,
|
| 898 |
+
attention_mask,
|
| 899 |
+
head_mask,
|
| 900 |
+
encoder_hidden_states,
|
| 901 |
+
encoder_attention_mask,
|
| 902 |
+
output_attentions=output_attentions,
|
| 903 |
+
)
|
| 904 |
+
query_attention_output = cross_attention_outputs[0]
|
| 905 |
+
# add cross attentions if we output attention weights
|
| 906 |
+
outputs = outputs + cross_attention_outputs[1:-1]
|
| 907 |
+
|
| 908 |
+
layer_output = apply_chunking_to_forward(
|
| 909 |
+
self.feed_forward_chunk_query,
|
| 910 |
+
self.chunk_size_feed_forward,
|
| 911 |
+
self.seq_len_dim,
|
| 912 |
+
query_attention_output,
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
if attention_output.shape[1] > query_length:
|
| 916 |
+
layer_output_text = apply_chunking_to_forward(
|
| 917 |
+
self.feed_forward_chunk,
|
| 918 |
+
self.chunk_size_feed_forward,
|
| 919 |
+
self.seq_len_dim,
|
| 920 |
+
attention_output[:, query_length:, :],
|
| 921 |
+
)
|
| 922 |
+
layer_output = torch.cat([layer_output, layer_output_text], dim=1)
|
| 923 |
+
else:
|
| 924 |
+
layer_output = apply_chunking_to_forward(
|
| 925 |
+
self.feed_forward_chunk,
|
| 926 |
+
self.chunk_size_feed_forward,
|
| 927 |
+
self.seq_len_dim,
|
| 928 |
+
attention_output,
|
| 929 |
+
)
|
| 930 |
+
outputs = (layer_output,) + outputs
|
| 931 |
+
|
| 932 |
+
outputs = outputs + (present_key_value,)
|
| 933 |
+
|
| 934 |
+
return outputs
|
| 935 |
+
|
| 936 |
+
def feed_forward_chunk(self, attention_output):
|
| 937 |
+
intermediate_output = self.intermediate(attention_output)
|
| 938 |
+
layer_output = self.output(intermediate_output, attention_output)
|
| 939 |
+
return layer_output
|
| 940 |
+
|
| 941 |
+
def feed_forward_chunk_query(self, attention_output):
|
| 942 |
+
intermediate_output = self.intermediate_query(attention_output)
|
| 943 |
+
layer_output = self.output_query(intermediate_output, attention_output)
|
| 944 |
+
return layer_output
|
| 945 |
+
|
| 946 |
+
class HuskyQFormerEncoder(nn.Module):
|
| 947 |
+
def __init__(self, config):
|
| 948 |
+
super().__init__()
|
| 949 |
+
self.config = config
|
| 950 |
+
self.layer = nn.ModuleList(
|
| 951 |
+
[HuskyQFormerLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
| 952 |
+
)
|
| 953 |
+
self.gradient_checkpointing = False
|
| 954 |
+
|
| 955 |
+
def forward(
|
| 956 |
+
self,
|
| 957 |
+
hidden_states,
|
| 958 |
+
attention_mask=None,
|
| 959 |
+
head_mask=None,
|
| 960 |
+
encoder_hidden_states=None,
|
| 961 |
+
encoder_attention_mask=None,
|
| 962 |
+
past_key_values=None,
|
| 963 |
+
use_cache=None,
|
| 964 |
+
output_attentions=False,
|
| 965 |
+
output_hidden_states=False,
|
| 966 |
+
return_dict=True,
|
| 967 |
+
query_length=0,
|
| 968 |
+
):
|
| 969 |
+
all_hidden_states = () if output_hidden_states else None
|
| 970 |
+
all_self_attentions = () if output_attentions else None
|
| 971 |
+
all_cross_attentions = () if output_attentions else None
|
| 972 |
+
|
| 973 |
+
next_decoder_cache = () if use_cache else None
|
| 974 |
+
|
| 975 |
+
for i in range(self.config.num_hidden_layers):
|
| 976 |
+
layer_module = self.layer[i]
|
| 977 |
+
if output_hidden_states:
|
| 978 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 979 |
+
|
| 980 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
| 981 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
| 982 |
+
|
| 983 |
+
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
| 984 |
+
if use_cache:
|
| 985 |
+
logger.warn(
|
| 986 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 987 |
+
)
|
| 988 |
+
use_cache = False
|
| 989 |
+
|
| 990 |
+
def create_custom_forward(module):
|
| 991 |
+
def custom_forward(*inputs):
|
| 992 |
+
return module(*inputs, past_key_value, output_attentions, query_length)
|
| 993 |
+
|
| 994 |
+
return custom_forward
|
| 995 |
+
|
| 996 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 997 |
+
create_custom_forward(layer_module),
|
| 998 |
+
hidden_states,
|
| 999 |
+
attention_mask,
|
| 1000 |
+
layer_head_mask,
|
| 1001 |
+
encoder_hidden_states,
|
| 1002 |
+
encoder_attention_mask,
|
| 1003 |
+
)
|
| 1004 |
+
else:
|
| 1005 |
+
layer_outputs = layer_module(
|
| 1006 |
+
hidden_states,
|
| 1007 |
+
attention_mask,
|
| 1008 |
+
layer_head_mask,
|
| 1009 |
+
encoder_hidden_states,
|
| 1010 |
+
encoder_attention_mask,
|
| 1011 |
+
past_key_value,
|
| 1012 |
+
output_attentions,
|
| 1013 |
+
query_length,
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
hidden_states = layer_outputs[0]
|
| 1017 |
+
if use_cache:
|
| 1018 |
+
next_decoder_cache += (layer_outputs[-1],)
|
| 1019 |
+
if output_attentions:
|
| 1020 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
| 1021 |
+
if layer_module.has_cross_attention:
|
| 1022 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
| 1023 |
+
|
| 1024 |
+
if output_hidden_states:
|
| 1025 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1026 |
+
|
| 1027 |
+
if not return_dict:
|
| 1028 |
+
return tuple(
|
| 1029 |
+
v
|
| 1030 |
+
for v in [
|
| 1031 |
+
hidden_states,
|
| 1032 |
+
next_decoder_cache,
|
| 1033 |
+
all_hidden_states,
|
| 1034 |
+
all_self_attentions,
|
| 1035 |
+
all_cross_attentions,
|
| 1036 |
+
]
|
| 1037 |
+
if v is not None
|
| 1038 |
+
)
|
| 1039 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 1040 |
+
last_hidden_state=hidden_states,
|
| 1041 |
+
past_key_values=next_decoder_cache,
|
| 1042 |
+
hidden_states=all_hidden_states,
|
| 1043 |
+
attentions=all_self_attentions,
|
| 1044 |
+
cross_attentions=all_cross_attentions,
|
| 1045 |
+
)
|
| 1046 |
+
|
| 1047 |
+
class HuskyQFormerModel(HuskyPreTrainedModel):
|
| 1048 |
+
"""
|
| 1049 |
+
Querying Transformer (Q-Former), used in Husky.
|
| 1050 |
+
"""
|
| 1051 |
+
|
| 1052 |
+
def __init__(self, config: HuskyQFormerConfig):
|
| 1053 |
+
super().__init__(config)
|
| 1054 |
+
self.config = config
|
| 1055 |
+
|
| 1056 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 1057 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1058 |
+
|
| 1059 |
+
self.encoder = HuskyQFormerEncoder(config)
|
| 1060 |
+
|
| 1061 |
+
self.post_init()
|
| 1062 |
+
|
| 1063 |
+
def get_input_embeddings(self):
|
| 1064 |
+
return self.embeddings.word_embeddings
|
| 1065 |
+
|
| 1066 |
+
def set_input_embeddings(self, value):
|
| 1067 |
+
self.embeddings.word_embeddings = value
|
| 1068 |
+
|
| 1069 |
+
def _prune_heads(self, heads_to_prune):
|
| 1070 |
+
"""
|
| 1071 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
| 1072 |
+
class PreTrainedModel
|
| 1073 |
+
"""
|
| 1074 |
+
for layer, heads in heads_to_prune.items():
|
| 1075 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
| 1076 |
+
|
| 1077 |
+
def get_extended_attention_mask(
|
| 1078 |
+
self,
|
| 1079 |
+
attention_mask: torch.Tensor,
|
| 1080 |
+
input_shape: Tuple[int],
|
| 1081 |
+
device: torch.device,
|
| 1082 |
+
has_query: bool = False,
|
| 1083 |
+
) -> torch.Tensor:
|
| 1084 |
+
"""
|
| 1085 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
| 1086 |
+
|
| 1087 |
+
Arguments:
|
| 1088 |
+
attention_mask (`torch.Tensor`):
|
| 1089 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
| 1090 |
+
input_shape (`Tuple[int]`):
|
| 1091 |
+
The shape of the input to the model.
|
| 1092 |
+
device (`torch.device`):
|
| 1093 |
+
The device of the input to the model.
|
| 1094 |
+
|
| 1095 |
+
Returns:
|
| 1096 |
+
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
|
| 1097 |
+
"""
|
| 1098 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 1099 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 1100 |
+
if attention_mask.dim() == 3:
|
| 1101 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
| 1102 |
+
elif attention_mask.dim() == 2:
|
| 1103 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
| 1104 |
+
# - the model is an encoder, so make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 1105 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
| 1106 |
+
else:
|
| 1107 |
+
raise ValueError(
|
| 1108 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
| 1109 |
+
input_shape, attention_mask.shape
|
| 1110 |
+
)
|
| 1111 |
+
)
|
| 1112 |
+
|
| 1113 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 1114 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 1115 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 1116 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 1117 |
+
# effectively the same as removing these entirely.
|
| 1118 |
+
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 1119 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
| 1120 |
+
return extended_attention_mask
|
| 1121 |
+
|
| 1122 |
+
def forward(
|
| 1123 |
+
self,
|
| 1124 |
+
query_embeds,
|
| 1125 |
+
attention_mask=None,
|
| 1126 |
+
head_mask=None,
|
| 1127 |
+
encoder_hidden_states=None,
|
| 1128 |
+
encoder_attention_mask=None,
|
| 1129 |
+
past_key_values=None,
|
| 1130 |
+
use_cache=None,
|
| 1131 |
+
output_attentions=None,
|
| 1132 |
+
output_hidden_states=None,
|
| 1133 |
+
return_dict=None,
|
| 1134 |
+
):
|
| 1135 |
+
r"""
|
| 1136 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 1137 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
| 1138 |
+
the model is configured as a decoder.
|
| 1139 |
+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
|
| 1140 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
| 1141 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
| 1142 |
+
- 1 for tokens that are **not masked**,
|
| 1143 |
+
- 0 for tokens that are **masked**.
|
| 1144 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of:
|
| 1145 |
+
shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): Contains precomputed key and
|
| 1146 |
+
value hidden states of the attention blocks. Can be used to speed up decoding. If `past_key_values` are
|
| 1147 |
+
used, the user can optionally input only the last `decoder_input_ids` (those that don't have their past key
|
| 1148 |
+
value states given to this model) of shape `(batch_size, 1)` instead of all `decoder_input_ids` of shape
|
| 1149 |
+
`(batch_size, sequence_length)`.
|
| 1150 |
+
use_cache (`bool`, `optional`):
|
| 1151 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
| 1152 |
+
`past_key_values`).
|
| 1153 |
+
"""
|
| 1154 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1155 |
+
output_hidden_states = (
|
| 1156 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1157 |
+
)
|
| 1158 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1159 |
+
|
| 1160 |
+
# past_key_values_length
|
| 1161 |
+
past_key_values_length = (
|
| 1162 |
+
past_key_values[0][0].shape[2] - self.config.query_length if past_key_values is not None else 0
|
| 1163 |
+
)
|
| 1164 |
+
|
| 1165 |
+
query_length = query_embeds.shape[1] if query_embeds is not None else 0
|
| 1166 |
+
|
| 1167 |
+
embedding_output = self.layernorm(query_embeds)
|
| 1168 |
+
embedding_output = self.dropout(embedding_output)
|
| 1169 |
+
|
| 1170 |
+
input_shape = embedding_output.size()[:-1]
|
| 1171 |
+
batch_size, seq_length = input_shape
|
| 1172 |
+
device = embedding_output.device
|
| 1173 |
+
|
| 1174 |
+
if attention_mask is None:
|
| 1175 |
+
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
|
| 1176 |
+
|
| 1177 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
| 1178 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
| 1179 |
+
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
| 1180 |
+
|
| 1181 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
| 1182 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 1183 |
+
if encoder_hidden_states is not None:
|
| 1184 |
+
if type(encoder_hidden_states) == list:
|
| 1185 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size()
|
| 1186 |
+
else:
|
| 1187 |
+
(
|
| 1188 |
+
encoder_batch_size,
|
| 1189 |
+
encoder_sequence_length,
|
| 1190 |
+
_,
|
| 1191 |
+
) = encoder_hidden_states.size()
|
| 1192 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 1193 |
+
|
| 1194 |
+
if type(encoder_attention_mask) == list:
|
| 1195 |
+
encoder_extended_attention_mask = [self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
| 1196 |
+
elif encoder_attention_mask is None:
|
| 1197 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 1198 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 1199 |
+
else:
|
| 1200 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 1201 |
+
else:
|
| 1202 |
+
encoder_extended_attention_mask = None
|
| 1203 |
+
|
| 1204 |
+
# Prepare head mask if needed
|
| 1205 |
+
# 1.0 in head_mask indicate we keep the head
|
| 1206 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 1207 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
| 1208 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
| 1209 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 1210 |
+
|
| 1211 |
+
encoder_outputs = self.encoder(
|
| 1212 |
+
embedding_output,
|
| 1213 |
+
attention_mask=extended_attention_mask,
|
| 1214 |
+
head_mask=head_mask,
|
| 1215 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1216 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
| 1217 |
+
past_key_values=past_key_values,
|
| 1218 |
+
use_cache=use_cache,
|
| 1219 |
+
output_attentions=output_attentions,
|
| 1220 |
+
output_hidden_states=output_hidden_states,
|
| 1221 |
+
return_dict=return_dict,
|
| 1222 |
+
query_length=query_length,
|
| 1223 |
+
)
|
| 1224 |
+
sequence_output = encoder_outputs[0]
|
| 1225 |
+
pooled_output = sequence_output[:, 0, :]
|
| 1226 |
+
|
| 1227 |
+
if not return_dict:
|
| 1228 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
| 1229 |
+
|
| 1230 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
| 1231 |
+
last_hidden_state=sequence_output,
|
| 1232 |
+
pooler_output=pooled_output,
|
| 1233 |
+
past_key_values=encoder_outputs.past_key_values,
|
| 1234 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 1235 |
+
attentions=encoder_outputs.attentions,
|
| 1236 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
| 1237 |
+
)
|
| 1238 |
+
|
| 1239 |
+
class AdapterMLP(nn.Module):
|
| 1240 |
+
def __init__(self, config):
|
| 1241 |
+
super().__init__()
|
| 1242 |
+
self.config = config
|
| 1243 |
+
self.activation_fn = ACT2FN["silu"]
|
| 1244 |
+
hidden_size = config.vision_config.hidden_size
|
| 1245 |
+
intermediate_size = hidden_size // 4
|
| 1246 |
+
output_size = config.qformer_config.hidden_size
|
| 1247 |
+
|
| 1248 |
+
self.fc1 = nn.Linear(hidden_size, intermediate_size)
|
| 1249 |
+
self.fc2 = nn.Linear(intermediate_size, output_size)
|
| 1250 |
+
self.layernorm = nn.LayerNorm(output_size, eps=config.vision_config.layer_norm_eps)
|
| 1251 |
+
|
| 1252 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 1253 |
+
hidden_states = self.fc1(hidden_states)
|
| 1254 |
+
hidden_states = self.activation_fn(hidden_states)
|
| 1255 |
+
hidden_states = self.fc2(hidden_states)
|
| 1256 |
+
hidden_states = self.layernorm(hidden_states)
|
| 1257 |
+
return hidden_states
|
| 1258 |
+
|
| 1259 |
+
@add_start_docstrings(
|
| 1260 |
+
"""
|
| 1261 |
+
Husky Model for generating text and image features. The model consists of a vision encoder, Querying Transformer
|
| 1262 |
+
(Q-Former) and a language model.
|
| 1263 |
+
""",
|
| 1264 |
+
Husky_START_DOCSTRING,
|
| 1265 |
+
)
|
| 1266 |
+
class HuskyModel(HuskyPreTrainedModel):
|
| 1267 |
+
config_class = HuskyConfig
|
| 1268 |
+
main_input_name = "pixel_values"
|
| 1269 |
+
|
| 1270 |
+
def __init__(self, config: HuskyConfig):
|
| 1271 |
+
super().__init__(config)
|
| 1272 |
+
|
| 1273 |
+
self.vision_model = HuskyVisionModel(config.vision_config)
|
| 1274 |
+
|
| 1275 |
+
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
| 1276 |
+
self.qformer = HuskyQFormerModel(config.qformer_config)
|
| 1277 |
+
|
| 1278 |
+
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
| 1279 |
+
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
|
| 1280 |
+
|
| 1281 |
+
self.config.hidden_size = config.text_config.hidden_size
|
| 1282 |
+
self.num_queries = config.num_query_tokens
|
| 1283 |
+
self.offset = 5
|
| 1284 |
+
|
| 1285 |
+
# Initialize weights and apply final processing
|
| 1286 |
+
self.post_init()
|
| 1287 |
+
|
| 1288 |
+
def get_input_embeddings(self):
|
| 1289 |
+
return self.language_model.get_input_embeddings()
|
| 1290 |
+
|
| 1291 |
+
def set_input_embeddings(self, value):
|
| 1292 |
+
self.language_model.set_input_embeddings(value)
|
| 1293 |
+
|
| 1294 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1295 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
| 1296 |
+
|
| 1297 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 1298 |
+
return self.language_model.get_output_embeddings()
|
| 1299 |
+
|
| 1300 |
+
def get_encoder(self):
|
| 1301 |
+
return self.language_model.get_encoder()
|
| 1302 |
+
|
| 1303 |
+
def get_decoder(self):
|
| 1304 |
+
return self.language_model.get_decoder()
|
| 1305 |
+
|
| 1306 |
+
def _tie_weights(self):
|
| 1307 |
+
if not self.config.use_decoder_only_language_model:
|
| 1308 |
+
self.language_model.encoder.embed_tokens = self.language_model.shared
|
| 1309 |
+
self.language_model.decoder.embed_tokens = self.language_model.shared
|
| 1310 |
+
|
| 1311 |
+
@add_start_docstrings_to_model_forward(Husky_TEXT_INPUTS_DOCSTRING)
|
| 1312 |
+
def get_text_features(
|
| 1313 |
+
self,
|
| 1314 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1315 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1316 |
+
output_attentions: Optional[bool] = None,
|
| 1317 |
+
output_hidden_states: Optional[bool] = None,
|
| 1318 |
+
return_dict: Optional[bool] = None,
|
| 1319 |
+
):
|
| 1320 |
+
r"""
|
| 1321 |
+
Returns:
|
| 1322 |
+
text_outputs (`CausalLMOutputWithPast`, or `tuple(torch.FloatTensor)` if `return_dict=False`):
|
| 1323 |
+
The language model outputs. If `return_dict=True`, the output is a [`CausalLMOutputWithPast`] that
|
| 1324 |
+
contains the language model logits, the past key values and the hidden states if
|
| 1325 |
+
`output_hidden_states=True`.
|
| 1326 |
+
```"""
|
| 1327 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1328 |
+
output_hidden_states = (
|
| 1329 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1330 |
+
)
|
| 1331 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1332 |
+
|
| 1333 |
+
text_outputs = self.language_model(
|
| 1334 |
+
input_ids=input_ids,
|
| 1335 |
+
attention_mask=attention_mask,
|
| 1336 |
+
output_attentions=output_attentions,
|
| 1337 |
+
output_hidden_states=output_hidden_states,
|
| 1338 |
+
return_dict=return_dict,
|
| 1339 |
+
)
|
| 1340 |
+
|
| 1341 |
+
return text_outputs
|
| 1342 |
+
|
| 1343 |
+
@add_start_docstrings_to_model_forward(Husky_VISION_INPUTS_DOCSTRING)
|
| 1344 |
+
def get_image_features(
|
| 1345 |
+
self,
|
| 1346 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1347 |
+
output_attentions: Optional[bool] = None,
|
| 1348 |
+
output_hidden_states: Optional[bool] = None,
|
| 1349 |
+
return_dict: Optional[bool] = None,
|
| 1350 |
+
):
|
| 1351 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1352 |
+
output_hidden_states = (
|
| 1353 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1354 |
+
)
|
| 1355 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1356 |
+
|
| 1357 |
+
vision_outputs = self.vision_model(
|
| 1358 |
+
pixel_values=pixel_values,
|
| 1359 |
+
output_attentions=output_attentions,
|
| 1360 |
+
output_hidden_states=output_hidden_states,
|
| 1361 |
+
return_dict=return_dict,
|
| 1362 |
+
)
|
| 1363 |
+
|
| 1364 |
+
return vision_outputs
|
| 1365 |
+
|
| 1366 |
+
@add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
|
| 1367 |
+
def get_qformer_features(
|
| 1368 |
+
self,
|
| 1369 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1370 |
+
output_attentions: Optional[bool] = None,
|
| 1371 |
+
output_hidden_states: Optional[bool] = None,
|
| 1372 |
+
return_dict: Optional[bool] = None,
|
| 1373 |
+
):
|
| 1374 |
+
r"""
|
| 1375 |
+
Returns:
|
| 1376 |
+
vision_outputs (`BaseModelOutputWithPooling` or tuple of `torch.FloatTensor`):
|
| 1377 |
+
The vision model outputs. If `return_dict=True`, the output is a [`BaseModelOutputWithPooling`] that
|
| 1378 |
+
contains the image features, the pooled image features and the hidden states if
|
| 1379 |
+
`output_hidden_states=True`.
|
| 1380 |
+
"""
|
| 1381 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1382 |
+
output_hidden_states = (
|
| 1383 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1384 |
+
)
|
| 1385 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1386 |
+
|
| 1387 |
+
vision_outputs = self.vision_model(
|
| 1388 |
+
pixel_values=pixel_values,
|
| 1389 |
+
output_attentions=output_attentions,
|
| 1390 |
+
output_hidden_states=output_hidden_states,
|
| 1391 |
+
return_dict=return_dict,
|
| 1392 |
+
)
|
| 1393 |
+
|
| 1394 |
+
image_embeds = vision_outputs[0]
|
| 1395 |
+
|
| 1396 |
+
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
| 1397 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1398 |
+
|
| 1399 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1400 |
+
query_outputs = self.qformer(
|
| 1401 |
+
query_embeds=query_tokens,
|
| 1402 |
+
encoder_hidden_states=image_embeds,
|
| 1403 |
+
encoder_attention_mask=image_attention_mask,
|
| 1404 |
+
output_attentions=output_attentions,
|
| 1405 |
+
output_hidden_states=output_hidden_states,
|
| 1406 |
+
return_dict=return_dict,
|
| 1407 |
+
)
|
| 1408 |
+
|
| 1409 |
+
return query_outputs
|
| 1410 |
+
|
| 1411 |
+
@add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
|
| 1412 |
+
# @replace_return_docstrings(output_type=HuskyForConditionalGenerationModelOutput, config_class=HuskyVisionConfig)
|
| 1413 |
+
def forward(
|
| 1414 |
+
self,
|
| 1415 |
+
pixel_values: torch.FloatTensor,
|
| 1416 |
+
input_ids: torch.FloatTensor,
|
| 1417 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1418 |
+
output_attentions: Optional[bool] = None,
|
| 1419 |
+
output_hidden_states: Optional[bool] = None,
|
| 1420 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1421 |
+
return_dict: Optional[bool] = None,
|
| 1422 |
+
) -> Union[Tuple, HuskyForConditionalGenerationModelOutput]:
|
| 1423 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1424 |
+
|
| 1425 |
+
# step 1: forward the images through the vision encoder,
|
| 1426 |
+
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
| 1427 |
+
vision_outputs = self.vision_model(
|
| 1428 |
+
pixel_values=pixel_values,
|
| 1429 |
+
output_attentions=output_attentions,
|
| 1430 |
+
output_hidden_states=output_hidden_states,
|
| 1431 |
+
return_dict=return_dict,
|
| 1432 |
+
)
|
| 1433 |
+
image_embeds = vision_outputs[0]
|
| 1434 |
+
|
| 1435 |
+
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
| 1436 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1437 |
+
|
| 1438 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1439 |
+
query_outputs = self.qformer(
|
| 1440 |
+
query_embeds=query_tokens,
|
| 1441 |
+
encoder_hidden_states=image_embeds,
|
| 1442 |
+
encoder_attention_mask=image_attention_mask,
|
| 1443 |
+
output_attentions=output_attentions,
|
| 1444 |
+
output_hidden_states=output_hidden_states,
|
| 1445 |
+
return_dict=return_dict,
|
| 1446 |
+
)
|
| 1447 |
+
query_output = query_outputs[0]
|
| 1448 |
+
|
| 1449 |
+
# step 3: use the language model, conditioned on the query outputs and the prompt
|
| 1450 |
+
language_model_inputs = self.language_projection(query_output)
|
| 1451 |
+
assert language_model_inputs.shape[1] == self.num_queries
|
| 1452 |
+
|
| 1453 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 1454 |
+
# Human: <img><IMAGE></img>. Give the describe Assistant:
|
| 1455 |
+
# position of <image>: [offset: offset+num_queries]
|
| 1456 |
+
|
| 1457 |
+
inputs_embeds[:, self.offset:self.offset + self.num_queries, :] = language_model_inputs
|
| 1458 |
+
if attention_mask is None:
|
| 1459 |
+
attention_mask = torch.ones_like(
|
| 1460 |
+
input_ids, dtype=torch.long, device=language_model_inputs.device)
|
| 1461 |
+
|
| 1462 |
+
outputs = self.language_model(
|
| 1463 |
+
inputs_embeds=inputs_embeds,
|
| 1464 |
+
attention_mask=attention_mask,
|
| 1465 |
+
output_attentions=output_attentions,
|
| 1466 |
+
output_hidden_states=output_hidden_states,
|
| 1467 |
+
return_dict=return_dict,
|
| 1468 |
+
)
|
| 1469 |
+
logits = outputs.logits if return_dict else outputs[0]
|
| 1470 |
+
loss = None
|
| 1471 |
+
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
| 1472 |
+
if labels is not None:
|
| 1473 |
+
labels = labels.to(logits.device)
|
| 1474 |
+
logits = logits[:, -labels.size(1):, :]
|
| 1475 |
+
# Shift so that tokens < n predict n
|
| 1476 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1477 |
+
shift_labels = labels[..., 1:].contiguous().to(logits.device)
|
| 1478 |
+
|
| 1479 |
+
# Flatten the tokens
|
| 1480 |
+
loss_fct = CrossEntropyLoss(reduction="mean")
|
| 1481 |
+
|
| 1482 |
+
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
|
| 1483 |
+
|
| 1484 |
+
if not return_dict:
|
| 1485 |
+
output = (logits, vision_outputs, query_outputs, outputs)
|
| 1486 |
+
return ((loss,) + output) if loss is not None else output
|
| 1487 |
+
|
| 1488 |
+
return HuskyForConditionalGenerationModelOutput(
|
| 1489 |
+
loss=loss,
|
| 1490 |
+
logits=logits,
|
| 1491 |
+
vision_outputs=vision_outputs,
|
| 1492 |
+
qformer_outputs=query_outputs,
|
| 1493 |
+
language_model_outputs=outputs,
|
| 1494 |
+
)
|
| 1495 |
+
|
| 1496 |
+
@add_start_docstrings(
|
| 1497 |
+
"""
|
| 1498 |
+
Husky Model for generating text given an image and an optional text prompt. The model consists of a vision
|
| 1499 |
+
encoder, Querying Transformer (Q-Former) and a language model.
|
| 1500 |
+
|
| 1501 |
+
One can optionally pass `input_ids` to the model, which serve as a text prompt, to make the language model continue
|
| 1502 |
+
the prompt. Otherwise, the language model starts generating text from the [BOS] (beginning-of-sequence) token.
|
| 1503 |
+
""",
|
| 1504 |
+
Husky_START_DOCSTRING,
|
| 1505 |
+
)
|
| 1506 |
+
class HuskyForConditionalGeneration(HuskyPreTrainedModel):
|
| 1507 |
+
config_class = HuskyConfig
|
| 1508 |
+
main_input_name = "pixel_values"
|
| 1509 |
+
|
| 1510 |
+
def __init__(self, config: HuskyConfig):
|
| 1511 |
+
super().__init__(config)
|
| 1512 |
+
|
| 1513 |
+
self.vision_model = HuskyVisionModel(config.vision_config)
|
| 1514 |
+
self.query_tokens = nn.Parameter(torch.zeros(1, config.num_query_tokens, config.qformer_config.hidden_size))
|
| 1515 |
+
self.qformer = HuskyQFormerModel(config.qformer_config)
|
| 1516 |
+
|
| 1517 |
+
self.language_projection = nn.Linear(config.qformer_config.hidden_size, config.text_config.hidden_size)
|
| 1518 |
+
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
|
| 1519 |
+
|
| 1520 |
+
self.config.hidden_size = config.text_config.hidden_size
|
| 1521 |
+
self.num_queries = config.num_query_tokens
|
| 1522 |
+
self.offset = 5
|
| 1523 |
+
|
| 1524 |
+
self.vision_adapter = AdapterMLP(config)
|
| 1525 |
+
self.layer_norms = nn.ModuleList()
|
| 1526 |
+
for i in range(4):
|
| 1527 |
+
self.layer_norms.append(
|
| 1528 |
+
nn.LayerNorm(config.vision_config.hidden_size, eps=config.vision_config.layer_norm_eps)
|
| 1529 |
+
)
|
| 1530 |
+
|
| 1531 |
+
# Initialize weights and apply final processing
|
| 1532 |
+
self.post_init()
|
| 1533 |
+
|
| 1534 |
+
def get_input_embeddings(self):
|
| 1535 |
+
return self.language_model.get_input_embeddings()
|
| 1536 |
+
|
| 1537 |
+
def set_input_embeddings(self, value):
|
| 1538 |
+
self.language_model.set_input_embeddings(value)
|
| 1539 |
+
|
| 1540 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1541 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
| 1542 |
+
|
| 1543 |
+
def get_output_embeddings(self) -> nn.Module:
|
| 1544 |
+
return self.language_model.get_output_embeddings()
|
| 1545 |
+
|
| 1546 |
+
def get_encoder(self):
|
| 1547 |
+
return self.language_model.get_encoder()
|
| 1548 |
+
|
| 1549 |
+
def get_decoder(self):
|
| 1550 |
+
return self.language_model.get_decoder()
|
| 1551 |
+
|
| 1552 |
+
def extract_feature(
|
| 1553 |
+
self,
|
| 1554 |
+
pixel_values: torch.FloatTensor,
|
| 1555 |
+
):
|
| 1556 |
+
vision_outputs = self.vision_model(
|
| 1557 |
+
pixel_values=pixel_values,
|
| 1558 |
+
output_hidden_states=True,
|
| 1559 |
+
)
|
| 1560 |
+
image_embeds = vision_outputs[0]
|
| 1561 |
+
|
| 1562 |
+
depth = len(vision_outputs[2])
|
| 1563 |
+
indices = range(depth // 4 - 1, depth, depth // 4)
|
| 1564 |
+
pooled_outputs = []
|
| 1565 |
+
for idx, layer_norm in zip(indices, self.layer_norms):
|
| 1566 |
+
pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
|
| 1567 |
+
pool_output = layer_norm(pool_output)
|
| 1568 |
+
pooled_outputs.append(pool_output)
|
| 1569 |
+
|
| 1570 |
+
pooled_outputs = torch.cat(pooled_outputs, dim=1)
|
| 1571 |
+
pooled_outputs = self.vision_adapter(pooled_outputs)
|
| 1572 |
+
|
| 1573 |
+
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
| 1574 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1575 |
+
|
| 1576 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1577 |
+
query_outputs = self.qformer(
|
| 1578 |
+
query_embeds=query_tokens,
|
| 1579 |
+
encoder_hidden_states=image_embeds,
|
| 1580 |
+
encoder_attention_mask=image_attention_mask
|
| 1581 |
+
)
|
| 1582 |
+
query_output = query_outputs[0]
|
| 1583 |
+
# soft_prompting
|
| 1584 |
+
query_output = torch.cat([pooled_outputs, query_output], dim=1)
|
| 1585 |
+
language_model_inputs = self.language_projection(query_output)
|
| 1586 |
+
|
| 1587 |
+
return language_model_inputs
|
| 1588 |
+
|
| 1589 |
+
def _tie_weights(self):
|
| 1590 |
+
if not self.config.use_decoder_only_language_model:
|
| 1591 |
+
self.language_model.encoder.embed_tokens = self.language_model.shared
|
| 1592 |
+
self.language_model.decoder.embed_tokens = self.language_model.shared
|
| 1593 |
+
|
| 1594 |
+
def _preprocess_accelerate(self):
|
| 1595 |
+
r"""
|
| 1596 |
+
Some pre-processing hacks to make the model `accelerate` compatible. Check
|
| 1597 |
+
https://github.com/huggingface/transformers/pull/21707 for more details.
|
| 1598 |
+
"""
|
| 1599 |
+
hf_device_map = self.hf_device_map
|
| 1600 |
+
|
| 1601 |
+
if len(hf_device_map) > 1 and "language_model" not in hf_device_map and torch.cuda.device_count() > 1:
|
| 1602 |
+
# warn users about unexpected behavior when using multi-GPU + Husky + `accelerate`.
|
| 1603 |
+
logger.warning(
|
| 1604 |
+
"The `language_model` is not in the `hf_device_map` dictionary and you are running your script"
|
| 1605 |
+
" in a multi-GPU environment. this may lead to unexpected behavior when using `accelerate`."
|
| 1606 |
+
" Please pass a `device_map` that contains `language_model` to remove this warning."
|
| 1607 |
+
" Please refer to https://github.com/huggingface/blog/blob/main/accelerate-large-models.md for"
|
| 1608 |
+
" more details on creating a `device_map` for large models.",
|
| 1609 |
+
)
|
| 1610 |
+
|
| 1611 |
+
if hasattr(self.language_model, "_hf_hook"):
|
| 1612 |
+
self.language_model._hf_hook.io_same_device = True # For `generate` compatibility
|
| 1613 |
+
|
| 1614 |
+
@add_start_docstrings_to_model_forward(Husky_INPUTS_DOCSTRING)
|
| 1615 |
+
# @replace_return_docstrings(output_type=HuskyForConditionalGenerationModelOutput, config_class=HuskyVisionConfig)
|
| 1616 |
+
def forward(
|
| 1617 |
+
self,
|
| 1618 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1619 |
+
input_ids: Optional[torch.FloatTensor] = None,
|
| 1620 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1621 |
+
output_attentions: Optional[bool] = None,
|
| 1622 |
+
output_hidden_states: Optional[bool] = None,
|
| 1623 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1624 |
+
return_dict: Optional[bool] = None,
|
| 1625 |
+
pixel_embeds: Optional[torch.FloatTensor] = None,
|
| 1626 |
+
) -> Union[Tuple, HuskyForConditionalGenerationModelOutput]:
|
| 1627 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1628 |
+
|
| 1629 |
+
# step 1: forward the images through the vision encoder,
|
| 1630 |
+
# to get image embeddings of shape (batch_size, seq_len, hidden_size)
|
| 1631 |
+
batch_size = input_ids.shape[0]
|
| 1632 |
+
vision_outputs = self.vision_model(
|
| 1633 |
+
pixel_values=pixel_values,
|
| 1634 |
+
output_attentions=output_attentions,
|
| 1635 |
+
output_hidden_states=True,
|
| 1636 |
+
return_dict=return_dict,
|
| 1637 |
+
pixel_embeds=pixel_embeds,
|
| 1638 |
+
)
|
| 1639 |
+
image_embeds = vision_outputs[0]
|
| 1640 |
+
depth = len(vision_outputs[2])
|
| 1641 |
+
indices = range(depth // 4 - 1, depth, depth // 4)
|
| 1642 |
+
pooled_outputs = []
|
| 1643 |
+
for idx, layer_norm in zip(indices, self.layer_norms):
|
| 1644 |
+
pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
|
| 1645 |
+
pool_output = layer_norm(pool_output)
|
| 1646 |
+
pooled_outputs.append(pool_output)
|
| 1647 |
+
|
| 1648 |
+
pooled_outputs = torch.cat(pooled_outputs, dim=1)
|
| 1649 |
+
pooled_outputs = self.vision_adapter(pooled_outputs)
|
| 1650 |
+
|
| 1651 |
+
# step 2: forward the query tokens through the QFormer, using the image embeddings for cross-attention
|
| 1652 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1653 |
+
|
| 1654 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1655 |
+
query_outputs = self.qformer(
|
| 1656 |
+
query_embeds=query_tokens,
|
| 1657 |
+
encoder_hidden_states=image_embeds,
|
| 1658 |
+
encoder_attention_mask=image_attention_mask,
|
| 1659 |
+
output_attentions=output_attentions,
|
| 1660 |
+
output_hidden_states=output_hidden_states,
|
| 1661 |
+
return_dict=return_dict,
|
| 1662 |
+
)
|
| 1663 |
+
query_output = query_outputs[0]
|
| 1664 |
+
query_output = torch.cat([pooled_outputs, query_output], dim=1) # 36 token
|
| 1665 |
+
|
| 1666 |
+
# step 3: use the language model, conditioned on the query outputs and the prompt
|
| 1667 |
+
language_model_inputs = self.language_projection(query_output)
|
| 1668 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 1669 |
+
|
| 1670 |
+
# Human: <img></img>. Give the describe Assistant:
|
| 1671 |
+
# position of <image>: [offset: offset+num_queries]
|
| 1672 |
+
prefix_embeds = inputs_embeds[:, :self.offset, :]
|
| 1673 |
+
postfix_embeds = inputs_embeds[:, self.offset:, :]
|
| 1674 |
+
inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)
|
| 1675 |
+
if attention_mask is None:
|
| 1676 |
+
attention_mask = torch.ones_like(
|
| 1677 |
+
inputs_embeds, dtype=torch.long, device=language_model_inputs.device)
|
| 1678 |
+
else:
|
| 1679 |
+
prefix_mask = attention_mask[:, :self.offset]
|
| 1680 |
+
postfix_mask = attention_mask[:, self.offset:]
|
| 1681 |
+
vision_mask = torch.ones(size=(batch_size, self.num_queries + 4), dtype=torch.long,
|
| 1682 |
+
device=attention_mask.device)
|
| 1683 |
+
attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)
|
| 1684 |
+
|
| 1685 |
+
outputs = self.language_model(
|
| 1686 |
+
inputs_embeds=inputs_embeds,
|
| 1687 |
+
attention_mask=attention_mask,
|
| 1688 |
+
output_attentions=output_attentions,
|
| 1689 |
+
output_hidden_states=output_hidden_states,
|
| 1690 |
+
return_dict=return_dict,
|
| 1691 |
+
)
|
| 1692 |
+
logits = outputs.logits if return_dict else outputs[0]
|
| 1693 |
+
loss = None
|
| 1694 |
+
# we compute the loss here since we need to take into account the sequence length of the query embeds
|
| 1695 |
+
if labels is not None:
|
| 1696 |
+
labels = labels.to(logits.device)
|
| 1697 |
+
logits = logits[:, -labels.size(1):, :]
|
| 1698 |
+
# Shift so that tokens < n predict n
|
| 1699 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 1700 |
+
shift_labels = labels[..., 1:].contiguous().to(logits.device)
|
| 1701 |
+
|
| 1702 |
+
# Flatten the tokens
|
| 1703 |
+
loss_fct = CrossEntropyLoss(reduction="mean")
|
| 1704 |
+
|
| 1705 |
+
loss = loss_fct(shift_logits.view(-1, self.config.text_config.vocab_size), shift_labels.view(-1))
|
| 1706 |
+
|
| 1707 |
+
if not return_dict:
|
| 1708 |
+
output = (logits, vision_outputs, query_outputs, outputs)
|
| 1709 |
+
return ((loss,) + output) if loss is not None else output
|
| 1710 |
+
|
| 1711 |
+
return HuskyForConditionalGenerationModelOutput(
|
| 1712 |
+
loss=loss,
|
| 1713 |
+
logits=logits,
|
| 1714 |
+
vision_outputs=vision_outputs,
|
| 1715 |
+
qformer_outputs=query_outputs,
|
| 1716 |
+
language_model_outputs=outputs,
|
| 1717 |
+
)
|
| 1718 |
+
|
| 1719 |
+
@torch.no_grad()
|
| 1720 |
+
def generate(
|
| 1721 |
+
self,
|
| 1722 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
| 1723 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 1724 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1725 |
+
language_model_inputs: Optional[torch.FloatTensor] = None,
|
| 1726 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 1727 |
+
**generate_kwargs,
|
| 1728 |
+
) -> torch.LongTensor:
|
| 1729 |
+
"""
|
| 1730 |
+
Overrides `generate` function to be able to use the model as a conditional generator.
|
| 1731 |
+
|
| 1732 |
+
Args:
|
| 1733 |
+
pixel_values (`torch.FloatTensor` of shape (batch_size, num_channels, height, width)):
|
| 1734 |
+
Input images to be processed.
|
| 1735 |
+
input_ids (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
| 1736 |
+
The sequence used as a prompt for the generation.
|
| 1737 |
+
attention_mask (`torch.LongTensor` of shape (batch_size, sequence_length), *optional*):
|
| 1738 |
+
Mask to avoid performing attention on padding token indices
|
| 1739 |
+
language_model_inputs (`torch.LongTensor` of shape (batch_size, sequence_length, num_channel), *optional*):
|
| 1740 |
+
The sequence used as the input for the generation
|
| 1741 |
+
language_model_inputs (`torch.LongTensor` of shape (batch_size, sequence_length, num_channel), *optional*):
|
| 1742 |
+
The sequence used as the input for the generation
|
| 1743 |
+
generation_config (`~generation.GenerationConfig`, *optional*):
|
| 1744 |
+
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
|
| 1745 |
+
passed to generate matching the attributes of `generation_config` will override them. If
|
| 1746 |
+
`generation_config` is not provided, the default will be used, which had the following loading
|
| 1747 |
+
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
|
| 1748 |
+
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
|
| 1749 |
+
default values, whose documentation should be checked to parameterize generation.
|
| 1750 |
+
|
| 1751 |
+
Returns:
|
| 1752 |
+
captions (list): A list of strings of length batch_size * num_captions.
|
| 1753 |
+
"""
|
| 1754 |
+
|
| 1755 |
+
if hasattr(self, "hf_device_map"):
|
| 1756 |
+
# preprocess for `accelerate`
|
| 1757 |
+
self._preprocess_accelerate()
|
| 1758 |
+
|
| 1759 |
+
if language_model_inputs is None:
|
| 1760 |
+
vision_outputs = self.vision_model(
|
| 1761 |
+
pixel_values=pixel_values,
|
| 1762 |
+
output_hidden_states=True,
|
| 1763 |
+
)
|
| 1764 |
+
image_embeds = vision_outputs[0]
|
| 1765 |
+
|
| 1766 |
+
depth = len(vision_outputs[2])
|
| 1767 |
+
indices = range(depth // 4 - 1, depth, depth // 4)
|
| 1768 |
+
pooled_outputs = []
|
| 1769 |
+
for idx, layer_norm in zip(indices, self.layer_norms):
|
| 1770 |
+
pool_output = vision_outputs[2][idx][:, 0, :].unsqueeze(1)
|
| 1771 |
+
pool_output = layer_norm(pool_output)
|
| 1772 |
+
pooled_outputs.append(pool_output)
|
| 1773 |
+
|
| 1774 |
+
pooled_outputs = torch.cat(pooled_outputs, dim=1)
|
| 1775 |
+
pooled_outputs = self.vision_adapter(pooled_outputs)
|
| 1776 |
+
|
| 1777 |
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
| 1778 |
+
|
| 1779 |
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
| 1780 |
+
query_outputs = self.qformer(
|
| 1781 |
+
query_embeds=query_tokens,
|
| 1782 |
+
encoder_hidden_states=image_embeds,
|
| 1783 |
+
encoder_attention_mask=image_attention_mask,
|
| 1784 |
+
)
|
| 1785 |
+
query_output = query_outputs[0]
|
| 1786 |
+
query_output = torch.cat([pooled_outputs, query_output], dim=1)
|
| 1787 |
+
|
| 1788 |
+
language_model_inputs = self.language_projection(query_output)
|
| 1789 |
+
|
| 1790 |
+
batch_size = language_model_inputs.shape[0]
|
| 1791 |
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
| 1792 |
+
|
| 1793 |
+
prefix_embeds = inputs_embeds[:, :self.offset, :]
|
| 1794 |
+
postfix_embeds = inputs_embeds[:, self.offset:, :]
|
| 1795 |
+
inputs_embeds = torch.cat([prefix_embeds, language_model_inputs, postfix_embeds], dim=1)
|
| 1796 |
+
|
| 1797 |
+
if input_ids is None:
|
| 1798 |
+
input_ids = (
|
| 1799 |
+
torch.LongTensor([[self.config.text_config.bos_token_id]])
|
| 1800 |
+
.repeat(batch_size, 1)
|
| 1801 |
+
.to(inputs_embeds.device)
|
| 1802 |
+
)
|
| 1803 |
+
|
| 1804 |
+
if attention_mask is None:
|
| 1805 |
+
attention_mask = torch.ones_like(
|
| 1806 |
+
input_ids, dtype=torch.long, device=language_model_inputs.device)
|
| 1807 |
+
else:
|
| 1808 |
+
prefix_mask = attention_mask[:, :self.offset]
|
| 1809 |
+
postfix_mask = attention_mask[:, self.offset:]
|
| 1810 |
+
vision_mask = torch.ones(size=(batch_size, self.num_queries + 4), dtype=torch.long,
|
| 1811 |
+
device=attention_mask.device)
|
| 1812 |
+
attention_mask = torch.cat([prefix_mask, vision_mask, postfix_mask], dim=-1)
|
| 1813 |
+
|
| 1814 |
+
outputs = self.language_model.generate(
|
| 1815 |
+
inputs_embeds=inputs_embeds,
|
| 1816 |
+
attention_mask=attention_mask,
|
| 1817 |
+
generation_config=generation_config,
|
| 1818 |
+
**generate_kwargs,
|
| 1819 |
+
)
|
| 1820 |
+
|
| 1821 |
+
return outputs
|
robohusky/model/processing_husky.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Processor class for Husky. Largely copy of Blip2Processor with addition of a tokenizer for the Q-Former.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from typing import List, Optional, Union
|
| 20 |
+
|
| 21 |
+
from transformers.processing_utils import ProcessorMixin
|
| 22 |
+
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, \
|
| 23 |
+
TruncationStrategy
|
| 24 |
+
from transformers.utils import TensorType
|
| 25 |
+
from transformers.models.auto import AutoTokenizer
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class HuskyProcessor(ProcessorMixin):
|
| 29 |
+
r"""
|
| 30 |
+
Constructs an Husky processor which wraps a BLIP image processor and a LLaMa/T5 tokenizer into a single
|
| 31 |
+
processor.
|
| 32 |
+
|
| 33 |
+
[`HuskyProcessor`] offers all the functionalities of [`BlipImageProcessor`] and [`AutoTokenizer`]. See the
|
| 34 |
+
docstring of [`~BlipProcessor.__call__`] and [`~BlipProcessor.decode`] for more information.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
image_processor (`BlipImageProcessor`):
|
| 38 |
+
An instance of [`BlipImageProcessor`]. The image processor is a required input.
|
| 39 |
+
tokenizer (`AutoTokenizer`):
|
| 40 |
+
An instance of ['PreTrainedTokenizer`]. The tokenizer is a required input.
|
| 41 |
+
"""
|
| 42 |
+
attributes = ["image_processor", "tokenizer"]
|
| 43 |
+
image_processor_class = "BlipImageProcessor"
|
| 44 |
+
tokenizer_class = "AutoTokenizer"
|
| 45 |
+
|
| 46 |
+
def __init__(self, image_processor, tokenizer):
|
| 47 |
+
super().__init__(image_processor, tokenizer)
|
| 48 |
+
self.current_processor = self.image_processor
|
| 49 |
+
|
| 50 |
+
# add QFormer tokenizer
|
| 51 |
+
self.qformer_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", truncation_side="left")
|
| 52 |
+
self.qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"})
|
| 53 |
+
|
| 54 |
+
def __call__(
|
| 55 |
+
self,
|
| 56 |
+
images=None,
|
| 57 |
+
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
|
| 58 |
+
add_special_tokens: bool = True,
|
| 59 |
+
padding: Union[bool, str, PaddingStrategy] = False,
|
| 60 |
+
truncation: Union[bool, str, TruncationStrategy] = None,
|
| 61 |
+
max_length: Optional[int] = None,
|
| 62 |
+
stride: int = 0,
|
| 63 |
+
pad_to_multiple_of: Optional[int] = None,
|
| 64 |
+
return_attention_mask: Optional[bool] = None,
|
| 65 |
+
return_overflowing_tokens: bool = False,
|
| 66 |
+
return_special_tokens_mask: bool = False,
|
| 67 |
+
return_offsets_mapping: bool = False,
|
| 68 |
+
return_token_type_ids: bool = False,
|
| 69 |
+
return_length: bool = False,
|
| 70 |
+
verbose: bool = True,
|
| 71 |
+
return_tensors: Optional[Union[str, TensorType]] = None,
|
| 72 |
+
**kwargs,
|
| 73 |
+
) -> BatchEncoding:
|
| 74 |
+
"""
|
| 75 |
+
This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and
|
| 76 |
+
[`BertTokenizerFast.__call__`] to prepare text for the model.
|
| 77 |
+
|
| 78 |
+
Please refer to the docstring of the above two methods for more information.
|
| 79 |
+
"""
|
| 80 |
+
if images is None and text is None:
|
| 81 |
+
raise ValueError("You have to specify either images or text.")
|
| 82 |
+
|
| 83 |
+
# Get only text
|
| 84 |
+
if images is None:
|
| 85 |
+
self.current_processor = self.tokenizer
|
| 86 |
+
text_encoding = self.tokenizer(
|
| 87 |
+
text=text,
|
| 88 |
+
add_special_tokens=add_special_tokens,
|
| 89 |
+
padding=padding,
|
| 90 |
+
truncation=truncation,
|
| 91 |
+
max_length=max_length,
|
| 92 |
+
stride=stride,
|
| 93 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 94 |
+
return_attention_mask=return_attention_mask,
|
| 95 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 96 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 97 |
+
return_offsets_mapping=return_offsets_mapping,
|
| 98 |
+
return_token_type_ids=return_token_type_ids,
|
| 99 |
+
return_length=return_length,
|
| 100 |
+
verbose=verbose,
|
| 101 |
+
return_tensors=return_tensors,
|
| 102 |
+
**kwargs,
|
| 103 |
+
)
|
| 104 |
+
return text_encoding
|
| 105 |
+
|
| 106 |
+
# add pixel_values
|
| 107 |
+
encoding_image_processor = self.image_processor(images, return_tensors=return_tensors)
|
| 108 |
+
|
| 109 |
+
if text is not None:
|
| 110 |
+
text_encoding = self.tokenizer(
|
| 111 |
+
text=text,
|
| 112 |
+
add_special_tokens=add_special_tokens,
|
| 113 |
+
padding=padding,
|
| 114 |
+
truncation=truncation,
|
| 115 |
+
max_length=max_length,
|
| 116 |
+
stride=stride,
|
| 117 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 118 |
+
return_attention_mask=return_attention_mask,
|
| 119 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 120 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 121 |
+
return_offsets_mapping=return_offsets_mapping,
|
| 122 |
+
return_token_type_ids=return_token_type_ids,
|
| 123 |
+
return_length=return_length,
|
| 124 |
+
verbose=verbose,
|
| 125 |
+
return_tensors=return_tensors,
|
| 126 |
+
**kwargs,
|
| 127 |
+
)
|
| 128 |
+
qformer_text_encoding = self.qformer_tokenizer(
|
| 129 |
+
text=text,
|
| 130 |
+
add_special_tokens=add_special_tokens,
|
| 131 |
+
padding=padding,
|
| 132 |
+
truncation=truncation,
|
| 133 |
+
max_length=max_length,
|
| 134 |
+
stride=stride,
|
| 135 |
+
pad_to_multiple_of=pad_to_multiple_of,
|
| 136 |
+
return_attention_mask=return_attention_mask,
|
| 137 |
+
return_overflowing_tokens=return_overflowing_tokens,
|
| 138 |
+
return_special_tokens_mask=return_special_tokens_mask,
|
| 139 |
+
return_offsets_mapping=return_offsets_mapping,
|
| 140 |
+
return_token_type_ids=return_token_type_ids,
|
| 141 |
+
return_length=return_length,
|
| 142 |
+
verbose=verbose,
|
| 143 |
+
return_tensors=return_tensors,
|
| 144 |
+
**kwargs,
|
| 145 |
+
)
|
| 146 |
+
qformer_text_encoding["qformer_input_ids"] = qformer_text_encoding.pop("input_ids")
|
| 147 |
+
qformer_text_encoding["qformer_attention_mask"] = qformer_text_encoding.pop("attention_mask")
|
| 148 |
+
text_encoding.update(qformer_text_encoding)
|
| 149 |
+
else:
|
| 150 |
+
text_encoding = None
|
| 151 |
+
|
| 152 |
+
if text_encoding is not None:
|
| 153 |
+
encoding_image_processor.update(text_encoding)
|
| 154 |
+
|
| 155 |
+
return encoding_image_processor
|
| 156 |
+
|
| 157 |
+
# Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
|
| 158 |
+
def batch_decode(self, *args, **kwargs):
|
| 159 |
+
"""
|
| 160 |
+
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 161 |
+
refer to the docstring of this method for more information.
|
| 162 |
+
"""
|
| 163 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 164 |
+
|
| 165 |
+
# Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer
|
| 166 |
+
def decode(self, *args, **kwargs):
|
| 167 |
+
"""
|
| 168 |
+
This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer
|
| 169 |
+
to the docstring of this method for more information.
|
| 170 |
+
"""
|
| 171 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 172 |
+
|
| 173 |
+
@property
|
| 174 |
+
# Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
|
| 175 |
+
def model_input_names(self):
|
| 176 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 177 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 178 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
robohusky/train/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
robohusky/train/llama_flash_attn_monkey_patch.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from typing import Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from flash_attn import __version__ as flash_attn_version
|
| 6 |
+
from flash_attn.bert_padding import pad_input, unpad_input
|
| 7 |
+
from flash_attn.flash_attn_interface import (
|
| 8 |
+
flash_attn_func,
|
| 9 |
+
flash_attn_varlen_kvpacked_func,
|
| 10 |
+
)
|
| 11 |
+
from transformers.models.llama.modeling_llama import (
|
| 12 |
+
LlamaAttention,
|
| 13 |
+
LlamaModel,
|
| 14 |
+
rotate_half,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def apply_rotary_pos_emb(q, k, cos_sin, position_ids):
|
| 18 |
+
gather_indices = position_ids[:, :, None, None] # [bsz, seq_len, 1, 1]
|
| 19 |
+
gather_indices = gather_indices.repeat(
|
| 20 |
+
1, 1, cos_sin[0].shape[1], cos_sin[0].shape[3]
|
| 21 |
+
)
|
| 22 |
+
bsz = gather_indices.shape[0]
|
| 23 |
+
cos, sin = (
|
| 24 |
+
torch.gather(x.transpose(1, 2).repeat(bsz, 1, 1, 1), 1, gather_indices)
|
| 25 |
+
for x in cos_sin
|
| 26 |
+
)
|
| 27 |
+
q, k = ((x * cos) + (rotate_half(x) * sin) for x in (q, k))
|
| 28 |
+
return q, k
|
| 29 |
+
|
| 30 |
+
def forward(
|
| 31 |
+
self,
|
| 32 |
+
hidden_states: torch.Tensor,
|
| 33 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 34 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 35 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 36 |
+
output_attentions: bool = False,
|
| 37 |
+
use_cache: bool = False,
|
| 38 |
+
padding_mask: Optional[torch.Tensor] = None,
|
| 39 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 40 |
+
if output_attentions:
|
| 41 |
+
warnings.warn(
|
| 42 |
+
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
bsz, q_len, _ = hidden_states.size()
|
| 46 |
+
kv_heads = getattr(self, "num_key_value_heads", self.num_heads)
|
| 47 |
+
|
| 48 |
+
q, k, v = (
|
| 49 |
+
op(hidden_states).view(bsz, q_len, nh, self.head_dim)
|
| 50 |
+
for op, nh in (
|
| 51 |
+
(self.q_proj, self.num_heads),
|
| 52 |
+
(self.k_proj, kv_heads),
|
| 53 |
+
(self.v_proj, kv_heads),
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
# shape: (b, s, num_heads, head_dim)
|
| 57 |
+
|
| 58 |
+
kv_seq_len = k.shape[1]
|
| 59 |
+
past_kv_len = 0
|
| 60 |
+
if past_key_value is not None:
|
| 61 |
+
past_kv_len = past_key_value[0].shape[2]
|
| 62 |
+
kv_seq_len += past_kv_len
|
| 63 |
+
|
| 64 |
+
cos_sin = self.rotary_emb(v, seq_len=kv_seq_len)
|
| 65 |
+
q, k = apply_rotary_pos_emb(q, k, cos_sin, position_ids)
|
| 66 |
+
|
| 67 |
+
if past_key_value is not None:
|
| 68 |
+
assert (
|
| 69 |
+
flash_attn_version >= "2.1.0"
|
| 70 |
+
), "past_key_value support requires flash-attn >= 2.1.0"
|
| 71 |
+
# reuse k, v
|
| 72 |
+
k = torch.cat([past_key_value[0].transpose(1, 2), k], dim=1)
|
| 73 |
+
v = torch.cat([past_key_value[1].transpose(1, 2), v], dim=1)
|
| 74 |
+
|
| 75 |
+
past_key_value = (k.transpose(1, 2), v.transpose(1, 2)) if use_cache else None
|
| 76 |
+
|
| 77 |
+
if attention_mask is None:
|
| 78 |
+
output = flash_attn_func(q, k, v, 0.0, softmax_scale=None, causal=True).view(
|
| 79 |
+
bsz, q_len, -1
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
q, indices, cu_q_lens, max_s = unpad_input(q, attention_mask[:, -q_len:])
|
| 83 |
+
# We can skip concat and call unpad twice but seems better to call unpad only once.
|
| 84 |
+
kv, _, cu_k_lens, max_k = unpad_input(
|
| 85 |
+
torch.stack((k, v), dim=2), attention_mask
|
| 86 |
+
)
|
| 87 |
+
output_unpad = flash_attn_varlen_kvpacked_func(
|
| 88 |
+
q,
|
| 89 |
+
kv,
|
| 90 |
+
cu_q_lens,
|
| 91 |
+
cu_k_lens,
|
| 92 |
+
max_s,
|
| 93 |
+
max_k,
|
| 94 |
+
0.0,
|
| 95 |
+
softmax_scale=None,
|
| 96 |
+
causal=True,
|
| 97 |
+
)
|
| 98 |
+
output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
|
| 99 |
+
output = pad_input(output_unpad, indices, bsz, q_len)
|
| 100 |
+
|
| 101 |
+
return self.o_proj(output), None, past_key_value
|
| 102 |
+
|
| 103 |
+
# Disable the transformation of the attention mask in LlamaModel as flash attention
|
| 104 |
+
# takes a boolean key_padding_mask. Fills in the past kv length for use in forward.
|
| 105 |
+
def _prepare_decoder_attention_mask(
|
| 106 |
+
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
| 107 |
+
):
|
| 108 |
+
# [bsz, seq_len]
|
| 109 |
+
if past_key_values_length > 0 and attention_mask is not None:
|
| 110 |
+
attention_mask = torch.cat(
|
| 111 |
+
(
|
| 112 |
+
torch.full(
|
| 113 |
+
(input_shape[0], past_key_values_length),
|
| 114 |
+
True,
|
| 115 |
+
dtype=attention_mask.dtype,
|
| 116 |
+
device=attention_mask.device,
|
| 117 |
+
),
|
| 118 |
+
attention_mask,
|
| 119 |
+
),
|
| 120 |
+
dim=-1,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if attention_mask is not None and torch.all(attention_mask):
|
| 124 |
+
return None # This uses the faster call when training with full samples
|
| 125 |
+
|
| 126 |
+
return attention_mask
|
| 127 |
+
|
| 128 |
+
def replace_llama_attn_with_flash_attn():
|
| 129 |
+
cuda_major, cuda_minor = torch.cuda.get_device_capability()
|
| 130 |
+
if cuda_major < 8:
|
| 131 |
+
warnings.warn(
|
| 132 |
+
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
|
| 133 |
+
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
|
| 137 |
+
LlamaAttention.forward = forward
|
| 138 |
+
|
| 139 |
+
def test():
|
| 140 |
+
from robohusky.train.llama_flash_attn_monkey_patch import forward as fastchat_forward
|
| 141 |
+
from transformers.models.llama.configuration_llama import LlamaConfig
|
| 142 |
+
|
| 143 |
+
config = LlamaConfig(
|
| 144 |
+
hidden_size=1024,
|
| 145 |
+
intermediate_size=128,
|
| 146 |
+
num_hidden_layers=1,
|
| 147 |
+
num_attention_heads=8,
|
| 148 |
+
max_position_embeddings=16,
|
| 149 |
+
)
|
| 150 |
+
device = torch.device("cuda")
|
| 151 |
+
model = LlamaModel(config)
|
| 152 |
+
attn = LlamaAttention(config).to(device).half()
|
| 153 |
+
bsz, hs, seqlen = 2, config.hidden_size, config.max_position_embeddings
|
| 154 |
+
position_ids = torch.arange(seqlen, dtype=torch.long, device=device).view(
|
| 155 |
+
-1, seqlen
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
|
| 159 |
+
for i in range(4):
|
| 160 |
+
hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
|
| 161 |
+
if i:
|
| 162 |
+
mask[0, -i:] = False
|
| 163 |
+
mask[1, :i] = False
|
| 164 |
+
|
| 165 |
+
lmask = model._prepare_decoder_attention_mask(mask, hidden.shape[:2], hidden, 0)
|
| 166 |
+
ref, _, _ = attn.forward(
|
| 167 |
+
hidden, attention_mask=lmask, position_ids=position_ids
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
fast, _, _ = fastchat_forward(
|
| 171 |
+
attn, hidden, attention_mask=mask, position_ids=position_ids
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
lmask = _prepare_decoder_attention_mask(
|
| 175 |
+
model, mask, hidden.shape[:2], hidden, 0
|
| 176 |
+
)
|
| 177 |
+
test, _, _ = forward(
|
| 178 |
+
attn, hidden, attention_mask=lmask, position_ids=position_ids
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
print(f"Mean(abs(ref)) = {torch.mean(torch.abs(ref))}")
|
| 182 |
+
print(f"Mean(abs(ref - fast)) = {torch.mean(torch.abs(ref - fast))}")
|
| 183 |
+
print(f"Mean(abs(ref - test)) = {torch.mean(torch.abs(ref - test))}")
|
| 184 |
+
print(f"Mean(abs(fast - test)) = {torch.mean(torch.abs(fast - test))}")
|
| 185 |
+
print(f"allclose(fast, test) = {torch.allclose(fast, test)}")
|
| 186 |
+
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
# Also check that past_kv is handled properly
|
| 189 |
+
hidden = torch.rand((bsz, seqlen, hs), dtype=torch.float16, device=device)
|
| 190 |
+
part_len = seqlen // 4
|
| 191 |
+
assert part_len * 4 == seqlen
|
| 192 |
+
mask = torch.full((bsz, seqlen), True, dtype=torch.bool, device=device)
|
| 193 |
+
mask[0, -2:] = False
|
| 194 |
+
lmask = _prepare_decoder_attention_mask(
|
| 195 |
+
model, mask, hidden.shape[:2], hidden, 0
|
| 196 |
+
)
|
| 197 |
+
oneshot, _, _ = forward(
|
| 198 |
+
attn, hidden, attention_mask=lmask, position_ids=position_ids
|
| 199 |
+
)
|
| 200 |
+
parts = []
|
| 201 |
+
past_kv, past_kv_len = None, 0
|
| 202 |
+
for i in range(4):
|
| 203 |
+
start = part_len * i
|
| 204 |
+
end = start + part_len
|
| 205 |
+
hidden_part = hidden[:, start:end, ...]
|
| 206 |
+
lmask = _prepare_decoder_attention_mask(
|
| 207 |
+
model,
|
| 208 |
+
mask[:, start:end],
|
| 209 |
+
hidden_part.shape[:2],
|
| 210 |
+
hidden_part,
|
| 211 |
+
past_kv_len,
|
| 212 |
+
)
|
| 213 |
+
part, _, past_kv = forward(
|
| 214 |
+
attn,
|
| 215 |
+
hidden_part.clone(),
|
| 216 |
+
attention_mask=lmask,
|
| 217 |
+
position_ids=position_ids[:, start:end],
|
| 218 |
+
past_key_value=past_kv,
|
| 219 |
+
use_cache=True,
|
| 220 |
+
)
|
| 221 |
+
parts.append(part)
|
| 222 |
+
past_kv_len = past_kv[0].shape[2]
|
| 223 |
+
|
| 224 |
+
print(
|
| 225 |
+
f"allclose(oneshot[:, 0], parts[0]) = {torch.allclose(oneshot[:, :part_len], parts[0])}"
|
| 226 |
+
)
|
| 227 |
+
print(
|
| 228 |
+
f"allclose(oneshot, parts) = {torch.allclose(oneshot, torch.cat(parts, dim=1))}"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if __name__ == "__main__":
|
| 232 |
+
test()
|
robohusky/train/llama_rmsnorm_monkey_patch.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import transformers
|
| 2 |
+
|
| 3 |
+
def replace_llama_rmsnorm_with_fused_rmsnorm():
|
| 4 |
+
try:
|
| 5 |
+
from apex.normalization import FusedRMSNorm
|
| 6 |
+
from functools import partial
|
| 7 |
+
LlamaRMSNorm = partial(FusedRMSNorm, eps=1e-6) # noqa
|
| 8 |
+
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
| 9 |
+
print("Discovered apex.normalization.FusedRMSNorm - will use it instead of LlamaRMSNorm")
|
| 10 |
+
except ImportError:
|
| 11 |
+
# using the normal LlamaRMSNorm
|
| 12 |
+
pass
|
| 13 |
+
except Exception:
|
| 14 |
+
print("discovered apex but it failed to load, falling back to LlamaRMSNorm")
|
| 15 |
+
pass
|
robohusky/train/train.py
ADDED
|
@@ -0,0 +1,597 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright Qing-Long Zhang. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""
|
| 17 |
+
Fine-tuning the library models for sequence to sequence.
|
| 18 |
+
"""
|
| 19 |
+
import json
|
| 20 |
+
import logging
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import warnings
|
| 24 |
+
from functools import partial
|
| 25 |
+
|
| 26 |
+
from multiprocessing import cpu_count
|
| 27 |
+
|
| 28 |
+
from typing import Optional
|
| 29 |
+
from dataclasses import dataclass, field
|
| 30 |
+
|
| 31 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 32 |
+
from datasets import load_dataset, load_from_disk
|
| 33 |
+
|
| 34 |
+
from robohusky.dist_utils import init_dist
|
| 35 |
+
from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
|
| 36 |
+
|
| 37 |
+
import transformers
|
| 38 |
+
from transformers import (
|
| 39 |
+
HfArgumentParser,
|
| 40 |
+
TrainingArguments,
|
| 41 |
+
LlamaTokenizer,
|
| 42 |
+
Trainer,
|
| 43 |
+
set_seed,
|
| 44 |
+
default_data_collator,
|
| 45 |
+
DataCollatorForSeq2Seq,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
from peft import (
|
| 49 |
+
LoraConfig,
|
| 50 |
+
get_peft_model,
|
| 51 |
+
prepare_model_for_int8_training,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
from robohusky.base_dataset import (
|
| 55 |
+
process_func,
|
| 56 |
+
BaseDataset,
|
| 57 |
+
CephDataset,
|
| 58 |
+
build_transform
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 62 |
+
from transformers.utils import check_min_version
|
| 63 |
+
from transformers.utils.versions import require_version
|
| 64 |
+
|
| 65 |
+
from transformers.utils.logging import (
|
| 66 |
+
set_verbosity_info,
|
| 67 |
+
set_verbosity,
|
| 68 |
+
enable_default_handler,
|
| 69 |
+
enable_explicit_format,
|
| 70 |
+
)
|
| 71 |
+
from robohusky.train.llama_flash_attn_monkey_patch import (
|
| 72 |
+
replace_llama_attn_with_flash_attn
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
from robohusky.train.llama_rmsnorm_monkey_patch import (
|
| 76 |
+
replace_llama_rmsnorm_with_fused_rmsnorm
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
replace_llama_attn_with_flash_attn()
|
| 80 |
+
replace_llama_rmsnorm_with_fused_rmsnorm()
|
| 81 |
+
|
| 82 |
+
IGNORE_INDEX = -100
|
| 83 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
| 84 |
+
DEFAULT_IMG_START_TOKEN = "<img>"
|
| 85 |
+
DEFAULT_IMG_END_TOKEN = "</img>"
|
| 86 |
+
|
| 87 |
+
DEFAULT_VIDEO_START_TOKEN = "<vid>"
|
| 88 |
+
DEFAULT_VIDEO_END_TOKEN = "</vid>"
|
| 89 |
+
|
| 90 |
+
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 91 |
+
check_min_version("4.32.0.dev0")
|
| 92 |
+
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
| 93 |
+
|
| 94 |
+
warnings.filterwarnings('ignore')
|
| 95 |
+
logger = logging.getLogger(__name__)
|
| 96 |
+
|
| 97 |
+
os.environ["WANDB_DISABLED"] = "true"
|
| 98 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
| 99 |
+
|
| 100 |
+
@dataclass
|
| 101 |
+
class ModelArguments:
|
| 102 |
+
"""
|
| 103 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
model_name_or_path: str = field(
|
| 107 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
| 108 |
+
)
|
| 109 |
+
config_name: Optional[str] = field(
|
| 110 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
| 111 |
+
)
|
| 112 |
+
tokenizer_name: Optional[str] = field(
|
| 113 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
| 114 |
+
)
|
| 115 |
+
cache_dir: Optional[str] = field(
|
| 116 |
+
default=None,
|
| 117 |
+
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
| 118 |
+
)
|
| 119 |
+
use_fast_tokenizer: bool = field(
|
| 120 |
+
default=False,
|
| 121 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
| 122 |
+
)
|
| 123 |
+
model_revision: str = field(
|
| 124 |
+
default="main",
|
| 125 |
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
| 126 |
+
)
|
| 127 |
+
use_auth_token: bool = field(
|
| 128 |
+
default=False,
|
| 129 |
+
metadata={
|
| 130 |
+
"help": (
|
| 131 |
+
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
|
| 132 |
+
"with private models)."
|
| 133 |
+
)
|
| 134 |
+
},
|
| 135 |
+
)
|
| 136 |
+
freeze_model: bool = field(
|
| 137 |
+
default=False,
|
| 138 |
+
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
|
| 139 |
+
)
|
| 140 |
+
freeze_vision_model: bool = field(
|
| 141 |
+
default=False,
|
| 142 |
+
metadata={"help": "Will enable to load a pretrained vision model whose head dimensions are different."},
|
| 143 |
+
)
|
| 144 |
+
freeze_vision_adapter: bool = field(
|
| 145 |
+
default=False,
|
| 146 |
+
metadata={"help": "Will enable to load a pretrained vision adapter whose head dimensions are different."},
|
| 147 |
+
)
|
| 148 |
+
freeze_text_model: bool = field(
|
| 149 |
+
default=False,
|
| 150 |
+
metadata={"help": "Will enable to load a pretrained text model whose head dimensions are different."},
|
| 151 |
+
)
|
| 152 |
+
freeze_qformer: bool = field(
|
| 153 |
+
default=False,
|
| 154 |
+
metadata={"help": "Will enable to load a pretrained qformer model whose head dimensions are different."},
|
| 155 |
+
)
|
| 156 |
+
un_freeze_vision_embedding: bool = field(
|
| 157 |
+
default=False,
|
| 158 |
+
metadata={"help": "Will enable to tuning image patch_embedding when vision_model are frozen"},
|
| 159 |
+
)
|
| 160 |
+
un_freeze_video_embedding: bool = field(
|
| 161 |
+
default=False,
|
| 162 |
+
metadata={"help": "Will enable to tuning video patch_embedding when vision_model are frozen"},
|
| 163 |
+
)
|
| 164 |
+
un_freeze_llm_head: bool = field(
|
| 165 |
+
default=False,
|
| 166 |
+
metadata={"help": "Will enable to tuning video patch_embedding when vision_model are frozen"},
|
| 167 |
+
)
|
| 168 |
+
use_lora: bool = field(
|
| 169 |
+
default=False, metadata={"help": "add the LoRA adapters to the base model"}
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
@dataclass
|
| 173 |
+
class DataTrainingArguments:
|
| 174 |
+
"""
|
| 175 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
dataset_name: Optional[str] = field(
|
| 179 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
| 180 |
+
)
|
| 181 |
+
dataset_config_name: Optional[str] = field(
|
| 182 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 183 |
+
)
|
| 184 |
+
data_dir: Optional[str] = field(
|
| 185 |
+
default=None, metadata={"help": "The data directory containing input files."})
|
| 186 |
+
train_file: Optional[str] = field(
|
| 187 |
+
default=None, metadata={"help": "The input training data file (a jsonlines)."})
|
| 188 |
+
validation_file: Optional[str] = field(
|
| 189 |
+
default=None,
|
| 190 |
+
metadata={
|
| 191 |
+
"help": "An optional input evaluation data file to evaluate the metrics (sacrebleu) on a jsonlines file."
|
| 192 |
+
},
|
| 193 |
+
)
|
| 194 |
+
train_val_split: Optional[float] = field(
|
| 195 |
+
default=0.0, metadata={"help": "Percent to split off of train for validation."}
|
| 196 |
+
)
|
| 197 |
+
test_file: Optional[str] = field(
|
| 198 |
+
default=None,
|
| 199 |
+
metadata={"help": "An optional input test data file to evaluate the metrics (sacrebleu) on a jsonlines file."},
|
| 200 |
+
)
|
| 201 |
+
image_path: Optional[str] = field(
|
| 202 |
+
default=None,
|
| 203 |
+
metadata={"help": "An optional image path"},
|
| 204 |
+
)
|
| 205 |
+
video_path: Optional[str] = field(
|
| 206 |
+
default=None,
|
| 207 |
+
metadata={"help": "An optional video path"},
|
| 208 |
+
)
|
| 209 |
+
input_size: Optional[int] = field(
|
| 210 |
+
default=224,
|
| 211 |
+
metadata={"help": "The input size of images."},
|
| 212 |
+
)
|
| 213 |
+
overwrite_cache: bool = field(
|
| 214 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 215 |
+
)
|
| 216 |
+
preprocessing_num_workers: Optional[int] = field(
|
| 217 |
+
default=None,
|
| 218 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
| 219 |
+
)
|
| 220 |
+
max_seq_length: Optional[int] = field(
|
| 221 |
+
default=128,
|
| 222 |
+
metadata={
|
| 223 |
+
"help": (
|
| 224 |
+
"The maximum total input sequence length after tokenization. Sequences longer "
|
| 225 |
+
"than this will be truncated, sequences shorter will be padded."
|
| 226 |
+
)
|
| 227 |
+
},
|
| 228 |
+
)
|
| 229 |
+
pad_to_max_length: bool = field(
|
| 230 |
+
default=False,
|
| 231 |
+
metadata={
|
| 232 |
+
"help": (
|
| 233 |
+
"Whether to pad all samples to model maximum sentence length. "
|
| 234 |
+
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
| 235 |
+
"efficient on GPU but very bad for TPU."
|
| 236 |
+
)
|
| 237 |
+
},
|
| 238 |
+
)
|
| 239 |
+
val_max_length: Optional[int] = field(
|
| 240 |
+
default=None,
|
| 241 |
+
metadata={
|
| 242 |
+
"help": (
|
| 243 |
+
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
| 244 |
+
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
| 245 |
+
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
| 246 |
+
"during ``evaluate`` and ``predict``."
|
| 247 |
+
)
|
| 248 |
+
},
|
| 249 |
+
)
|
| 250 |
+
max_train_samples: Optional[int] = field(
|
| 251 |
+
default=None,
|
| 252 |
+
metadata={
|
| 253 |
+
"help": (
|
| 254 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 255 |
+
"value if set."
|
| 256 |
+
)
|
| 257 |
+
},
|
| 258 |
+
)
|
| 259 |
+
max_eval_samples: Optional[int] = field(
|
| 260 |
+
default=None,
|
| 261 |
+
metadata={
|
| 262 |
+
"help": (
|
| 263 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
| 264 |
+
"value if set."
|
| 265 |
+
)
|
| 266 |
+
},
|
| 267 |
+
)
|
| 268 |
+
max_predict_samples: Optional[int] = field(
|
| 269 |
+
default=None,
|
| 270 |
+
metadata={
|
| 271 |
+
"help": (
|
| 272 |
+
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
| 273 |
+
"value if set."
|
| 274 |
+
)
|
| 275 |
+
},
|
| 276 |
+
)
|
| 277 |
+
conv_style: Optional[str] = field(
|
| 278 |
+
default=None, metadata={"help": "prompt style for a conversation."}
|
| 279 |
+
)
|
| 280 |
+
save_data_path: Optional[str] = field(
|
| 281 |
+
default=None, metadata={"help": "prompt style for a conversation."}
|
| 282 |
+
)
|
| 283 |
+
num_beams: Optional[int] = field(
|
| 284 |
+
default=None,
|
| 285 |
+
metadata={
|
| 286 |
+
"help": (
|
| 287 |
+
"Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
| 288 |
+
"which is used during ``evaluate`` and ``predict``."
|
| 289 |
+
)
|
| 290 |
+
},
|
| 291 |
+
)
|
| 292 |
+
ignore_pad_token_for_loss: bool = field(
|
| 293 |
+
default=True,
|
| 294 |
+
metadata={
|
| 295 |
+
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
| 296 |
+
},
|
| 297 |
+
)
|
| 298 |
+
source_prefix: Optional[str] = field(
|
| 299 |
+
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
| 300 |
+
)
|
| 301 |
+
forced_bos_token: Optional[str] = field(
|
| 302 |
+
default=None,
|
| 303 |
+
metadata={
|
| 304 |
+
"help": (
|
| 305 |
+
"The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for"
|
| 306 |
+
" multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to"
|
| 307 |
+
" be the target language token.(Usually it is the target language token)"
|
| 308 |
+
)
|
| 309 |
+
},
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
def __post_init__(self):
|
| 313 |
+
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
| 314 |
+
raise ValueError("Need either a dataset name or a training/validation file.")
|
| 315 |
+
# accepting both json and jsonl file extensions, as
|
| 316 |
+
# many jsonlines files actually have a .json extension
|
| 317 |
+
else:
|
| 318 |
+
if self.train_file is not None:
|
| 319 |
+
extension = self.train_file.split(".")[-1]
|
| 320 |
+
assert extension in ["csv", "json", "jsonl", "parquet"], "`train_file` should be a csv or a json file."
|
| 321 |
+
if self.validation_file is not None:
|
| 322 |
+
extension = self.validation_file.split(".")[-1]
|
| 323 |
+
assert extension in ["csv", "json", "jsonl",
|
| 324 |
+
"parquet"], "`validation_file` should be a csv or a json file."
|
| 325 |
+
if self.test_file is not None:
|
| 326 |
+
extension = self.test_file.split(".")[-1]
|
| 327 |
+
assert extension == "json", "`test_file` should be a json file."
|
| 328 |
+
|
| 329 |
+
def main():
|
| 330 |
+
# 1. Parse input arguments
|
| 331 |
+
# See all possible arguments in src/transformers/training_args.py
|
| 332 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
| 333 |
+
init_dist(launcher='slurm', backend='nccl', port=29598)
|
| 334 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 335 |
+
# If we pass only one argument to the script, and it's the path to a json file,
|
| 336 |
+
# let's parse it to get our arguments.
|
| 337 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
| 338 |
+
else:
|
| 339 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 340 |
+
|
| 341 |
+
# 2. Setup logging
|
| 342 |
+
logging.basicConfig(
|
| 343 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 344 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 345 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
if training_args.should_log:
|
| 349 |
+
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
| 350 |
+
transformers.utils.logging.set_verbosity_info()
|
| 351 |
+
|
| 352 |
+
log_level = training_args.get_process_log_level()
|
| 353 |
+
logger.setLevel(log_level)
|
| 354 |
+
set_verbosity(log_level)
|
| 355 |
+
enable_default_handler()
|
| 356 |
+
enable_explicit_format()
|
| 357 |
+
|
| 358 |
+
# Log on each process the small summary:
|
| 359 |
+
logger.warning(
|
| 360 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
| 361 |
+
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
| 362 |
+
)
|
| 363 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
| 364 |
+
|
| 365 |
+
# 3. Detecting last checkpoint and eventually continue from last checkpoint.
|
| 366 |
+
last_checkpoint = None
|
| 367 |
+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
| 368 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
| 369 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
| 370 |
+
raise ValueError(
|
| 371 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
| 372 |
+
"Use --overwrite_output_dir to overcome."
|
| 373 |
+
)
|
| 374 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
| 375 |
+
logger.info(
|
| 376 |
+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
| 377 |
+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# Set seed before initializing model.
|
| 381 |
+
set_seed(training_args.seed)
|
| 382 |
+
|
| 383 |
+
# 4. Get the datasets
|
| 384 |
+
# you can either provide your own JSON training and evaluation files (see below)
|
| 385 |
+
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
| 386 |
+
# (the dataset will be downloaded automatically from the datasets Hub).
|
| 387 |
+
#
|
| 388 |
+
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
| 389 |
+
# download the dataset.
|
| 390 |
+
|
| 391 |
+
if data_args.dataset_name is not None:
|
| 392 |
+
# Downloading and loading a dataset from the hub.
|
| 393 |
+
ds = load_dataset(
|
| 394 |
+
data_args.dataset_name,
|
| 395 |
+
data_args.dataset_config_name,
|
| 396 |
+
data_dir=data_args.data_dir,
|
| 397 |
+
cache_dir=model_args.cache_dir,
|
| 398 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
| 399 |
+
)
|
| 400 |
+
else:
|
| 401 |
+
data_files = {}
|
| 402 |
+
if data_args.train_file is not None:
|
| 403 |
+
data_files["train"] = data_args.train_file
|
| 404 |
+
extension = data_args.train_file.split(".")[-1]
|
| 405 |
+
if data_args.validation_file is not None:
|
| 406 |
+
data_files["validation"] = data_args.validation_file
|
| 407 |
+
extension = data_args.validation_file.split(".")[-1]
|
| 408 |
+
if data_args.test_file is not None:
|
| 409 |
+
data_files["test"] = data_args.test_file
|
| 410 |
+
extension = data_args.test_file.split(".")[-1]
|
| 411 |
+
|
| 412 |
+
# ds = load_dataset(
|
| 413 |
+
# "json" if extension == "jsonl" else extension,
|
| 414 |
+
# data_files=data_files,
|
| 415 |
+
# split="train"
|
| 416 |
+
# )
|
| 417 |
+
ds = json.load(open(data_args.train_file, "r"))
|
| 418 |
+
|
| 419 |
+
# 5. Load pretrained model, tokenizer, and image processor
|
| 420 |
+
#
|
| 421 |
+
# Distributed training: The .from_pretrained methods guarantee that only one local process can concurrently
|
| 422 |
+
# download model & vocab.
|
| 423 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
| 424 |
+
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
| 425 |
+
cache_dir=model_args.cache_dir,
|
| 426 |
+
use_fast=model_args.use_fast_tokenizer,
|
| 427 |
+
legacy=True,
|
| 428 |
+
)
|
| 429 |
+
# add special token
|
| 430 |
+
tokenizer.pad_token_id = 0
|
| 431 |
+
if tokenizer.unk_token is None:
|
| 432 |
+
tokenizer.add_special_tokens({"unk_token": DEFAULT_UNK_TOKEN})
|
| 433 |
+
|
| 434 |
+
tokens_list = [
|
| 435 |
+
DEFAULT_IMG_START_TOKEN, DEFAULT_IMG_END_TOKEN,
|
| 436 |
+
DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN
|
| 437 |
+
]
|
| 438 |
+
tokenizer.add_tokens(tokens_list, special_tokens=True)
|
| 439 |
+
|
| 440 |
+
model = HuskyForConditionalGeneration.from_pretrained(
|
| 441 |
+
model_args.model_name_or_path, ignore_mismatched_sizes=True
|
| 442 |
+
)
|
| 443 |
+
embedding_size = model.language_model.get_input_embeddings().weight.shape[0]
|
| 444 |
+
|
| 445 |
+
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
| 446 |
+
# on a small vocab and want a smaller embedding size, remove this test.
|
| 447 |
+
|
| 448 |
+
if len(tokenizer) > embedding_size:
|
| 449 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 450 |
+
model.language_model.resize_token_embeddings(len(tokenizer))
|
| 451 |
+
model.config.text_config.vocab_size = len(tokenizer)
|
| 452 |
+
|
| 453 |
+
model.config.use_cache = False
|
| 454 |
+
|
| 455 |
+
def _freeze_params(module):
|
| 456 |
+
for param in module.parameters():
|
| 457 |
+
param.requires_grad = False
|
| 458 |
+
|
| 459 |
+
if model_args.freeze_model:
|
| 460 |
+
_freeze_params(model)
|
| 461 |
+
# only update language projection
|
| 462 |
+
model.language_projection.weight.requires_grad = True
|
| 463 |
+
|
| 464 |
+
if model_args.freeze_vision_model:
|
| 465 |
+
model.vision_model = model.vision_model.eval()
|
| 466 |
+
_freeze_params(model.vision_model)
|
| 467 |
+
|
| 468 |
+
if model_args.freeze_vision_adapter:
|
| 469 |
+
_freeze_params(model.vision_adapter)
|
| 470 |
+
|
| 471 |
+
if model_args.freeze_qformer:
|
| 472 |
+
model.qformer = model.qformer.eval()
|
| 473 |
+
_freeze_params(model.qformer)
|
| 474 |
+
model.query_tokens.requires_grad = False
|
| 475 |
+
|
| 476 |
+
if model_args.freeze_text_model:
|
| 477 |
+
_freeze_params(model.language_model)
|
| 478 |
+
|
| 479 |
+
if model_args.use_lora:
|
| 480 |
+
training_args.ddp_find_unused_parameters = False
|
| 481 |
+
_freeze_params(model)
|
| 482 |
+
lora_config = LoraConfig(
|
| 483 |
+
r=16,
|
| 484 |
+
target_modules=["q_proj", "v_proj"],
|
| 485 |
+
lora_alpha=32,
|
| 486 |
+
lora_dropout=0.05,
|
| 487 |
+
bias="none",
|
| 488 |
+
task_type="CAUSAL_LM",
|
| 489 |
+
)
|
| 490 |
+
model.language_model = get_peft_model(model.language_model, lora_config)
|
| 491 |
+
model.language_model.print_trainable_parameters()
|
| 492 |
+
|
| 493 |
+
if model_args.un_freeze_video_embedding:
|
| 494 |
+
_freeze_params(model)
|
| 495 |
+
model.vision_model.video_embeddings.patch_embedding.weight.requires_grad = True
|
| 496 |
+
model.vision_model.video_embeddings.class_embedding.requires_grad = True
|
| 497 |
+
model.vision_model.video_embeddings.position_embedding.requires_grad = True
|
| 498 |
+
|
| 499 |
+
if model_args.un_freeze_llm_head:
|
| 500 |
+
model.language_model.lm_head.weight.requires_grad = True
|
| 501 |
+
|
| 502 |
+
# set seed for torch dataloaders
|
| 503 |
+
set_seed(training_args.seed)
|
| 504 |
+
|
| 505 |
+
# 7. Preprocessing the datasets.
|
| 506 |
+
# We need to tokenize input captions and transform the images.
|
| 507 |
+
|
| 508 |
+
# set padding.
|
| 509 |
+
padding = "max_length" if data_args.pad_to_max_length else False
|
| 510 |
+
|
| 511 |
+
def husky_processor(examples):
|
| 512 |
+
processor = partial(
|
| 513 |
+
process_func,
|
| 514 |
+
tokenizer=tokenizer,
|
| 515 |
+
max_seq_length=data_args.max_seq_length,
|
| 516 |
+
)
|
| 517 |
+
model_inputs = processor(examples)
|
| 518 |
+
return model_inputs
|
| 519 |
+
|
| 520 |
+
# Data collator
|
| 521 |
+
label_pad_token_id = IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
| 522 |
+
if data_args.pad_to_max_length:
|
| 523 |
+
data_collator = default_data_collator
|
| 524 |
+
else:
|
| 525 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 526 |
+
tokenizer,
|
| 527 |
+
model=model,
|
| 528 |
+
label_pad_token_id=label_pad_token_id,
|
| 529 |
+
pad_to_multiple_of=8 if training_args.fp16 else None,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
concat_dataset = []
|
| 533 |
+
for data in ds:
|
| 534 |
+
data_file = data["text_file"]
|
| 535 |
+
extension = data_file.split(".")[-1]
|
| 536 |
+
extension = "json" if extension == "jsonl" else extension
|
| 537 |
+
logger.info(f"Loading dataset: {data['data_name']}")
|
| 538 |
+
|
| 539 |
+
raw_dataset = load_dataset(extension, data_files=data_file, num_proc=cpu_count(), split="train")
|
| 540 |
+
if data["data_type"] == "base":
|
| 541 |
+
temp = BaseDataset(
|
| 542 |
+
raw_dataset,
|
| 543 |
+
processor=husky_processor,
|
| 544 |
+
image_path=data["image_path"],
|
| 545 |
+
input_size=data_args.input_size
|
| 546 |
+
)
|
| 547 |
+
else:
|
| 548 |
+
temp = CephDataset(
|
| 549 |
+
raw_dataset,
|
| 550 |
+
processor=husky_processor,
|
| 551 |
+
input_size=data_args.input_size
|
| 552 |
+
)
|
| 553 |
+
concat_dataset.append(temp)
|
| 554 |
+
|
| 555 |
+
logger.info(f"All datasets have been loaded!")
|
| 556 |
+
|
| 557 |
+
if len(concat_dataset) > 1:
|
| 558 |
+
train_dataset = ConcatDataset(concat_dataset)
|
| 559 |
+
# train_dataset = train_dataset.shuffle(seed=42)
|
| 560 |
+
else:
|
| 561 |
+
train_dataset = concat_dataset[0]
|
| 562 |
+
|
| 563 |
+
# 8. Initialize our Trainer
|
| 564 |
+
trainer = Trainer(
|
| 565 |
+
model=model,
|
| 566 |
+
args=training_args,
|
| 567 |
+
train_dataset=train_dataset if training_args.do_train else None,
|
| 568 |
+
eval_dataset=None,
|
| 569 |
+
tokenizer=tokenizer,
|
| 570 |
+
data_collator=data_collator,
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
# 9. Training
|
| 574 |
+
if training_args.do_train:
|
| 575 |
+
checkpoint = None
|
| 576 |
+
if training_args.resume_from_checkpoint is not None:
|
| 577 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 578 |
+
elif last_checkpoint is not None:
|
| 579 |
+
checkpoint = last_checkpoint
|
| 580 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
| 581 |
+
if model_args.use_lora:
|
| 582 |
+
model.language_model.save_pretrained(training_args.output_dir)
|
| 583 |
+
else:
|
| 584 |
+
trainer.save_model() # Saves the tokenizer too for easy upload
|
| 585 |
+
|
| 586 |
+
metrics = train_result.metrics
|
| 587 |
+
max_train_samples = (
|
| 588 |
+
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
| 589 |
+
)
|
| 590 |
+
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
| 591 |
+
|
| 592 |
+
trainer.log_metrics("train", metrics)
|
| 593 |
+
trainer.save_metrics("train", metrics)
|
| 594 |
+
trainer.save_state()
|
| 595 |
+
|
| 596 |
+
if __name__ == "__main__":
|
| 597 |
+
main()
|
robohusky/train/train_uni.py
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright Qing-Long Zhang. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""
|
| 17 |
+
Fine-tuning the library models for sequence to sequence.
|
| 18 |
+
"""
|
| 19 |
+
import json
|
| 20 |
+
import logging
|
| 21 |
+
import os
|
| 22 |
+
import sys
|
| 23 |
+
import warnings
|
| 24 |
+
from functools import partial
|
| 25 |
+
|
| 26 |
+
from multiprocessing import cpu_count
|
| 27 |
+
|
| 28 |
+
from typing import Optional
|
| 29 |
+
from dataclasses import dataclass, field
|
| 30 |
+
|
| 31 |
+
from torch.utils.data import Dataset, ConcatDataset
|
| 32 |
+
from datasets import load_dataset, load_from_disk
|
| 33 |
+
|
| 34 |
+
from robohusky.dist_utils import init_dist
|
| 35 |
+
from robohusky.model.modeling_husky_embody2 import HuskyForConditionalGeneration
|
| 36 |
+
|
| 37 |
+
import transformers
|
| 38 |
+
from transformers import (
|
| 39 |
+
HfArgumentParser,
|
| 40 |
+
TrainingArguments,
|
| 41 |
+
LlamaTokenizer,
|
| 42 |
+
Trainer,
|
| 43 |
+
set_seed,
|
| 44 |
+
default_data_collator,
|
| 45 |
+
DataCollatorForSeq2Seq,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
from peft import (
|
| 49 |
+
LoraConfig,
|
| 50 |
+
get_peft_model,
|
| 51 |
+
prepare_model_for_int8_training,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
from robohusky.base_dataset_uni import (
|
| 55 |
+
process_func,
|
| 56 |
+
BaseDataset,
|
| 57 |
+
WeightedConcatDataset
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
from transformers.trainer_utils import get_last_checkpoint
|
| 61 |
+
from transformers.utils import check_min_version
|
| 62 |
+
from transformers.utils.versions import require_version
|
| 63 |
+
|
| 64 |
+
from transformers.utils.logging import (
|
| 65 |
+
set_verbosity_info,
|
| 66 |
+
set_verbosity,
|
| 67 |
+
enable_default_handler,
|
| 68 |
+
enable_explicit_format,
|
| 69 |
+
)
|
| 70 |
+
from robohusky.train.llama_flash_attn_monkey_patch import (
|
| 71 |
+
replace_llama_attn_with_flash_attn
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
from robohusky.train.llama_rmsnorm_monkey_patch import (
|
| 75 |
+
replace_llama_rmsnorm_with_fused_rmsnorm
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
replace_llama_attn_with_flash_attn()
|
| 79 |
+
replace_llama_rmsnorm_with_fused_rmsnorm()
|
| 80 |
+
|
| 81 |
+
IGNORE_INDEX = -100
|
| 82 |
+
DEFAULT_UNK_TOKEN = "<unk>"
|
| 83 |
+
DEFAULT_IMG_START_TOKEN = "<img>"
|
| 84 |
+
DEFAULT_IMG_END_TOKEN = "</img>"
|
| 85 |
+
|
| 86 |
+
DEFAULT_VIDEO_START_TOKEN = "<vid>"
|
| 87 |
+
DEFAULT_VIDEO_END_TOKEN = "</vid>"
|
| 88 |
+
|
| 89 |
+
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
| 90 |
+
check_min_version("4.32.0.dev0")
|
| 91 |
+
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/translation/requirements.txt")
|
| 92 |
+
|
| 93 |
+
warnings.filterwarnings('ignore')
|
| 94 |
+
logger = logging.getLogger(__name__)
|
| 95 |
+
|
| 96 |
+
os.environ["WANDB_DISABLED"] = "true"
|
| 97 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
| 98 |
+
|
| 99 |
+
@dataclass
|
| 100 |
+
class ModelArguments:
|
| 101 |
+
"""
|
| 102 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
model_name_or_path: str = field(
|
| 106 |
+
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
|
| 107 |
+
)
|
| 108 |
+
config_name: Optional[str] = field(
|
| 109 |
+
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
| 110 |
+
)
|
| 111 |
+
tokenizer_name: Optional[str] = field(
|
| 112 |
+
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
| 113 |
+
)
|
| 114 |
+
cache_dir: Optional[str] = field(
|
| 115 |
+
default=None,
|
| 116 |
+
metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
|
| 117 |
+
)
|
| 118 |
+
use_fast_tokenizer: bool = field(
|
| 119 |
+
default=False,
|
| 120 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
| 121 |
+
)
|
| 122 |
+
model_revision: str = field(
|
| 123 |
+
default="main",
|
| 124 |
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
| 125 |
+
)
|
| 126 |
+
use_auth_token: bool = field(
|
| 127 |
+
default=False,
|
| 128 |
+
metadata={
|
| 129 |
+
"help": (
|
| 130 |
+
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
|
| 131 |
+
"with private models)."
|
| 132 |
+
)
|
| 133 |
+
},
|
| 134 |
+
)
|
| 135 |
+
freeze_model: bool = field(
|
| 136 |
+
default=False,
|
| 137 |
+
metadata={"help": "Will enable to load a pretrained model whose head dimensions are different."},
|
| 138 |
+
)
|
| 139 |
+
freeze_vision_model: bool = field(
|
| 140 |
+
default=False,
|
| 141 |
+
metadata={"help": "Will enable to load a pretrained vision model whose head dimensions are different."},
|
| 142 |
+
)
|
| 143 |
+
freeze_vision_adapter: bool = field(
|
| 144 |
+
default=False,
|
| 145 |
+
metadata={"help": "Will enable to load a pretrained vision adapter whose head dimensions are different."},
|
| 146 |
+
)
|
| 147 |
+
freeze_text_model: bool = field(
|
| 148 |
+
default=False,
|
| 149 |
+
metadata={"help": "Will enable to load a pretrained text model whose head dimensions are different."},
|
| 150 |
+
)
|
| 151 |
+
freeze_qformer: bool = field(
|
| 152 |
+
default=False,
|
| 153 |
+
metadata={"help": "Will enable to load a pretrained qformer model whose head dimensions are different."},
|
| 154 |
+
)
|
| 155 |
+
un_freeze_vision_embedding: bool = field(
|
| 156 |
+
default=False,
|
| 157 |
+
metadata={"help": "Will enable to tuning image patch_embedding when vision_model are frozen"},
|
| 158 |
+
)
|
| 159 |
+
un_freeze_video_embedding: bool = field(
|
| 160 |
+
default=False,
|
| 161 |
+
metadata={"help": "Will enable to tuning video patch_embedding when vision_model are frozen"},
|
| 162 |
+
)
|
| 163 |
+
un_freeze_llm_head: bool = field(
|
| 164 |
+
default=False,
|
| 165 |
+
metadata={"help": "Will enable to tuning video patch_embedding when vision_model are frozen"},
|
| 166 |
+
)
|
| 167 |
+
use_lora: bool = field(
|
| 168 |
+
default=False, metadata={"help": "add the LoRA adapters to the base model"}
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
@dataclass
|
| 172 |
+
class DataTrainingArguments:
|
| 173 |
+
"""
|
| 174 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
dataset_name: Optional[str] = field(
|
| 178 |
+
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
| 179 |
+
)
|
| 180 |
+
dataset_config_name: Optional[str] = field(
|
| 181 |
+
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
| 182 |
+
)
|
| 183 |
+
data_dir: Optional[str] = field(
|
| 184 |
+
default=None, metadata={"help": "The data directory containing input files."})
|
| 185 |
+
train_file: Optional[str] = field(
|
| 186 |
+
default=None, metadata={"help": "The input training data file (a jsonlines)."})
|
| 187 |
+
validation_file: Optional[str] = field(
|
| 188 |
+
default=None,
|
| 189 |
+
metadata={
|
| 190 |
+
"help": "An optional input evaluation data file to evaluate the metrics (sacrebleu) on a jsonlines file."
|
| 191 |
+
},
|
| 192 |
+
)
|
| 193 |
+
train_val_split: Optional[float] = field(
|
| 194 |
+
default=0.0, metadata={"help": "Percent to split off of train for validation."}
|
| 195 |
+
)
|
| 196 |
+
test_file: Optional[str] = field(
|
| 197 |
+
default=None,
|
| 198 |
+
metadata={"help": "An optional input test data file to evaluate the metrics (sacrebleu) on a jsonlines file."},
|
| 199 |
+
)
|
| 200 |
+
image_path: Optional[str] = field(
|
| 201 |
+
default=None,
|
| 202 |
+
metadata={"help": "An optional image path"},
|
| 203 |
+
)
|
| 204 |
+
video_path: Optional[str] = field(
|
| 205 |
+
default=None,
|
| 206 |
+
metadata={"help": "An optional video path"},
|
| 207 |
+
)
|
| 208 |
+
input_size: Optional[int] = field(
|
| 209 |
+
default=224,
|
| 210 |
+
metadata={"help": "The input size of images."},
|
| 211 |
+
)
|
| 212 |
+
overwrite_cache: bool = field(
|
| 213 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
| 214 |
+
)
|
| 215 |
+
preprocessing_num_workers: Optional[int] = field(
|
| 216 |
+
default=None,
|
| 217 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
| 218 |
+
)
|
| 219 |
+
max_seq_length: Optional[int] = field(
|
| 220 |
+
default=128,
|
| 221 |
+
metadata={
|
| 222 |
+
"help": (
|
| 223 |
+
"The maximum total input sequence length after tokenization. Sequences longer "
|
| 224 |
+
"than this will be truncated, sequences shorter will be padded."
|
| 225 |
+
)
|
| 226 |
+
},
|
| 227 |
+
)
|
| 228 |
+
pad_to_max_length: bool = field(
|
| 229 |
+
default=False,
|
| 230 |
+
metadata={
|
| 231 |
+
"help": (
|
| 232 |
+
"Whether to pad all samples to model maximum sentence length. "
|
| 233 |
+
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
| 234 |
+
"efficient on GPU but very bad for TPU."
|
| 235 |
+
)
|
| 236 |
+
},
|
| 237 |
+
)
|
| 238 |
+
val_max_length: Optional[int] = field(
|
| 239 |
+
default=None,
|
| 240 |
+
metadata={
|
| 241 |
+
"help": (
|
| 242 |
+
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
| 243 |
+
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
| 244 |
+
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
| 245 |
+
"during ``evaluate`` and ``predict``."
|
| 246 |
+
)
|
| 247 |
+
},
|
| 248 |
+
)
|
| 249 |
+
max_train_samples: Optional[int] = field(
|
| 250 |
+
default=None,
|
| 251 |
+
metadata={
|
| 252 |
+
"help": (
|
| 253 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
| 254 |
+
"value if set."
|
| 255 |
+
)
|
| 256 |
+
},
|
| 257 |
+
)
|
| 258 |
+
max_eval_samples: Optional[int] = field(
|
| 259 |
+
default=None,
|
| 260 |
+
metadata={
|
| 261 |
+
"help": (
|
| 262 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
| 263 |
+
"value if set."
|
| 264 |
+
)
|
| 265 |
+
},
|
| 266 |
+
)
|
| 267 |
+
max_predict_samples: Optional[int] = field(
|
| 268 |
+
default=None,
|
| 269 |
+
metadata={
|
| 270 |
+
"help": (
|
| 271 |
+
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
| 272 |
+
"value if set."
|
| 273 |
+
)
|
| 274 |
+
},
|
| 275 |
+
)
|
| 276 |
+
conv_style: Optional[str] = field(
|
| 277 |
+
default=None, metadata={"help": "prompt style for a conversation."}
|
| 278 |
+
)
|
| 279 |
+
save_data_path: Optional[str] = field(
|
| 280 |
+
default=None, metadata={"help": "prompt style for a conversation."}
|
| 281 |
+
)
|
| 282 |
+
num_beams: Optional[int] = field(
|
| 283 |
+
default=None,
|
| 284 |
+
metadata={
|
| 285 |
+
"help": (
|
| 286 |
+
"Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
| 287 |
+
"which is used during ``evaluate`` and ``predict``."
|
| 288 |
+
)
|
| 289 |
+
},
|
| 290 |
+
)
|
| 291 |
+
ignore_pad_token_for_loss: bool = field(
|
| 292 |
+
default=True,
|
| 293 |
+
metadata={
|
| 294 |
+
"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
|
| 295 |
+
},
|
| 296 |
+
)
|
| 297 |
+
source_prefix: Optional[str] = field(
|
| 298 |
+
default=None, metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
|
| 299 |
+
)
|
| 300 |
+
forced_bos_token: Optional[str] = field(
|
| 301 |
+
default=None,
|
| 302 |
+
metadata={
|
| 303 |
+
"help": (
|
| 304 |
+
"The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for"
|
| 305 |
+
" multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to"
|
| 306 |
+
" be the target language token.(Usually it is the target language token)"
|
| 307 |
+
)
|
| 308 |
+
},
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
def __post_init__(self):
|
| 312 |
+
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
| 313 |
+
raise ValueError("Need either a dataset name or a training/validation file.")
|
| 314 |
+
# accepting both json and jsonl file extensions, as
|
| 315 |
+
# many jsonlines files actually have a .json extension
|
| 316 |
+
else:
|
| 317 |
+
if self.train_file is not None:
|
| 318 |
+
extension = self.train_file.split(".")[-1]
|
| 319 |
+
assert extension in ["csv", "json", "jsonl", "parquet"], "`train_file` should be a csv or a json file."
|
| 320 |
+
if self.validation_file is not None:
|
| 321 |
+
extension = self.validation_file.split(".")[-1]
|
| 322 |
+
assert extension in ["csv", "json", "jsonl",
|
| 323 |
+
"parquet"], "`validation_file` should be a csv or a json file."
|
| 324 |
+
if self.test_file is not None:
|
| 325 |
+
extension = self.test_file.split(".")[-1]
|
| 326 |
+
assert extension == "json", "`test_file` should be a json file."
|
| 327 |
+
|
| 328 |
+
def main():
|
| 329 |
+
# 1. Parse input arguments
|
| 330 |
+
# See all possible arguments in src/transformers/training_args.py
|
| 331 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
| 332 |
+
init_dist(launcher='slurm', backend='nccl', port=29598)
|
| 333 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
| 334 |
+
# If we pass only one argument to the script, and it's the path to a json file,
|
| 335 |
+
# let's parse it to get our arguments.
|
| 336 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
| 337 |
+
else:
|
| 338 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 339 |
+
|
| 340 |
+
# 2. Setup logging
|
| 341 |
+
logging.basicConfig(
|
| 342 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 343 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 344 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
if training_args.should_log:
|
| 348 |
+
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
| 349 |
+
transformers.utils.logging.set_verbosity_info()
|
| 350 |
+
|
| 351 |
+
log_level = training_args.get_process_log_level()
|
| 352 |
+
logger.setLevel(log_level)
|
| 353 |
+
set_verbosity(log_level)
|
| 354 |
+
enable_default_handler()
|
| 355 |
+
enable_explicit_format()
|
| 356 |
+
|
| 357 |
+
# Log on each process the small summary:
|
| 358 |
+
logger.warning(
|
| 359 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
| 360 |
+
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
| 361 |
+
)
|
| 362 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
| 363 |
+
|
| 364 |
+
# 3. Detecting last checkpoint and eventually continue from last checkpoint.
|
| 365 |
+
last_checkpoint = None
|
| 366 |
+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
| 367 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
| 368 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
| 369 |
+
raise ValueError(
|
| 370 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
| 371 |
+
"Use --overwrite_output_dir to overcome."
|
| 372 |
+
)
|
| 373 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
| 374 |
+
logger.info(
|
| 375 |
+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
| 376 |
+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
| 377 |
+
)
|
| 378 |
+
|
| 379 |
+
# Set seed before initializing model.
|
| 380 |
+
set_seed(training_args.seed)
|
| 381 |
+
|
| 382 |
+
# 4. Get the datasets
|
| 383 |
+
# you can either provide your own JSON training and evaluation files (see below)
|
| 384 |
+
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
| 385 |
+
# (the dataset will be downloaded automatically from the datasets Hub).
|
| 386 |
+
#
|
| 387 |
+
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
| 388 |
+
# download the dataset.
|
| 389 |
+
|
| 390 |
+
if data_args.dataset_name is not None:
|
| 391 |
+
# Downloading and loading a dataset from the hub.
|
| 392 |
+
ds = load_dataset(
|
| 393 |
+
data_args.dataset_name,
|
| 394 |
+
data_args.dataset_config_name,
|
| 395 |
+
data_dir=data_args.data_dir,
|
| 396 |
+
cache_dir=model_args.cache_dir,
|
| 397 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
| 398 |
+
)
|
| 399 |
+
else:
|
| 400 |
+
data_files = {}
|
| 401 |
+
if data_args.train_file is not None:
|
| 402 |
+
data_files["train"] = data_args.train_file
|
| 403 |
+
extension = data_args.train_file.split(".")[-1]
|
| 404 |
+
if data_args.validation_file is not None:
|
| 405 |
+
data_files["validation"] = data_args.validation_file
|
| 406 |
+
extension = data_args.validation_file.split(".")[-1]
|
| 407 |
+
if data_args.test_file is not None:
|
| 408 |
+
data_files["test"] = data_args.test_file
|
| 409 |
+
extension = data_args.test_file.split(".")[-1]
|
| 410 |
+
|
| 411 |
+
# ds = load_dataset(
|
| 412 |
+
# "json" if extension == "jsonl" else extension,
|
| 413 |
+
# data_files=data_files,
|
| 414 |
+
# split="train"
|
| 415 |
+
# )
|
| 416 |
+
ds = json.load(open(data_args.train_file, "r"))
|
| 417 |
+
|
| 418 |
+
# 5. Load pretrained model, tokenizer, and image processor
|
| 419 |
+
#
|
| 420 |
+
# Distributed training: The .from_pretrained methods guarantee that only one local process can concurrently
|
| 421 |
+
# download model & vocab.
|
| 422 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
| 423 |
+
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
|
| 424 |
+
cache_dir=model_args.cache_dir,
|
| 425 |
+
use_fast=model_args.use_fast_tokenizer,
|
| 426 |
+
legacy=True,
|
| 427 |
+
)
|
| 428 |
+
# add special token
|
| 429 |
+
tokenizer.pad_token_id = 0
|
| 430 |
+
if tokenizer.unk_token is None:
|
| 431 |
+
tokenizer.add_special_tokens({"unk_token": DEFAULT_UNK_TOKEN})
|
| 432 |
+
|
| 433 |
+
tokens_list = [
|
| 434 |
+
DEFAULT_IMG_START_TOKEN, DEFAULT_IMG_END_TOKEN,
|
| 435 |
+
DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN
|
| 436 |
+
]
|
| 437 |
+
tokenizer.add_tokens(tokens_list, special_tokens=True)
|
| 438 |
+
|
| 439 |
+
model = HuskyForConditionalGeneration.from_pretrained(
|
| 440 |
+
model_args.model_name_or_path, ignore_mismatched_sizes=True
|
| 441 |
+
)
|
| 442 |
+
embedding_size = model.language_model.get_input_embeddings().weight.shape[0]
|
| 443 |
+
|
| 444 |
+
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
| 445 |
+
# on a small vocab and want a smaller embedding size, remove this test.
|
| 446 |
+
|
| 447 |
+
if len(tokenizer) > embedding_size:
|
| 448 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 449 |
+
model.language_model.resize_token_embeddings(len(tokenizer))
|
| 450 |
+
model.config.text_config.vocab_size = len(tokenizer)
|
| 451 |
+
|
| 452 |
+
model.config.use_cache = False
|
| 453 |
+
|
| 454 |
+
def _freeze_params(module):
|
| 455 |
+
for param in module.parameters():
|
| 456 |
+
param.requires_grad = False
|
| 457 |
+
|
| 458 |
+
if model_args.freeze_model:
|
| 459 |
+
_freeze_params(model)
|
| 460 |
+
# only update language projection
|
| 461 |
+
model.language_projection.weight.requires_grad = True
|
| 462 |
+
|
| 463 |
+
if model_args.freeze_vision_model:
|
| 464 |
+
model.vision_model = model.vision_model.eval()
|
| 465 |
+
_freeze_params(model.vision_model)
|
| 466 |
+
|
| 467 |
+
if model_args.freeze_vision_adapter:
|
| 468 |
+
_freeze_params(model.vision_adapter)
|
| 469 |
+
|
| 470 |
+
if model_args.freeze_qformer:
|
| 471 |
+
model.qformer = model.qformer.eval()
|
| 472 |
+
_freeze_params(model.qformer)
|
| 473 |
+
model.query_tokens.requires_grad = False
|
| 474 |
+
|
| 475 |
+
if model_args.freeze_text_model:
|
| 476 |
+
_freeze_params(model.language_model)
|
| 477 |
+
|
| 478 |
+
if model_args.use_lora:
|
| 479 |
+
training_args.ddp_find_unused_parameters = False
|
| 480 |
+
_freeze_params(model)
|
| 481 |
+
lora_config = LoraConfig(
|
| 482 |
+
r=16,
|
| 483 |
+
target_modules=["q_proj", "v_proj"],
|
| 484 |
+
lora_alpha=32,
|
| 485 |
+
lora_dropout=0.05,
|
| 486 |
+
bias="none",
|
| 487 |
+
task_type="CAUSAL_LM",
|
| 488 |
+
)
|
| 489 |
+
model.language_model = get_peft_model(model.language_model, lora_config)
|
| 490 |
+
model.language_model.print_trainable_parameters()
|
| 491 |
+
|
| 492 |
+
if model_args.un_freeze_video_embedding:
|
| 493 |
+
_freeze_params(model)
|
| 494 |
+
model.vision_model.video_embeddings.patch_embedding.weight.requires_grad = True
|
| 495 |
+
model.vision_model.video_embeddings.class_embedding.requires_grad = True
|
| 496 |
+
model.vision_model.video_embeddings.position_embedding.requires_grad = True
|
| 497 |
+
|
| 498 |
+
if model_args.un_freeze_llm_head:
|
| 499 |
+
model.language_model.lm_head.weight.requires_grad = True
|
| 500 |
+
|
| 501 |
+
# set seed for torch dataloaders
|
| 502 |
+
set_seed(training_args.seed)
|
| 503 |
+
|
| 504 |
+
# 7. Preprocessing the datasets.
|
| 505 |
+
# We need to tokenize input captions and transform the images.
|
| 506 |
+
|
| 507 |
+
# set padding.
|
| 508 |
+
padding = "max_length" if data_args.pad_to_max_length else False
|
| 509 |
+
|
| 510 |
+
def husky_processor(examples):
|
| 511 |
+
processor = partial(
|
| 512 |
+
process_func,
|
| 513 |
+
tokenizer=tokenizer,
|
| 514 |
+
max_seq_length=data_args.max_seq_length,
|
| 515 |
+
)
|
| 516 |
+
model_inputs = processor(examples)
|
| 517 |
+
return model_inputs
|
| 518 |
+
|
| 519 |
+
# Data collator
|
| 520 |
+
label_pad_token_id = IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
|
| 521 |
+
if data_args.pad_to_max_length:
|
| 522 |
+
data_collator = default_data_collator
|
| 523 |
+
else:
|
| 524 |
+
data_collator = DataCollatorForSeq2Seq(
|
| 525 |
+
tokenizer,
|
| 526 |
+
model=model,
|
| 527 |
+
label_pad_token_id=label_pad_token_id,
|
| 528 |
+
pad_to_multiple_of=8 if training_args.fp16 else None,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
concat_dataset = []
|
| 532 |
+
weights = []
|
| 533 |
+
batch_size = training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
| 534 |
+
for data in ds:
|
| 535 |
+
data_name = data['data_name']
|
| 536 |
+
data_file = data["text_file"]
|
| 537 |
+
extension = data_file.split(".")[-1]
|
| 538 |
+
extension = "json" if extension == "jsonl" else extension
|
| 539 |
+
logger.info(f"Loading dataset: {data_name}")
|
| 540 |
+
|
| 541 |
+
raw_dataset = load_dataset(extension, data_files=data_file, num_proc=cpu_count(), split="train")
|
| 542 |
+
raw_dataset = raw_dataset.shuffle(seed=0)
|
| 543 |
+
max_train_sample = min(len(raw_dataset), batch_size * (len(raw_dataset) // batch_size))
|
| 544 |
+
raw_dataset = raw_dataset.select(range(max_train_sample))
|
| 545 |
+
|
| 546 |
+
media_type = data["data_type"]
|
| 547 |
+
input_size = data_args.video_size if media_type == "video" else data_args.input_size
|
| 548 |
+
|
| 549 |
+
temp = BaseDataset(
|
| 550 |
+
raw_dataset,
|
| 551 |
+
processor=husky_processor,
|
| 552 |
+
image_path=data["image_path"],
|
| 553 |
+
input_size=input_size,
|
| 554 |
+
num_segments=8,
|
| 555 |
+
norm_type="openai",
|
| 556 |
+
media_type=media_type
|
| 557 |
+
)
|
| 558 |
+
|
| 559 |
+
concat_dataset.append(temp)
|
| 560 |
+
weights.append(1 / len(temp))
|
| 561 |
+
logger.info(f"All datasets have been loaded!")
|
| 562 |
+
|
| 563 |
+
if len(concat_dataset) > 1:
|
| 564 |
+
train_dataset = WeightedConcatDataset(datasets=concat_dataset, weights=weights, batch_size=batch_size)
|
| 565 |
+
else:
|
| 566 |
+
train_dataset = concat_dataset[0]
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
# 8. Initialize our Trainer
|
| 570 |
+
trainer = Trainer(
|
| 571 |
+
model=model,
|
| 572 |
+
args=training_args,
|
| 573 |
+
train_dataset=train_dataset if training_args.do_train else None,
|
| 574 |
+
eval_dataset=None,
|
| 575 |
+
tokenizer=tokenizer,
|
| 576 |
+
data_collator=data_collator,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
# 9. Training
|
| 580 |
+
if training_args.do_train:
|
| 581 |
+
checkpoint = None
|
| 582 |
+
if training_args.resume_from_checkpoint is not None:
|
| 583 |
+
checkpoint = training_args.resume_from_checkpoint
|
| 584 |
+
elif last_checkpoint is not None:
|
| 585 |
+
checkpoint = last_checkpoint
|
| 586 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
| 587 |
+
if model_args.use_lora:
|
| 588 |
+
model.language_model.save_pretrained(training_args.output_dir)
|
| 589 |
+
else:
|
| 590 |
+
trainer.save_model() # Saves the tokenizer too for easy upload
|
| 591 |
+
|
| 592 |
+
metrics = train_result.metrics
|
| 593 |
+
max_train_samples = (
|
| 594 |
+
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
| 595 |
+
)
|
| 596 |
+
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
| 597 |
+
|
| 598 |
+
trainer.log_metrics("train", metrics)
|
| 599 |
+
trainer.save_metrics("train", metrics)
|
| 600 |
+
trainer.save_state()
|
| 601 |
+
|
| 602 |
+
if __name__ == "__main__":
|
| 603 |
+
main()
|
robohusky/utils.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from asyncio import AbstractEventLoop
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import logging.handlers
|
| 5 |
+
import os
|
| 6 |
+
import platform
|
| 7 |
+
import sys
|
| 8 |
+
from typing import AsyncGenerator, Generator
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
import requests
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from husky.constants import LOGDIR
|
| 15 |
+
|
| 16 |
+
handler = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_logger(logger_name, logger_filename):
|
| 20 |
+
global handler
|
| 21 |
+
|
| 22 |
+
formatter = logging.Formatter(
|
| 23 |
+
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
| 24 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Set the format of root handlers
|
| 28 |
+
if not logging.getLogger().handlers:
|
| 29 |
+
if sys.version_info[1] >= 9:
|
| 30 |
+
# This is for windows
|
| 31 |
+
logging.basicConfig(level=logging.INFO, encoding="utf-8")
|
| 32 |
+
else:
|
| 33 |
+
if platform.system() == "Windows":
|
| 34 |
+
warnings.warn(
|
| 35 |
+
"If you are running on Windows, "
|
| 36 |
+
"we recommend you use Python >= 3.9 for UTF-8 encoding."
|
| 37 |
+
)
|
| 38 |
+
logging.basicConfig(level=logging.INFO)
|
| 39 |
+
logging.getLogger().handlers[0].setFormatter(formatter)
|
| 40 |
+
|
| 41 |
+
# Redirect stdout and stderr to loggers
|
| 42 |
+
stdout_logger = logging.getLogger("stdout")
|
| 43 |
+
stdout_logger.setLevel(logging.INFO)
|
| 44 |
+
sl = StreamToLogger(stdout_logger, logging.INFO)
|
| 45 |
+
sys.stdout = sl
|
| 46 |
+
|
| 47 |
+
stderr_logger = logging.getLogger("stderr")
|
| 48 |
+
stderr_logger.setLevel(logging.ERROR)
|
| 49 |
+
sl = StreamToLogger(stderr_logger, logging.ERROR)
|
| 50 |
+
sys.stderr = sl
|
| 51 |
+
|
| 52 |
+
# Get logger
|
| 53 |
+
logger = logging.getLogger(logger_name)
|
| 54 |
+
logger.setLevel(logging.INFO)
|
| 55 |
+
|
| 56 |
+
# Add a file handler for all loggers
|
| 57 |
+
if handler is None:
|
| 58 |
+
os.makedirs(LOGDIR, exist_ok=True)
|
| 59 |
+
filename = os.path.join(LOGDIR, logger_filename)
|
| 60 |
+
handler = logging.handlers.TimedRotatingFileHandler(
|
| 61 |
+
filename, when="D", utc=True, encoding="utf-8"
|
| 62 |
+
)
|
| 63 |
+
handler.setFormatter(formatter)
|
| 64 |
+
|
| 65 |
+
for name, item in logging.root.manager.loggerDict.items():
|
| 66 |
+
if isinstance(item, logging.Logger):
|
| 67 |
+
item.addHandler(handler)
|
| 68 |
+
|
| 69 |
+
return logger
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class StreamToLogger(object):
|
| 73 |
+
"""
|
| 74 |
+
Fake file-like stream object that redirects writes to a logger instance.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self, logger, log_level=logging.INFO):
|
| 78 |
+
self.terminal = sys.stdout
|
| 79 |
+
self.logger = logger
|
| 80 |
+
self.log_level = log_level
|
| 81 |
+
self.linebuf = ""
|
| 82 |
+
|
| 83 |
+
def __getattr__(self, attr):
|
| 84 |
+
return getattr(self.terminal, attr)
|
| 85 |
+
|
| 86 |
+
def write(self, buf):
|
| 87 |
+
temp_linebuf = self.linebuf + buf
|
| 88 |
+
self.linebuf = ""
|
| 89 |
+
for line in temp_linebuf.splitlines(True):
|
| 90 |
+
# From the io.TextIOWrapper docs:
|
| 91 |
+
# On output, if newline is None, any '\n' characters written
|
| 92 |
+
# are translated to the system default line separator.
|
| 93 |
+
# By default sys.stdout.write() expects '\n' newlines and then
|
| 94 |
+
# translates them so this is still cross platform.
|
| 95 |
+
if line[-1] == "\n":
|
| 96 |
+
encoded_message = line.encode("utf-8", "ignore").decode("utf-8")
|
| 97 |
+
self.logger.log(self.log_level, encoded_message.rstrip())
|
| 98 |
+
else:
|
| 99 |
+
self.linebuf += line
|
| 100 |
+
|
| 101 |
+
def flush(self):
|
| 102 |
+
if self.linebuf != "":
|
| 103 |
+
encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8")
|
| 104 |
+
self.logger.log(self.log_level, encoded_message.rstrip())
|
| 105 |
+
self.linebuf = ""
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def disable_torch_init():
|
| 109 |
+
"""
|
| 110 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
| 111 |
+
"""
|
| 112 |
+
import torch
|
| 113 |
+
|
| 114 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 115 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def get_gpu_memory(max_gpus=None):
|
| 119 |
+
"""Get available memory for each GPU."""
|
| 120 |
+
gpu_memory = []
|
| 121 |
+
num_gpus = (
|
| 122 |
+
torch.cuda.device_count()
|
| 123 |
+
if max_gpus is None
|
| 124 |
+
else min(max_gpus, torch.cuda.device_count())
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
for gpu_id in range(num_gpus):
|
| 128 |
+
with torch.cuda.device(gpu_id):
|
| 129 |
+
device = torch.cuda.current_device()
|
| 130 |
+
gpu_properties = torch.cuda.get_device_properties(device)
|
| 131 |
+
total_memory = gpu_properties.total_memory / (1024 ** 3)
|
| 132 |
+
allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)
|
| 133 |
+
available_memory = total_memory - allocated_memory
|
| 134 |
+
gpu_memory.append(available_memory)
|
| 135 |
+
return gpu_memory
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def violates_moderation(text):
|
| 139 |
+
"""
|
| 140 |
+
Check whether the text violates OpenAI moderation API.
|
| 141 |
+
"""
|
| 142 |
+
url = "https://api.openai.com/v1/moderations"
|
| 143 |
+
headers = {
|
| 144 |
+
"Content-Type": "application/json",
|
| 145 |
+
"Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
|
| 146 |
+
}
|
| 147 |
+
text = text.replace("\n", "")
|
| 148 |
+
data = "{" + '"input": ' + f'"{text}"' + "}"
|
| 149 |
+
data = data.encode("utf-8")
|
| 150 |
+
try:
|
| 151 |
+
ret = requests.post(url, headers=headers, data=data, timeout=5)
|
| 152 |
+
flagged = ret.json()["results"][0]["flagged"]
|
| 153 |
+
except requests.exceptions.RequestException as e:
|
| 154 |
+
flagged = False
|
| 155 |
+
except KeyError as e:
|
| 156 |
+
flagged = False
|
| 157 |
+
|
| 158 |
+
return flagged
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings,
|
| 162 |
+
# Use this function to make sure it can be correctly loaded.
|
| 163 |
+
def clean_flant5_ckpt(ckpt_path):
|
| 164 |
+
index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json")
|
| 165 |
+
index_json = json.load(open(index_file, "r"))
|
| 166 |
+
|
| 167 |
+
weightmap = index_json["weight_map"]
|
| 168 |
+
|
| 169 |
+
share_weight_file = weightmap["shared.weight"]
|
| 170 |
+
share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[
|
| 171 |
+
"shared.weight"
|
| 172 |
+
]
|
| 173 |
+
|
| 174 |
+
for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]:
|
| 175 |
+
weight_file = weightmap[weight_name]
|
| 176 |
+
weight = torch.load(os.path.join(ckpt_path, weight_file))
|
| 177 |
+
weight[weight_name] = share_weight
|
| 178 |
+
torch.save(weight, os.path.join(ckpt_path, weight_file))
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def pretty_print_semaphore(semaphore):
|
| 182 |
+
"""Print a semaphore in better format."""
|
| 183 |
+
if semaphore is None:
|
| 184 |
+
return "None"
|
| 185 |
+
return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
"""A javascript function to get url parameters for the gradio web server."""
|
| 189 |
+
get_window_url_params_js = """
|
| 190 |
+
function() {
|
| 191 |
+
const params = new URLSearchParams(window.location.search);
|
| 192 |
+
url_params = Object.fromEntries(params);
|
| 193 |
+
console.log("url_params", url_params);
|
| 194 |
+
return url_params;
|
| 195 |
+
}
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def iter_over_async(
|
| 200 |
+
async_gen: AsyncGenerator, event_loop: AbstractEventLoop
|
| 201 |
+
) -> Generator:
|
| 202 |
+
"""
|
| 203 |
+
Convert async generator to sync generator
|
| 204 |
+
|
| 205 |
+
:param async_gen: the AsyncGenerator to convert
|
| 206 |
+
:param event_loop: the event loop to run on
|
| 207 |
+
:returns: Sync generator
|
| 208 |
+
"""
|
| 209 |
+
ait = async_gen.__aiter__()
|
| 210 |
+
|
| 211 |
+
async def get_next():
|
| 212 |
+
try:
|
| 213 |
+
obj = await ait.__anext__()
|
| 214 |
+
return False, obj
|
| 215 |
+
except StopAsyncIteration:
|
| 216 |
+
return True, None
|
| 217 |
+
|
| 218 |
+
while True:
|
| 219 |
+
done, obj = event_loop.run_until_complete(get_next())
|
| 220 |
+
if done:
|
| 221 |
+
break
|
| 222 |
+
yield obj
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def detect_language(text: str) -> str:
|
| 226 |
+
"""Detect the langauge of a string."""
|
| 227 |
+
import polyglot # pip3 install polyglot pyicu pycld2
|
| 228 |
+
from polyglot.detect import Detector
|
| 229 |
+
from polyglot.detect.base import logger as polyglot_logger
|
| 230 |
+
import pycld2
|
| 231 |
+
|
| 232 |
+
polyglot_logger.setLevel("ERROR")
|
| 233 |
+
|
| 234 |
+
try:
|
| 235 |
+
lang_code = Detector(text).language.name
|
| 236 |
+
except (pycld2.error, polyglot.detect.base.UnknownLanguage):
|
| 237 |
+
lang_code = "unknown"
|
| 238 |
+
return lang_code
|
robohusky/video_transformers.py
ADDED
|
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision
|
| 2 |
+
import random
|
| 3 |
+
from PIL import Image, ImageOps
|
| 4 |
+
import numpy as np
|
| 5 |
+
import numbers
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
class GroupRandomCrop(object):
|
| 10 |
+
def __init__(self, size):
|
| 11 |
+
if isinstance(size, numbers.Number):
|
| 12 |
+
self.size = (int(size), int(size))
|
| 13 |
+
else:
|
| 14 |
+
self.size = size
|
| 15 |
+
|
| 16 |
+
def __call__(self, img_group):
|
| 17 |
+
|
| 18 |
+
w, h = img_group[0].size
|
| 19 |
+
th, tw = self.size
|
| 20 |
+
|
| 21 |
+
out_images = list()
|
| 22 |
+
|
| 23 |
+
x1 = random.randint(0, w - tw)
|
| 24 |
+
y1 = random.randint(0, h - th)
|
| 25 |
+
|
| 26 |
+
for img in img_group:
|
| 27 |
+
assert (img.size[0] == w and img.size[1] == h)
|
| 28 |
+
if w == tw and h == th:
|
| 29 |
+
out_images.append(img)
|
| 30 |
+
else:
|
| 31 |
+
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
|
| 32 |
+
|
| 33 |
+
return out_images
|
| 34 |
+
|
| 35 |
+
class MultiGroupRandomCrop(object):
|
| 36 |
+
def __init__(self, size, groups=1):
|
| 37 |
+
if isinstance(size, numbers.Number):
|
| 38 |
+
self.size = (int(size), int(size))
|
| 39 |
+
else:
|
| 40 |
+
self.size = size
|
| 41 |
+
self.groups = groups
|
| 42 |
+
|
| 43 |
+
def __call__(self, img_group):
|
| 44 |
+
|
| 45 |
+
w, h = img_group[0].size
|
| 46 |
+
th, tw = self.size
|
| 47 |
+
|
| 48 |
+
out_images = list()
|
| 49 |
+
|
| 50 |
+
for i in range(self.groups):
|
| 51 |
+
x1 = random.randint(0, w - tw)
|
| 52 |
+
y1 = random.randint(0, h - th)
|
| 53 |
+
|
| 54 |
+
for img in img_group:
|
| 55 |
+
assert (img.size[0] == w and img.size[1] == h)
|
| 56 |
+
if w == tw and h == th:
|
| 57 |
+
out_images.append(img)
|
| 58 |
+
else:
|
| 59 |
+
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
|
| 60 |
+
|
| 61 |
+
return out_images
|
| 62 |
+
|
| 63 |
+
class GroupCenterCrop(object):
|
| 64 |
+
def __init__(self, size):
|
| 65 |
+
self.worker = torchvision.transforms.CenterCrop(size)
|
| 66 |
+
|
| 67 |
+
def __call__(self, img_group):
|
| 68 |
+
return [self.worker(img) for img in img_group]
|
| 69 |
+
|
| 70 |
+
class GroupRandomHorizontalFlip(object):
|
| 71 |
+
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self, is_flow=False):
|
| 75 |
+
self.is_flow = is_flow
|
| 76 |
+
|
| 77 |
+
def __call__(self, img_group, is_flow=False):
|
| 78 |
+
v = random.random()
|
| 79 |
+
if v < 0.5:
|
| 80 |
+
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
|
| 81 |
+
if self.is_flow:
|
| 82 |
+
for i in range(0, len(ret), 2):
|
| 83 |
+
# invert flow pixel values when flipping
|
| 84 |
+
ret[i] = ImageOps.invert(ret[i])
|
| 85 |
+
return ret
|
| 86 |
+
else:
|
| 87 |
+
return img_group
|
| 88 |
+
|
| 89 |
+
class GroupNormalize(object):
|
| 90 |
+
def __init__(self, mean, std):
|
| 91 |
+
self.mean = mean
|
| 92 |
+
self.std = std
|
| 93 |
+
|
| 94 |
+
def __call__(self, tensor):
|
| 95 |
+
rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
|
| 96 |
+
rep_std = self.std * (tensor.size()[0] // len(self.std))
|
| 97 |
+
|
| 98 |
+
# TODO: make efficient
|
| 99 |
+
for t, m, s in zip(tensor, rep_mean, rep_std):
|
| 100 |
+
t.sub_(m).div_(s)
|
| 101 |
+
|
| 102 |
+
return tensor
|
| 103 |
+
|
| 104 |
+
class GroupScale(object):
|
| 105 |
+
""" Rescales the input PIL.Image to the given 'size'.
|
| 106 |
+
'size' will be the size of the smaller edge.
|
| 107 |
+
For example, if height > width, then image will be
|
| 108 |
+
rescaled to (size * height / width, size)
|
| 109 |
+
size: size of the smaller edge
|
| 110 |
+
interpolation: Default: PIL.Image.BILINEAR
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(self, size, interpolation=Image.BILINEAR):
|
| 114 |
+
self.worker = torchvision.transforms.Resize(size, interpolation)
|
| 115 |
+
|
| 116 |
+
def __call__(self, img_group):
|
| 117 |
+
return [self.worker(img) for img in img_group]
|
| 118 |
+
|
| 119 |
+
class GroupOverSample(object):
|
| 120 |
+
def __init__(self, crop_size, scale_size=None, flip=True):
|
| 121 |
+
self.crop_size = crop_size if not isinstance(
|
| 122 |
+
crop_size, int) else (crop_size, crop_size)
|
| 123 |
+
|
| 124 |
+
if scale_size is not None:
|
| 125 |
+
self.scale_worker = GroupScale(scale_size)
|
| 126 |
+
else:
|
| 127 |
+
self.scale_worker = None
|
| 128 |
+
self.flip = flip
|
| 129 |
+
|
| 130 |
+
def __call__(self, img_group):
|
| 131 |
+
|
| 132 |
+
if self.scale_worker is not None:
|
| 133 |
+
img_group = self.scale_worker(img_group)
|
| 134 |
+
|
| 135 |
+
image_w, image_h = img_group[0].size
|
| 136 |
+
crop_w, crop_h = self.crop_size
|
| 137 |
+
|
| 138 |
+
offsets = GroupMultiScaleCrop.fill_fix_offset(
|
| 139 |
+
False, image_w, image_h, crop_w, crop_h)
|
| 140 |
+
oversample_group = list()
|
| 141 |
+
for o_w, o_h in offsets:
|
| 142 |
+
normal_group = list()
|
| 143 |
+
flip_group = list()
|
| 144 |
+
for i, img in enumerate(img_group):
|
| 145 |
+
crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
|
| 146 |
+
normal_group.append(crop)
|
| 147 |
+
flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
|
| 148 |
+
|
| 149 |
+
if img.mode == 'L' and i % 2 == 0:
|
| 150 |
+
flip_group.append(ImageOps.invert(flip_crop))
|
| 151 |
+
else:
|
| 152 |
+
flip_group.append(flip_crop)
|
| 153 |
+
|
| 154 |
+
oversample_group.extend(normal_group)
|
| 155 |
+
if self.flip:
|
| 156 |
+
oversample_group.extend(flip_group)
|
| 157 |
+
return oversample_group
|
| 158 |
+
|
| 159 |
+
class GroupFullResSample(object):
|
| 160 |
+
def __init__(self, crop_size, scale_size=None, flip=True):
|
| 161 |
+
self.crop_size = crop_size if not isinstance(
|
| 162 |
+
crop_size, int) else (crop_size, crop_size)
|
| 163 |
+
|
| 164 |
+
if scale_size is not None:
|
| 165 |
+
self.scale_worker = GroupScale(scale_size)
|
| 166 |
+
else:
|
| 167 |
+
self.scale_worker = None
|
| 168 |
+
self.flip = flip
|
| 169 |
+
|
| 170 |
+
def __call__(self, img_group):
|
| 171 |
+
|
| 172 |
+
if self.scale_worker is not None:
|
| 173 |
+
img_group = self.scale_worker(img_group)
|
| 174 |
+
|
| 175 |
+
image_w, image_h = img_group[0].size
|
| 176 |
+
crop_w, crop_h = self.crop_size
|
| 177 |
+
|
| 178 |
+
w_step = (image_w - crop_w) // 4
|
| 179 |
+
h_step = (image_h - crop_h) // 4
|
| 180 |
+
|
| 181 |
+
offsets = list()
|
| 182 |
+
offsets.append((0 * w_step, 2 * h_step)) # left
|
| 183 |
+
offsets.append((4 * w_step, 2 * h_step)) # right
|
| 184 |
+
offsets.append((2 * w_step, 2 * h_step)) # center
|
| 185 |
+
|
| 186 |
+
oversample_group = list()
|
| 187 |
+
for o_w, o_h in offsets:
|
| 188 |
+
normal_group = list()
|
| 189 |
+
flip_group = list()
|
| 190 |
+
for i, img in enumerate(img_group):
|
| 191 |
+
crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
|
| 192 |
+
normal_group.append(crop)
|
| 193 |
+
if self.flip:
|
| 194 |
+
flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
|
| 195 |
+
|
| 196 |
+
if img.mode == 'L' and i % 2 == 0:
|
| 197 |
+
flip_group.append(ImageOps.invert(flip_crop))
|
| 198 |
+
else:
|
| 199 |
+
flip_group.append(flip_crop)
|
| 200 |
+
|
| 201 |
+
oversample_group.extend(normal_group)
|
| 202 |
+
oversample_group.extend(flip_group)
|
| 203 |
+
return oversample_group
|
| 204 |
+
|
| 205 |
+
class GroupMultiScaleCrop(object):
|
| 206 |
+
|
| 207 |
+
def __init__(self, input_size, scales=None, max_distort=1,
|
| 208 |
+
fix_crop=True, more_fix_crop=True):
|
| 209 |
+
self.scales = scales if scales is not None else [1, .875, .75, .66]
|
| 210 |
+
self.max_distort = max_distort
|
| 211 |
+
self.fix_crop = fix_crop
|
| 212 |
+
self.more_fix_crop = more_fix_crop
|
| 213 |
+
self.input_size = input_size if not isinstance(input_size, int) else [
|
| 214 |
+
input_size, input_size]
|
| 215 |
+
self.interpolation = Image.BILINEAR
|
| 216 |
+
|
| 217 |
+
def __call__(self, img_group):
|
| 218 |
+
|
| 219 |
+
im_size = img_group[0].size
|
| 220 |
+
|
| 221 |
+
crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
|
| 222 |
+
crop_img_group = [
|
| 223 |
+
img.crop(
|
| 224 |
+
(offset_w,
|
| 225 |
+
offset_h,
|
| 226 |
+
offset_w +
|
| 227 |
+
crop_w,
|
| 228 |
+
offset_h +
|
| 229 |
+
crop_h)) for img in img_group]
|
| 230 |
+
ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
|
| 231 |
+
for img in crop_img_group]
|
| 232 |
+
return ret_img_group
|
| 233 |
+
|
| 234 |
+
def _sample_crop_size(self, im_size):
|
| 235 |
+
image_w, image_h = im_size[0], im_size[1]
|
| 236 |
+
|
| 237 |
+
# find a crop size
|
| 238 |
+
base_size = min(image_w, image_h)
|
| 239 |
+
crop_sizes = [int(base_size * x) for x in self.scales]
|
| 240 |
+
crop_h = [
|
| 241 |
+
self.input_size[1] if abs(
|
| 242 |
+
x - self.input_size[1]) < 3 else x for x in crop_sizes]
|
| 243 |
+
crop_w = [
|
| 244 |
+
self.input_size[0] if abs(
|
| 245 |
+
x - self.input_size[0]) < 3 else x for x in crop_sizes]
|
| 246 |
+
|
| 247 |
+
pairs = []
|
| 248 |
+
for i, h in enumerate(crop_h):
|
| 249 |
+
for j, w in enumerate(crop_w):
|
| 250 |
+
if abs(i - j) <= self.max_distort:
|
| 251 |
+
pairs.append((w, h))
|
| 252 |
+
|
| 253 |
+
crop_pair = random.choice(pairs)
|
| 254 |
+
if not self.fix_crop:
|
| 255 |
+
w_offset = random.randint(0, image_w - crop_pair[0])
|
| 256 |
+
h_offset = random.randint(0, image_h - crop_pair[1])
|
| 257 |
+
else:
|
| 258 |
+
w_offset, h_offset = self._sample_fix_offset(
|
| 259 |
+
image_w, image_h, crop_pair[0], crop_pair[1])
|
| 260 |
+
|
| 261 |
+
return crop_pair[0], crop_pair[1], w_offset, h_offset
|
| 262 |
+
|
| 263 |
+
def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
|
| 264 |
+
offsets = self.fill_fix_offset(
|
| 265 |
+
self.more_fix_crop, image_w, image_h, crop_w, crop_h)
|
| 266 |
+
return random.choice(offsets)
|
| 267 |
+
|
| 268 |
+
@staticmethod
|
| 269 |
+
def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
|
| 270 |
+
w_step = (image_w - crop_w) // 4
|
| 271 |
+
h_step = (image_h - crop_h) // 4
|
| 272 |
+
|
| 273 |
+
ret = list()
|
| 274 |
+
ret.append((0, 0)) # upper left
|
| 275 |
+
ret.append((4 * w_step, 0)) # upper right
|
| 276 |
+
ret.append((0, 4 * h_step)) # lower left
|
| 277 |
+
ret.append((4 * w_step, 4 * h_step)) # lower right
|
| 278 |
+
ret.append((2 * w_step, 2 * h_step)) # center
|
| 279 |
+
|
| 280 |
+
if more_fix_crop:
|
| 281 |
+
ret.append((0, 2 * h_step)) # center left
|
| 282 |
+
ret.append((4 * w_step, 2 * h_step)) # center right
|
| 283 |
+
ret.append((2 * w_step, 4 * h_step)) # lower center
|
| 284 |
+
ret.append((2 * w_step, 0 * h_step)) # upper center
|
| 285 |
+
|
| 286 |
+
ret.append((1 * w_step, 1 * h_step)) # upper left quarter
|
| 287 |
+
ret.append((3 * w_step, 1 * h_step)) # upper right quarter
|
| 288 |
+
ret.append((1 * w_step, 3 * h_step)) # lower left quarter
|
| 289 |
+
ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
|
| 290 |
+
|
| 291 |
+
return ret
|
| 292 |
+
|
| 293 |
+
class GroupRandomSizedCrop(object):
|
| 294 |
+
"""Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
|
| 295 |
+
and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
|
| 296 |
+
This is popularly used to train the Inception networks
|
| 297 |
+
size: size of the smaller edge
|
| 298 |
+
interpolation: Default: PIL.Image.BILINEAR
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
def __init__(self, size, interpolation=Image.BILINEAR):
|
| 302 |
+
self.size = size
|
| 303 |
+
self.interpolation = interpolation
|
| 304 |
+
|
| 305 |
+
def __call__(self, img_group):
|
| 306 |
+
for attempt in range(10):
|
| 307 |
+
area = img_group[0].size[0] * img_group[0].size[1]
|
| 308 |
+
target_area = random.uniform(0.08, 1.0) * area
|
| 309 |
+
aspect_ratio = random.uniform(3. / 4, 4. / 3)
|
| 310 |
+
|
| 311 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
| 312 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
| 313 |
+
|
| 314 |
+
if random.random() < 0.5:
|
| 315 |
+
w, h = h, w
|
| 316 |
+
|
| 317 |
+
if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
|
| 318 |
+
x1 = random.randint(0, img_group[0].size[0] - w)
|
| 319 |
+
y1 = random.randint(0, img_group[0].size[1] - h)
|
| 320 |
+
found = True
|
| 321 |
+
break
|
| 322 |
+
else:
|
| 323 |
+
found = False
|
| 324 |
+
x1 = 0
|
| 325 |
+
y1 = 0
|
| 326 |
+
|
| 327 |
+
if found:
|
| 328 |
+
out_group = list()
|
| 329 |
+
for img in img_group:
|
| 330 |
+
img = img.crop((x1, y1, x1 + w, y1 + h))
|
| 331 |
+
assert (img.size == (w, h))
|
| 332 |
+
out_group.append(
|
| 333 |
+
img.resize(
|
| 334 |
+
(self.size, self.size), self.interpolation))
|
| 335 |
+
return out_group
|
| 336 |
+
else:
|
| 337 |
+
# Fallback
|
| 338 |
+
scale = GroupScale(self.size, interpolation=self.interpolation)
|
| 339 |
+
crop = GroupRandomCrop(self.size)
|
| 340 |
+
return crop(scale(img_group))
|
| 341 |
+
|
| 342 |
+
class ConvertDataFormat(object):
|
| 343 |
+
def __init__(self, model_type):
|
| 344 |
+
self.model_type = model_type
|
| 345 |
+
|
| 346 |
+
def __call__(self, images):
|
| 347 |
+
if self.model_type == '2D':
|
| 348 |
+
return images
|
| 349 |
+
tc, h, w = images.size()
|
| 350 |
+
t = tc // 3
|
| 351 |
+
images = images.view(t, 3, h, w)
|
| 352 |
+
images = images.permute(1, 0, 2, 3)
|
| 353 |
+
return images
|
| 354 |
+
|
| 355 |
+
class Stack(object):
|
| 356 |
+
|
| 357 |
+
def __init__(self, roll=False):
|
| 358 |
+
self.roll = roll
|
| 359 |
+
|
| 360 |
+
def __call__(self, img_group):
|
| 361 |
+
if img_group[0].mode == 'L':
|
| 362 |
+
return np.concatenate([np.expand_dims(x, 2)
|
| 363 |
+
for x in img_group], axis=2)
|
| 364 |
+
elif img_group[0].mode == 'RGB':
|
| 365 |
+
if self.roll:
|
| 366 |
+
return np.concatenate([np.array(x)[:, :, ::-1]
|
| 367 |
+
for x in img_group], axis=2)
|
| 368 |
+
else:
|
| 369 |
+
# print(np.concatenate(img_group, axis=2).shape)
|
| 370 |
+
# print(img_group[0].shape)
|
| 371 |
+
return np.concatenate(img_group, axis=2)
|
| 372 |
+
|
| 373 |
+
class ToTorchFormatTensor(object):
|
| 374 |
+
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
|
| 375 |
+
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
|
| 376 |
+
|
| 377 |
+
def __init__(self, div=True):
|
| 378 |
+
self.div = div
|
| 379 |
+
|
| 380 |
+
def __call__(self, pic):
|
| 381 |
+
if isinstance(pic, np.ndarray):
|
| 382 |
+
# handle numpy array
|
| 383 |
+
img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
|
| 384 |
+
else:
|
| 385 |
+
# handle PIL Image
|
| 386 |
+
img = torch.ByteTensor(
|
| 387 |
+
torch.ByteStorage.from_buffer(
|
| 388 |
+
pic.tobytes()))
|
| 389 |
+
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
|
| 390 |
+
# put it from HWC to CHW format
|
| 391 |
+
# yikes, this transpose takes 80% of the loading time/CPU
|
| 392 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
| 393 |
+
return img.float().div(255) if self.div else img.float()
|
| 394 |
+
|
| 395 |
+
class IdentityTransform(object):
|
| 396 |
+
|
| 397 |
+
def __call__(self, data):
|
| 398 |
+
return data
|
| 399 |
+
|
| 400 |
+
def get_index(num_frames, num_segments):
|
| 401 |
+
seg_size = float(num_frames - 1) / num_segments
|
| 402 |
+
start = int(seg_size / 2)
|
| 403 |
+
offsets = np.array([
|
| 404 |
+
start + int(np.round(seg_size * idx)) for idx in range(num_segments)
|
| 405 |
+
])
|
| 406 |
+
return offsets
|