Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- generation/control/oldm/hack.py +111 -0
- generation/control/oldm/lora.py +1119 -0
- generation/control/oldm/lora_ldm.py +343 -0
- generation/control/oldm/model.py +28 -0
- generation/control/oldm/oft_ldm.py +353 -0
- generation/subject/download_dreambooth.sh +4 -0
- generation/subject/evaluate.py +462 -0
- generation/subject/get_result.py +62 -0
- generation/subject/oft_utils/__init__.py +2 -0
- generation/subject/oft_utils/attention_processor.py +1036 -0
- generation/subject/oft_utils/mhe.py +360 -0
- generation/subject/train_dreambooth_hra.py +1123 -0
- generation/subject/train_dreambooth_hra.sh +186 -0
- llama/data/MATH_test.jsonl +0 -0
- llama/data/gsm8k_test.jsonl +0 -0
- llama/data/oft/__init__.py +20 -0
- llama/data/oft/config.py +119 -0
- llama/data/oft/layer.py +388 -0
- llama/data/oft/model.py +106 -0
- llama/finetune_32.py +368 -0
- llama/inference/MATH_inference.py +108 -0
- llama/inference/grader.py +141 -0
- llama/inference/gsm8k_inference.py +127 -0
- llama/inference/util.py +253 -0
- llama/merge_adapter_to_base_model.py +27 -0
- llama/output/cp1e4/ft/README.md +202 -0
- llama/output/cp1e4/ft/adapter_config.json +23 -0
- llama/output/cp1e4/ft/added_tokens.json +3 -0
- llama/output/cp1e4/ft/special_tokens_map.json +30 -0
- llama/output/cp1e4/ft/tokenizer.json +0 -0
- llama/output/cp1e4/ft/tokenizer_config.json +51 -0
- llama/output/cp1e5/ft/README.md +202 -0
- llama/output/cp1e5/ft/adapter_config.json +23 -0
- llama/output/cp1e5/trainer_state.json +30 -0
- llama/output/cp1e5N/ft/README.md +202 -0
- llama/output/cp1e5N/ft/adapter_config.json +23 -0
- llama/output/cp1e5N/ft/added_tokens.json +3 -0
- llama/output/cp1e5N/ft/special_tokens_map.json +30 -0
- llama/output/cp1e5N/ft/tokenizer.json +0 -0
- llama/output/cp1e5N/ft/tokenizer_config.json +51 -0
- llama/output/cp3e5/ft/README.md +202 -0
- llama/output/cp3e5/ft/adapter_config.json +23 -0
- llama/output/cp3e5/trainer_state.json +72 -0
- llama/output/cp3e5N/ft/README.md +202 -0
- llama/output/cp3e5N/ft/adapter_config.json +23 -0
- llama/output/cp3e5N/ft/added_tokens.json +3 -0
- llama/output/cp3e5N/ft/special_tokens_map.json +30 -0
- llama/output/cp3e5N/ft/tokenizer.json +0 -0
- llama/output/cp3e5N/ft/tokenizer_config.json +51 -0
- llama/output/cpr1/ft/README.md +202 -0
generation/control/oldm/hack.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import einops
|
| 3 |
+
|
| 4 |
+
import ldm.modules.encoders.modules
|
| 5 |
+
import ldm.modules.attention
|
| 6 |
+
|
| 7 |
+
from transformers import logging
|
| 8 |
+
from ldm.modules.attention import default
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def disable_verbosity():
|
| 12 |
+
logging.set_verbosity_error()
|
| 13 |
+
print('logging improved.')
|
| 14 |
+
return
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def enable_sliced_attention():
|
| 18 |
+
ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
|
| 19 |
+
print('Enabled sliced_attention.')
|
| 20 |
+
return
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def hack_everything(clip_skip=0):
|
| 24 |
+
disable_verbosity()
|
| 25 |
+
ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
|
| 26 |
+
ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
|
| 27 |
+
print('Enabled clip hacks.')
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Written by Lvmin
|
| 32 |
+
def _hacked_clip_forward(self, text):
|
| 33 |
+
PAD = self.tokenizer.pad_token_id
|
| 34 |
+
EOS = self.tokenizer.eos_token_id
|
| 35 |
+
BOS = self.tokenizer.bos_token_id
|
| 36 |
+
|
| 37 |
+
def tokenize(t):
|
| 38 |
+
return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
|
| 39 |
+
|
| 40 |
+
def transformer_encode(t):
|
| 41 |
+
if self.clip_skip > 1:
|
| 42 |
+
rt = self.transformer(input_ids=t, output_hidden_states=True)
|
| 43 |
+
return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
|
| 44 |
+
else:
|
| 45 |
+
return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
|
| 46 |
+
|
| 47 |
+
def split(x):
|
| 48 |
+
return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
|
| 49 |
+
|
| 50 |
+
def pad(x, p, i):
|
| 51 |
+
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
|
| 52 |
+
|
| 53 |
+
raw_tokens_list = tokenize(text)
|
| 54 |
+
tokens_list = []
|
| 55 |
+
|
| 56 |
+
for raw_tokens in raw_tokens_list:
|
| 57 |
+
raw_tokens_123 = split(raw_tokens)
|
| 58 |
+
raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
|
| 59 |
+
raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
|
| 60 |
+
tokens_list.append(raw_tokens_123)
|
| 61 |
+
|
| 62 |
+
tokens_list = torch.IntTensor(tokens_list).to(self.device)
|
| 63 |
+
|
| 64 |
+
feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
|
| 65 |
+
y = transformer_encode(feed)
|
| 66 |
+
z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
|
| 67 |
+
|
| 68 |
+
return z
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
|
| 72 |
+
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
|
| 73 |
+
h = self.heads
|
| 74 |
+
|
| 75 |
+
q = self.to_q(x)
|
| 76 |
+
context = default(context, x)
|
| 77 |
+
k = self.to_k(context)
|
| 78 |
+
v = self.to_v(context)
|
| 79 |
+
del context, x
|
| 80 |
+
|
| 81 |
+
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
| 82 |
+
|
| 83 |
+
limit = k.shape[0]
|
| 84 |
+
att_step = 1
|
| 85 |
+
q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
|
| 86 |
+
k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
|
| 87 |
+
v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
|
| 88 |
+
|
| 89 |
+
q_chunks.reverse()
|
| 90 |
+
k_chunks.reverse()
|
| 91 |
+
v_chunks.reverse()
|
| 92 |
+
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
| 93 |
+
del k, q, v
|
| 94 |
+
for i in range(0, limit, att_step):
|
| 95 |
+
q_buffer = q_chunks.pop()
|
| 96 |
+
k_buffer = k_chunks.pop()
|
| 97 |
+
v_buffer = v_chunks.pop()
|
| 98 |
+
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
|
| 99 |
+
|
| 100 |
+
del k_buffer, q_buffer
|
| 101 |
+
# attention, what we cannot get enough of, by chunks
|
| 102 |
+
|
| 103 |
+
sim_buffer = sim_buffer.softmax(dim=-1)
|
| 104 |
+
|
| 105 |
+
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
|
| 106 |
+
del v_buffer
|
| 107 |
+
sim[i:i + att_step, :, :] = sim_buffer
|
| 108 |
+
|
| 109 |
+
del sim_buffer
|
| 110 |
+
sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
|
| 111 |
+
return self.to_out(sim)
|
generation/control/oldm/lora.py
ADDED
|
@@ -0,0 +1,1119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This script is retrived from lora available at:
|
| 3 |
+
https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
|
| 4 |
+
|
| 5 |
+
Original Author: Simo Ryu
|
| 6 |
+
License: Apache License 2.0
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import json
|
| 10 |
+
import math
|
| 11 |
+
from itertools import groupby
|
| 12 |
+
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
| 13 |
+
|
| 14 |
+
import pickle
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import PIL
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from safetensors.torch import safe_open
|
| 24 |
+
from safetensors.torch import save_file as safe_save
|
| 25 |
+
|
| 26 |
+
safetensors_available = True
|
| 27 |
+
except ImportError:
|
| 28 |
+
from .safe_open import safe_open
|
| 29 |
+
|
| 30 |
+
def safe_save(
|
| 31 |
+
tensors: Dict[str, torch.Tensor],
|
| 32 |
+
filename: str,
|
| 33 |
+
metadata: Optional[Dict[str, str]] = None,
|
| 34 |
+
) -> None:
|
| 35 |
+
raise EnvironmentError(
|
| 36 |
+
"Saving safetensors requires the safetensors library. Please install with pip or similar."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
safetensors_available = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class LoraInjectedLinear(nn.Module):
|
| 43 |
+
def __init__(
|
| 44 |
+
self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
|
| 45 |
+
):
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
if r > min(in_features, out_features):
|
| 49 |
+
raise ValueError(
|
| 50 |
+
f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
|
| 51 |
+
)
|
| 52 |
+
self.r = r
|
| 53 |
+
self.linear = nn.Linear(in_features, out_features, bias)
|
| 54 |
+
self.lora_down = nn.Linear(in_features, r, bias=False)
|
| 55 |
+
self.dropout = nn.Dropout(dropout_p)
|
| 56 |
+
self.lora_up = nn.Linear(r, out_features, bias=False)
|
| 57 |
+
self.scale = scale
|
| 58 |
+
self.selector = nn.Identity()
|
| 59 |
+
|
| 60 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
| 61 |
+
nn.init.zeros_(self.lora_up.weight)
|
| 62 |
+
|
| 63 |
+
def forward(self, input):
|
| 64 |
+
return (
|
| 65 |
+
self.linear(input) + self.lora_up(self.selector(self.lora_down(input))) * self.scale
|
| 66 |
+
# + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
| 67 |
+
# * self.scale
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
def realize_as_lora(self):
|
| 71 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
| 72 |
+
|
| 73 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
| 74 |
+
# diag is a 1D tensor of size (r,)
|
| 75 |
+
assert diag.shape == (self.r,)
|
| 76 |
+
self.selector = nn.Linear(self.r, self.r, bias=False)
|
| 77 |
+
self.selector.weight.data = torch.diag(diag)
|
| 78 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
| 79 |
+
self.lora_up.weight.device
|
| 80 |
+
).to(self.lora_up.weight.dtype)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class LoraInjectedConv2d(nn.Module):
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
in_channels: int,
|
| 87 |
+
out_channels: int,
|
| 88 |
+
kernel_size,
|
| 89 |
+
stride=1,
|
| 90 |
+
padding=0,
|
| 91 |
+
dilation=1,
|
| 92 |
+
groups: int = 1,
|
| 93 |
+
bias: bool = True,
|
| 94 |
+
r: int = 4,
|
| 95 |
+
dropout_p: float = 0.1,
|
| 96 |
+
scale: float = 1.0,
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
if r > min(in_channels, out_channels):
|
| 100 |
+
raise ValueError(
|
| 101 |
+
f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
|
| 102 |
+
)
|
| 103 |
+
self.r = r
|
| 104 |
+
self.conv = nn.Conv2d(
|
| 105 |
+
in_channels=in_channels,
|
| 106 |
+
out_channels=out_channels,
|
| 107 |
+
kernel_size=kernel_size,
|
| 108 |
+
stride=stride,
|
| 109 |
+
padding=padding,
|
| 110 |
+
dilation=dilation,
|
| 111 |
+
groups=groups,
|
| 112 |
+
bias=bias,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
self.lora_down = nn.Conv2d(
|
| 116 |
+
in_channels=in_channels,
|
| 117 |
+
out_channels=r,
|
| 118 |
+
kernel_size=kernel_size,
|
| 119 |
+
stride=stride,
|
| 120 |
+
padding=padding,
|
| 121 |
+
dilation=dilation,
|
| 122 |
+
groups=groups,
|
| 123 |
+
bias=False,
|
| 124 |
+
)
|
| 125 |
+
self.dropout = nn.Dropout(dropout_p)
|
| 126 |
+
self.lora_up = nn.Conv2d(
|
| 127 |
+
in_channels=r,
|
| 128 |
+
out_channels=out_channels,
|
| 129 |
+
kernel_size=1,
|
| 130 |
+
stride=1,
|
| 131 |
+
padding=0,
|
| 132 |
+
bias=False,
|
| 133 |
+
)
|
| 134 |
+
self.selector = nn.Identity()
|
| 135 |
+
self.scale = scale
|
| 136 |
+
|
| 137 |
+
nn.init.normal_(self.lora_down.weight, std=1 / r)
|
| 138 |
+
nn.init.zeros_(self.lora_up.weight)
|
| 139 |
+
|
| 140 |
+
def forward(self, input):
|
| 141 |
+
return (
|
| 142 |
+
self.linear(input) + self.lora_up(self.selector(self.lora_down(input))) * self.scale
|
| 143 |
+
# + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
|
| 144 |
+
# * self.scale
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def realize_as_lora(self):
|
| 148 |
+
return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
|
| 149 |
+
|
| 150 |
+
def set_selector_from_diag(self, diag: torch.Tensor):
|
| 151 |
+
# diag is a 1D tensor of size (r,)
|
| 152 |
+
assert diag.shape == (self.r,)
|
| 153 |
+
self.selector = nn.Conv2d(
|
| 154 |
+
in_channels=self.r,
|
| 155 |
+
out_channels=self.r,
|
| 156 |
+
kernel_size=1,
|
| 157 |
+
stride=1,
|
| 158 |
+
padding=0,
|
| 159 |
+
bias=False,
|
| 160 |
+
)
|
| 161 |
+
self.selector.weight.data = torch.diag(diag)
|
| 162 |
+
|
| 163 |
+
# same device + dtype as lora_up
|
| 164 |
+
self.selector.weight.data = self.selector.weight.data.to(
|
| 165 |
+
self.lora_up.weight.device
|
| 166 |
+
).to(self.lora_up.weight.dtype)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
|
| 170 |
+
|
| 171 |
+
UNET_EXTENDED_TARGET_REPLACE = {"ResBlock", "CrossAttention", "Attention", "GEGLU"}
|
| 172 |
+
|
| 173 |
+
TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
|
| 174 |
+
|
| 175 |
+
TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
|
| 176 |
+
|
| 177 |
+
DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
|
| 178 |
+
|
| 179 |
+
EMBED_FLAG = "<embed>"
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _find_children(
|
| 183 |
+
model,
|
| 184 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
| 185 |
+
):
|
| 186 |
+
"""
|
| 187 |
+
Find all modules of a certain class (or union of classes).
|
| 188 |
+
Returns all matching modules, along with the parent of those moduless and the
|
| 189 |
+
names they are referenced by.
|
| 190 |
+
"""
|
| 191 |
+
result = []
|
| 192 |
+
for parent in model.modules():
|
| 193 |
+
for name, module in parent.named_children():
|
| 194 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
| 195 |
+
result.append((parent, name, module)) # Append the result to the list
|
| 196 |
+
|
| 197 |
+
return result # Return the list instead of using 'yield'
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def _find_modules_v2(
|
| 201 |
+
model,
|
| 202 |
+
ancestor_class: Optional[Set[str]] = None,
|
| 203 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
| 204 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = [
|
| 205 |
+
LoraInjectedLinear,
|
| 206 |
+
LoraInjectedConv2d,
|
| 207 |
+
],
|
| 208 |
+
):
|
| 209 |
+
"""
|
| 210 |
+
Find all modules of a certain class (or union of classes) that are direct or
|
| 211 |
+
indirect descendants of other modules of a certain class (or union of classes).
|
| 212 |
+
Returns all matching modules, along with the parent of those moduless and the
|
| 213 |
+
names they are referenced by.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
# Get the targets we should replace all linears under
|
| 217 |
+
if ancestor_class is not None:
|
| 218 |
+
ancestors = (
|
| 219 |
+
module
|
| 220 |
+
for module in model.modules()
|
| 221 |
+
if module.__class__.__name__ in ancestor_class
|
| 222 |
+
)
|
| 223 |
+
else:
|
| 224 |
+
# this, incase you want to naively iterate over all modules.
|
| 225 |
+
ancestors = [module for module in model.modules()]
|
| 226 |
+
|
| 227 |
+
results = []
|
| 228 |
+
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
|
| 229 |
+
for ancestor in ancestors:
|
| 230 |
+
for fullname, module in ancestor.named_modules():
|
| 231 |
+
if any([isinstance(module, _class) for _class in search_class]):
|
| 232 |
+
# Find the direct parent if this is a descendant, not a child, of target
|
| 233 |
+
*path, name = fullname.split(".")
|
| 234 |
+
parent = ancestor
|
| 235 |
+
while path:
|
| 236 |
+
parent = parent.get_submodule(path.pop(0))
|
| 237 |
+
# Skip this linear if it's a child of a LoraInjectedLinear
|
| 238 |
+
if exclude_children_of and any(
|
| 239 |
+
[isinstance(parent, _class) for _class in exclude_children_of]
|
| 240 |
+
):
|
| 241 |
+
continue
|
| 242 |
+
results.append((parent, name, module)) # Append the result to the list
|
| 243 |
+
|
| 244 |
+
return results # Return the list instead of using 'yield'
|
| 245 |
+
|
| 246 |
+
def _find_modules_old(
|
| 247 |
+
model,
|
| 248 |
+
ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
|
| 249 |
+
search_class: List[Type[nn.Module]] = [nn.Linear],
|
| 250 |
+
exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
|
| 251 |
+
):
|
| 252 |
+
ret = []
|
| 253 |
+
for _module in model.modules():
|
| 254 |
+
if _module.__class__.__name__ in ancestor_class:
|
| 255 |
+
|
| 256 |
+
for name, _child_module in _module.named_modules():
|
| 257 |
+
if _child_module.__class__ in search_class:
|
| 258 |
+
ret.append((_module, name, _child_module))
|
| 259 |
+
# print(ret)
|
| 260 |
+
return ret
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
_find_modules = _find_modules_v2
|
| 264 |
+
# _find_modules = _find_modules_old
|
| 265 |
+
|
| 266 |
+
def inject_trainable_lora(
|
| 267 |
+
model: nn.Module,
|
| 268 |
+
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
|
| 269 |
+
r: int = 4,
|
| 270 |
+
loras=None, # path to lora .pt
|
| 271 |
+
verbose: bool = False,
|
| 272 |
+
dropout_p: float = 0.0,
|
| 273 |
+
scale: float = 1.0,
|
| 274 |
+
):
|
| 275 |
+
"""
|
| 276 |
+
inject lora into model, and returns lora parameter groups.
|
| 277 |
+
"""
|
| 278 |
+
|
| 279 |
+
require_grad_params = []
|
| 280 |
+
names = []
|
| 281 |
+
|
| 282 |
+
if loras != None:
|
| 283 |
+
loras = torch.load(loras)
|
| 284 |
+
|
| 285 |
+
for _module, name, _child_module in _find_modules(
|
| 286 |
+
model, target_replace_module, search_class=[nn.Linear]
|
| 287 |
+
):
|
| 288 |
+
|
| 289 |
+
weight = _child_module.weight
|
| 290 |
+
bias = _child_module.bias
|
| 291 |
+
if verbose:
|
| 292 |
+
print("LoRA Injection : injecting lora into ", name)
|
| 293 |
+
print("LoRA Injection : weight shape", weight.shape)
|
| 294 |
+
_tmp = LoraInjectedLinear(
|
| 295 |
+
_child_module.in_features,
|
| 296 |
+
_child_module.out_features,
|
| 297 |
+
_child_module.bias is not None,
|
| 298 |
+
r=r,
|
| 299 |
+
dropout_p=dropout_p,
|
| 300 |
+
scale=scale,
|
| 301 |
+
)
|
| 302 |
+
_tmp.linear.weight = weight
|
| 303 |
+
if bias is not None:
|
| 304 |
+
_tmp.linear.bias = bias
|
| 305 |
+
|
| 306 |
+
# switch the module
|
| 307 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
| 308 |
+
_module._modules[name] = _tmp
|
| 309 |
+
|
| 310 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
| 311 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
| 312 |
+
|
| 313 |
+
if loras != None:
|
| 314 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
| 315 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
| 316 |
+
|
| 317 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
| 318 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
| 319 |
+
names.append(name)
|
| 320 |
+
|
| 321 |
+
return require_grad_params, names
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def inject_trainable_lora_extended(
|
| 325 |
+
model: nn.Module,
|
| 326 |
+
target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
|
| 327 |
+
r: int = 4,
|
| 328 |
+
loras=None, # path to lora .pt
|
| 329 |
+
):
|
| 330 |
+
"""
|
| 331 |
+
inject lora into model, and returns lora parameter groups.
|
| 332 |
+
"""
|
| 333 |
+
|
| 334 |
+
require_grad_params = []
|
| 335 |
+
names = []
|
| 336 |
+
|
| 337 |
+
if loras != None:
|
| 338 |
+
loras = torch.load(loras)
|
| 339 |
+
|
| 340 |
+
for _module, name, _child_module in _find_modules(
|
| 341 |
+
model, target_replace_module, search_class=[nn.Linear, nn.Conv2d]
|
| 342 |
+
):
|
| 343 |
+
if _child_module.__class__ == nn.Linear:
|
| 344 |
+
weight = _child_module.weight
|
| 345 |
+
bias = _child_module.bias
|
| 346 |
+
_tmp = LoraInjectedLinear(
|
| 347 |
+
_child_module.in_features,
|
| 348 |
+
_child_module.out_features,
|
| 349 |
+
_child_module.bias is not None,
|
| 350 |
+
r=r,
|
| 351 |
+
)
|
| 352 |
+
_tmp.linear.weight = weight
|
| 353 |
+
if bias is not None:
|
| 354 |
+
_tmp.linear.bias = bias
|
| 355 |
+
elif _child_module.__class__ == nn.Conv2d:
|
| 356 |
+
weight = _child_module.weight
|
| 357 |
+
bias = _child_module.bias
|
| 358 |
+
_tmp = LoraInjectedConv2d(
|
| 359 |
+
_child_module.in_channels,
|
| 360 |
+
_child_module.out_channels,
|
| 361 |
+
_child_module.kernel_size,
|
| 362 |
+
_child_module.stride,
|
| 363 |
+
_child_module.padding,
|
| 364 |
+
_child_module.dilation,
|
| 365 |
+
_child_module.groups,
|
| 366 |
+
_child_module.bias is not None,
|
| 367 |
+
r=r,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
_tmp.conv.weight = weight
|
| 371 |
+
if bias is not None:
|
| 372 |
+
_tmp.conv.bias = bias
|
| 373 |
+
|
| 374 |
+
# switch the module
|
| 375 |
+
_tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
|
| 376 |
+
if bias is not None:
|
| 377 |
+
_tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
|
| 378 |
+
|
| 379 |
+
_module._modules[name] = _tmp
|
| 380 |
+
|
| 381 |
+
require_grad_params.append(_module._modules[name].lora_up.parameters())
|
| 382 |
+
require_grad_params.append(_module._modules[name].lora_down.parameters())
|
| 383 |
+
|
| 384 |
+
if loras != None:
|
| 385 |
+
_module._modules[name].lora_up.weight = loras.pop(0)
|
| 386 |
+
_module._modules[name].lora_down.weight = loras.pop(0)
|
| 387 |
+
|
| 388 |
+
_module._modules[name].lora_up.weight.requires_grad = True
|
| 389 |
+
_module._modules[name].lora_down.weight.requires_grad = True
|
| 390 |
+
names.append(name)
|
| 391 |
+
|
| 392 |
+
return require_grad_params, names
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
|
| 396 |
+
|
| 397 |
+
loras = []
|
| 398 |
+
|
| 399 |
+
for _m, _n, _child_module in _find_modules(
|
| 400 |
+
model,
|
| 401 |
+
target_replace_module,
|
| 402 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
| 403 |
+
):
|
| 404 |
+
loras.append((_child_module.lora_up, _child_module.lora_down))
|
| 405 |
+
|
| 406 |
+
if len(loras) == 0:
|
| 407 |
+
raise ValueError("No lora injected.")
|
| 408 |
+
|
| 409 |
+
return loras
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def extract_lora_as_tensor(
|
| 413 |
+
model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
|
| 414 |
+
):
|
| 415 |
+
|
| 416 |
+
loras = []
|
| 417 |
+
|
| 418 |
+
for _m, _n, _child_module in _find_modules(
|
| 419 |
+
model,
|
| 420 |
+
target_replace_module,
|
| 421 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
| 422 |
+
):
|
| 423 |
+
up, down = _child_module.realize_as_lora()
|
| 424 |
+
if as_fp16:
|
| 425 |
+
up = up.to(torch.float16)
|
| 426 |
+
down = down.to(torch.float16)
|
| 427 |
+
|
| 428 |
+
loras.append((up, down))
|
| 429 |
+
|
| 430 |
+
if len(loras) == 0:
|
| 431 |
+
raise ValueError("No lora injected.")
|
| 432 |
+
|
| 433 |
+
return loras
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def save_lora_weight(
|
| 437 |
+
model,
|
| 438 |
+
path="./lora.pt",
|
| 439 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 440 |
+
):
|
| 441 |
+
weights = []
|
| 442 |
+
for _up, _down in extract_lora_ups_down(
|
| 443 |
+
model, target_replace_module=target_replace_module
|
| 444 |
+
):
|
| 445 |
+
weights.append(_up.weight.to("cpu").to(torch.float16))
|
| 446 |
+
weights.append(_down.weight.to("cpu").to(torch.float16))
|
| 447 |
+
|
| 448 |
+
torch.save(weights, path)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def save_lora_as_json(model, path="./lora.json"):
|
| 452 |
+
weights = []
|
| 453 |
+
for _up, _down in extract_lora_ups_down(model):
|
| 454 |
+
weights.append(_up.weight.detach().cpu().numpy().tolist())
|
| 455 |
+
weights.append(_down.weight.detach().cpu().numpy().tolist())
|
| 456 |
+
|
| 457 |
+
import json
|
| 458 |
+
|
| 459 |
+
with open(path, "w") as f:
|
| 460 |
+
json.dump(weights, f)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def save_safeloras_with_embeds(
|
| 464 |
+
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
|
| 465 |
+
embeds: Dict[str, torch.Tensor] = {},
|
| 466 |
+
outpath="./lora.safetensors",
|
| 467 |
+
):
|
| 468 |
+
"""
|
| 469 |
+
Saves the Lora from multiple modules in a single safetensor file.
|
| 470 |
+
modelmap is a dictionary of {
|
| 471 |
+
"module name": (module, target_replace_module)
|
| 472 |
+
}
|
| 473 |
+
"""
|
| 474 |
+
weights = {}
|
| 475 |
+
metadata = {}
|
| 476 |
+
|
| 477 |
+
for name, (model, target_replace_module) in modelmap.items():
|
| 478 |
+
metadata[name] = json.dumps(list(target_replace_module))
|
| 479 |
+
|
| 480 |
+
for i, (_up, _down) in enumerate(
|
| 481 |
+
extract_lora_as_tensor(model, target_replace_module)
|
| 482 |
+
):
|
| 483 |
+
rank = _down.shape[0]
|
| 484 |
+
|
| 485 |
+
metadata[f"{name}:{i}:rank"] = str(rank)
|
| 486 |
+
weights[f"{name}:{i}:up"] = _up
|
| 487 |
+
weights[f"{name}:{i}:down"] = _down
|
| 488 |
+
|
| 489 |
+
for token, tensor in embeds.items():
|
| 490 |
+
metadata[token] = EMBED_FLAG
|
| 491 |
+
weights[token] = tensor
|
| 492 |
+
|
| 493 |
+
print(f"Saving weights to {outpath}")
|
| 494 |
+
safe_save(weights, outpath, metadata)
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def save_safeloras(
|
| 498 |
+
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
|
| 499 |
+
outpath="./lora.safetensors",
|
| 500 |
+
):
|
| 501 |
+
return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def convert_loras_to_safeloras_with_embeds(
|
| 505 |
+
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
|
| 506 |
+
embeds: Dict[str, torch.Tensor] = {},
|
| 507 |
+
outpath="./lora.safetensors",
|
| 508 |
+
):
|
| 509 |
+
"""
|
| 510 |
+
Converts the Lora from multiple pytorch .pt files into a single safetensor file.
|
| 511 |
+
modelmap is a dictionary of {
|
| 512 |
+
"module name": (pytorch_model_path, target_replace_module, rank)
|
| 513 |
+
}
|
| 514 |
+
"""
|
| 515 |
+
|
| 516 |
+
weights = {}
|
| 517 |
+
metadata = {}
|
| 518 |
+
|
| 519 |
+
for name, (path, target_replace_module, r) in modelmap.items():
|
| 520 |
+
metadata[name] = json.dumps(list(target_replace_module))
|
| 521 |
+
|
| 522 |
+
lora = torch.load(path)
|
| 523 |
+
for i, weight in enumerate(lora):
|
| 524 |
+
is_up = i % 2 == 0
|
| 525 |
+
i = i // 2
|
| 526 |
+
|
| 527 |
+
if is_up:
|
| 528 |
+
metadata[f"{name}:{i}:rank"] = str(r)
|
| 529 |
+
weights[f"{name}:{i}:up"] = weight
|
| 530 |
+
else:
|
| 531 |
+
weights[f"{name}:{i}:down"] = weight
|
| 532 |
+
|
| 533 |
+
for token, tensor in embeds.items():
|
| 534 |
+
metadata[token] = EMBED_FLAG
|
| 535 |
+
weights[token] = tensor
|
| 536 |
+
|
| 537 |
+
print(f"Saving weights to {outpath}")
|
| 538 |
+
safe_save(weights, outpath, metadata)
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def convert_loras_to_safeloras(
|
| 542 |
+
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
|
| 543 |
+
outpath="./lora.safetensors",
|
| 544 |
+
):
|
| 545 |
+
convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def parse_safeloras(
|
| 549 |
+
safeloras,
|
| 550 |
+
) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
|
| 551 |
+
"""
|
| 552 |
+
Converts a loaded safetensor file that contains a set of module Loras
|
| 553 |
+
into Parameters and other information
|
| 554 |
+
Output is a dictionary of {
|
| 555 |
+
"module name": (
|
| 556 |
+
[list of weights],
|
| 557 |
+
[list of ranks],
|
| 558 |
+
target_replacement_modules
|
| 559 |
+
)
|
| 560 |
+
}
|
| 561 |
+
"""
|
| 562 |
+
loras = {}
|
| 563 |
+
metadata = safeloras.metadata()
|
| 564 |
+
|
| 565 |
+
get_name = lambda k: k.split(":")[0]
|
| 566 |
+
|
| 567 |
+
keys = list(safeloras.keys())
|
| 568 |
+
keys.sort(key=get_name)
|
| 569 |
+
|
| 570 |
+
for name, module_keys in groupby(keys, get_name):
|
| 571 |
+
info = metadata.get(name)
|
| 572 |
+
|
| 573 |
+
if not info:
|
| 574 |
+
raise ValueError(
|
| 575 |
+
f"Tensor {name} has no metadata - is this a Lora safetensor?"
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
# Skip Textual Inversion embeds
|
| 579 |
+
if info == EMBED_FLAG:
|
| 580 |
+
continue
|
| 581 |
+
|
| 582 |
+
# Handle Loras
|
| 583 |
+
# Extract the targets
|
| 584 |
+
target = json.loads(info)
|
| 585 |
+
|
| 586 |
+
# Build the result lists - Python needs us to preallocate lists to insert into them
|
| 587 |
+
module_keys = list(module_keys)
|
| 588 |
+
ranks = [4] * (len(module_keys) // 2)
|
| 589 |
+
weights = [None] * len(module_keys)
|
| 590 |
+
|
| 591 |
+
for key in module_keys:
|
| 592 |
+
# Split the model name and index out of the key
|
| 593 |
+
_, idx, direction = key.split(":")
|
| 594 |
+
idx = int(idx)
|
| 595 |
+
|
| 596 |
+
# Add the rank
|
| 597 |
+
ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
|
| 598 |
+
|
| 599 |
+
# Insert the weight into the list
|
| 600 |
+
idx = idx * 2 + (1 if direction == "down" else 0)
|
| 601 |
+
weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
|
| 602 |
+
|
| 603 |
+
loras[name] = (weights, ranks, target)
|
| 604 |
+
|
| 605 |
+
return loras
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def parse_safeloras_embeds(
|
| 609 |
+
safeloras,
|
| 610 |
+
) -> Dict[str, torch.Tensor]:
|
| 611 |
+
"""
|
| 612 |
+
Converts a loaded safetensor file that contains Textual Inversion embeds into
|
| 613 |
+
a dictionary of embed_token: Tensor
|
| 614 |
+
"""
|
| 615 |
+
embeds = {}
|
| 616 |
+
metadata = safeloras.metadata()
|
| 617 |
+
|
| 618 |
+
for key in safeloras.keys():
|
| 619 |
+
# Only handle Textual Inversion embeds
|
| 620 |
+
meta = metadata.get(key)
|
| 621 |
+
if not meta or meta != EMBED_FLAG:
|
| 622 |
+
continue
|
| 623 |
+
|
| 624 |
+
embeds[key] = safeloras.get_tensor(key)
|
| 625 |
+
|
| 626 |
+
return embeds
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
def load_safeloras(path, device="cpu"):
|
| 630 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
| 631 |
+
return parse_safeloras(safeloras)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def load_safeloras_embeds(path, device="cpu"):
|
| 635 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
| 636 |
+
return parse_safeloras_embeds(safeloras)
|
| 637 |
+
|
| 638 |
+
|
| 639 |
+
def load_safeloras_both(path, device="cpu"):
|
| 640 |
+
safeloras = safe_open(path, framework="pt", device=device)
|
| 641 |
+
return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def collapse_lora(model, alpha=1.0):
|
| 645 |
+
|
| 646 |
+
for _module, name, _child_module in _find_modules(
|
| 647 |
+
model,
|
| 648 |
+
UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
|
| 649 |
+
search_class=[LoraInjectedLinear, LoraInjectedConv2d],
|
| 650 |
+
):
|
| 651 |
+
|
| 652 |
+
if isinstance(_child_module, LoraInjectedLinear):
|
| 653 |
+
print("Collapsing Lin Lora in", name)
|
| 654 |
+
|
| 655 |
+
_child_module.linear.weight = nn.Parameter(
|
| 656 |
+
_child_module.linear.weight.data
|
| 657 |
+
+ alpha
|
| 658 |
+
* (
|
| 659 |
+
_child_module.lora_up.weight.data
|
| 660 |
+
@ _child_module.lora_down.weight.data
|
| 661 |
+
)
|
| 662 |
+
.type(_child_module.linear.weight.dtype)
|
| 663 |
+
.to(_child_module.linear.weight.device)
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
else:
|
| 667 |
+
print("Collapsing Conv Lora in", name)
|
| 668 |
+
_child_module.conv.weight = nn.Parameter(
|
| 669 |
+
_child_module.conv.weight.data
|
| 670 |
+
+ alpha
|
| 671 |
+
* (
|
| 672 |
+
_child_module.lora_up.weight.data.flatten(start_dim=1)
|
| 673 |
+
@ _child_module.lora_down.weight.data.flatten(start_dim=1)
|
| 674 |
+
)
|
| 675 |
+
.reshape(_child_module.conv.weight.data.shape)
|
| 676 |
+
.type(_child_module.conv.weight.dtype)
|
| 677 |
+
.to(_child_module.conv.weight.device)
|
| 678 |
+
)
|
| 679 |
+
|
| 680 |
+
|
| 681 |
+
def monkeypatch_or_replace_lora(
|
| 682 |
+
model,
|
| 683 |
+
loras,
|
| 684 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 685 |
+
r: Union[int, List[int]] = 4,
|
| 686 |
+
):
|
| 687 |
+
for _module, name, _child_module in _find_modules(
|
| 688 |
+
model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
|
| 689 |
+
):
|
| 690 |
+
_source = (
|
| 691 |
+
_child_module.linear
|
| 692 |
+
if isinstance(_child_module, LoraInjectedLinear)
|
| 693 |
+
else _child_module
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
weight = _source.weight
|
| 697 |
+
bias = _source.bias
|
| 698 |
+
_tmp = LoraInjectedLinear(
|
| 699 |
+
_source.in_features,
|
| 700 |
+
_source.out_features,
|
| 701 |
+
_source.bias is not None,
|
| 702 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
| 703 |
+
)
|
| 704 |
+
_tmp.linear.weight = weight
|
| 705 |
+
|
| 706 |
+
if bias is not None:
|
| 707 |
+
_tmp.linear.bias = bias
|
| 708 |
+
|
| 709 |
+
# switch the module
|
| 710 |
+
_module._modules[name] = _tmp
|
| 711 |
+
|
| 712 |
+
up_weight = loras.pop(0)
|
| 713 |
+
down_weight = loras.pop(0)
|
| 714 |
+
|
| 715 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
| 716 |
+
up_weight.type(weight.dtype)
|
| 717 |
+
)
|
| 718 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
| 719 |
+
down_weight.type(weight.dtype)
|
| 720 |
+
)
|
| 721 |
+
|
| 722 |
+
_module._modules[name].to(weight.device)
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def monkeypatch_or_replace_lora_extended(
|
| 726 |
+
model,
|
| 727 |
+
loras,
|
| 728 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 729 |
+
r: Union[int, List[int]] = 4,
|
| 730 |
+
):
|
| 731 |
+
for _module, name, _child_module in _find_modules(
|
| 732 |
+
model,
|
| 733 |
+
target_replace_module,
|
| 734 |
+
search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d],
|
| 735 |
+
):
|
| 736 |
+
|
| 737 |
+
if (_child_module.__class__ == nn.Linear) or (
|
| 738 |
+
_child_module.__class__ == LoraInjectedLinear
|
| 739 |
+
):
|
| 740 |
+
if len(loras[0].shape) != 2:
|
| 741 |
+
continue
|
| 742 |
+
|
| 743 |
+
_source = (
|
| 744 |
+
_child_module.linear
|
| 745 |
+
if isinstance(_child_module, LoraInjectedLinear)
|
| 746 |
+
else _child_module
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
weight = _source.weight
|
| 750 |
+
bias = _source.bias
|
| 751 |
+
_tmp = LoraInjectedLinear(
|
| 752 |
+
_source.in_features,
|
| 753 |
+
_source.out_features,
|
| 754 |
+
_source.bias is not None,
|
| 755 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
| 756 |
+
)
|
| 757 |
+
_tmp.linear.weight = weight
|
| 758 |
+
|
| 759 |
+
if bias is not None:
|
| 760 |
+
_tmp.linear.bias = bias
|
| 761 |
+
|
| 762 |
+
elif (_child_module.__class__ == nn.Conv2d) or (
|
| 763 |
+
_child_module.__class__ == LoraInjectedConv2d
|
| 764 |
+
):
|
| 765 |
+
if len(loras[0].shape) != 4:
|
| 766 |
+
continue
|
| 767 |
+
_source = (
|
| 768 |
+
_child_module.conv
|
| 769 |
+
if isinstance(_child_module, LoraInjectedConv2d)
|
| 770 |
+
else _child_module
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
weight = _source.weight
|
| 774 |
+
bias = _source.bias
|
| 775 |
+
_tmp = LoraInjectedConv2d(
|
| 776 |
+
_source.in_channels,
|
| 777 |
+
_source.out_channels,
|
| 778 |
+
_source.kernel_size,
|
| 779 |
+
_source.stride,
|
| 780 |
+
_source.padding,
|
| 781 |
+
_source.dilation,
|
| 782 |
+
_source.groups,
|
| 783 |
+
_source.bias is not None,
|
| 784 |
+
r=r.pop(0) if isinstance(r, list) else r,
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
_tmp.conv.weight = weight
|
| 788 |
+
|
| 789 |
+
if bias is not None:
|
| 790 |
+
_tmp.conv.bias = bias
|
| 791 |
+
|
| 792 |
+
# switch the module
|
| 793 |
+
_module._modules[name] = _tmp
|
| 794 |
+
|
| 795 |
+
up_weight = loras.pop(0)
|
| 796 |
+
down_weight = loras.pop(0)
|
| 797 |
+
|
| 798 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
| 799 |
+
up_weight.type(weight.dtype)
|
| 800 |
+
)
|
| 801 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
| 802 |
+
down_weight.type(weight.dtype)
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
_module._modules[name].to(weight.device)
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
def monkeypatch_or_replace_safeloras(models, safeloras):
|
| 809 |
+
loras = parse_safeloras(safeloras)
|
| 810 |
+
|
| 811 |
+
for name, (lora, ranks, target) in loras.items():
|
| 812 |
+
model = getattr(models, name, None)
|
| 813 |
+
|
| 814 |
+
if not model:
|
| 815 |
+
print(f"No model provided for {name}, contained in Lora")
|
| 816 |
+
continue
|
| 817 |
+
|
| 818 |
+
monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
def monkeypatch_remove_lora(model):
|
| 822 |
+
for _module, name, _child_module in _find_modules(
|
| 823 |
+
model, search_class=[LoraInjectedLinear, LoraInjectedConv2d]
|
| 824 |
+
):
|
| 825 |
+
if isinstance(_child_module, LoraInjectedLinear):
|
| 826 |
+
_source = _child_module.linear
|
| 827 |
+
weight, bias = _source.weight, _source.bias
|
| 828 |
+
|
| 829 |
+
_tmp = nn.Linear(
|
| 830 |
+
_source.in_features, _source.out_features, bias is not None
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
_tmp.weight = weight
|
| 834 |
+
if bias is not None:
|
| 835 |
+
_tmp.bias = bias
|
| 836 |
+
|
| 837 |
+
else:
|
| 838 |
+
_source = _child_module.conv
|
| 839 |
+
weight, bias = _source.weight, _source.bias
|
| 840 |
+
|
| 841 |
+
_tmp = nn.Conv2d(
|
| 842 |
+
in_channels=_source.in_channels,
|
| 843 |
+
out_channels=_source.out_channels,
|
| 844 |
+
kernel_size=_source.kernel_size,
|
| 845 |
+
stride=_source.stride,
|
| 846 |
+
padding=_source.padding,
|
| 847 |
+
dilation=_source.dilation,
|
| 848 |
+
groups=_source.groups,
|
| 849 |
+
bias=bias is not None,
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
_tmp.weight = weight
|
| 853 |
+
if bias is not None:
|
| 854 |
+
_tmp.bias = bias
|
| 855 |
+
|
| 856 |
+
_module._modules[name] = _tmp
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
def monkeypatch_add_lora(
|
| 860 |
+
model,
|
| 861 |
+
loras,
|
| 862 |
+
target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 863 |
+
alpha: float = 1.0,
|
| 864 |
+
beta: float = 1.0,
|
| 865 |
+
):
|
| 866 |
+
for _module, name, _child_module in _find_modules(
|
| 867 |
+
model, target_replace_module, search_class=[LoraInjectedLinear]
|
| 868 |
+
):
|
| 869 |
+
weight = _child_module.linear.weight
|
| 870 |
+
|
| 871 |
+
up_weight = loras.pop(0)
|
| 872 |
+
down_weight = loras.pop(0)
|
| 873 |
+
|
| 874 |
+
_module._modules[name].lora_up.weight = nn.Parameter(
|
| 875 |
+
up_weight.type(weight.dtype).to(weight.device) * alpha
|
| 876 |
+
+ _module._modules[name].lora_up.weight.to(weight.device) * beta
|
| 877 |
+
)
|
| 878 |
+
_module._modules[name].lora_down.weight = nn.Parameter(
|
| 879 |
+
down_weight.type(weight.dtype).to(weight.device) * alpha
|
| 880 |
+
+ _module._modules[name].lora_down.weight.to(weight.device) * beta
|
| 881 |
+
)
|
| 882 |
+
|
| 883 |
+
_module._modules[name].to(weight.device)
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
def tune_lora_scale(model, alpha: float = 1.0):
|
| 887 |
+
for _module in model.modules():
|
| 888 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
| 889 |
+
_module.scale = alpha
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
def set_lora_diag(model, diag: torch.Tensor):
|
| 893 |
+
for _module in model.modules():
|
| 894 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
| 895 |
+
_module.set_selector_from_diag(diag)
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
def _text_lora_path(path: str) -> str:
|
| 899 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
| 900 |
+
return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
def _ti_lora_path(path: str) -> str:
|
| 904 |
+
assert path.endswith(".pt"), "Only .pt files are supported"
|
| 905 |
+
return ".".join(path.split(".")[:-1] + ["ti", "pt"])
|
| 906 |
+
|
| 907 |
+
|
| 908 |
+
def apply_learned_embed_in_clip(
|
| 909 |
+
learned_embeds,
|
| 910 |
+
text_encoder,
|
| 911 |
+
tokenizer,
|
| 912 |
+
token: Optional[Union[str, List[str]]] = None,
|
| 913 |
+
idempotent=False,
|
| 914 |
+
):
|
| 915 |
+
if isinstance(token, str):
|
| 916 |
+
trained_tokens = [token]
|
| 917 |
+
elif isinstance(token, list):
|
| 918 |
+
assert len(learned_embeds.keys()) == len(
|
| 919 |
+
token
|
| 920 |
+
), "The number of tokens and the number of embeds should be the same"
|
| 921 |
+
trained_tokens = token
|
| 922 |
+
else:
|
| 923 |
+
trained_tokens = list(learned_embeds.keys())
|
| 924 |
+
|
| 925 |
+
for token in trained_tokens:
|
| 926 |
+
print(token)
|
| 927 |
+
embeds = learned_embeds[token]
|
| 928 |
+
|
| 929 |
+
# cast to dtype of text_encoder
|
| 930 |
+
dtype = text_encoder.get_input_embeddings().weight.dtype
|
| 931 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
| 932 |
+
|
| 933 |
+
i = 1
|
| 934 |
+
if not idempotent:
|
| 935 |
+
while num_added_tokens == 0:
|
| 936 |
+
print(f"The tokenizer already contains the token {token}.")
|
| 937 |
+
token = f"{token[:-1]}-{i}>"
|
| 938 |
+
print(f"Attempting to add the token {token}.")
|
| 939 |
+
num_added_tokens = tokenizer.add_tokens(token)
|
| 940 |
+
i += 1
|
| 941 |
+
elif num_added_tokens == 0 and idempotent:
|
| 942 |
+
print(f"The tokenizer already contains the token {token}.")
|
| 943 |
+
print(f"Replacing {token} embedding.")
|
| 944 |
+
|
| 945 |
+
# resize the token embeddings
|
| 946 |
+
text_encoder.resize_token_embeddings(len(tokenizer))
|
| 947 |
+
|
| 948 |
+
# get the id for the token and assign the embeds
|
| 949 |
+
token_id = tokenizer.convert_tokens_to_ids(token)
|
| 950 |
+
text_encoder.get_input_embeddings().weight.data[token_id] = embeds
|
| 951 |
+
return token
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
def load_learned_embed_in_clip(
|
| 955 |
+
learned_embeds_path,
|
| 956 |
+
text_encoder,
|
| 957 |
+
tokenizer,
|
| 958 |
+
token: Optional[Union[str, List[str]]] = None,
|
| 959 |
+
idempotent=False,
|
| 960 |
+
):
|
| 961 |
+
learned_embeds = torch.load(learned_embeds_path)
|
| 962 |
+
apply_learned_embed_in_clip(
|
| 963 |
+
learned_embeds, text_encoder, tokenizer, token, idempotent
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
|
| 967 |
+
def patch_pipe(
|
| 968 |
+
pipe,
|
| 969 |
+
maybe_unet_path,
|
| 970 |
+
token: Optional[str] = None,
|
| 971 |
+
r: int = 4,
|
| 972 |
+
patch_unet=True,
|
| 973 |
+
patch_text=True,
|
| 974 |
+
patch_ti=True,
|
| 975 |
+
idempotent_token=True,
|
| 976 |
+
unet_target_replace_module=DEFAULT_TARGET_REPLACE,
|
| 977 |
+
text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
| 978 |
+
):
|
| 979 |
+
if maybe_unet_path.endswith(".pt"):
|
| 980 |
+
# torch format
|
| 981 |
+
|
| 982 |
+
if maybe_unet_path.endswith(".ti.pt"):
|
| 983 |
+
unet_path = maybe_unet_path[:-6] + ".pt"
|
| 984 |
+
elif maybe_unet_path.endswith(".text_encoder.pt"):
|
| 985 |
+
unet_path = maybe_unet_path[:-16] + ".pt"
|
| 986 |
+
else:
|
| 987 |
+
unet_path = maybe_unet_path
|
| 988 |
+
|
| 989 |
+
ti_path = _ti_lora_path(unet_path)
|
| 990 |
+
text_path = _text_lora_path(unet_path)
|
| 991 |
+
|
| 992 |
+
if patch_unet:
|
| 993 |
+
print("LoRA : Patching Unet")
|
| 994 |
+
monkeypatch_or_replace_lora(
|
| 995 |
+
pipe.unet,
|
| 996 |
+
torch.load(unet_path),
|
| 997 |
+
r=r,
|
| 998 |
+
target_replace_module=unet_target_replace_module,
|
| 999 |
+
)
|
| 1000 |
+
|
| 1001 |
+
if patch_text:
|
| 1002 |
+
print("LoRA : Patching text encoder")
|
| 1003 |
+
monkeypatch_or_replace_lora(
|
| 1004 |
+
pipe.text_encoder,
|
| 1005 |
+
torch.load(text_path),
|
| 1006 |
+
target_replace_module=text_target_replace_module,
|
| 1007 |
+
r=r,
|
| 1008 |
+
)
|
| 1009 |
+
if patch_ti:
|
| 1010 |
+
print("LoRA : Patching token input")
|
| 1011 |
+
token = load_learned_embed_in_clip(
|
| 1012 |
+
ti_path,
|
| 1013 |
+
pipe.text_encoder,
|
| 1014 |
+
pipe.tokenizer,
|
| 1015 |
+
token=token,
|
| 1016 |
+
idempotent=idempotent_token,
|
| 1017 |
+
)
|
| 1018 |
+
|
| 1019 |
+
elif maybe_unet_path.endswith(".safetensors"):
|
| 1020 |
+
safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
|
| 1021 |
+
monkeypatch_or_replace_safeloras(pipe, safeloras)
|
| 1022 |
+
tok_dict = parse_safeloras_embeds(safeloras)
|
| 1023 |
+
if patch_ti:
|
| 1024 |
+
apply_learned_embed_in_clip(
|
| 1025 |
+
tok_dict,
|
| 1026 |
+
pipe.text_encoder,
|
| 1027 |
+
pipe.tokenizer,
|
| 1028 |
+
token=token,
|
| 1029 |
+
idempotent=idempotent_token,
|
| 1030 |
+
)
|
| 1031 |
+
return tok_dict
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
@torch.no_grad()
|
| 1035 |
+
def inspect_lora(model):
|
| 1036 |
+
moved = {}
|
| 1037 |
+
|
| 1038 |
+
for name, _module in model.named_modules():
|
| 1039 |
+
if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
|
| 1040 |
+
ups = _module.lora_up.weight.data.clone()
|
| 1041 |
+
downs = _module.lora_down.weight.data.clone()
|
| 1042 |
+
|
| 1043 |
+
wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
|
| 1044 |
+
|
| 1045 |
+
dist = wght.flatten().abs().mean().item()
|
| 1046 |
+
if name in moved:
|
| 1047 |
+
moved[name].append(dist)
|
| 1048 |
+
else:
|
| 1049 |
+
moved[name] = [dist]
|
| 1050 |
+
|
| 1051 |
+
return moved
|
| 1052 |
+
|
| 1053 |
+
|
| 1054 |
+
def save_all(
|
| 1055 |
+
unet,
|
| 1056 |
+
text_encoder,
|
| 1057 |
+
save_path,
|
| 1058 |
+
placeholder_token_ids=None,
|
| 1059 |
+
placeholder_tokens=None,
|
| 1060 |
+
save_lora=True,
|
| 1061 |
+
save_ti=True,
|
| 1062 |
+
target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
|
| 1063 |
+
target_replace_module_unet=DEFAULT_TARGET_REPLACE,
|
| 1064 |
+
safe_form=True,
|
| 1065 |
+
):
|
| 1066 |
+
if not safe_form:
|
| 1067 |
+
# save ti
|
| 1068 |
+
if save_ti:
|
| 1069 |
+
ti_path = _ti_lora_path(save_path)
|
| 1070 |
+
learned_embeds_dict = {}
|
| 1071 |
+
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
|
| 1072 |
+
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
|
| 1073 |
+
print(
|
| 1074 |
+
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
|
| 1075 |
+
learned_embeds[:4],
|
| 1076 |
+
)
|
| 1077 |
+
learned_embeds_dict[tok] = learned_embeds.detach().cpu()
|
| 1078 |
+
|
| 1079 |
+
torch.save(learned_embeds_dict, ti_path)
|
| 1080 |
+
print("Ti saved to ", ti_path)
|
| 1081 |
+
|
| 1082 |
+
# save text encoder
|
| 1083 |
+
if save_lora:
|
| 1084 |
+
|
| 1085 |
+
save_lora_weight(
|
| 1086 |
+
unet, save_path, target_replace_module=target_replace_module_unet
|
| 1087 |
+
)
|
| 1088 |
+
print("Unet saved to ", save_path)
|
| 1089 |
+
|
| 1090 |
+
save_lora_weight(
|
| 1091 |
+
text_encoder,
|
| 1092 |
+
_text_lora_path(save_path),
|
| 1093 |
+
target_replace_module=target_replace_module_text,
|
| 1094 |
+
)
|
| 1095 |
+
print("Text Encoder saved to ", _text_lora_path(save_path))
|
| 1096 |
+
|
| 1097 |
+
else:
|
| 1098 |
+
assert save_path.endswith(
|
| 1099 |
+
".safetensors"
|
| 1100 |
+
), f"Save path : {save_path} should end with .safetensors"
|
| 1101 |
+
|
| 1102 |
+
loras = {}
|
| 1103 |
+
embeds = {}
|
| 1104 |
+
|
| 1105 |
+
if save_lora:
|
| 1106 |
+
|
| 1107 |
+
loras["unet"] = (unet, target_replace_module_unet)
|
| 1108 |
+
loras["text_encoder"] = (text_encoder, target_replace_module_text)
|
| 1109 |
+
|
| 1110 |
+
if save_ti:
|
| 1111 |
+
for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
|
| 1112 |
+
learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
|
| 1113 |
+
print(
|
| 1114 |
+
f"Current Learned Embeddings for {tok}:, id {tok_id} ",
|
| 1115 |
+
learned_embeds[:4],
|
| 1116 |
+
)
|
| 1117 |
+
embeds[tok] = learned_embeds.detach().cpu()
|
| 1118 |
+
|
| 1119 |
+
save_safeloras_with_embeds(loras, embeds, save_path)
|
generation/control/oldm/lora_ldm.py
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import einops
|
| 2 |
+
import torch
|
| 3 |
+
import torch as th
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
from ldm.modules.diffusionmodules.util import (
|
| 10 |
+
conv_nd,
|
| 11 |
+
linear,
|
| 12 |
+
zero_module,
|
| 13 |
+
timestep_embedding,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from einops import rearrange, repeat
|
| 17 |
+
from torchvision.utils import make_grid
|
| 18 |
+
from ldm.modules.attention import SpatialTransformer
|
| 19 |
+
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Upsample, Downsample, AttentionBlock, normalization
|
| 20 |
+
from ldm.models.diffusion.ddpm import LatentDiffusion
|
| 21 |
+
from ldm.util import log_txt_as_img, exists, instantiate_from_config
|
| 22 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
| 23 |
+
|
| 24 |
+
from cldm.lora import inject_trainable_lora, extract_lora_ups_down, inject_trainable_lora_extended
|
| 25 |
+
|
| 26 |
+
def count_parameters(params):
|
| 27 |
+
num_params = sum(p.numel() for p in params)
|
| 28 |
+
return round(num_params / 1e6, 1)
|
| 29 |
+
|
| 30 |
+
def set_requires_grad(model, requires_grad=True):
|
| 31 |
+
for param in model.parameters():
|
| 32 |
+
param.requires_grad = requires_grad
|
| 33 |
+
|
| 34 |
+
class ControlledUnetModel(UNetModel):
|
| 35 |
+
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
|
| 36 |
+
hs = []
|
| 37 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 38 |
+
emb = self.time_embed(t_emb)
|
| 39 |
+
|
| 40 |
+
h = x.type(self.dtype)
|
| 41 |
+
for module in self.input_blocks:
|
| 42 |
+
if control is not None:
|
| 43 |
+
h = module(h, emb, context)
|
| 44 |
+
h += control
|
| 45 |
+
control = None
|
| 46 |
+
else:
|
| 47 |
+
h = module(h, emb, context)
|
| 48 |
+
hs.append(h)
|
| 49 |
+
h = self.middle_block(h, emb, context)
|
| 50 |
+
for module in self.output_blocks:
|
| 51 |
+
h = th.cat([h, hs.pop()], dim=1)
|
| 52 |
+
h = module(h, emb, context)
|
| 53 |
+
h = h.type(x.dtype)
|
| 54 |
+
|
| 55 |
+
return self.out(h)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ControlNet(nn.Module):
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
image_size,
|
| 62 |
+
in_channels,
|
| 63 |
+
model_channels,
|
| 64 |
+
out_channels,
|
| 65 |
+
hint_channels,
|
| 66 |
+
num_res_blocks,
|
| 67 |
+
attention_resolutions,
|
| 68 |
+
dropout=0,
|
| 69 |
+
channel_mult=(1, 2, 4, 8),
|
| 70 |
+
conv_resample=True,
|
| 71 |
+
dims=2,
|
| 72 |
+
use_checkpoint=False,
|
| 73 |
+
use_fp16=False,
|
| 74 |
+
num_heads=-1,
|
| 75 |
+
num_head_channels=-1,
|
| 76 |
+
num_heads_upsample=-1,
|
| 77 |
+
use_scale_shift_norm=False,
|
| 78 |
+
resblock_updown=False,
|
| 79 |
+
use_new_attention_order=False,
|
| 80 |
+
use_spatial_transformer=False, # custom transformer support
|
| 81 |
+
transformer_depth=1, # custom transformer support
|
| 82 |
+
context_dim=None, # custom transformer support
|
| 83 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
| 84 |
+
legacy=True,
|
| 85 |
+
disable_self_attentions=None,
|
| 86 |
+
num_attention_blocks=None,
|
| 87 |
+
disable_middle_self_attn=False,
|
| 88 |
+
use_linear_in_transformer=False,
|
| 89 |
+
):
|
| 90 |
+
super().__init__()
|
| 91 |
+
if use_spatial_transformer:
|
| 92 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
| 93 |
+
|
| 94 |
+
if context_dim is not None:
|
| 95 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
| 96 |
+
from omegaconf.listconfig import ListConfig
|
| 97 |
+
if type(context_dim) == ListConfig:
|
| 98 |
+
context_dim = list(context_dim)
|
| 99 |
+
|
| 100 |
+
if num_heads_upsample == -1:
|
| 101 |
+
num_heads_upsample = num_heads
|
| 102 |
+
|
| 103 |
+
if num_heads == -1:
|
| 104 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
| 105 |
+
|
| 106 |
+
if num_head_channels == -1:
|
| 107 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
| 108 |
+
|
| 109 |
+
self.dims = dims
|
| 110 |
+
self.image_size = image_size
|
| 111 |
+
self.in_channels = in_channels
|
| 112 |
+
self.model_channels = model_channels
|
| 113 |
+
if isinstance(num_res_blocks, int):
|
| 114 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
| 115 |
+
else:
|
| 116 |
+
if len(num_res_blocks) != len(channel_mult):
|
| 117 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
| 118 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
| 119 |
+
self.num_res_blocks = num_res_blocks
|
| 120 |
+
if disable_self_attentions is not None:
|
| 121 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
| 122 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
| 123 |
+
if num_attention_blocks is not None:
|
| 124 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
| 125 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
| 126 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
| 127 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
| 128 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
| 129 |
+
f"attention will still not be set.")
|
| 130 |
+
|
| 131 |
+
self.attention_resolutions = attention_resolutions
|
| 132 |
+
self.dropout = dropout
|
| 133 |
+
self.channel_mult = channel_mult
|
| 134 |
+
self.conv_resample = conv_resample
|
| 135 |
+
self.use_checkpoint = use_checkpoint
|
| 136 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
| 137 |
+
self.num_heads = num_heads
|
| 138 |
+
self.num_head_channels = num_head_channels
|
| 139 |
+
self.num_heads_upsample = num_heads_upsample
|
| 140 |
+
self.predict_codebook_ids = n_embed is not None
|
| 141 |
+
|
| 142 |
+
time_embed_dim = model_channels * 4
|
| 143 |
+
self.time_embed = nn.Sequential(
|
| 144 |
+
linear(model_channels, time_embed_dim),
|
| 145 |
+
nn.SiLU(),
|
| 146 |
+
linear(time_embed_dim, time_embed_dim),
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
self.input_hint_block = TimestepEmbedSequential(
|
| 150 |
+
conv_nd(dims, hint_channels, 16, 3, padding=1),
|
| 151 |
+
nn.SiLU(),
|
| 152 |
+
conv_nd(dims, 16, 16, 3, padding=1),
|
| 153 |
+
nn.SiLU(),
|
| 154 |
+
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
| 155 |
+
nn.SiLU(),
|
| 156 |
+
conv_nd(dims, 32, 32, 3, padding=1),
|
| 157 |
+
nn.SiLU(),
|
| 158 |
+
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
| 159 |
+
nn.SiLU(),
|
| 160 |
+
conv_nd(dims, 96, 96, 3, padding=1),
|
| 161 |
+
nn.SiLU(),
|
| 162 |
+
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
| 163 |
+
nn.SiLU(),
|
| 164 |
+
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def forward(self, x, hint, timesteps, context, **kwargs):
|
| 169 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 170 |
+
emb = self.time_embed(t_emb)
|
| 171 |
+
|
| 172 |
+
guided_hint = self.input_hint_block(hint, emb, context)
|
| 173 |
+
|
| 174 |
+
# print('guided_hint', len(guided_hint), guided_hint[0].shape)
|
| 175 |
+
# sys.exit()
|
| 176 |
+
|
| 177 |
+
return guided_hint
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class ControlLDM(LatentDiffusion):
|
| 181 |
+
def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
|
| 182 |
+
super().__init__(*args, **kwargs)
|
| 183 |
+
self.control_model = instantiate_from_config(control_stage_config)
|
| 184 |
+
self.control_key = control_key
|
| 185 |
+
self.only_mid_control = only_mid_control
|
| 186 |
+
self.control_scales = [1.0] * 13
|
| 187 |
+
|
| 188 |
+
@torch.no_grad()
|
| 189 |
+
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
| 190 |
+
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
|
| 191 |
+
control = batch[self.control_key]
|
| 192 |
+
if bs is not None:
|
| 193 |
+
control = control[:bs]
|
| 194 |
+
control = control.to(self.device)
|
| 195 |
+
control = einops.rearrange(control, 'b h w c -> b c h w')
|
| 196 |
+
control = control.to(memory_format=torch.contiguous_format).float()
|
| 197 |
+
return x, dict(c_crossattn=[c], c_concat=[control])
|
| 198 |
+
|
| 199 |
+
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
|
| 200 |
+
assert isinstance(cond, dict)
|
| 201 |
+
diffusion_model = self.model.diffusion_model
|
| 202 |
+
|
| 203 |
+
cond_txt = torch.cat(cond['c_crossattn'], 1)
|
| 204 |
+
|
| 205 |
+
if cond['c_concat'] is None:
|
| 206 |
+
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
|
| 207 |
+
else:
|
| 208 |
+
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
|
| 209 |
+
# control = [c * scale for c, scale in zip(control, self.control_scales)]
|
| 210 |
+
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
|
| 211 |
+
|
| 212 |
+
return eps
|
| 213 |
+
|
| 214 |
+
@torch.no_grad()
|
| 215 |
+
def get_unconditional_conditioning(self, N):
|
| 216 |
+
return self.get_learned_conditioning([""] * N)
|
| 217 |
+
|
| 218 |
+
@torch.no_grad()
|
| 219 |
+
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
|
| 220 |
+
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
| 221 |
+
plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
|
| 222 |
+
use_ema_scope=True, num_samples=1,
|
| 223 |
+
**kwargs):
|
| 224 |
+
use_ddim = ddim_steps is not None
|
| 225 |
+
|
| 226 |
+
log = dict()
|
| 227 |
+
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
| 228 |
+
c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
|
| 229 |
+
N = min(z.shape[0], N)
|
| 230 |
+
n_row = min(z.shape[0], n_row)
|
| 231 |
+
log["reconstruction"] = self.decode_first_stage(z)
|
| 232 |
+
log["control"] = c_cat * 2.0 - 1.0
|
| 233 |
+
log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
|
| 234 |
+
|
| 235 |
+
if plot_diffusion_rows:
|
| 236 |
+
# get diffusion row
|
| 237 |
+
diffusion_row = list()
|
| 238 |
+
z_start = z[:n_row]
|
| 239 |
+
for t in range(self.num_timesteps):
|
| 240 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
| 241 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
| 242 |
+
t = t.to(self.device).long()
|
| 243 |
+
noise = torch.randn_like(z_start)
|
| 244 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
| 245 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
| 246 |
+
|
| 247 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
| 248 |
+
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
| 249 |
+
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
| 250 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
| 251 |
+
log["diffusion_row"] = diffusion_grid
|
| 252 |
+
|
| 253 |
+
if sample:
|
| 254 |
+
# get denoise row
|
| 255 |
+
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
| 256 |
+
batch_size=N, ddim=use_ddim,
|
| 257 |
+
ddim_steps=ddim_steps, eta=ddim_eta)
|
| 258 |
+
x_samples = self.decode_first_stage(samples)
|
| 259 |
+
log["samples"] = x_samples
|
| 260 |
+
if plot_denoise_rows:
|
| 261 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
| 262 |
+
log["denoise_row"] = denoise_grid
|
| 263 |
+
|
| 264 |
+
if kwargs['split'] == 'train':
|
| 265 |
+
if unconditional_guidance_scale > 1.0:
|
| 266 |
+
uc_cross = self.get_unconditional_conditioning(N)
|
| 267 |
+
uc_cat = c_cat # torch.zeros_like(c_cat)
|
| 268 |
+
|
| 269 |
+
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
| 270 |
+
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
| 271 |
+
batch_size=N, ddim=use_ddim,
|
| 272 |
+
ddim_steps=ddim_steps, eta=ddim_eta,
|
| 273 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 274 |
+
unconditional_conditioning=uc_full,
|
| 275 |
+
)
|
| 276 |
+
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
| 277 |
+
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
| 278 |
+
|
| 279 |
+
else:
|
| 280 |
+
if unconditional_guidance_scale > 1.0:
|
| 281 |
+
# uc_cross = self.get_unconditional_conditioning(N)
|
| 282 |
+
# uc_cat = c_cat # torch.zeros_like(c_cat)
|
| 283 |
+
|
| 284 |
+
c_cat = torch.stack([c_cat[0] for _ in range(num_samples)], dim=0).clone()
|
| 285 |
+
|
| 286 |
+
cond = {"c_concat": [c_cat], "c_crossattn": [self.get_learned_conditioning([batch['txt'][0]] * num_samples)]}
|
| 287 |
+
uc_full = {"c_concat": [c_cat], "c_crossattn": [self.get_learned_conditioning([''] * num_samples)]}
|
| 288 |
+
|
| 289 |
+
samples_cfg, _ = self.sample_log(cond=cond, # cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
| 290 |
+
batch_size=num_samples, ddim=use_ddim,
|
| 291 |
+
ddim_steps=ddim_steps, eta=ddim_eta,
|
| 292 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 293 |
+
unconditional_conditioning=uc_full,
|
| 294 |
+
)
|
| 295 |
+
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
| 296 |
+
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
| 297 |
+
|
| 298 |
+
return log
|
| 299 |
+
|
| 300 |
+
@torch.no_grad()
|
| 301 |
+
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
| 302 |
+
ddim_sampler = DDIMSampler(self)
|
| 303 |
+
b, c, h, w = cond["c_concat"][0].shape
|
| 304 |
+
shape = (self.channels, h // 8, w // 8)
|
| 305 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
| 306 |
+
return samples, intermediates
|
| 307 |
+
|
| 308 |
+
def configure_optimizers(self):
|
| 309 |
+
lr = self.learning_rate
|
| 310 |
+
params = list(self.control_model.parameters())
|
| 311 |
+
names = []
|
| 312 |
+
for name, param in self.model.diffusion_model.named_parameters():
|
| 313 |
+
if param.requires_grad:
|
| 314 |
+
params.append(param)
|
| 315 |
+
names.append(name)
|
| 316 |
+
|
| 317 |
+
# params += self.unet_lora_params
|
| 318 |
+
if not self.sd_locked:
|
| 319 |
+
params += list(self.model.diffusion_model.output_blocks.parameters())
|
| 320 |
+
params += list(self.model.diffusion_model.out.parameters())
|
| 321 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
| 322 |
+
|
| 323 |
+
set_requires_grad(self.model.diffusion_model, True)
|
| 324 |
+
|
| 325 |
+
num_params = count_parameters(params)
|
| 326 |
+
print()
|
| 327 |
+
print()
|
| 328 |
+
print(f"Total number of trainable parameters: {num_params} M!")
|
| 329 |
+
print()
|
| 330 |
+
print()
|
| 331 |
+
return opt
|
| 332 |
+
|
| 333 |
+
def low_vram_shift(self, is_diffusing):
|
| 334 |
+
if is_diffusing:
|
| 335 |
+
self.model = self.model.cuda()
|
| 336 |
+
self.control_model = self.control_model.cuda()
|
| 337 |
+
self.first_stage_model = self.first_stage_model.cpu()
|
| 338 |
+
self.cond_stage_model = self.cond_stage_model.cpu()
|
| 339 |
+
else:
|
| 340 |
+
self.model = self.model.cpu()
|
| 341 |
+
self.control_model = self.control_model.cpu()
|
| 342 |
+
self.first_stage_model = self.first_stage_model.cuda()
|
| 343 |
+
self.cond_stage_model = self.cond_stage_model.cuda()
|
generation/control/oldm/model.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
from ldm.util import instantiate_from_config
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_state_dict(d):
|
| 9 |
+
return d.get('state_dict', d)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_state_dict(ckpt_path, location='cpu'):
|
| 13 |
+
_, extension = os.path.splitext(ckpt_path)
|
| 14 |
+
if extension.lower() == ".safetensors":
|
| 15 |
+
import safetensors.torch
|
| 16 |
+
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
| 17 |
+
else:
|
| 18 |
+
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
|
| 19 |
+
state_dict = get_state_dict(state_dict)
|
| 20 |
+
print(f'Loaded state_dict from [{ckpt_path}]')
|
| 21 |
+
return state_dict
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def create_model(config_path):
|
| 25 |
+
config = OmegaConf.load(config_path)
|
| 26 |
+
model = instantiate_from_config(config.model).cpu()
|
| 27 |
+
print(f'Loaded model config from [{config_path}]')
|
| 28 |
+
return model
|
generation/control/oldm/oft_ldm.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import einops
|
| 2 |
+
import torch
|
| 3 |
+
import torch as th
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
from ldm.modules.diffusionmodules.util import (
|
| 10 |
+
conv_nd,
|
| 11 |
+
linear,
|
| 12 |
+
zero_module,
|
| 13 |
+
timestep_embedding,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from einops import rearrange, repeat
|
| 17 |
+
from torchvision.utils import make_grid
|
| 18 |
+
from ldm.modules.attention import SpatialTransformer
|
| 19 |
+
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Upsample, Downsample, AttentionBlock, normalization
|
| 20 |
+
from ldm.models.diffusion.ddpm import LatentDiffusion
|
| 21 |
+
from ldm.util import log_txt_as_img, exists, instantiate_from_config
|
| 22 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def count_parameters(params):
|
| 26 |
+
num_params = 0
|
| 27 |
+
for p in params:
|
| 28 |
+
shape = p.shape
|
| 29 |
+
if len(shape) == 3 and shape[1] == shape[2]:
|
| 30 |
+
N, D, _ = shape
|
| 31 |
+
num_params += N * D * (D - 1) // 2
|
| 32 |
+
else:
|
| 33 |
+
num_params += p.numel()
|
| 34 |
+
# num_params = sum(p.numel() for p in params)
|
| 35 |
+
return round(num_params / 1e6, 1)
|
| 36 |
+
|
| 37 |
+
def set_requires_grad(model, requires_grad=True):
|
| 38 |
+
for param in model.parameters():
|
| 39 |
+
param.requires_grad = requires_grad
|
| 40 |
+
|
| 41 |
+
class ControlledUnetModel(UNetModel):
|
| 42 |
+
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
|
| 43 |
+
hs = []
|
| 44 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 45 |
+
emb = self.time_embed(t_emb)
|
| 46 |
+
|
| 47 |
+
h = x.type(self.dtype)
|
| 48 |
+
for module in self.input_blocks:
|
| 49 |
+
if control is not None:
|
| 50 |
+
h = module(h, emb, context)
|
| 51 |
+
h += control
|
| 52 |
+
control = None
|
| 53 |
+
else:
|
| 54 |
+
h = module(h, emb, context)
|
| 55 |
+
hs.append(h)
|
| 56 |
+
h = self.middle_block(h, emb, context)
|
| 57 |
+
for module in self.output_blocks:
|
| 58 |
+
h = th.cat([h, hs.pop()], dim=1)
|
| 59 |
+
h = module(h, emb, context)
|
| 60 |
+
h = h.type(x.dtype)
|
| 61 |
+
|
| 62 |
+
return self.out(h)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class ControlNet(nn.Module):
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
image_size,
|
| 69 |
+
in_channels,
|
| 70 |
+
model_channels,
|
| 71 |
+
out_channels,
|
| 72 |
+
hint_channels,
|
| 73 |
+
num_res_blocks,
|
| 74 |
+
attention_resolutions,
|
| 75 |
+
dropout=0,
|
| 76 |
+
channel_mult=(1, 2, 4, 8),
|
| 77 |
+
conv_resample=True,
|
| 78 |
+
dims=2,
|
| 79 |
+
use_checkpoint=False,
|
| 80 |
+
use_fp16=False,
|
| 81 |
+
num_heads=-1,
|
| 82 |
+
num_head_channels=-1,
|
| 83 |
+
num_heads_upsample=-1,
|
| 84 |
+
use_scale_shift_norm=False,
|
| 85 |
+
resblock_updown=False,
|
| 86 |
+
use_new_attention_order=False,
|
| 87 |
+
use_spatial_transformer=False, # custom transformer support
|
| 88 |
+
transformer_depth=1, # custom transformer support
|
| 89 |
+
context_dim=None, # custom transformer support
|
| 90 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
| 91 |
+
legacy=True,
|
| 92 |
+
disable_self_attentions=None,
|
| 93 |
+
num_attention_blocks=None,
|
| 94 |
+
disable_middle_self_attn=False,
|
| 95 |
+
use_linear_in_transformer=False,
|
| 96 |
+
):
|
| 97 |
+
super().__init__()
|
| 98 |
+
if use_spatial_transformer:
|
| 99 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
| 100 |
+
|
| 101 |
+
if context_dim is not None:
|
| 102 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
| 103 |
+
from omegaconf.listconfig import ListConfig
|
| 104 |
+
if type(context_dim) == ListConfig:
|
| 105 |
+
context_dim = list(context_dim)
|
| 106 |
+
|
| 107 |
+
if num_heads_upsample == -1:
|
| 108 |
+
num_heads_upsample = num_heads
|
| 109 |
+
|
| 110 |
+
if num_heads == -1:
|
| 111 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
| 112 |
+
|
| 113 |
+
if num_head_channels == -1:
|
| 114 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
| 115 |
+
|
| 116 |
+
self.dims = dims
|
| 117 |
+
self.image_size = image_size
|
| 118 |
+
self.in_channels = in_channels
|
| 119 |
+
self.model_channels = model_channels
|
| 120 |
+
if isinstance(num_res_blocks, int):
|
| 121 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
| 122 |
+
else:
|
| 123 |
+
if len(num_res_blocks) != len(channel_mult):
|
| 124 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
| 125 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
| 126 |
+
self.num_res_blocks = num_res_blocks
|
| 127 |
+
if disable_self_attentions is not None:
|
| 128 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
| 129 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
| 130 |
+
if num_attention_blocks is not None:
|
| 131 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
| 132 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
| 133 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
| 134 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
| 135 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
| 136 |
+
f"attention will still not be set.")
|
| 137 |
+
|
| 138 |
+
self.attention_resolutions = attention_resolutions
|
| 139 |
+
self.dropout = dropout
|
| 140 |
+
self.channel_mult = channel_mult
|
| 141 |
+
self.conv_resample = conv_resample
|
| 142 |
+
self.use_checkpoint = use_checkpoint
|
| 143 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
| 144 |
+
self.num_heads = num_heads
|
| 145 |
+
self.num_head_channels = num_head_channels
|
| 146 |
+
self.num_heads_upsample = num_heads_upsample
|
| 147 |
+
self.predict_codebook_ids = n_embed is not None
|
| 148 |
+
|
| 149 |
+
time_embed_dim = model_channels * 4
|
| 150 |
+
self.time_embed = nn.Sequential(
|
| 151 |
+
linear(model_channels, time_embed_dim),
|
| 152 |
+
nn.SiLU(),
|
| 153 |
+
linear(time_embed_dim, time_embed_dim),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self.input_hint_block = TimestepEmbedSequential(
|
| 157 |
+
conv_nd(dims, hint_channels, 16, 3, padding=1),
|
| 158 |
+
nn.SiLU(),
|
| 159 |
+
conv_nd(dims, 16, 16, 3, padding=1),
|
| 160 |
+
nn.SiLU(),
|
| 161 |
+
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
| 162 |
+
nn.SiLU(),
|
| 163 |
+
conv_nd(dims, 32, 32, 3, padding=1),
|
| 164 |
+
nn.SiLU(),
|
| 165 |
+
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
| 166 |
+
nn.SiLU(),
|
| 167 |
+
conv_nd(dims, 96, 96, 3, padding=1),
|
| 168 |
+
nn.SiLU(),
|
| 169 |
+
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
| 170 |
+
nn.SiLU(),
|
| 171 |
+
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def forward(self, x, hint, timesteps, context, **kwargs):
|
| 176 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
| 177 |
+
emb = self.time_embed(t_emb)
|
| 178 |
+
|
| 179 |
+
guided_hint = self.input_hint_block(hint, emb, context)
|
| 180 |
+
|
| 181 |
+
# print('guided_hint', len(guided_hint), guided_hint[0].shape, guided_hint.max(), guided_hint.min())
|
| 182 |
+
# sys.exit()
|
| 183 |
+
|
| 184 |
+
return guided_hint
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class ControlLDM(LatentDiffusion):
|
| 188 |
+
def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
|
| 189 |
+
super().__init__(*args, **kwargs)
|
| 190 |
+
self.control_model = instantiate_from_config(control_stage_config)
|
| 191 |
+
self.control_key = control_key
|
| 192 |
+
self.only_mid_control = only_mid_control
|
| 193 |
+
self.control_scales = [1.0] * 13
|
| 194 |
+
|
| 195 |
+
@torch.no_grad()
|
| 196 |
+
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
| 197 |
+
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
|
| 198 |
+
control = batch[self.control_key]
|
| 199 |
+
if bs is not None:
|
| 200 |
+
control = control[:bs]
|
| 201 |
+
control = control.to(self.device)
|
| 202 |
+
control = einops.rearrange(control, 'b h w c -> b c h w')
|
| 203 |
+
control = control.to(memory_format=torch.contiguous_format).float()
|
| 204 |
+
return x, dict(c_crossattn=[c], c_concat=[control])
|
| 205 |
+
|
| 206 |
+
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
|
| 207 |
+
assert isinstance(cond, dict)
|
| 208 |
+
diffusion_model = self.model.diffusion_model
|
| 209 |
+
|
| 210 |
+
cond_txt = torch.cat(cond['c_crossattn'], 1)
|
| 211 |
+
|
| 212 |
+
if cond['c_concat'] is None:
|
| 213 |
+
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
|
| 214 |
+
else:
|
| 215 |
+
control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
|
| 216 |
+
# control = [c * scale for c, scale in zip(control, self.control_scales)]
|
| 217 |
+
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
|
| 218 |
+
|
| 219 |
+
return eps
|
| 220 |
+
|
| 221 |
+
@torch.no_grad()
|
| 222 |
+
def get_unconditional_conditioning(self, N):
|
| 223 |
+
return self.get_learned_conditioning([""] * N)
|
| 224 |
+
|
| 225 |
+
@torch.no_grad()
|
| 226 |
+
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
|
| 227 |
+
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
| 228 |
+
plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
|
| 229 |
+
use_ema_scope=True, num_samples=1,
|
| 230 |
+
**kwargs):
|
| 231 |
+
use_ddim = ddim_steps is not None
|
| 232 |
+
|
| 233 |
+
log = dict()
|
| 234 |
+
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
| 235 |
+
c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
|
| 236 |
+
N = min(z.shape[0], N)
|
| 237 |
+
n_row = min(z.shape[0], n_row)
|
| 238 |
+
log["reconstruction"] = self.decode_first_stage(z)
|
| 239 |
+
log["control"] = c_cat * 2.0 - 1.0
|
| 240 |
+
log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
|
| 241 |
+
|
| 242 |
+
if plot_diffusion_rows:
|
| 243 |
+
# get diffusion row
|
| 244 |
+
diffusion_row = list()
|
| 245 |
+
z_start = z[:n_row]
|
| 246 |
+
for t in range(self.num_timesteps):
|
| 247 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
| 248 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
| 249 |
+
t = t.to(self.device).long()
|
| 250 |
+
noise = torch.randn_like(z_start)
|
| 251 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
| 252 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
| 253 |
+
|
| 254 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
| 255 |
+
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
| 256 |
+
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
| 257 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
| 258 |
+
log["diffusion_row"] = diffusion_grid
|
| 259 |
+
|
| 260 |
+
if sample:
|
| 261 |
+
# get denoise row
|
| 262 |
+
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
| 263 |
+
batch_size=N, ddim=use_ddim,
|
| 264 |
+
ddim_steps=ddim_steps, eta=ddim_eta)
|
| 265 |
+
x_samples = self.decode_first_stage(samples)
|
| 266 |
+
log["samples"] = x_samples
|
| 267 |
+
if plot_denoise_rows:
|
| 268 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
| 269 |
+
log["denoise_row"] = denoise_grid
|
| 270 |
+
|
| 271 |
+
if kwargs['split'] == 'train':
|
| 272 |
+
if unconditional_guidance_scale > 1.0:
|
| 273 |
+
uc_cross = self.get_unconditional_conditioning(N)
|
| 274 |
+
uc_cat = c_cat # torch.zeros_like(c_cat)
|
| 275 |
+
|
| 276 |
+
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
| 277 |
+
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
| 278 |
+
batch_size=N, ddim=use_ddim,
|
| 279 |
+
ddim_steps=ddim_steps, eta=ddim_eta,
|
| 280 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 281 |
+
unconditional_conditioning=uc_full,
|
| 282 |
+
)
|
| 283 |
+
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
| 284 |
+
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
| 285 |
+
|
| 286 |
+
else:
|
| 287 |
+
if unconditional_guidance_scale > 1.0:
|
| 288 |
+
# uc_cross = self.get_unconditional_conditioning(N)
|
| 289 |
+
# uc_cat = c_cat # torch.zeros_like(c_cat)
|
| 290 |
+
|
| 291 |
+
c_cat = torch.stack([c_cat[0] for _ in range(num_samples)], dim=0).clone()
|
| 292 |
+
|
| 293 |
+
cond = {"c_concat": [c_cat], "c_crossattn": [self.get_learned_conditioning([batch['txt'][0]] * num_samples)]}
|
| 294 |
+
uc_full = {"c_concat": [c_cat], "c_crossattn": [self.get_learned_conditioning([''] * num_samples)]}
|
| 295 |
+
|
| 296 |
+
samples_cfg, _ = self.sample_log(cond=cond, # cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
| 297 |
+
batch_size=num_samples, ddim=use_ddim,
|
| 298 |
+
ddim_steps=ddim_steps, eta=ddim_eta,
|
| 299 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
| 300 |
+
unconditional_conditioning=uc_full,
|
| 301 |
+
)
|
| 302 |
+
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
| 303 |
+
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
| 304 |
+
|
| 305 |
+
return log
|
| 306 |
+
|
| 307 |
+
@torch.no_grad()
|
| 308 |
+
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
| 309 |
+
ddim_sampler = DDIMSampler(self)
|
| 310 |
+
b, c, h, w = cond["c_concat"][0].shape
|
| 311 |
+
shape = (self.channels, h // 8, w // 8)
|
| 312 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
| 313 |
+
return samples, intermediates
|
| 314 |
+
|
| 315 |
+
def configure_optimizers(self):
|
| 316 |
+
lr = self.learning_rate
|
| 317 |
+
params = list(self.control_model.parameters())
|
| 318 |
+
|
| 319 |
+
names = []
|
| 320 |
+
for name, param in self.model.diffusion_model.named_parameters():
|
| 321 |
+
if param.requires_grad:
|
| 322 |
+
params.append(param)
|
| 323 |
+
names.append(name)
|
| 324 |
+
# print(name, param.shape)
|
| 325 |
+
|
| 326 |
+
# params += self.unet_lora_params
|
| 327 |
+
if not self.sd_locked:
|
| 328 |
+
params += list(self.model.diffusion_model.output_blocks.parameters())
|
| 329 |
+
params += list(self.model.diffusion_model.out.parameters())
|
| 330 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
| 331 |
+
|
| 332 |
+
set_requires_grad(self.model.diffusion_model, True)
|
| 333 |
+
|
| 334 |
+
num_params = count_parameters(params)
|
| 335 |
+
print()
|
| 336 |
+
print()
|
| 337 |
+
print(f"Total number of trainable parameters: {num_params} M!")
|
| 338 |
+
print()
|
| 339 |
+
print()
|
| 340 |
+
|
| 341 |
+
return opt
|
| 342 |
+
|
| 343 |
+
def low_vram_shift(self, is_diffusing):
|
| 344 |
+
if is_diffusing:
|
| 345 |
+
self.model = self.model.cuda()
|
| 346 |
+
self.control_model = self.control_model.cuda()
|
| 347 |
+
self.first_stage_model = self.first_stage_model.cpu()
|
| 348 |
+
self.cond_stage_model = self.cond_stage_model.cpu()
|
| 349 |
+
else:
|
| 350 |
+
self.model = self.model.cpu()
|
| 351 |
+
self.control_model = self.control_model.cpu()
|
| 352 |
+
self.first_stage_model = self.first_stage_model.cuda()
|
| 353 |
+
self.cond_stage_model = self.cond_stage_model.cuda()
|
generation/subject/download_dreambooth.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
echo -e "\nDownloading dreambooth dataset..."
|
| 4 |
+
git clone https://github.com/google/dreambooth.git
|
generation/subject/evaluate.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import hashlib
|
| 18 |
+
import logging
|
| 19 |
+
import math
|
| 20 |
+
import os
|
| 21 |
+
import warnings
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
from functools import reduce
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
import torch.utils.checkpoint
|
| 29 |
+
import transformers
|
| 30 |
+
from packaging import version
|
| 31 |
+
from PIL import Image
|
| 32 |
+
from torch.utils.data import Dataset, DataLoader
|
| 33 |
+
from torchvision import transforms
|
| 34 |
+
from tqdm.auto import tqdm
|
| 35 |
+
from transformers import AutoTokenizer, PretrainedConfig, ViTFeatureExtractor, ViTModel
|
| 36 |
+
|
| 37 |
+
import lpips
|
| 38 |
+
import json
|
| 39 |
+
from PIL import Image
|
| 40 |
+
import requests
|
| 41 |
+
from transformers import AutoProcessor, AutoTokenizer, CLIPModel
|
| 42 |
+
import torchvision.transforms.functional as TF
|
| 43 |
+
from torch.nn.functional import cosine_similarity
|
| 44 |
+
from torchvision.transforms import Compose, ToTensor, Normalize, Resize, ToPILImage
|
| 45 |
+
import re
|
| 46 |
+
|
| 47 |
+
def get_prompt(subject_name, prompt_idx):
|
| 48 |
+
|
| 49 |
+
subject_names = [
|
| 50 |
+
"backpack", "backpack_dog", "bear_plushie", "berry_bowl", "can",
|
| 51 |
+
"candle", "cat", "cat2", "clock", "colorful_sneaker",
|
| 52 |
+
"dog", "dog2", "dog3", "dog5", "dog6",
|
| 53 |
+
"dog7", "dog8", "duck_toy", "fancy_boot", "grey_sloth_plushie",
|
| 54 |
+
"monster_toy", "pink_sunglasses", "poop_emoji", "rc_car", "red_cartoon",
|
| 55 |
+
"robot_toy", "shiny_sneaker", "teapot", "vase", "wolf_plushie",
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
class_tokens = [
|
| 59 |
+
"backpack", "backpack", "stuffed animal", "bowl", "can",
|
| 60 |
+
"candle", "cat", "cat", "clock", "sneaker",
|
| 61 |
+
"dog", "dog", "dog", "dog", "dog",
|
| 62 |
+
"dog", "dog", "toy", "boot", "stuffed animal",
|
| 63 |
+
"toy", "glasses", "toy", "toy", "cartoon",
|
| 64 |
+
"toy", "sneaker", "teapot", "vase", "stuffed animal",
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
class_token = class_tokens[subject_names.index(subject_name)]
|
| 68 |
+
|
| 69 |
+
prompt_list = [
|
| 70 |
+
f"a qwe {class_token} in the jungle",
|
| 71 |
+
f"a qwe {class_token} in the snow",
|
| 72 |
+
f"a qwe {class_token} on the beach",
|
| 73 |
+
f"a qwe {class_token} on a cobblestone street",
|
| 74 |
+
f"a qwe {class_token} on top of pink fabric",
|
| 75 |
+
f"a qwe {class_token} on top of a wooden floor",
|
| 76 |
+
f"a qwe {class_token} with a city in the background",
|
| 77 |
+
f"a qwe {class_token} with a mountain in the background",
|
| 78 |
+
f"a qwe {class_token} with a blue house in the background",
|
| 79 |
+
f"a qwe {class_token} on top of a purple rug in a forest",
|
| 80 |
+
f"a qwe {class_token} wearing a red hat",
|
| 81 |
+
f"a qwe {class_token} wearing a santa hat",
|
| 82 |
+
f"a qwe {class_token} wearing a rainbow scarf",
|
| 83 |
+
f"a qwe {class_token} wearing a black top hat and a monocle",
|
| 84 |
+
f"a qwe {class_token} in a chef outfit",
|
| 85 |
+
f"a qwe {class_token} in a firefighter outfit",
|
| 86 |
+
f"a qwe {class_token} in a police outfit",
|
| 87 |
+
f"a qwe {class_token} wearing pink glasses",
|
| 88 |
+
f"a qwe {class_token} wearing a yellow shirt",
|
| 89 |
+
f"a qwe {class_token} in a purple wizard outfit",
|
| 90 |
+
f"a red qwe {class_token}",
|
| 91 |
+
f"a purple qwe {class_token}",
|
| 92 |
+
f"a shiny qwe {class_token}",
|
| 93 |
+
f"a wet qwe {class_token}",
|
| 94 |
+
f"a cube shaped qwe {class_token}",
|
| 95 |
+
]
|
| 96 |
+
|
| 97 |
+
return prompt_list[int(prompt_idx)]
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class PromptDatasetCLIP(Dataset):
|
| 102 |
+
def __init__(self, subject_name, data_dir_B, tokenizer, processor, epoch=None):
|
| 103 |
+
self.data_dir_B = data_dir_B
|
| 104 |
+
|
| 105 |
+
subject_name, prompt_idx = subject_name.split('-')
|
| 106 |
+
|
| 107 |
+
data_dir_B = os.path.join(self.data_dir_B, str(epoch))
|
| 108 |
+
self.image_lst = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")]
|
| 109 |
+
self.prompt_lst = [get_prompt(subject_name, prompt_idx)] * len(self.image_lst)
|
| 110 |
+
|
| 111 |
+
self.tokenizer = tokenizer
|
| 112 |
+
self.processor = processor
|
| 113 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 114 |
+
|
| 115 |
+
def __len__(self):
|
| 116 |
+
return len(self.image_lst)
|
| 117 |
+
|
| 118 |
+
def __getitem__(self, idx):
|
| 119 |
+
image_path = self.image_lst[idx]
|
| 120 |
+
image = Image.open(image_path)
|
| 121 |
+
prompt = self.prompt_lst[idx]
|
| 122 |
+
|
| 123 |
+
extrema = image.getextrema()
|
| 124 |
+
if all(min_val == max_val == 0 for min_val, max_val in extrema):
|
| 125 |
+
return None, None
|
| 126 |
+
else:
|
| 127 |
+
prompt_inputs = self.tokenizer([prompt], padding=True, return_tensors="pt")
|
| 128 |
+
image_inputs = self.processor(images=image, return_tensors="pt")
|
| 129 |
+
|
| 130 |
+
return image_inputs, prompt_inputs
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class PairwiseImageDatasetCLIP(Dataset):
|
| 134 |
+
def __init__(self, subject_name, data_dir_A, data_dir_B, processor, epoch):
|
| 135 |
+
self.data_dir_A = data_dir_A
|
| 136 |
+
self.data_dir_B = data_dir_B
|
| 137 |
+
|
| 138 |
+
subject_name, prompt_idx = subject_name.split('-')
|
| 139 |
+
|
| 140 |
+
self.data_dir_A = os.path.join(self.data_dir_A, subject_name)
|
| 141 |
+
self.image_files_A = [os.path.join(self.data_dir_A, f) for f in os.listdir(self.data_dir_A) if f.endswith(".jpg")]
|
| 142 |
+
|
| 143 |
+
data_dir_B = os.path.join(self.data_dir_B, str(epoch))
|
| 144 |
+
self.image_files_B = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")]
|
| 145 |
+
|
| 146 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 147 |
+
self.processor = processor
|
| 148 |
+
|
| 149 |
+
def __len__(self):
|
| 150 |
+
return len(self.image_files_A) * len(self.image_files_B)
|
| 151 |
+
|
| 152 |
+
def __getitem__(self, index):
|
| 153 |
+
index_A = index // len(self.image_files_B)
|
| 154 |
+
index_B = index % len(self.image_files_B)
|
| 155 |
+
|
| 156 |
+
image_A = Image.open(self.image_files_A[index_A]) # .convert("RGB")
|
| 157 |
+
image_B = Image.open(self.image_files_B[index_B]) # .convert("RGB")
|
| 158 |
+
|
| 159 |
+
extrema_A = image_A.getextrema()
|
| 160 |
+
extrema_B = image_B.getextrema()
|
| 161 |
+
if all(min_val == max_val == 0 for min_val, max_val in extrema_A) or all(min_val == max_val == 0 for min_val, max_val in extrema_B):
|
| 162 |
+
return None, None
|
| 163 |
+
else:
|
| 164 |
+
inputs_A = self.processor(images=image_A, return_tensors="pt")
|
| 165 |
+
inputs_B = self.processor(images=image_B, return_tensors="pt")
|
| 166 |
+
|
| 167 |
+
return inputs_A, inputs_B
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class PairwiseImageDatasetDINO(Dataset):
|
| 171 |
+
def __init__(self, subject_name, data_dir_A, data_dir_B, feature_extractor, epoch):
|
| 172 |
+
self.data_dir_A = data_dir_A
|
| 173 |
+
self.data_dir_B = data_dir_B
|
| 174 |
+
|
| 175 |
+
subject_name, prompt_idx = subject_name.split('-')
|
| 176 |
+
|
| 177 |
+
self.data_dir_A = os.path.join(self.data_dir_A, subject_name)
|
| 178 |
+
self.image_files_A = [os.path.join(self.data_dir_A, f) for f in os.listdir(self.data_dir_A) if f.endswith(".jpg")]
|
| 179 |
+
|
| 180 |
+
data_dir_B = os.path.join(self.data_dir_B, str(epoch))
|
| 181 |
+
self.image_files_B = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")]
|
| 182 |
+
|
| 183 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 184 |
+
self.feature_extractor = feature_extractor
|
| 185 |
+
|
| 186 |
+
def __len__(self):
|
| 187 |
+
return len(self.image_files_A) * len(self.image_files_B)
|
| 188 |
+
|
| 189 |
+
def __getitem__(self, index):
|
| 190 |
+
index_A = index // len(self.image_files_B)
|
| 191 |
+
index_B = index % len(self.image_files_B)
|
| 192 |
+
|
| 193 |
+
image_A = Image.open(self.image_files_A[index_A]) # .convert("RGB")
|
| 194 |
+
image_B = Image.open(self.image_files_B[index_B]) # .convert("RGB")
|
| 195 |
+
|
| 196 |
+
extrema_A = image_A.getextrema()
|
| 197 |
+
extrema_B = image_B.getextrema()
|
| 198 |
+
if all(min_val == max_val == 0 for min_val, max_val in extrema_A) or all(min_val == max_val == 0 for min_val, max_val in extrema_B):
|
| 199 |
+
return None, None
|
| 200 |
+
else:
|
| 201 |
+
inputs_A = self.feature_extractor(images=image_A, return_tensors="pt")
|
| 202 |
+
inputs_B = self.feature_extractor(images=image_B, return_tensors="pt")
|
| 203 |
+
|
| 204 |
+
return inputs_A, inputs_B
|
| 205 |
+
|
| 206 |
+
class PairwiseImageDatasetLPIPS(Dataset):
|
| 207 |
+
def __init__(self, subject_name, data_dir_A, data_dir_B, epoch):
|
| 208 |
+
self.data_dir_A = data_dir_A
|
| 209 |
+
self.data_dir_B = data_dir_B
|
| 210 |
+
|
| 211 |
+
subject_name, prompt_idx = subject_name.split('-')
|
| 212 |
+
|
| 213 |
+
self.data_dir_A = os.path.join(self.data_dir_A, subject_name)
|
| 214 |
+
self.image_files_A = [os.path.join(self.data_dir_A, f) for f in os.listdir(self.data_dir_A) if f.endswith(".jpg")]
|
| 215 |
+
|
| 216 |
+
data_dir_B = os.path.join(self.data_dir_B, str(epoch))
|
| 217 |
+
self.image_files_B = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")]
|
| 218 |
+
|
| 219 |
+
self.transform = Compose([
|
| 220 |
+
Resize((512, 512)),
|
| 221 |
+
ToTensor(),
|
| 222 |
+
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
| 223 |
+
])
|
| 224 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 225 |
+
|
| 226 |
+
def __len__(self):
|
| 227 |
+
return len(self.image_files_A) * len(self.image_files_B)
|
| 228 |
+
|
| 229 |
+
def __getitem__(self, index):
|
| 230 |
+
index_A = index // len(self.image_files_B)
|
| 231 |
+
index_B = index % len(self.image_files_B)
|
| 232 |
+
|
| 233 |
+
image_A = Image.open(self.image_files_A[index_A]) # .convert("RGB")
|
| 234 |
+
image_B = Image.open(self.image_files_B[index_B]) # .convert("RGB")
|
| 235 |
+
|
| 236 |
+
extrema_A = image_A.getextrema()
|
| 237 |
+
extrema_B = image_B.getextrema()
|
| 238 |
+
if all(min_val == max_val == 0 for min_val, max_val in extrema_A) or all(min_val == max_val == 0 for min_val, max_val in extrema_B):
|
| 239 |
+
return None, None
|
| 240 |
+
else:
|
| 241 |
+
if self.transform:
|
| 242 |
+
image_A = self.transform(image_A)
|
| 243 |
+
image_B = self.transform(image_B)
|
| 244 |
+
|
| 245 |
+
return image_A, image_B
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def clip_text(subject_name, image_dir):
|
| 249 |
+
criterion = 'clip_text'
|
| 250 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 251 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
| 252 |
+
# Get the text features
|
| 253 |
+
tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
| 254 |
+
# Get the image features
|
| 255 |
+
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
| 256 |
+
|
| 257 |
+
epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)])
|
| 258 |
+
best_mean_similarity = 0
|
| 259 |
+
mean_similarity_list = []
|
| 260 |
+
for epoch in epochs:
|
| 261 |
+
similarity = []
|
| 262 |
+
dataset = PromptDatasetCLIP(subject_name, image_dir, tokenizer, processor, epoch)
|
| 263 |
+
dataloader = DataLoader(dataset, batch_size=32)
|
| 264 |
+
for i in range(len(dataset)):
|
| 265 |
+
image_inputs, prompt_inputs = dataset[i]
|
| 266 |
+
if image_inputs is not None and prompt_inputs is not None:
|
| 267 |
+
image_inputs['pixel_values'] = image_inputs['pixel_values'].to(device)
|
| 268 |
+
prompt_inputs['input_ids'] = prompt_inputs['input_ids'].to(device)
|
| 269 |
+
prompt_inputs['attention_mask'] = prompt_inputs['attention_mask'].to(device)
|
| 270 |
+
# print(prompt_inputs)
|
| 271 |
+
image_features = model.get_image_features(**image_inputs)
|
| 272 |
+
text_features = model.get_text_features(**prompt_inputs)
|
| 273 |
+
|
| 274 |
+
sim = cosine_similarity(image_features, text_features)
|
| 275 |
+
|
| 276 |
+
#image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
|
| 277 |
+
#text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
|
| 278 |
+
#logit_scale = model.logit_scale.exp()
|
| 279 |
+
#sim = torch.matmul(text_features, image_features.t()) * logit_scale
|
| 280 |
+
similarity.append(sim.item())
|
| 281 |
+
|
| 282 |
+
if similarity:
|
| 283 |
+
mean_similarity = torch.tensor(similarity).mean().item()
|
| 284 |
+
mean_similarity_list.append(mean_similarity)
|
| 285 |
+
best_mean_similarity = max(best_mean_similarity, mean_similarity)
|
| 286 |
+
print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {mean_similarity}({best_mean_similarity})')
|
| 287 |
+
else:
|
| 288 |
+
mean_similarity_list.append(0)
|
| 289 |
+
print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {0}({best_mean_similarity})')
|
| 290 |
+
|
| 291 |
+
return mean_similarity_list
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def clip_image(subject_name, image_dir, dreambooth_dir='dreambooth/dataset'):
|
| 295 |
+
criterion = 'clip_image'
|
| 296 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 297 |
+
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
|
| 298 |
+
# Get the image features
|
| 299 |
+
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
| 300 |
+
|
| 301 |
+
epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)])
|
| 302 |
+
best_mean_similarity = 0
|
| 303 |
+
mean_similarity_list = []
|
| 304 |
+
for epoch in epochs:
|
| 305 |
+
similarity = []
|
| 306 |
+
dataset = PairwiseImageDatasetCLIP(subject_name, dreambooth_dir, image_dir, processor, epoch)
|
| 307 |
+
# dataset = SelfPairwiseImageDatasetCLIP(subject, './data', processor)
|
| 308 |
+
|
| 309 |
+
for i in range(len(dataset)):
|
| 310 |
+
inputs_A, inputs_B = dataset[i]
|
| 311 |
+
if inputs_A is not None and inputs_B is not None:
|
| 312 |
+
inputs_A['pixel_values'] = inputs_A['pixel_values'].to(device)
|
| 313 |
+
inputs_B['pixel_values'] = inputs_B['pixel_values'].to(device)
|
| 314 |
+
|
| 315 |
+
image_A_features = model.get_image_features(**inputs_A)
|
| 316 |
+
image_B_features = model.get_image_features(**inputs_B)
|
| 317 |
+
|
| 318 |
+
image_A_features = image_A_features / image_A_features.norm(p=2, dim=-1, keepdim=True)
|
| 319 |
+
image_B_features = image_B_features / image_B_features.norm(p=2, dim=-1, keepdim=True)
|
| 320 |
+
|
| 321 |
+
logit_scale = model.logit_scale.exp()
|
| 322 |
+
sim = torch.matmul(image_A_features, image_B_features.t()) # * logit_scale
|
| 323 |
+
similarity.append(sim.item())
|
| 324 |
+
|
| 325 |
+
if similarity:
|
| 326 |
+
mean_similarity = torch.tensor(similarity).mean().item()
|
| 327 |
+
best_mean_similarity = max(best_mean_similarity, mean_similarity)
|
| 328 |
+
mean_similarity_list.append(mean_similarity)
|
| 329 |
+
print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {mean_similarity}({best_mean_similarity})')
|
| 330 |
+
else:
|
| 331 |
+
mean_similarity_list.append(0)
|
| 332 |
+
print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {0}({best_mean_similarity})')
|
| 333 |
+
|
| 334 |
+
return mean_similarity_list
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def dino(subject_name, image_dir, dreambooth_dir='dreambooth/dataset'):
|
| 338 |
+
criterion = 'dino'
|
| 339 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 340 |
+
model = ViTModel.from_pretrained('facebook/dino-vits16').to(device)
|
| 341 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vits16')
|
| 342 |
+
|
| 343 |
+
epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)])
|
| 344 |
+
best_mean_similarity = 0
|
| 345 |
+
mean_similarity_list = []
|
| 346 |
+
for epoch in epochs:
|
| 347 |
+
similarity = []
|
| 348 |
+
# dataset = PairwiseImageDatasetDINO(subject, './data', image_dir, feature_extractor, epoch)
|
| 349 |
+
dataset = PairwiseImageDatasetDINO(subject_name, dreambooth_dir, image_dir, feature_extractor, epoch)
|
| 350 |
+
# dataset = SelfPairwiseImageDatasetDINO(subject, './data', feature_extractor)
|
| 351 |
+
|
| 352 |
+
for i in range(len(dataset)):
|
| 353 |
+
inputs_A, inputs_B = dataset[i]
|
| 354 |
+
if inputs_A is not None and inputs_B is not None:
|
| 355 |
+
inputs_A['pixel_values'] = inputs_A['pixel_values'].to(device)
|
| 356 |
+
inputs_B['pixel_values'] = inputs_B['pixel_values'].to(device)
|
| 357 |
+
|
| 358 |
+
outputs_A = model(**inputs_A)
|
| 359 |
+
image_A_features = outputs_A.last_hidden_state[:, 0, :]
|
| 360 |
+
|
| 361 |
+
outputs_B = model(**inputs_B)
|
| 362 |
+
image_B_features = outputs_B.last_hidden_state[:, 0, :]
|
| 363 |
+
|
| 364 |
+
image_A_features = image_A_features / image_A_features.norm(p=2, dim=-1, keepdim=True)
|
| 365 |
+
image_B_features = image_B_features / image_B_features.norm(p=2, dim=-1, keepdim=True)
|
| 366 |
+
|
| 367 |
+
sim = torch.matmul(image_A_features, image_B_features.t()) # * logit_scale
|
| 368 |
+
similarity.append(sim.item())
|
| 369 |
+
|
| 370 |
+
mean_similarity = torch.tensor(similarity).mean().item()
|
| 371 |
+
best_mean_similarity = max(best_mean_similarity, mean_similarity)
|
| 372 |
+
mean_similarity_list.append(mean_similarity)
|
| 373 |
+
print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {mean_similarity}({best_mean_similarity})')
|
| 374 |
+
|
| 375 |
+
return mean_similarity_list
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
def lpips_image(subject_name, image_dir, dreambooth_dir='dreambooth/dataset'):
|
| 379 |
+
criterion = 'lpips_image'
|
| 380 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 381 |
+
# Set up the LPIPS model (vgg=True uses the VGG-based model from the paper)
|
| 382 |
+
loss_fn = lpips.LPIPS(net='vgg').to(device)
|
| 383 |
+
|
| 384 |
+
# 有可能有些epoch没跑全
|
| 385 |
+
epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)])
|
| 386 |
+
mean_similarity_list = []
|
| 387 |
+
best_mean_similarity = 0
|
| 388 |
+
for epoch in epochs:
|
| 389 |
+
similarity = []
|
| 390 |
+
dataset = PairwiseImageDatasetLPIPS(subject_name, dreambooth_dir, image_dir, epoch)
|
| 391 |
+
# dataset = SelfPairwiseImageDatasetLPIPS(subject, './data')
|
| 392 |
+
|
| 393 |
+
for i in range(len(dataset)):
|
| 394 |
+
image_A, image_B = dataset[i]
|
| 395 |
+
if image_A is not None and image_B is not None:
|
| 396 |
+
image_A = image_A.to(device)
|
| 397 |
+
image_B = image_B.to(device)
|
| 398 |
+
|
| 399 |
+
# Calculate LPIPS between the two images
|
| 400 |
+
distance = loss_fn(image_A, image_B)
|
| 401 |
+
|
| 402 |
+
similarity.append(distance.item())
|
| 403 |
+
|
| 404 |
+
mean_similarity = torch.tensor(similarity).mean().item()
|
| 405 |
+
best_mean_similarity = max(best_mean_similarity, mean_similarity)
|
| 406 |
+
mean_similarity_list.append(mean_similarity)
|
| 407 |
+
print(f'epoch: {epoch}, criterion: LPIPS distance, mean_similarity: {mean_similarity}({best_mean_similarity})')
|
| 408 |
+
|
| 409 |
+
return mean_similarity_list
|
| 410 |
+
|
| 411 |
+
if __name__ == "__main__":
|
| 412 |
+
image_dir = 'log_hra/lr_1e-4_r_8/'
|
| 413 |
+
|
| 414 |
+
subject_dirs, subject_names = [], []
|
| 415 |
+
for name in os.listdir(image_dir):
|
| 416 |
+
if os.path.isdir(os.path.join(image_dir, name)):
|
| 417 |
+
subject_dirs.append(os.path.join(image_dir, name))
|
| 418 |
+
subject_names.append(name)
|
| 419 |
+
|
| 420 |
+
results_path = os.path.join(image_dir, 'true_results.json')
|
| 421 |
+
# {'backpack-0':{'DINO':[x, ...], 'CLIP-I':[x, ...], 'CLIP-T':[x, ...], 'LPIPS':[x, ...],}}
|
| 422 |
+
|
| 423 |
+
results_dict = dict()
|
| 424 |
+
if os.path.exists(results_path):
|
| 425 |
+
with open(results_path, 'r') as f:
|
| 426 |
+
results = f.__iter__()
|
| 427 |
+
while True:
|
| 428 |
+
try:
|
| 429 |
+
result_json = json.loads(next(results))
|
| 430 |
+
results_dict.update(result_json)
|
| 431 |
+
|
| 432 |
+
except StopIteration:
|
| 433 |
+
print("finish extraction.")
|
| 434 |
+
break
|
| 435 |
+
|
| 436 |
+
for idx in range(len(subject_names)):
|
| 437 |
+
subject_name = subject_names[idx]
|
| 438 |
+
subject_dir = subject_dirs[idx]
|
| 439 |
+
|
| 440 |
+
if subject_name in results_dict:
|
| 441 |
+
continue
|
| 442 |
+
|
| 443 |
+
print(f'evaluating {subject_dir}')
|
| 444 |
+
dino_sim = dino(subject_name, subject_dir)
|
| 445 |
+
clip_i_sim = clip_image(subject_name, subject_dir)
|
| 446 |
+
clip_t_sim = clip_text(subject_name, subject_dir)
|
| 447 |
+
lpips_sim = lpips_image(subject_name, subject_dir)
|
| 448 |
+
|
| 449 |
+
subject_result = {'DINO': dino_sim, 'CLIP-I': clip_i_sim, 'CLIP-T': clip_t_sim, 'LPIPS': lpips_sim}
|
| 450 |
+
print(subject_result)
|
| 451 |
+
|
| 452 |
+
with open(results_path,'a') as f:
|
| 453 |
+
json_string = json.dumps({subject_name: subject_result})
|
| 454 |
+
f.write(json_string + "\n")
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
|
generation/subject/get_result.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
|
| 16 |
+
import math
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
from functools import reduce
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
import json
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
if __name__ == "__main__":
|
| 26 |
+
image_dir = 'log_hra/lr_1e-4_r_8/'
|
| 27 |
+
|
| 28 |
+
results_path = os.path.join(image_dir, 'true_results.json')
|
| 29 |
+
# {'backpack-0':{'DINO':[x, ...], 'CLIP-I':[x, ...], 'CLIP-T':[x, ...], 'LPIPS':[x, ...],}}
|
| 30 |
+
|
| 31 |
+
results_dict = dict()
|
| 32 |
+
if os.path.exists(results_path):
|
| 33 |
+
with open(results_path, 'r') as f:
|
| 34 |
+
results = f.__iter__()
|
| 35 |
+
while True:
|
| 36 |
+
try:
|
| 37 |
+
result_json = json.loads(next(results))
|
| 38 |
+
results_dict.update(result_json)
|
| 39 |
+
|
| 40 |
+
except StopIteration:
|
| 41 |
+
print("finish extraction.")
|
| 42 |
+
break
|
| 43 |
+
|
| 44 |
+
total_result = np.zeros(4)
|
| 45 |
+
metric_name_list = ['DINO', 'CLIP-I', 'CLIP-T', 'LPIPS']
|
| 46 |
+
for subject_name, subject_results in results_dict.items():
|
| 47 |
+
|
| 48 |
+
metric_results_percent = None
|
| 49 |
+
for metric_name, metric_results in subject_results.items():
|
| 50 |
+
metric_results = [0 if np.isnan(r) else r for r in metric_results]
|
| 51 |
+
metric_results_norm = np.array(metric_results) / (max(metric_results) - min(metric_results))
|
| 52 |
+
if metric_results_percent is None:
|
| 53 |
+
metric_results_percent = metric_results_norm
|
| 54 |
+
else:
|
| 55 |
+
metric_results_percent += metric_results_norm
|
| 56 |
+
|
| 57 |
+
subject_results_max_idx = np.argmax(metric_results_percent)
|
| 58 |
+
for idx, metric_name in enumerate(metric_name_list):
|
| 59 |
+
total_result[idx] += subject_results[metric_name][subject_results_max_idx]
|
| 60 |
+
total_result /= len(results_dict)
|
| 61 |
+
print(f'DINO: {total_result[0]}, CLIP-I: {total_result[1]}, CLIP-T: {total_result[2]}, LPIPS: {total_result[3]}')
|
| 62 |
+
|
generation/subject/oft_utils/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mhe import MHE_db, MHE_OFT, MHE_LoRA
|
| 2 |
+
|
generation/subject/oft_utils/attention_processor.py
ADDED
|
@@ -0,0 +1,1036 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import Callable, Optional, Union
|
| 15 |
+
import math
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
from torch import nn
|
| 19 |
+
from torch.autograd import Function
|
| 20 |
+
|
| 21 |
+
from diffusers.utils import deprecate, logging
|
| 22 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
if is_xformers_available():
|
| 28 |
+
import xformers
|
| 29 |
+
import xformers.ops
|
| 30 |
+
else:
|
| 31 |
+
xformers = None
|
| 32 |
+
|
| 33 |
+
class Attention(nn.Module):
|
| 34 |
+
r"""
|
| 35 |
+
A cross attention layer.
|
| 36 |
+
Parameters:
|
| 37 |
+
query_dim (`int`): The number of channels in the query.
|
| 38 |
+
cross_attention_dim (`int`, *optional*):
|
| 39 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
| 40 |
+
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
| 41 |
+
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
| 42 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 43 |
+
bias (`bool`, *optional*, defaults to False):
|
| 44 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
query_dim: int,
|
| 50 |
+
cross_attention_dim: Optional[int] = None,
|
| 51 |
+
heads: int = 8,
|
| 52 |
+
dim_head: int = 64,
|
| 53 |
+
dropout: float = 0.0,
|
| 54 |
+
bias=False,
|
| 55 |
+
upcast_attention: bool = False,
|
| 56 |
+
upcast_softmax: bool = False,
|
| 57 |
+
cross_attention_norm: Optional[str] = None,
|
| 58 |
+
cross_attention_norm_num_groups: int = 32,
|
| 59 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 60 |
+
norm_num_groups: Optional[int] = None,
|
| 61 |
+
out_bias: bool = True,
|
| 62 |
+
scale_qk: bool = True,
|
| 63 |
+
only_cross_attention: bool = False,
|
| 64 |
+
processor: Optional["AttnProcessor"] = None,
|
| 65 |
+
):
|
| 66 |
+
super().__init__()
|
| 67 |
+
inner_dim = dim_head * heads
|
| 68 |
+
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 69 |
+
self.upcast_attention = upcast_attention
|
| 70 |
+
self.upcast_softmax = upcast_softmax
|
| 71 |
+
|
| 72 |
+
self.scale = dim_head**-0.5 if scale_qk else 1.0
|
| 73 |
+
|
| 74 |
+
self.heads = heads
|
| 75 |
+
# for slice_size > 0 the attention score computation
|
| 76 |
+
# is split across the batch axis to save memory
|
| 77 |
+
# You can set slice_size with `set_attention_slice`
|
| 78 |
+
self.sliceable_head_dim = heads
|
| 79 |
+
|
| 80 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
| 81 |
+
self.only_cross_attention = only_cross_attention
|
| 82 |
+
|
| 83 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
| 84 |
+
raise ValueError(
|
| 85 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if norm_num_groups is not None:
|
| 89 |
+
self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
| 90 |
+
else:
|
| 91 |
+
self.group_norm = None
|
| 92 |
+
|
| 93 |
+
if cross_attention_norm is None:
|
| 94 |
+
self.norm_cross = None
|
| 95 |
+
elif cross_attention_norm == "layer_norm":
|
| 96 |
+
self.norm_cross = nn.LayerNorm(cross_attention_dim)
|
| 97 |
+
elif cross_attention_norm == "group_norm":
|
| 98 |
+
if self.added_kv_proj_dim is not None:
|
| 99 |
+
# The given `encoder_hidden_states` are initially of shape
|
| 100 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
| 101 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
| 102 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
| 103 |
+
# the number of channels for the group norm.
|
| 104 |
+
norm_cross_num_channels = added_kv_proj_dim
|
| 105 |
+
else:
|
| 106 |
+
norm_cross_num_channels = cross_attention_dim
|
| 107 |
+
|
| 108 |
+
self.norm_cross = nn.GroupNorm(
|
| 109 |
+
num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
|
| 110 |
+
)
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(
|
| 113 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
| 117 |
+
|
| 118 |
+
if not self.only_cross_attention:
|
| 119 |
+
# only relevant for the `AddedKVProcessor` classes
|
| 120 |
+
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
| 121 |
+
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
| 122 |
+
else:
|
| 123 |
+
self.to_k = None
|
| 124 |
+
self.to_v = None
|
| 125 |
+
|
| 126 |
+
if self.added_kv_proj_dim is not None:
|
| 127 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
| 128 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
|
| 129 |
+
|
| 130 |
+
self.to_out = nn.ModuleList([])
|
| 131 |
+
self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
|
| 132 |
+
self.to_out.append(nn.Dropout(dropout))
|
| 133 |
+
|
| 134 |
+
# set attention processor
|
| 135 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
| 136 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
| 137 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
| 138 |
+
if processor is None:
|
| 139 |
+
processor = (
|
| 140 |
+
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else AttnProcessor()
|
| 141 |
+
)
|
| 142 |
+
self.set_processor(processor)
|
| 143 |
+
|
| 144 |
+
def set_use_memory_efficient_attention_xformers(
|
| 145 |
+
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
| 146 |
+
):
|
| 147 |
+
is_lora = hasattr(self, "processor") and isinstance(
|
| 148 |
+
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
if use_memory_efficient_attention_xformers:
|
| 152 |
+
if self.added_kv_proj_dim is not None:
|
| 153 |
+
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
| 154 |
+
# which uses this type of cross attention ONLY because the attention mask of format
|
| 155 |
+
# [0, ..., -10.000, ..., 0, ...,] is not supported
|
| 156 |
+
raise NotImplementedError(
|
| 157 |
+
"Memory efficient attention with `xformers` is currently not supported when"
|
| 158 |
+
" `self.added_kv_proj_dim` is defined."
|
| 159 |
+
)
|
| 160 |
+
elif not is_xformers_available():
|
| 161 |
+
raise ModuleNotFoundError(
|
| 162 |
+
(
|
| 163 |
+
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
| 164 |
+
" xformers"
|
| 165 |
+
),
|
| 166 |
+
name="xformers",
|
| 167 |
+
)
|
| 168 |
+
elif not torch.cuda.is_available():
|
| 169 |
+
raise ValueError(
|
| 170 |
+
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
| 171 |
+
" only available for GPU "
|
| 172 |
+
)
|
| 173 |
+
else:
|
| 174 |
+
try:
|
| 175 |
+
# Make sure we can run the memory efficient attention
|
| 176 |
+
_ = xformers.ops.memory_efficient_attention(
|
| 177 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 178 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 179 |
+
torch.randn((1, 2, 40), device="cuda"),
|
| 180 |
+
)
|
| 181 |
+
except Exception as e:
|
| 182 |
+
raise e
|
| 183 |
+
|
| 184 |
+
if is_lora:
|
| 185 |
+
processor = LoRAXFormersAttnProcessor(
|
| 186 |
+
hidden_size=self.processor.hidden_size,
|
| 187 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
| 188 |
+
rank=self.processor.rank,
|
| 189 |
+
attention_op=attention_op,
|
| 190 |
+
)
|
| 191 |
+
processor.load_state_dict(self.processor.state_dict())
|
| 192 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
| 193 |
+
else:
|
| 194 |
+
processor = XFormersAttnProcessor(attention_op=attention_op)
|
| 195 |
+
else:
|
| 196 |
+
if is_lora:
|
| 197 |
+
processor = LoRAAttnProcessor(
|
| 198 |
+
hidden_size=self.processor.hidden_size,
|
| 199 |
+
cross_attention_dim=self.processor.cross_attention_dim,
|
| 200 |
+
rank=self.processor.rank,
|
| 201 |
+
)
|
| 202 |
+
processor.load_state_dict(self.processor.state_dict())
|
| 203 |
+
processor.to(self.processor.to_q_lora.up.weight.device)
|
| 204 |
+
else:
|
| 205 |
+
processor = AttnProcessor()
|
| 206 |
+
|
| 207 |
+
self.set_processor(processor)
|
| 208 |
+
|
| 209 |
+
def set_attention_slice(self, slice_size):
|
| 210 |
+
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
| 211 |
+
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
| 212 |
+
|
| 213 |
+
if slice_size is not None and self.added_kv_proj_dim is not None:
|
| 214 |
+
processor = SlicedAttnAddedKVProcessor(slice_size)
|
| 215 |
+
elif slice_size is not None:
|
| 216 |
+
processor = SlicedAttnProcessor(slice_size)
|
| 217 |
+
elif self.added_kv_proj_dim is not None:
|
| 218 |
+
processor = AttnAddedKVProcessor()
|
| 219 |
+
else:
|
| 220 |
+
processor = AttnProcessor()
|
| 221 |
+
|
| 222 |
+
self.set_processor(processor)
|
| 223 |
+
|
| 224 |
+
def set_processor(self, processor: "AttnProcessor"):
|
| 225 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
| 226 |
+
# pop `processor` from `self._modules`
|
| 227 |
+
if (
|
| 228 |
+
hasattr(self, "processor")
|
| 229 |
+
and isinstance(self.processor, torch.nn.Module)
|
| 230 |
+
and not isinstance(processor, torch.nn.Module)
|
| 231 |
+
):
|
| 232 |
+
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
|
| 233 |
+
self._modules.pop("processor")
|
| 234 |
+
|
| 235 |
+
self.processor = processor
|
| 236 |
+
|
| 237 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
| 238 |
+
# The `Attention` class can call different attention processors / attention functions
|
| 239 |
+
# here we simply pass along all tensors to the selected processor class
|
| 240 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
| 241 |
+
return self.processor(
|
| 242 |
+
self,
|
| 243 |
+
hidden_states,
|
| 244 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 245 |
+
attention_mask=attention_mask,
|
| 246 |
+
**cross_attention_kwargs,
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
def batch_to_head_dim(self, tensor):
|
| 250 |
+
head_size = self.heads
|
| 251 |
+
batch_size, seq_len, dim = tensor.shape
|
| 252 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
| 253 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
| 254 |
+
return tensor
|
| 255 |
+
|
| 256 |
+
def head_to_batch_dim(self, tensor, out_dim=3):
|
| 257 |
+
head_size = self.heads
|
| 258 |
+
batch_size, seq_len, dim = tensor.shape
|
| 259 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
| 260 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
| 261 |
+
|
| 262 |
+
if out_dim == 3:
|
| 263 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
| 264 |
+
|
| 265 |
+
return tensor
|
| 266 |
+
|
| 267 |
+
def get_attention_scores(self, query, key, attention_mask=None):
|
| 268 |
+
dtype = query.dtype
|
| 269 |
+
if self.upcast_attention:
|
| 270 |
+
query = query.float()
|
| 271 |
+
key = key.float()
|
| 272 |
+
|
| 273 |
+
if attention_mask is None:
|
| 274 |
+
baddbmm_input = torch.empty(
|
| 275 |
+
query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
|
| 276 |
+
)
|
| 277 |
+
beta = 0
|
| 278 |
+
else:
|
| 279 |
+
baddbmm_input = attention_mask
|
| 280 |
+
beta = 1
|
| 281 |
+
|
| 282 |
+
attention_scores = torch.baddbmm(
|
| 283 |
+
baddbmm_input,
|
| 284 |
+
query,
|
| 285 |
+
key.transpose(-1, -2),
|
| 286 |
+
beta=beta,
|
| 287 |
+
alpha=self.scale,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
if self.upcast_softmax:
|
| 291 |
+
attention_scores = attention_scores.float()
|
| 292 |
+
|
| 293 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
| 294 |
+
attention_probs = attention_probs.to(dtype)
|
| 295 |
+
|
| 296 |
+
return attention_probs
|
| 297 |
+
|
| 298 |
+
def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
|
| 299 |
+
if batch_size is None:
|
| 300 |
+
deprecate(
|
| 301 |
+
"batch_size=None",
|
| 302 |
+
"0.0.15",
|
| 303 |
+
(
|
| 304 |
+
"Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
|
| 305 |
+
" attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
|
| 306 |
+
" `prepare_attention_mask` when preparing the attention_mask."
|
| 307 |
+
),
|
| 308 |
+
)
|
| 309 |
+
batch_size = 1
|
| 310 |
+
|
| 311 |
+
head_size = self.heads
|
| 312 |
+
if attention_mask is None:
|
| 313 |
+
return attention_mask
|
| 314 |
+
|
| 315 |
+
if attention_mask.shape[-1] != target_length:
|
| 316 |
+
if attention_mask.device.type == "mps":
|
| 317 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
| 318 |
+
# Instead, we can manually construct the padding tensor.
|
| 319 |
+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
| 320 |
+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
| 321 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
| 322 |
+
else:
|
| 323 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
| 324 |
+
|
| 325 |
+
if out_dim == 3:
|
| 326 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
| 327 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
| 328 |
+
elif out_dim == 4:
|
| 329 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 330 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
| 331 |
+
|
| 332 |
+
return attention_mask
|
| 333 |
+
|
| 334 |
+
def norm_encoder_hidden_states(self, encoder_hidden_states):
|
| 335 |
+
assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
| 336 |
+
|
| 337 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
| 338 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
| 339 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
| 340 |
+
# Group norm norms along the channels dimension and expects
|
| 341 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
| 342 |
+
# to norm along the hidden dimension, so we need to move
|
| 343 |
+
# (batch_size, sequence_length, hidden_size) ->
|
| 344 |
+
# (batch_size, hidden_size, sequence_length)
|
| 345 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
| 346 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
| 347 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
| 348 |
+
else:
|
| 349 |
+
assert False
|
| 350 |
+
|
| 351 |
+
return encoder_hidden_states
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class AttnProcessor:
|
| 355 |
+
def __call__(
|
| 356 |
+
self,
|
| 357 |
+
attn: Attention,
|
| 358 |
+
hidden_states,
|
| 359 |
+
encoder_hidden_states=None,
|
| 360 |
+
attention_mask=None,
|
| 361 |
+
):
|
| 362 |
+
batch_size, sequence_length, _ = (
|
| 363 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 364 |
+
)
|
| 365 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 366 |
+
query = attn.to_q(hidden_states)
|
| 367 |
+
|
| 368 |
+
if encoder_hidden_states is None:
|
| 369 |
+
encoder_hidden_states = hidden_states
|
| 370 |
+
elif attn.norm_cross:
|
| 371 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 372 |
+
|
| 373 |
+
key = attn.to_k(encoder_hidden_states)
|
| 374 |
+
value = attn.to_v(encoder_hidden_states)
|
| 375 |
+
|
| 376 |
+
query = attn.head_to_batch_dim(query)
|
| 377 |
+
key = attn.head_to_batch_dim(key)
|
| 378 |
+
value = attn.head_to_batch_dim(value)
|
| 379 |
+
|
| 380 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 381 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 382 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 383 |
+
|
| 384 |
+
# linear proj
|
| 385 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 386 |
+
# dropout
|
| 387 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 388 |
+
|
| 389 |
+
return hidden_states
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class HRALinearLayer(nn.Module):
|
| 393 |
+
def __init__(self, in_features, out_features, bias=False, r=8, apply_GS=False):
|
| 394 |
+
super(HRALinearLayer, self).__init__()
|
| 395 |
+
|
| 396 |
+
self.in_features=in_features
|
| 397 |
+
self.out_features=out_features
|
| 398 |
+
|
| 399 |
+
self.register_buffer('cross_attention_dim', torch.tensor(in_features))
|
| 400 |
+
self.register_buffer('hidden_size', torch.tensor(out_features))
|
| 401 |
+
|
| 402 |
+
self.r = r
|
| 403 |
+
self.apply_GS = apply_GS
|
| 404 |
+
|
| 405 |
+
half_u = torch.zeros(in_features, r // 2)
|
| 406 |
+
nn.init.kaiming_uniform_(half_u, a=math.sqrt(5))
|
| 407 |
+
self.hra_u = nn.Parameter(torch.repeat_interleave(half_u, 2, dim=1), requires_grad=True)
|
| 408 |
+
|
| 409 |
+
def forward(self, attn, x):
|
| 410 |
+
# orig_dtype = x.dtype
|
| 411 |
+
# dtype = self.v_list[0].dtype
|
| 412 |
+
|
| 413 |
+
# unit_v_list = [v / (torch.sqrt(torch.sum(v ** 2) + self.eps)) for v in self.v_list]
|
| 414 |
+
|
| 415 |
+
# filt = attn.weight.data.to(dtype)
|
| 416 |
+
# for unit_v in unit_v_list:
|
| 417 |
+
# filt = torch.mm(filt, torch.eye(self.in_features, device=x.device) - 2 * unit_v @ unit_v.t())
|
| 418 |
+
# # filt = torch.mm(filt, torch.eye(self.in_features, device=x.device) + self.v_square)
|
| 419 |
+
|
| 420 |
+
# bias_term = attn.bias.data if attn.bias is not None else None
|
| 421 |
+
# if bias_term is not None:
|
| 422 |
+
# bias_term = bias_term.to(orig_dtype)
|
| 423 |
+
|
| 424 |
+
# out = nn.functional.linear(input=x.to(orig_dtype), weight=filt.to(orig_dtype), bias=bias_term)
|
| 425 |
+
|
| 426 |
+
# return out
|
| 427 |
+
orig_weight = attn.weight.data
|
| 428 |
+
if self.apply_GS:
|
| 429 |
+
weight = [(self.hra_u[:, 0] / self.hra_u[:, 0].norm()).view(-1, 1)]
|
| 430 |
+
for i in range(1, self.r):
|
| 431 |
+
ui = self.hra_u[:, i].view(-1, 1)
|
| 432 |
+
for j in range(i):
|
| 433 |
+
ui = ui - (weight[j].t() @ ui) * weight[j]
|
| 434 |
+
weight.append((ui / ui.norm()).view(-1, 1))
|
| 435 |
+
weight = torch.cat(weight, dim=1)
|
| 436 |
+
new_weight = orig_weight @ (torch.eye(self.in_features, device=x.device) - 2 * weight @ weight.t())
|
| 437 |
+
|
| 438 |
+
else:
|
| 439 |
+
new_weight = orig_weight
|
| 440 |
+
hra_u_norm = self.hra_u / self.hra_u.norm(dim=0)
|
| 441 |
+
for i in range(self.r):
|
| 442 |
+
ui = hra_u_norm[:, i].view(-1, 1)
|
| 443 |
+
new_weight = torch.mm(new_weight, torch.eye(self.in_features, device=x.device) - 2 * ui @ ui.t())
|
| 444 |
+
|
| 445 |
+
out = nn.functional.linear(input=x, weight=new_weight, bias=attn.bias)
|
| 446 |
+
return out
|
| 447 |
+
|
| 448 |
+
class HRAAttnProcessor(nn.Module):
|
| 449 |
+
def __init__(self, hidden_size, cross_attention_dim=None, r=8, apply_GS=False):
|
| 450 |
+
super().__init__()
|
| 451 |
+
|
| 452 |
+
self.hidden_size = hidden_size
|
| 453 |
+
self.cross_attention_dim = cross_attention_dim
|
| 454 |
+
self.r = r
|
| 455 |
+
|
| 456 |
+
self.to_q_hra = HRALinearLayer(hidden_size, hidden_size, r=r, apply_GS=apply_GS)
|
| 457 |
+
self.to_k_hra = HRALinearLayer(cross_attention_dim or hidden_size, hidden_size, r=r, apply_GS=apply_GS)
|
| 458 |
+
self.to_v_hra = HRALinearLayer(cross_attention_dim or hidden_size, hidden_size, r=r, apply_GS=apply_GS)
|
| 459 |
+
self.to_out_hra = HRALinearLayer(hidden_size, hidden_size, r=r, apply_GS=apply_GS)
|
| 460 |
+
|
| 461 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
| 462 |
+
batch_size, sequence_length, _ = (
|
| 463 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 464 |
+
)
|
| 465 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 466 |
+
|
| 467 |
+
# query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
| 468 |
+
|
| 469 |
+
query = self.to_q_hra(attn.to_q, hidden_states)
|
| 470 |
+
query = attn.head_to_batch_dim(query)
|
| 471 |
+
|
| 472 |
+
if encoder_hidden_states is None:
|
| 473 |
+
encoder_hidden_states = hidden_states
|
| 474 |
+
elif attn.norm_cross:
|
| 475 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 476 |
+
|
| 477 |
+
# key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
| 478 |
+
key = self.to_k_hra(attn.to_k, encoder_hidden_states)
|
| 479 |
+
# value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
| 480 |
+
value = self.to_v_hra(attn.to_v, encoder_hidden_states)
|
| 481 |
+
|
| 482 |
+
key = attn.head_to_batch_dim(key)
|
| 483 |
+
value = attn.head_to_batch_dim(value)
|
| 484 |
+
|
| 485 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 486 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 487 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 488 |
+
|
| 489 |
+
# linear proj
|
| 490 |
+
# hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
| 491 |
+
hidden_states = self.to_out_hra(attn.to_out[0], hidden_states)
|
| 492 |
+
# dropout
|
| 493 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 494 |
+
|
| 495 |
+
return hidden_states
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def project(R, eps):
|
| 499 |
+
I = torch.zeros((R.size(0), R.size(0)), dtype=R.dtype, device=R.device)
|
| 500 |
+
diff = R - I
|
| 501 |
+
norm_diff = torch.norm(diff)
|
| 502 |
+
if norm_diff <= eps:
|
| 503 |
+
return R
|
| 504 |
+
else:
|
| 505 |
+
return I + eps * (diff / norm_diff)
|
| 506 |
+
|
| 507 |
+
def project_batch(R, eps=1e-5):
|
| 508 |
+
# scaling factor for each of the smaller block matrix
|
| 509 |
+
eps = eps * 1 / torch.sqrt(torch.tensor(R.shape[0]))
|
| 510 |
+
I = torch.zeros((R.size(1), R.size(1)), device=R.device, dtype=R.dtype).unsqueeze(0).expand_as(R)
|
| 511 |
+
diff = R - I
|
| 512 |
+
norm_diff = torch.norm(R - I, dim=(1, 2), keepdim=True)
|
| 513 |
+
mask = (norm_diff <= eps).bool()
|
| 514 |
+
out = torch.where(mask, R, I + eps * (diff / norm_diff))
|
| 515 |
+
return out
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class OFTLinearLayer(nn.Module):
|
| 519 |
+
def __init__(self, in_features, out_features, bias=False, block_share=False, eps=6e-5, r=4, is_coft=False):
|
| 520 |
+
super(OFTLinearLayer, self).__init__()
|
| 521 |
+
|
| 522 |
+
# Define the reduction rate:
|
| 523 |
+
self.r = r
|
| 524 |
+
|
| 525 |
+
# Check whether to use the constrained variant COFT
|
| 526 |
+
self.is_coft = is_coft
|
| 527 |
+
|
| 528 |
+
assert in_features % self.r == 0, "in_features must be divisible by r"
|
| 529 |
+
|
| 530 |
+
# Get the number of available GPUs
|
| 531 |
+
# self.num_gpus = torch.cuda.device_count()
|
| 532 |
+
# Set the device IDs for distributed training
|
| 533 |
+
# self.device_ids = list(range(self.num_gpus))
|
| 534 |
+
|
| 535 |
+
self.in_features=in_features
|
| 536 |
+
self.out_features=out_features
|
| 537 |
+
|
| 538 |
+
self.register_buffer('cross_attention_dim', torch.tensor(in_features))
|
| 539 |
+
self.register_buffer('hidden_size', torch.tensor(out_features))
|
| 540 |
+
|
| 541 |
+
# Define the fixed Linear layer: v
|
| 542 |
+
# self.OFT = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
|
| 543 |
+
|
| 544 |
+
#self.filt_shape = [in_features, in_features]
|
| 545 |
+
self.fix_filt_shape = [in_features, out_features]
|
| 546 |
+
|
| 547 |
+
self.block_share = block_share
|
| 548 |
+
# Define the trainable matrix parameter: R
|
| 549 |
+
if self.block_share:
|
| 550 |
+
# Initialized as an identity matrix
|
| 551 |
+
self.R_shape = [in_features // self.r, in_features // self.r]
|
| 552 |
+
self.R = nn.Parameter(torch.zeros(self.R_shape[0], self.R_shape[0]), requires_grad=True)
|
| 553 |
+
|
| 554 |
+
self.eps = eps * self.R_shape[0] * self.R_shape[0]
|
| 555 |
+
else:
|
| 556 |
+
# Initialized as an identity matrix
|
| 557 |
+
self.R_shape = [self.r, in_features // self.r, in_features // self.r]
|
| 558 |
+
R = torch.zeros(self.R_shape[1], self.R_shape[1])
|
| 559 |
+
R = torch.stack([R] * self.r)
|
| 560 |
+
self.R = nn.Parameter(R, requires_grad=True)
|
| 561 |
+
self.eps = eps * self.R_shape[1] * self.R_shape[1]
|
| 562 |
+
|
| 563 |
+
self.tmp = None
|
| 564 |
+
|
| 565 |
+
def forward(self, attn, x):
|
| 566 |
+
orig_dtype = x.dtype
|
| 567 |
+
dtype = self.R.dtype
|
| 568 |
+
|
| 569 |
+
if self.block_share:
|
| 570 |
+
if self.is_coft:
|
| 571 |
+
with torch.no_grad():
|
| 572 |
+
self.R.copy_(project(self.R, eps=self.eps))
|
| 573 |
+
orth_rotate = self.cayley(self.R)
|
| 574 |
+
else:
|
| 575 |
+
if self.is_coft:
|
| 576 |
+
with torch.no_grad():
|
| 577 |
+
self.R.copy_(project_batch(self.R, eps=self.eps))
|
| 578 |
+
# 如果没有cayley_batch这一步,那么self.R也不会更新
|
| 579 |
+
orth_rotate = self.cayley_batch(self.R)
|
| 580 |
+
|
| 581 |
+
# print('self.tmp[:5, :5]')
|
| 582 |
+
# print(self.tmp[:5, :5])
|
| 583 |
+
# if self.tmp is not None:
|
| 584 |
+
# print('self.R[0, :5, :5] - self.tmp[0, :5, :5]')
|
| 585 |
+
# print(self.R[0, :5, :5] - self.tmp[0, :5, :5])
|
| 586 |
+
# self.tmp = self.R.clone()
|
| 587 |
+
|
| 588 |
+
# Block-diagonal parametrization
|
| 589 |
+
block_diagonal_matrix = self.block_diagonal(orth_rotate)
|
| 590 |
+
|
| 591 |
+
# fix filter
|
| 592 |
+
fix_filt = attn.weight.data
|
| 593 |
+
fix_filt = torch.transpose(fix_filt, 0, 1)
|
| 594 |
+
filt = torch.mm(block_diagonal_matrix, fix_filt.to(dtype))
|
| 595 |
+
filt = torch.transpose(filt, 0, 1)
|
| 596 |
+
|
| 597 |
+
# Apply the trainable identity matrix
|
| 598 |
+
bias_term = attn.bias.data if attn.bias is not None else None
|
| 599 |
+
if bias_term is not None:
|
| 600 |
+
bias_term = bias_term.to(orig_dtype)
|
| 601 |
+
|
| 602 |
+
out = nn.functional.linear(input=x.to(orig_dtype), weight=filt.to(orig_dtype), bias=bias_term)
|
| 603 |
+
# out = nn.functional.linear(input=x, weight=fix_filt.transpose(0, 1), bias=bias_term)
|
| 604 |
+
|
| 605 |
+
return out
|
| 606 |
+
|
| 607 |
+
def cayley(self, data):
|
| 608 |
+
r, c = list(data.shape)
|
| 609 |
+
# Ensure the input matrix is skew-symmetric
|
| 610 |
+
skew = 0.5 * (data - data.t())
|
| 611 |
+
I = torch.eye(r, device=data.device)
|
| 612 |
+
# Perform the Cayley parametrization
|
| 613 |
+
Q = torch.mm(I - skew, torch.inverse(I + skew))
|
| 614 |
+
|
| 615 |
+
return Q
|
| 616 |
+
|
| 617 |
+
def cayley_batch(self, data):
|
| 618 |
+
b, r, c = data.shape
|
| 619 |
+
# Ensure the input matrix is skew-symmetric
|
| 620 |
+
skew = 0.5 * (data - data.transpose(1, 2))
|
| 621 |
+
# I = torch.eye(r, device=data.device).unsqueeze(0).repeat(b, 1, 1)
|
| 622 |
+
I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c)
|
| 623 |
+
|
| 624 |
+
# Perform the Cayley parametrization
|
| 625 |
+
Q = torch.bmm(I - skew, torch.inverse(I + skew))
|
| 626 |
+
|
| 627 |
+
return Q
|
| 628 |
+
|
| 629 |
+
def block_diagonal(self, R):
|
| 630 |
+
if len(R.shape) == 2:
|
| 631 |
+
# Create a list of R repeated block_count times
|
| 632 |
+
blocks = [R] * self.r
|
| 633 |
+
else:
|
| 634 |
+
# Create a list of R slices along the third dimension
|
| 635 |
+
blocks = [R[i, ...] for i in range(self.r)]
|
| 636 |
+
|
| 637 |
+
# Use torch.block_diag to create the block diagonal matrix
|
| 638 |
+
A = torch.block_diag(*blocks)
|
| 639 |
+
|
| 640 |
+
return A
|
| 641 |
+
|
| 642 |
+
def is_orthogonal(self, R, eps=1e-5):
|
| 643 |
+
with torch.no_grad():
|
| 644 |
+
RtR = torch.matmul(R.t(), R)
|
| 645 |
+
diff = torch.abs(RtR - torch.eye(R.shape[1], dtype=R.dtype, device=R.device))
|
| 646 |
+
return torch.all(diff < eps)
|
| 647 |
+
|
| 648 |
+
def is_identity_matrix(self, tensor):
|
| 649 |
+
if not torch.is_tensor(tensor):
|
| 650 |
+
raise TypeError("Input must be a PyTorch tensor.")
|
| 651 |
+
if tensor.ndim != 2 or tensor.shape[0] != tensor.shape[1]:
|
| 652 |
+
return False
|
| 653 |
+
identity = torch.eye(tensor.shape[0], device=tensor.device)
|
| 654 |
+
return torch.all(torch.eq(tensor, identity))
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
class OFTAttnProcessor(nn.Module):
|
| 658 |
+
def __init__(self, hidden_size, cross_attention_dim=None, eps=2e-5, r=4, is_coft=False):
|
| 659 |
+
super().__init__()
|
| 660 |
+
|
| 661 |
+
self.hidden_size = hidden_size
|
| 662 |
+
self.cross_attention_dim = cross_attention_dim
|
| 663 |
+
self.r = r
|
| 664 |
+
self.is_coft = is_coft
|
| 665 |
+
|
| 666 |
+
self.to_q_oft = OFTLinearLayer(hidden_size, hidden_size, eps=eps, r=r, is_coft=is_coft)
|
| 667 |
+
self.to_k_oft = OFTLinearLayer(cross_attention_dim or hidden_size, hidden_size, eps=eps, r=r, is_coft=is_coft)
|
| 668 |
+
self.to_v_oft = OFTLinearLayer(cross_attention_dim or hidden_size, hidden_size, eps=eps, r=r, is_coft=is_coft)
|
| 669 |
+
self.to_out_oft = OFTLinearLayer(hidden_size, hidden_size, eps=eps, r=r, is_coft=is_coft)
|
| 670 |
+
|
| 671 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
|
| 672 |
+
batch_size, sequence_length, _ = (
|
| 673 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 674 |
+
)
|
| 675 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 676 |
+
|
| 677 |
+
# query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
| 678 |
+
|
| 679 |
+
query = self.to_q_oft(attn.to_q, hidden_states)
|
| 680 |
+
query = attn.head_to_batch_dim(query)
|
| 681 |
+
|
| 682 |
+
if encoder_hidden_states is None:
|
| 683 |
+
encoder_hidden_states = hidden_states
|
| 684 |
+
elif attn.norm_cross:
|
| 685 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 686 |
+
|
| 687 |
+
# key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
| 688 |
+
key = self.to_k_oft(attn.to_k, encoder_hidden_states)
|
| 689 |
+
# value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
| 690 |
+
value = self.to_v_oft(attn.to_v, encoder_hidden_states)
|
| 691 |
+
|
| 692 |
+
key = attn.head_to_batch_dim(key)
|
| 693 |
+
value = attn.head_to_batch_dim(value)
|
| 694 |
+
|
| 695 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 696 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 697 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 698 |
+
|
| 699 |
+
# linear proj
|
| 700 |
+
# hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
| 701 |
+
hidden_states = self.to_out_oft(attn.to_out[0], hidden_states)
|
| 702 |
+
# dropout
|
| 703 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 704 |
+
|
| 705 |
+
return hidden_states
|
| 706 |
+
|
| 707 |
+
|
| 708 |
+
class AttnAddedKVProcessor:
|
| 709 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 710 |
+
residual = hidden_states
|
| 711 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
| 712 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 713 |
+
|
| 714 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 715 |
+
|
| 716 |
+
if encoder_hidden_states is None:
|
| 717 |
+
encoder_hidden_states = hidden_states
|
| 718 |
+
elif attn.norm_cross:
|
| 719 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 720 |
+
|
| 721 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 722 |
+
|
| 723 |
+
query = attn.to_q(hidden_states)
|
| 724 |
+
query = attn.head_to_batch_dim(query)
|
| 725 |
+
|
| 726 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
| 727 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
| 728 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
| 729 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
| 730 |
+
|
| 731 |
+
if not attn.only_cross_attention:
|
| 732 |
+
key = attn.to_k(hidden_states)
|
| 733 |
+
value = attn.to_v(hidden_states)
|
| 734 |
+
key = attn.head_to_batch_dim(key)
|
| 735 |
+
value = attn.head_to_batch_dim(value)
|
| 736 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
| 737 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
| 738 |
+
else:
|
| 739 |
+
key = encoder_hidden_states_key_proj
|
| 740 |
+
value = encoder_hidden_states_value_proj
|
| 741 |
+
|
| 742 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 743 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 744 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 745 |
+
|
| 746 |
+
# linear proj
|
| 747 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 748 |
+
# dropout
|
| 749 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 750 |
+
|
| 751 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
| 752 |
+
hidden_states = hidden_states + residual
|
| 753 |
+
|
| 754 |
+
return hidden_states
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
class AttnAddedKVProcessor2_0:
|
| 758 |
+
def __init__(self):
|
| 759 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 760 |
+
raise ImportError(
|
| 761 |
+
"AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
| 762 |
+
)
|
| 763 |
+
|
| 764 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 765 |
+
residual = hidden_states
|
| 766 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
| 767 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 768 |
+
|
| 769 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
|
| 770 |
+
|
| 771 |
+
if encoder_hidden_states is None:
|
| 772 |
+
encoder_hidden_states = hidden_states
|
| 773 |
+
elif attn.norm_cross:
|
| 774 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 775 |
+
|
| 776 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 777 |
+
|
| 778 |
+
query = attn.to_q(hidden_states)
|
| 779 |
+
query = attn.head_to_batch_dim(query, out_dim=4)
|
| 780 |
+
|
| 781 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
| 782 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
| 783 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
|
| 784 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
|
| 785 |
+
|
| 786 |
+
if not attn.only_cross_attention:
|
| 787 |
+
key = attn.to_k(hidden_states)
|
| 788 |
+
value = attn.to_v(hidden_states)
|
| 789 |
+
key = attn.head_to_batch_dim(key, out_dim=4)
|
| 790 |
+
value = attn.head_to_batch_dim(value, out_dim=4)
|
| 791 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
| 792 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
| 793 |
+
else:
|
| 794 |
+
key = encoder_hidden_states_key_proj
|
| 795 |
+
value = encoder_hidden_states_value_proj
|
| 796 |
+
|
| 797 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 798 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 799 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 800 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 801 |
+
)
|
| 802 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
|
| 803 |
+
|
| 804 |
+
# linear proj
|
| 805 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 806 |
+
# dropout
|
| 807 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 808 |
+
|
| 809 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
| 810 |
+
hidden_states = hidden_states + residual
|
| 811 |
+
|
| 812 |
+
return hidden_states
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
class XFormersAttnProcessor:
|
| 816 |
+
def __init__(self, attention_op: Optional[Callable] = None):
|
| 817 |
+
self.attention_op = attention_op
|
| 818 |
+
|
| 819 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 820 |
+
batch_size, sequence_length, _ = (
|
| 821 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 825 |
+
|
| 826 |
+
query = attn.to_q(hidden_states)
|
| 827 |
+
|
| 828 |
+
if encoder_hidden_states is None:
|
| 829 |
+
encoder_hidden_states = hidden_states
|
| 830 |
+
elif attn.norm_cross:
|
| 831 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 832 |
+
|
| 833 |
+
key = attn.to_k(encoder_hidden_states)
|
| 834 |
+
value = attn.to_v(encoder_hidden_states)
|
| 835 |
+
|
| 836 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
| 837 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
| 838 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
| 839 |
+
|
| 840 |
+
hidden_states = xformers.ops.memory_efficient_attention(
|
| 841 |
+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
| 842 |
+
)
|
| 843 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 844 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 845 |
+
|
| 846 |
+
# linear proj
|
| 847 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 848 |
+
# dropout
|
| 849 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 850 |
+
return hidden_states
|
| 851 |
+
|
| 852 |
+
|
| 853 |
+
class AttnProcessor2_0:
|
| 854 |
+
def __init__(self):
|
| 855 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 856 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 857 |
+
|
| 858 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 859 |
+
batch_size, sequence_length, _ = (
|
| 860 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 861 |
+
)
|
| 862 |
+
inner_dim = hidden_states.shape[-1]
|
| 863 |
+
|
| 864 |
+
if attention_mask is not None:
|
| 865 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 866 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 867 |
+
# (batch, heads, source_length, target_length)
|
| 868 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 869 |
+
|
| 870 |
+
query = attn.to_q(hidden_states)
|
| 871 |
+
|
| 872 |
+
if encoder_hidden_states is None:
|
| 873 |
+
encoder_hidden_states = hidden_states
|
| 874 |
+
elif attn.norm_cross:
|
| 875 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 876 |
+
|
| 877 |
+
key = attn.to_k(encoder_hidden_states)
|
| 878 |
+
value = attn.to_v(encoder_hidden_states)
|
| 879 |
+
|
| 880 |
+
head_dim = inner_dim // attn.heads
|
| 881 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 882 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 883 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 884 |
+
|
| 885 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 886 |
+
# TODO: add support for attn.scale when we move to Torch 2.1
|
| 887 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 888 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 892 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 893 |
+
|
| 894 |
+
# linear proj
|
| 895 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 896 |
+
# dropout
|
| 897 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 898 |
+
return hidden_states
|
| 899 |
+
|
| 900 |
+
|
| 901 |
+
class SlicedAttnProcessor:
|
| 902 |
+
def __init__(self, slice_size):
|
| 903 |
+
self.slice_size = slice_size
|
| 904 |
+
|
| 905 |
+
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 906 |
+
batch_size, sequence_length, _ = (
|
| 907 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 908 |
+
)
|
| 909 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 910 |
+
|
| 911 |
+
query = attn.to_q(hidden_states)
|
| 912 |
+
dim = query.shape[-1]
|
| 913 |
+
query = attn.head_to_batch_dim(query)
|
| 914 |
+
|
| 915 |
+
if encoder_hidden_states is None:
|
| 916 |
+
encoder_hidden_states = hidden_states
|
| 917 |
+
elif attn.norm_cross:
|
| 918 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 919 |
+
|
| 920 |
+
key = attn.to_k(encoder_hidden_states)
|
| 921 |
+
value = attn.to_v(encoder_hidden_states)
|
| 922 |
+
key = attn.head_to_batch_dim(key)
|
| 923 |
+
value = attn.head_to_batch_dim(value)
|
| 924 |
+
|
| 925 |
+
batch_size_attention, query_tokens, _ = query.shape
|
| 926 |
+
hidden_states = torch.zeros(
|
| 927 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
for i in range(batch_size_attention // self.slice_size):
|
| 931 |
+
start_idx = i * self.slice_size
|
| 932 |
+
end_idx = (i + 1) * self.slice_size
|
| 933 |
+
|
| 934 |
+
query_slice = query[start_idx:end_idx]
|
| 935 |
+
key_slice = key[start_idx:end_idx]
|
| 936 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
| 937 |
+
|
| 938 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
| 939 |
+
|
| 940 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
| 941 |
+
|
| 942 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
| 943 |
+
|
| 944 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 945 |
+
|
| 946 |
+
# linear proj
|
| 947 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 948 |
+
# dropout
|
| 949 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 950 |
+
|
| 951 |
+
return hidden_states
|
| 952 |
+
|
| 953 |
+
|
| 954 |
+
class SlicedAttnAddedKVProcessor:
|
| 955 |
+
def __init__(self, slice_size):
|
| 956 |
+
self.slice_size = slice_size
|
| 957 |
+
|
| 958 |
+
def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 959 |
+
residual = hidden_states
|
| 960 |
+
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
| 961 |
+
|
| 962 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
| 963 |
+
|
| 964 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 965 |
+
|
| 966 |
+
if encoder_hidden_states is None:
|
| 967 |
+
encoder_hidden_states = hidden_states
|
| 968 |
+
elif attn.norm_cross:
|
| 969 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 970 |
+
|
| 971 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 972 |
+
|
| 973 |
+
query = attn.to_q(hidden_states)
|
| 974 |
+
dim = query.shape[-1]
|
| 975 |
+
query = attn.head_to_batch_dim(query)
|
| 976 |
+
|
| 977 |
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
| 978 |
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
| 979 |
+
|
| 980 |
+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
| 981 |
+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
| 982 |
+
|
| 983 |
+
if not attn.only_cross_attention:
|
| 984 |
+
key = attn.to_k(hidden_states)
|
| 985 |
+
value = attn.to_v(hidden_states)
|
| 986 |
+
key = attn.head_to_batch_dim(key)
|
| 987 |
+
value = attn.head_to_batch_dim(value)
|
| 988 |
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
|
| 989 |
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
|
| 990 |
+
else:
|
| 991 |
+
key = encoder_hidden_states_key_proj
|
| 992 |
+
value = encoder_hidden_states_value_proj
|
| 993 |
+
|
| 994 |
+
batch_size_attention, query_tokens, _ = query.shape
|
| 995 |
+
hidden_states = torch.zeros(
|
| 996 |
+
(batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
|
| 997 |
+
)
|
| 998 |
+
|
| 999 |
+
for i in range(batch_size_attention // self.slice_size):
|
| 1000 |
+
start_idx = i * self.slice_size
|
| 1001 |
+
end_idx = (i + 1) * self.slice_size
|
| 1002 |
+
|
| 1003 |
+
query_slice = query[start_idx:end_idx]
|
| 1004 |
+
key_slice = key[start_idx:end_idx]
|
| 1005 |
+
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
| 1006 |
+
|
| 1007 |
+
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
| 1008 |
+
|
| 1009 |
+
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
| 1010 |
+
|
| 1011 |
+
hidden_states[start_idx:end_idx] = attn_slice
|
| 1012 |
+
|
| 1013 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 1014 |
+
|
| 1015 |
+
# linear proj
|
| 1016 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 1017 |
+
# dropout
|
| 1018 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 1019 |
+
|
| 1020 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
| 1021 |
+
hidden_states = hidden_states + residual
|
| 1022 |
+
|
| 1023 |
+
return hidden_states
|
| 1024 |
+
|
| 1025 |
+
|
| 1026 |
+
AttentionProcessor = Union[
|
| 1027 |
+
AttnProcessor,
|
| 1028 |
+
AttnProcessor2_0,
|
| 1029 |
+
XFormersAttnProcessor,
|
| 1030 |
+
SlicedAttnProcessor,
|
| 1031 |
+
AttnAddedKVProcessor,
|
| 1032 |
+
SlicedAttnAddedKVProcessor,
|
| 1033 |
+
AttnAddedKVProcessor2_0,
|
| 1034 |
+
OFTAttnProcessor,
|
| 1035 |
+
HRAAttnProcessor
|
| 1036 |
+
]
|
generation/subject/oft_utils/mhe.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
import copy
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
class MHE_LoRA(nn.Module):
|
| 12 |
+
def __init__(self, model):
|
| 13 |
+
super(MHE_LoRA, self).__init__()
|
| 14 |
+
# self.model = copy.deepcopy(model)
|
| 15 |
+
self.model = self.copy_without_grad(model)
|
| 16 |
+
|
| 17 |
+
self.extracted_params = {}
|
| 18 |
+
keys_to_delete = []
|
| 19 |
+
# for name, param in self.model.named_parameters():
|
| 20 |
+
# self.extracted_params[name] = param
|
| 21 |
+
|
| 22 |
+
for name, tensor in model.state_dict().items():
|
| 23 |
+
self.extracted_params[name] = tensor.detach().clone()
|
| 24 |
+
|
| 25 |
+
for name in self.extracted_params:
|
| 26 |
+
if 'attn' in name and 'processor' not in name:
|
| 27 |
+
if 'weight' in name:
|
| 28 |
+
if 'to_q' in name:
|
| 29 |
+
lora_down = name.replace('to_q', 'processor.to_q_lora.down')
|
| 30 |
+
lora_up = name.replace('to_q', 'processor.to_q_lora.up')
|
| 31 |
+
elif 'to_k' in name:
|
| 32 |
+
lora_down = name.replace('to_k', 'processor.to_k_lora.down')
|
| 33 |
+
lora_up = name.replace('to_k', 'processor.to_k_lora.up')
|
| 34 |
+
elif 'to_v' in name:
|
| 35 |
+
lora_down = name.replace('to_v', 'processor.to_v_lora.down')
|
| 36 |
+
lora_up = name.replace('to_v', 'processor.to_v_lora.up')
|
| 37 |
+
elif 'to_out' in name:
|
| 38 |
+
lora_down = name.replace('to_out.0', 'processor.to_out_lora.down')
|
| 39 |
+
lora_up = name.replace('to_out.0', 'processor.to_out_lora.up')
|
| 40 |
+
else:
|
| 41 |
+
pass
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
self.extracted_params[name] += self.extracted_params[lora_up].cuda() @ self.extracted_params[lora_down].cuda()
|
| 44 |
+
keys_to_delete.append(lora_up)
|
| 45 |
+
keys_to_delete.append(lora_down)
|
| 46 |
+
|
| 47 |
+
for key in keys_to_delete:
|
| 48 |
+
del self.extracted_params[key]
|
| 49 |
+
|
| 50 |
+
def copy_without_grad(self, model):
|
| 51 |
+
copied_model = copy.deepcopy(model)
|
| 52 |
+
for param in copied_model.parameters():
|
| 53 |
+
param.requires_grad = False
|
| 54 |
+
param.detach_()
|
| 55 |
+
return copied_model
|
| 56 |
+
|
| 57 |
+
@staticmethod
|
| 58 |
+
def mhe_loss(filt):
|
| 59 |
+
if len(filt.shape) == 2:
|
| 60 |
+
n_filt, _ = filt.shape
|
| 61 |
+
filt = torch.transpose(filt, 0, 1)
|
| 62 |
+
filt_neg = filt * (-1)
|
| 63 |
+
filt = torch.cat((filt, filt_neg), dim=1)
|
| 64 |
+
n_filt *= 2
|
| 65 |
+
|
| 66 |
+
filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
|
| 67 |
+
norm_mat = torch.matmul(filt_norm.t(), filt_norm)
|
| 68 |
+
inner_pro = torch.matmul(filt.t(), filt)
|
| 69 |
+
inner_pro /= norm_mat
|
| 70 |
+
|
| 71 |
+
cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
|
| 72 |
+
final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
|
| 73 |
+
final -= torch.tril(final)
|
| 74 |
+
cnt = n_filt * (n_filt - 1) / 2.0
|
| 75 |
+
MHE_loss = 1 * torch.sum(final) / cnt
|
| 76 |
+
|
| 77 |
+
else:
|
| 78 |
+
n_filt, _, _, _ = filt.shape
|
| 79 |
+
filt = filt.reshape(n_filt, -1)
|
| 80 |
+
filt = torch.transpose(filt, 0, 1)
|
| 81 |
+
filt_neg = filt * -1
|
| 82 |
+
filt = torch.cat((filt, filt_neg), dim=1)
|
| 83 |
+
n_filt *= 2
|
| 84 |
+
|
| 85 |
+
filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
|
| 86 |
+
norm_mat = torch.matmul(filt_norm.t(), filt_norm)
|
| 87 |
+
inner_pro = torch.matmul(filt.t(), filt)
|
| 88 |
+
inner_pro /= norm_mat
|
| 89 |
+
|
| 90 |
+
cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
|
| 91 |
+
final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
|
| 92 |
+
final -= torch.tril(final)
|
| 93 |
+
cnt = n_filt * (n_filt - 1) / 2.0
|
| 94 |
+
MHE_loss = 1 * torch.sum(final) / cnt
|
| 95 |
+
|
| 96 |
+
return MHE_loss
|
| 97 |
+
|
| 98 |
+
def calculate_mhe(self):
|
| 99 |
+
mhe_loss = []
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
for name in self.extracted_params:
|
| 102 |
+
weight = self.extracted_params[name]
|
| 103 |
+
# linear layer or conv layer
|
| 104 |
+
if len(weight.shape) == 2 or len(weight.shape) == 4:
|
| 105 |
+
loss = self.mhe_loss(weight)
|
| 106 |
+
mhe_loss.append(loss.cpu().detach().item())
|
| 107 |
+
mhe_loss = np.array(mhe_loss)
|
| 108 |
+
return mhe_loss.sum()
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def project(R, eps):
|
| 112 |
+
I = torch.zeros((R.size(0), R.size(0)), dtype=R.dtype, device=R.device)
|
| 113 |
+
diff = R - I
|
| 114 |
+
norm_diff = torch.norm(diff)
|
| 115 |
+
if norm_diff <= eps:
|
| 116 |
+
return R
|
| 117 |
+
else:
|
| 118 |
+
return I + eps * (diff / norm_diff)
|
| 119 |
+
|
| 120 |
+
def project_batch(R, eps=1e-5):
|
| 121 |
+
# scaling factor for each of the smaller block matrix
|
| 122 |
+
eps = eps * 1 / torch.sqrt(torch.tensor(R.shape[0]))
|
| 123 |
+
I = torch.zeros((R.size(1), R.size(1)), device=R.device, dtype=R.dtype).unsqueeze(0).expand_as(R)
|
| 124 |
+
diff = R - I
|
| 125 |
+
norm_diff = torch.norm(R - I, dim=(1, 2), keepdim=True)
|
| 126 |
+
mask = (norm_diff <= eps).bool()
|
| 127 |
+
out = torch.where(mask, R, I + eps * (diff / norm_diff))
|
| 128 |
+
return out
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class MHE_OFT(nn.Module):
|
| 132 |
+
def __init__(self, model, eps=6e-5, r=4):
|
| 133 |
+
super(MHE_OFT, self).__init__()
|
| 134 |
+
# self.model = copy.deepcopy(model)
|
| 135 |
+
# self.model = self.copy_without_grad(model)
|
| 136 |
+
|
| 137 |
+
self.r = r
|
| 138 |
+
|
| 139 |
+
self.extracted_params = {}
|
| 140 |
+
keys_to_delete = []
|
| 141 |
+
# for name, param in self.model.named_parameters():
|
| 142 |
+
# self.extracted_params[name] = param
|
| 143 |
+
|
| 144 |
+
for name, tensor in model.state_dict().items():
|
| 145 |
+
self.extracted_params[name] = tensor.detach().clone()
|
| 146 |
+
|
| 147 |
+
for name in self.extracted_params:
|
| 148 |
+
if 'attn' in name and 'processor' not in name:
|
| 149 |
+
if 'weight' in name:
|
| 150 |
+
if 'to_q' in name:
|
| 151 |
+
oft_R = name.replace('to_q.weight', 'processor.to_q_oft.R')
|
| 152 |
+
elif 'to_k' in name:
|
| 153 |
+
oft_R = name.replace('to_k.weight', 'processor.to_k_oft.R')
|
| 154 |
+
elif 'to_v' in name:
|
| 155 |
+
oft_R = name.replace('to_v.weight', 'processor.to_v_oft.R')
|
| 156 |
+
elif 'to_out' in name:
|
| 157 |
+
oft_R = name.replace('to_out.0.weight', 'processor.to_out_oft.R')
|
| 158 |
+
else:
|
| 159 |
+
pass
|
| 160 |
+
|
| 161 |
+
R = self.extracted_params[oft_R].cuda()
|
| 162 |
+
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
if len(R.shape) == 2:
|
| 165 |
+
self.eps = eps * R.shape[0] * R.shape[0]
|
| 166 |
+
R.copy_(project(R, eps=self.eps))
|
| 167 |
+
orth_rotate = self.cayley(R)
|
| 168 |
+
else:
|
| 169 |
+
self.eps = eps * R.shape[1] * R.shape[0]
|
| 170 |
+
R.copy_(project_batch(R, eps=self.eps))
|
| 171 |
+
orth_rotate = self.cayley_batch(R)
|
| 172 |
+
|
| 173 |
+
self.extracted_params[name] = self.extracted_params[name] @ self.block_diagonal(orth_rotate)
|
| 174 |
+
keys_to_delete.append(oft_R)
|
| 175 |
+
|
| 176 |
+
for key in keys_to_delete:
|
| 177 |
+
del self.extracted_params[key]
|
| 178 |
+
|
| 179 |
+
def is_orthogonal(self, R, eps=1e-5):
|
| 180 |
+
with torch.no_grad():
|
| 181 |
+
RtR = torch.matmul(R.t(), R)
|
| 182 |
+
diff = torch.abs(RtR - torch.eye(R.shape[1], dtype=R.dtype, device=R.device))
|
| 183 |
+
return torch.all(diff < eps)
|
| 184 |
+
|
| 185 |
+
def block_diagonal(self, R):
|
| 186 |
+
if len(R.shape) == 2:
|
| 187 |
+
# Create a list of R repeated block_count times
|
| 188 |
+
blocks = [R] * self.r
|
| 189 |
+
else:
|
| 190 |
+
# Create a list of R slices along the third dimension
|
| 191 |
+
blocks = [R[i, ...] for i in range(R.shape[0])]
|
| 192 |
+
|
| 193 |
+
# Use torch.block_diag to create the block diagonal matrix
|
| 194 |
+
A = torch.block_diag(*blocks)
|
| 195 |
+
|
| 196 |
+
return A
|
| 197 |
+
|
| 198 |
+
def copy_without_grad(self, model):
|
| 199 |
+
copied_model = copy.deepcopy(model)
|
| 200 |
+
for param in copied_model.parameters():
|
| 201 |
+
param.requires_grad = False
|
| 202 |
+
param.detach_()
|
| 203 |
+
return copied_model
|
| 204 |
+
|
| 205 |
+
def cayley(self, data):
|
| 206 |
+
r, c = list(data.shape)
|
| 207 |
+
# Ensure the input matrix is skew-symmetric
|
| 208 |
+
skew = 0.5 * (data - data.t())
|
| 209 |
+
I = torch.eye(r, device=data.device)
|
| 210 |
+
# Perform the Cayley parametrization
|
| 211 |
+
Q = torch.mm(I + skew, torch.inverse(I - skew))
|
| 212 |
+
return Q
|
| 213 |
+
|
| 214 |
+
def cayley_batch(self, data):
|
| 215 |
+
b, r, c = data.shape
|
| 216 |
+
# Ensure the input matrix is skew-symmetric
|
| 217 |
+
skew = 0.5 * (data - data.transpose(1, 2))
|
| 218 |
+
# I = torch.eye(r, device=data.device).unsqueeze(0).repeat(b, 1, 1)
|
| 219 |
+
I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c)
|
| 220 |
+
|
| 221 |
+
# Perform the Cayley parametrization
|
| 222 |
+
Q = torch.bmm(I + skew, torch.inverse(I - skew))
|
| 223 |
+
|
| 224 |
+
return Q
|
| 225 |
+
|
| 226 |
+
@staticmethod
|
| 227 |
+
def mhe_loss(filt):
|
| 228 |
+
if len(filt.shape) == 2:
|
| 229 |
+
n_filt, _ = filt.shape
|
| 230 |
+
filt = torch.transpose(filt, 0, 1)
|
| 231 |
+
filt_neg = filt * (-1)
|
| 232 |
+
filt = torch.cat((filt, filt_neg), dim=1)
|
| 233 |
+
n_filt *= 2
|
| 234 |
+
|
| 235 |
+
filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
|
| 236 |
+
norm_mat = torch.matmul(filt_norm.t(), filt_norm)
|
| 237 |
+
inner_pro = torch.matmul(filt.t(), filt)
|
| 238 |
+
inner_pro /= norm_mat
|
| 239 |
+
|
| 240 |
+
cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
|
| 241 |
+
final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
|
| 242 |
+
final -= torch.tril(final)
|
| 243 |
+
cnt = n_filt * (n_filt - 1) / 2.0
|
| 244 |
+
MHE_loss = 1 * torch.sum(final) / cnt
|
| 245 |
+
|
| 246 |
+
else:
|
| 247 |
+
n_filt, _, _, _ = filt.shape
|
| 248 |
+
filt = filt.reshape(n_filt, -1)
|
| 249 |
+
filt = torch.transpose(filt, 0, 1)
|
| 250 |
+
filt_neg = filt * -1
|
| 251 |
+
filt = torch.cat((filt, filt_neg), dim=1)
|
| 252 |
+
n_filt *= 2
|
| 253 |
+
|
| 254 |
+
filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
|
| 255 |
+
norm_mat = torch.matmul(filt_norm.t(), filt_norm)
|
| 256 |
+
inner_pro = torch.matmul(filt.t(), filt)
|
| 257 |
+
inner_pro /= norm_mat
|
| 258 |
+
|
| 259 |
+
cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
|
| 260 |
+
final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
|
| 261 |
+
final -= torch.tril(final)
|
| 262 |
+
cnt = n_filt * (n_filt - 1) / 2.0
|
| 263 |
+
MHE_loss = 1 * torch.sum(final) / cnt
|
| 264 |
+
|
| 265 |
+
return MHE_loss
|
| 266 |
+
|
| 267 |
+
def calculate_mhe(self):
|
| 268 |
+
mhe_loss = []
|
| 269 |
+
with torch.no_grad():
|
| 270 |
+
for name in self.extracted_params:
|
| 271 |
+
weight = self.extracted_params[name]
|
| 272 |
+
# linear layer or conv layer
|
| 273 |
+
if len(weight.shape) == 2 or len(weight.shape) == 4:
|
| 274 |
+
loss = self.mhe_loss(weight)
|
| 275 |
+
mhe_loss.append(loss.cpu().detach().item())
|
| 276 |
+
mhe_loss = np.array(mhe_loss)
|
| 277 |
+
return mhe_loss.sum()
|
| 278 |
+
|
| 279 |
+
def is_orthogonal(self, R, eps=1e-5):
|
| 280 |
+
with torch.no_grad():
|
| 281 |
+
RtR = torch.matmul(R.t(), R)
|
| 282 |
+
diff = torch.abs(RtR - torch.eye(R.shape[1], dtype=R.dtype, device=R.device))
|
| 283 |
+
return torch.all(diff < eps)
|
| 284 |
+
|
| 285 |
+
def is_identity_matrix(self, tensor):
|
| 286 |
+
if not torch.is_tensor(tensor):
|
| 287 |
+
raise TypeError("Input must be a PyTorch tensor.")
|
| 288 |
+
if tensor.ndim != 2 or tensor.shape[0] != tensor.shape[1]:
|
| 289 |
+
return False
|
| 290 |
+
identity = torch.eye(tensor.shape[0], device=tensor.device)
|
| 291 |
+
return torch.all(torch.eq(tensor, identity))
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class MHE_db:
|
| 296 |
+
def __init__(self, model):
|
| 297 |
+
# self.model = copy.deepcopy(model)
|
| 298 |
+
# self.model.load_state_dict(model.state_dict())
|
| 299 |
+
# self.model = self.copy_without_grad(model)
|
| 300 |
+
|
| 301 |
+
#self.extracted_params = {}
|
| 302 |
+
#for name, param in model.named_parameters():
|
| 303 |
+
# self.extracted_params[name] = param
|
| 304 |
+
|
| 305 |
+
self.extracted_params = {}
|
| 306 |
+
for name, tensor in model.state_dict().items():
|
| 307 |
+
self.extracted_params[name] = tensor.detach().clone()
|
| 308 |
+
|
| 309 |
+
@staticmethod
|
| 310 |
+
def mhe_loss(filt):
|
| 311 |
+
if len(filt.shape) == 2:
|
| 312 |
+
n_filt, _ = filt.shape
|
| 313 |
+
filt = torch.transpose(filt, 0, 1)
|
| 314 |
+
filt_neg = filt * (-1)
|
| 315 |
+
filt = torch.cat((filt, filt_neg), dim=1)
|
| 316 |
+
n_filt *= 2
|
| 317 |
+
|
| 318 |
+
filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
|
| 319 |
+
norm_mat = torch.matmul(filt_norm.t(), filt_norm)
|
| 320 |
+
inner_pro = torch.matmul(filt.t(), filt)
|
| 321 |
+
inner_pro /= norm_mat
|
| 322 |
+
|
| 323 |
+
cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
|
| 324 |
+
final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
|
| 325 |
+
final -= torch.tril(final)
|
| 326 |
+
cnt = n_filt * (n_filt - 1) / 2.0
|
| 327 |
+
MHE_loss = 1 * torch.sum(final) / cnt
|
| 328 |
+
|
| 329 |
+
else:
|
| 330 |
+
n_filt, _, _, _ = filt.shape
|
| 331 |
+
filt = filt.reshape(n_filt, -1)
|
| 332 |
+
filt = torch.transpose(filt, 0, 1)
|
| 333 |
+
filt_neg = filt * -1
|
| 334 |
+
filt = torch.cat((filt, filt_neg), dim=1)
|
| 335 |
+
n_filt *= 2
|
| 336 |
+
|
| 337 |
+
filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
|
| 338 |
+
norm_mat = torch.matmul(filt_norm.t(), filt_norm)
|
| 339 |
+
inner_pro = torch.matmul(filt.t(), filt)
|
| 340 |
+
inner_pro /= norm_mat
|
| 341 |
+
|
| 342 |
+
cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
|
| 343 |
+
final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
|
| 344 |
+
final -= torch.tril(final)
|
| 345 |
+
cnt = n_filt * (n_filt - 1) / 2.0
|
| 346 |
+
MHE_loss = 1 * torch.sum(final) / cnt
|
| 347 |
+
|
| 348 |
+
return MHE_loss
|
| 349 |
+
|
| 350 |
+
def calculate_mhe(self):
|
| 351 |
+
mhe_loss = []
|
| 352 |
+
with torch.no_grad():
|
| 353 |
+
for name in self.extracted_params:
|
| 354 |
+
weight = self.extracted_params[name]
|
| 355 |
+
# linear layer or conv layer
|
| 356 |
+
if len(weight.shape) == 2 or len(weight.shape) == 4:
|
| 357 |
+
loss = self.mhe_loss(weight)
|
| 358 |
+
mhe_loss.append(loss.cpu().detach().item())
|
| 359 |
+
mhe_loss = np.array(mhe_loss)
|
| 360 |
+
return mhe_loss.sum()
|
generation/subject/train_dreambooth_hra.py
ADDED
|
@@ -0,0 +1,1123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding=utf-8
|
| 3 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import hashlib
|
| 18 |
+
import logging
|
| 19 |
+
import math
|
| 20 |
+
import os
|
| 21 |
+
import warnings
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
import torch.utils.checkpoint
|
| 28 |
+
import transformers
|
| 29 |
+
from accelerate import Accelerator
|
| 30 |
+
from accelerate.logging import get_logger
|
| 31 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 32 |
+
from huggingface_hub import create_repo, upload_folder
|
| 33 |
+
from packaging import version
|
| 34 |
+
from PIL import Image
|
| 35 |
+
from torch.utils.data import Dataset
|
| 36 |
+
from torchvision import transforms
|
| 37 |
+
from tqdm.auto import tqdm
|
| 38 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
| 39 |
+
|
| 40 |
+
import diffusers
|
| 41 |
+
from diffusers import (
|
| 42 |
+
AutoencoderKL,
|
| 43 |
+
DDPMScheduler,
|
| 44 |
+
DiffusionPipeline,
|
| 45 |
+
DPMSolverMultistepScheduler,
|
| 46 |
+
UNet2DConditionModel,
|
| 47 |
+
)
|
| 48 |
+
from diffusers.loaders import AttnProcsLayers
|
| 49 |
+
from oft_utils.attention_processor import HRAAttnProcessor
|
| 50 |
+
from diffusers.optimization import get_scheduler
|
| 51 |
+
from diffusers.utils import check_min_version, is_wandb_available
|
| 52 |
+
from diffusers.utils.import_utils import is_xformers_available
|
| 53 |
+
from oft_utils.mhe import MHE_OFT as MHE
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 57 |
+
check_min_version("0.16.0.dev0")
|
| 58 |
+
|
| 59 |
+
logger = get_logger(__name__)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
|
| 63 |
+
img_str = ""
|
| 64 |
+
for i, image in enumerate(images):
|
| 65 |
+
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
| 66 |
+
img_str += f"\n"
|
| 67 |
+
|
| 68 |
+
yaml = f"""
|
| 69 |
+
---
|
| 70 |
+
license: creativeml-openrail-m
|
| 71 |
+
base_model: {base_model}
|
| 72 |
+
instance_prompt: {prompt}
|
| 73 |
+
tags:
|
| 74 |
+
- stable-diffusion
|
| 75 |
+
- stable-diffusion-diffusers
|
| 76 |
+
- text-to-image
|
| 77 |
+
- diffusers
|
| 78 |
+
- oft
|
| 79 |
+
inference: true
|
| 80 |
+
---
|
| 81 |
+
"""
|
| 82 |
+
model_card = f"""
|
| 83 |
+
# OFT DreamBooth - {repo_id}
|
| 84 |
+
|
| 85 |
+
These are OFT adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
|
| 86 |
+
{img_str}
|
| 87 |
+
"""
|
| 88 |
+
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
| 89 |
+
f.write(yaml + model_card)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
| 93 |
+
text_encoder_config = PretrainedConfig.from_pretrained(
|
| 94 |
+
pretrained_model_name_or_path,
|
| 95 |
+
subfolder="text_encoder",
|
| 96 |
+
revision=revision,
|
| 97 |
+
)
|
| 98 |
+
model_class = text_encoder_config.architectures[0]
|
| 99 |
+
|
| 100 |
+
if model_class == "CLIPTextModel":
|
| 101 |
+
from transformers import CLIPTextModel
|
| 102 |
+
|
| 103 |
+
return CLIPTextModel
|
| 104 |
+
elif model_class == "RobertaSeriesModelWithTransformation":
|
| 105 |
+
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
| 106 |
+
|
| 107 |
+
return RobertaSeriesModelWithTransformation
|
| 108 |
+
else:
|
| 109 |
+
raise ValueError(f"{model_class} is not supported.")
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def parse_args(input_args=None):
|
| 113 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
"--pretrained_model_name_or_path",
|
| 116 |
+
type=str,
|
| 117 |
+
default='runwayml/stable-diffusion-v1-5',
|
| 118 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 119 |
+
)
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--revision",
|
| 122 |
+
type=str,
|
| 123 |
+
default=None,
|
| 124 |
+
help="Revision of pretrained model identifier from huggingface.co/models.",
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
"--tokenizer_name",
|
| 128 |
+
type=str,
|
| 129 |
+
default=None,
|
| 130 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument(
|
| 133 |
+
"--instance_data_dir",
|
| 134 |
+
type=str,
|
| 135 |
+
default='../data/dreambooth/backpack',
|
| 136 |
+
help="A folder containing the training data of instance images.",
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--class_data_dir",
|
| 140 |
+
type=str,
|
| 141 |
+
default='data/class_data/backpack',
|
| 142 |
+
help="A folder containing the training data of class images.",
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--instance_prompt",
|
| 146 |
+
type=str,
|
| 147 |
+
default='a photo of qwe backpack',
|
| 148 |
+
help="The prompt with identifier specifying the instance",
|
| 149 |
+
)
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--class_prompt",
|
| 152 |
+
type=str,
|
| 153 |
+
default='a photo of backpack',
|
| 154 |
+
help="The prompt to specify images in the same class as provided instance images.",
|
| 155 |
+
)
|
| 156 |
+
parser.add_argument(
|
| 157 |
+
"--validation_prompt",
|
| 158 |
+
type=str,
|
| 159 |
+
default='a qwe backpack in the jungle',
|
| 160 |
+
help="A prompt that is used during validation to verify that the model is learning.",
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--test_prompt",
|
| 164 |
+
type=str,
|
| 165 |
+
default=None,
|
| 166 |
+
help="A prompt that is used during validation to verify that the model is keeps class prior.",
|
| 167 |
+
)
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--num_validation_images",
|
| 170 |
+
type=int,
|
| 171 |
+
default=4,
|
| 172 |
+
help="Number of images that should be generated during validation with `validation_prompt`.",
|
| 173 |
+
)
|
| 174 |
+
parser.add_argument(
|
| 175 |
+
"--validation_epochs",
|
| 176 |
+
type=int,
|
| 177 |
+
default=1,
|
| 178 |
+
help=(
|
| 179 |
+
"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
|
| 180 |
+
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
| 181 |
+
),
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--with_prior_preservation",
|
| 185 |
+
default=True,
|
| 186 |
+
action="store_true",
|
| 187 |
+
help="Flag to add prior preservation loss.",
|
| 188 |
+
)
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--prior_loss_weight",
|
| 191 |
+
type=float,
|
| 192 |
+
default=1.0,
|
| 193 |
+
help="The weight of prior preservation loss."
|
| 194 |
+
)
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--num_class_images",
|
| 197 |
+
type=int,
|
| 198 |
+
default=200,
|
| 199 |
+
help=(
|
| 200 |
+
"Minimal class images for prior preservation loss. If there are not enough images already present in"
|
| 201 |
+
" class_data_dir, additional images will be sampled with class_prompt."
|
| 202 |
+
),
|
| 203 |
+
)
|
| 204 |
+
parser.add_argument(
|
| 205 |
+
"--output_dir",
|
| 206 |
+
type=str,
|
| 207 |
+
default="log_hra/backpack-0",
|
| 208 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--seed",
|
| 212 |
+
type=int,
|
| 213 |
+
default=0,
|
| 214 |
+
help="A seed for reproducible training."
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--resolution",
|
| 218 |
+
type=int,
|
| 219 |
+
default=512,
|
| 220 |
+
help=(
|
| 221 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 222 |
+
" resolution"
|
| 223 |
+
),
|
| 224 |
+
)
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--center_crop",
|
| 227 |
+
default=False,
|
| 228 |
+
action="store_true",
|
| 229 |
+
help=(
|
| 230 |
+
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
| 231 |
+
" cropped. The images will be resized to the resolution first before cropping."
|
| 232 |
+
),
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument(
|
| 235 |
+
"--train_batch_size",
|
| 236 |
+
type=int,
|
| 237 |
+
default=1,
|
| 238 |
+
help="Batch size (per device) for the training dataloader."
|
| 239 |
+
)
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--sample_batch_size",
|
| 242 |
+
type=int,
|
| 243 |
+
default=4,
|
| 244 |
+
help="Batch size (per device) for sampling images.",
|
| 245 |
+
)
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
"--num_train_epochs",
|
| 248 |
+
type=int,
|
| 249 |
+
default=1,
|
| 250 |
+
)
|
| 251 |
+
parser.add_argument(
|
| 252 |
+
"--max_train_steps",
|
| 253 |
+
type=int,
|
| 254 |
+
default=2005,
|
| 255 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 256 |
+
)
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
"--checkpointing_steps",
|
| 259 |
+
type=int,
|
| 260 |
+
default=5000,
|
| 261 |
+
help=(
|
| 262 |
+
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
| 263 |
+
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
| 264 |
+
" training using `--resume_from_checkpoint`."
|
| 265 |
+
),
|
| 266 |
+
)
|
| 267 |
+
parser.add_argument(
|
| 268 |
+
"--checkpoints_total_limit",
|
| 269 |
+
type=int,
|
| 270 |
+
default=None,
|
| 271 |
+
help=(
|
| 272 |
+
"Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
|
| 273 |
+
" See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
|
| 274 |
+
" for more docs"
|
| 275 |
+
),
|
| 276 |
+
)
|
| 277 |
+
parser.add_argument(
|
| 278 |
+
"--resume_from_checkpoint",
|
| 279 |
+
type=str,
|
| 280 |
+
default=None,
|
| 281 |
+
help=(
|
| 282 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 283 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 284 |
+
),
|
| 285 |
+
)
|
| 286 |
+
parser.add_argument(
|
| 287 |
+
"--gradient_accumulation_steps",
|
| 288 |
+
type=int,
|
| 289 |
+
default=1,
|
| 290 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 291 |
+
)
|
| 292 |
+
parser.add_argument(
|
| 293 |
+
"--gradient_checkpointing",
|
| 294 |
+
action="store_true",
|
| 295 |
+
default=False,
|
| 296 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 297 |
+
)
|
| 298 |
+
parser.add_argument(
|
| 299 |
+
"--learning_rate",
|
| 300 |
+
type=float,
|
| 301 |
+
default=6e-05,
|
| 302 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 303 |
+
)
|
| 304 |
+
parser.add_argument(
|
| 305 |
+
"--scale_lr",
|
| 306 |
+
action="store_true",
|
| 307 |
+
default=False,
|
| 308 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 309 |
+
)
|
| 310 |
+
parser.add_argument(
|
| 311 |
+
"--lr_scheduler",
|
| 312 |
+
type=str,
|
| 313 |
+
default="constant",
|
| 314 |
+
help=(
|
| 315 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 316 |
+
' "constant", "constant_with_warmup"]'
|
| 317 |
+
),
|
| 318 |
+
)
|
| 319 |
+
parser.add_argument(
|
| 320 |
+
"--lr_warmup_steps",
|
| 321 |
+
type=int,
|
| 322 |
+
default=0,
|
| 323 |
+
help="Number of steps for the warmup in the lr scheduler."
|
| 324 |
+
)
|
| 325 |
+
parser.add_argument(
|
| 326 |
+
"--lr_num_cycles",
|
| 327 |
+
type=int,
|
| 328 |
+
default=1,
|
| 329 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
| 330 |
+
)
|
| 331 |
+
parser.add_argument(
|
| 332 |
+
"--lr_power",
|
| 333 |
+
type=float,
|
| 334 |
+
default=1.0,
|
| 335 |
+
help="Power factor of the polynomial scheduler.",
|
| 336 |
+
)
|
| 337 |
+
parser.add_argument(
|
| 338 |
+
"--dataloader_num_workers",
|
| 339 |
+
type=int,
|
| 340 |
+
default=0,
|
| 341 |
+
help=(
|
| 342 |
+
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
| 343 |
+
),
|
| 344 |
+
)
|
| 345 |
+
parser.add_argument(
|
| 346 |
+
"--use_8bit_adam",
|
| 347 |
+
action="store_true",
|
| 348 |
+
help="Whether or not to use 8-bit Adam from bitsandbytes."
|
| 349 |
+
)
|
| 350 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
| 351 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
| 352 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
| 353 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
| 354 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
| 355 |
+
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
| 356 |
+
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
| 357 |
+
parser.add_argument(
|
| 358 |
+
"--hub_model_id",
|
| 359 |
+
type=str,
|
| 360 |
+
default=None,
|
| 361 |
+
help="The name of the repository to keep in sync with the local `output_dir`.",
|
| 362 |
+
)
|
| 363 |
+
parser.add_argument(
|
| 364 |
+
"--logging_dir",
|
| 365 |
+
type=str,
|
| 366 |
+
default="logs",
|
| 367 |
+
help=(
|
| 368 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 369 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 370 |
+
),
|
| 371 |
+
)
|
| 372 |
+
parser.add_argument(
|
| 373 |
+
"--allow_tf32",
|
| 374 |
+
action="store_true",
|
| 375 |
+
help=(
|
| 376 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 377 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 378 |
+
),
|
| 379 |
+
)
|
| 380 |
+
parser.add_argument(
|
| 381 |
+
"--report_to",
|
| 382 |
+
type=str,
|
| 383 |
+
default="wandb",
|
| 384 |
+
help=(
|
| 385 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 386 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 387 |
+
),
|
| 388 |
+
)
|
| 389 |
+
parser.add_argument(
|
| 390 |
+
"--mixed_precision",
|
| 391 |
+
type=str,
|
| 392 |
+
default=None,
|
| 393 |
+
choices=["no", "fp16", "bf16"],
|
| 394 |
+
help=(
|
| 395 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 396 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 397 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 398 |
+
),
|
| 399 |
+
)
|
| 400 |
+
parser.add_argument(
|
| 401 |
+
"--prior_generation_precision",
|
| 402 |
+
type=str,
|
| 403 |
+
default=None,
|
| 404 |
+
choices=["no", "fp32", "fp16", "bf16"],
|
| 405 |
+
help=(
|
| 406 |
+
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 407 |
+
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
|
| 408 |
+
),
|
| 409 |
+
)
|
| 410 |
+
parser.add_argument("--local_rank", type=int, default=6, help="For distributed training: local_rank")
|
| 411 |
+
parser.add_argument(
|
| 412 |
+
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
| 413 |
+
)
|
| 414 |
+
parser.add_argument(
|
| 415 |
+
"--name",
|
| 416 |
+
type=str,
|
| 417 |
+
default='backpack-0',
|
| 418 |
+
help=(
|
| 419 |
+
"The name of the current experiment run, consists of [data]-[prompt]"
|
| 420 |
+
),
|
| 421 |
+
)
|
| 422 |
+
parser.add_argument(
|
| 423 |
+
"--hra_r",
|
| 424 |
+
type=int,
|
| 425 |
+
default=8,
|
| 426 |
+
help=(
|
| 427 |
+
"The rank of HRA across different layers. It is best to set 'r' to an even number; otherwise, the default initialization method will not work."
|
| 428 |
+
),
|
| 429 |
+
)
|
| 430 |
+
parser.add_argument(
|
| 431 |
+
"--hra_apply_GS",
|
| 432 |
+
action='store_true',
|
| 433 |
+
default=False,
|
| 434 |
+
help=(
|
| 435 |
+
"Whether to apply Gram-Schmidt orthogonalization."
|
| 436 |
+
),
|
| 437 |
+
)
|
| 438 |
+
if input_args is not None:
|
| 439 |
+
args = parser.parse_args(input_args)
|
| 440 |
+
else:
|
| 441 |
+
args = parser.parse_args()
|
| 442 |
+
|
| 443 |
+
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
| 444 |
+
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
| 445 |
+
args.local_rank = env_local_rank
|
| 446 |
+
|
| 447 |
+
if args.with_prior_preservation:
|
| 448 |
+
if args.class_data_dir is None:
|
| 449 |
+
raise ValueError("You must specify a data directory for class images.")
|
| 450 |
+
if args.class_prompt is None:
|
| 451 |
+
raise ValueError("You must specify prompt for class images.")
|
| 452 |
+
else:
|
| 453 |
+
# logger is not available yet
|
| 454 |
+
if args.class_data_dir is not None:
|
| 455 |
+
warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
|
| 456 |
+
if args.class_prompt is not None:
|
| 457 |
+
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
|
| 458 |
+
|
| 459 |
+
return args
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class DreamBoothDataset(Dataset):
|
| 463 |
+
"""
|
| 464 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
| 465 |
+
It pre-processes the images and the tokenizes prompts.
|
| 466 |
+
"""
|
| 467 |
+
|
| 468 |
+
def __init__(
|
| 469 |
+
self,
|
| 470 |
+
instance_data_root,
|
| 471 |
+
instance_prompt,
|
| 472 |
+
tokenizer,
|
| 473 |
+
class_data_root=None,
|
| 474 |
+
class_prompt=None,
|
| 475 |
+
class_num=None,
|
| 476 |
+
size=512,
|
| 477 |
+
center_crop=False,
|
| 478 |
+
):
|
| 479 |
+
self.size = size
|
| 480 |
+
self.center_crop = center_crop
|
| 481 |
+
self.tokenizer = tokenizer
|
| 482 |
+
|
| 483 |
+
self.instance_data_root = Path(instance_data_root)
|
| 484 |
+
if not self.instance_data_root.exists():
|
| 485 |
+
raise ValueError("Instance images root doesn't exists.")
|
| 486 |
+
|
| 487 |
+
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
| 488 |
+
self.num_instance_images = len(self.instance_images_path)
|
| 489 |
+
self.instance_prompt = instance_prompt
|
| 490 |
+
self._length = self.num_instance_images
|
| 491 |
+
|
| 492 |
+
if class_data_root is not None:
|
| 493 |
+
self.class_data_root = Path(class_data_root)
|
| 494 |
+
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
| 495 |
+
self.class_images_path = list(self.class_data_root.iterdir())
|
| 496 |
+
if class_num is not None:
|
| 497 |
+
self.num_class_images = min(len(self.class_images_path), class_num)
|
| 498 |
+
else:
|
| 499 |
+
self.num_class_images = len(self.class_images_path)
|
| 500 |
+
self._length = max(self.num_class_images, self.num_instance_images)
|
| 501 |
+
self.class_prompt = class_prompt
|
| 502 |
+
else:
|
| 503 |
+
self.class_data_root = None
|
| 504 |
+
|
| 505 |
+
self.image_transforms = transforms.Compose(
|
| 506 |
+
[
|
| 507 |
+
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
| 508 |
+
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
| 509 |
+
transforms.ToTensor(),
|
| 510 |
+
transforms.Normalize([0.5], [0.5]),
|
| 511 |
+
]
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
def __len__(self):
|
| 515 |
+
return self._length
|
| 516 |
+
|
| 517 |
+
def __getitem__(self, index):
|
| 518 |
+
example = {}
|
| 519 |
+
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
|
| 520 |
+
if not instance_image.mode == "RGB":
|
| 521 |
+
instance_image = instance_image.convert("RGB")
|
| 522 |
+
example["instance_images"] = self.image_transforms(instance_image)
|
| 523 |
+
example["instance_prompt_ids"] = self.tokenizer(
|
| 524 |
+
self.instance_prompt,
|
| 525 |
+
truncation=True,
|
| 526 |
+
padding="max_length",
|
| 527 |
+
max_length=self.tokenizer.model_max_length,
|
| 528 |
+
return_tensors="pt",
|
| 529 |
+
).input_ids
|
| 530 |
+
|
| 531 |
+
if self.class_data_root:
|
| 532 |
+
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
| 533 |
+
if not class_image.mode == "RGB":
|
| 534 |
+
class_image = class_image.convert("RGB")
|
| 535 |
+
example["class_images"] = self.image_transforms(class_image)
|
| 536 |
+
example["class_prompt_ids"] = self.tokenizer(
|
| 537 |
+
self.class_prompt,
|
| 538 |
+
truncation=True,
|
| 539 |
+
padding="max_length",
|
| 540 |
+
max_length=self.tokenizer.model_max_length,
|
| 541 |
+
return_tensors="pt",
|
| 542 |
+
).input_ids
|
| 543 |
+
|
| 544 |
+
return example
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def collate_fn(examples, with_prior_preservation=False):
|
| 548 |
+
input_ids = [example["instance_prompt_ids"] for example in examples]
|
| 549 |
+
pixel_values = [example["instance_images"] for example in examples]
|
| 550 |
+
|
| 551 |
+
# Concat class and instance examples for prior preservation.
|
| 552 |
+
# We do this to avoid doing two forward passes.
|
| 553 |
+
if with_prior_preservation:
|
| 554 |
+
input_ids += [example["class_prompt_ids"] for example in examples]
|
| 555 |
+
pixel_values += [example["class_images"] for example in examples]
|
| 556 |
+
|
| 557 |
+
pixel_values = torch.stack(pixel_values)
|
| 558 |
+
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
| 559 |
+
|
| 560 |
+
input_ids = torch.cat(input_ids, dim=0)
|
| 561 |
+
|
| 562 |
+
batch = {
|
| 563 |
+
"input_ids": input_ids,
|
| 564 |
+
"pixel_values": pixel_values,
|
| 565 |
+
}
|
| 566 |
+
return batch
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
class PromptDataset(Dataset):
|
| 570 |
+
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
| 571 |
+
|
| 572 |
+
def __init__(self, prompt, num_samples):
|
| 573 |
+
self.prompt = prompt
|
| 574 |
+
self.num_samples = num_samples
|
| 575 |
+
|
| 576 |
+
def __len__(self):
|
| 577 |
+
return self.num_samples
|
| 578 |
+
|
| 579 |
+
def __getitem__(self, index):
|
| 580 |
+
example = {}
|
| 581 |
+
example["prompt"] = self.prompt
|
| 582 |
+
example["index"] = index
|
| 583 |
+
return example
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def main(args):
|
| 587 |
+
logging_dir = Path(args.output_dir, args.logging_dir)
|
| 588 |
+
|
| 589 |
+
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) # total_limit=args.checkpoints_total_limit)
|
| 590 |
+
|
| 591 |
+
wandb_init = {
|
| 592 |
+
"wandb": {
|
| 593 |
+
"name": args.name,
|
| 594 |
+
# "project": args.project,
|
| 595 |
+
}
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
accelerator = Accelerator(
|
| 599 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
| 600 |
+
mixed_precision=args.mixed_precision,
|
| 601 |
+
log_with=args.report_to,
|
| 602 |
+
project_config=accelerator_project_config,
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
if args.report_to == "wandb":
|
| 606 |
+
if not is_wandb_available():
|
| 607 |
+
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
| 608 |
+
import wandb
|
| 609 |
+
|
| 610 |
+
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
| 611 |
+
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
| 612 |
+
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
| 613 |
+
# Make one log on every process with the configuration for debugging.
|
| 614 |
+
logging.basicConfig(
|
| 615 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 616 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
| 617 |
+
level=logging.INFO,
|
| 618 |
+
)
|
| 619 |
+
logger.info(accelerator.state, main_process_only=False)
|
| 620 |
+
if accelerator.is_local_main_process:
|
| 621 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 622 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 623 |
+
else:
|
| 624 |
+
transformers.utils.logging.set_verbosity_error()
|
| 625 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 626 |
+
|
| 627 |
+
# If passed along, set the training seed now.
|
| 628 |
+
if args.seed is not None:
|
| 629 |
+
set_seed(args.seed)
|
| 630 |
+
|
| 631 |
+
# Generate class images if prior preservation is enabled.
|
| 632 |
+
if args.with_prior_preservation:
|
| 633 |
+
class_images_dir = Path(args.class_data_dir)
|
| 634 |
+
if not class_images_dir.exists():
|
| 635 |
+
class_images_dir.mkdir(parents=True)
|
| 636 |
+
cur_class_images = len(list(class_images_dir.iterdir()))
|
| 637 |
+
|
| 638 |
+
if cur_class_images < args.num_class_images:
|
| 639 |
+
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
| 640 |
+
if args.prior_generation_precision == "fp32":
|
| 641 |
+
torch_dtype = torch.float32
|
| 642 |
+
elif args.prior_generation_precision == "fp16":
|
| 643 |
+
torch_dtype = torch.float16
|
| 644 |
+
elif args.prior_generation_precision == "bf16":
|
| 645 |
+
torch_dtype = torch.bfloat16
|
| 646 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
| 647 |
+
args.pretrained_model_name_or_path,
|
| 648 |
+
torch_dtype=torch_dtype,
|
| 649 |
+
safety_checker=None,
|
| 650 |
+
revision=args.revision,
|
| 651 |
+
)
|
| 652 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 653 |
+
|
| 654 |
+
num_new_images = args.num_class_images - cur_class_images
|
| 655 |
+
logger.info(f"Number of class images to sample: {num_new_images}.")
|
| 656 |
+
|
| 657 |
+
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
| 658 |
+
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
| 659 |
+
|
| 660 |
+
sample_dataloader = accelerator.prepare(sample_dataloader)
|
| 661 |
+
pipeline.to(accelerator.device)
|
| 662 |
+
|
| 663 |
+
for example in tqdm(
|
| 664 |
+
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
|
| 665 |
+
):
|
| 666 |
+
images = pipeline(example["prompt"]).images
|
| 667 |
+
|
| 668 |
+
for i, image in enumerate(images):
|
| 669 |
+
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
| 670 |
+
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
| 671 |
+
image.save(image_filename)
|
| 672 |
+
|
| 673 |
+
del pipeline
|
| 674 |
+
if torch.cuda.is_available():
|
| 675 |
+
torch.cuda.empty_cache()
|
| 676 |
+
|
| 677 |
+
# Handle the repository creation
|
| 678 |
+
if accelerator.is_main_process:
|
| 679 |
+
if args.output_dir is not None:
|
| 680 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 681 |
+
|
| 682 |
+
if args.push_to_hub:
|
| 683 |
+
repo_id = create_repo(
|
| 684 |
+
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
| 685 |
+
).repo_id
|
| 686 |
+
|
| 687 |
+
# Load the tokenizer
|
| 688 |
+
if args.tokenizer_name:
|
| 689 |
+
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
|
| 690 |
+
elif args.pretrained_model_name_or_path:
|
| 691 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 692 |
+
args.pretrained_model_name_or_path,
|
| 693 |
+
subfolder="tokenizer",
|
| 694 |
+
revision=args.revision,
|
| 695 |
+
use_fast=False,
|
| 696 |
+
)
|
| 697 |
+
|
| 698 |
+
# import correct text encoder class
|
| 699 |
+
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
| 700 |
+
|
| 701 |
+
# Load scheduler and models
|
| 702 |
+
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
| 703 |
+
text_encoder = text_encoder_cls.from_pretrained(
|
| 704 |
+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
| 705 |
+
)
|
| 706 |
+
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
| 707 |
+
unet = UNet2DConditionModel.from_pretrained(
|
| 708 |
+
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
# We only train the additional adapter OFT layers
|
| 712 |
+
vae.requires_grad_(False)
|
| 713 |
+
text_encoder.requires_grad_(False)
|
| 714 |
+
unet.requires_grad_(False)
|
| 715 |
+
|
| 716 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
| 717 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
| 718 |
+
weight_dtype = torch.float32
|
| 719 |
+
if accelerator.mixed_precision == "fp16":
|
| 720 |
+
weight_dtype = torch.float16
|
| 721 |
+
elif accelerator.mixed_precision == "bf16":
|
| 722 |
+
weight_dtype = torch.bfloat16
|
| 723 |
+
|
| 724 |
+
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
| 725 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
| 726 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 727 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
| 728 |
+
|
| 729 |
+
if args.enable_xformers_memory_efficient_attention:
|
| 730 |
+
if is_xformers_available():
|
| 731 |
+
import xformers
|
| 732 |
+
|
| 733 |
+
xformers_version = version.parse(xformers.__version__)
|
| 734 |
+
if xformers_version == version.parse("0.0.16"):
|
| 735 |
+
logger.warn(
|
| 736 |
+
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
| 737 |
+
)
|
| 738 |
+
unet.enable_xformers_memory_efficient_attention()
|
| 739 |
+
else:
|
| 740 |
+
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
| 741 |
+
|
| 742 |
+
# now we will add new COT weights to the attention layers
|
| 743 |
+
# It's important to realize here how many attention weights will be added and of which sizes
|
| 744 |
+
# The sizes of the attention layers consist only of two different variables:
|
| 745 |
+
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
|
| 746 |
+
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
|
| 747 |
+
|
| 748 |
+
# Let's first see how many attention processors we will have to set.
|
| 749 |
+
# For Stable Diffusion, it should be equal to:
|
| 750 |
+
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
|
| 751 |
+
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
|
| 752 |
+
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
|
| 753 |
+
# => 32 layers
|
| 754 |
+
|
| 755 |
+
# Set correct oft layers
|
| 756 |
+
oft_attn_procs = {}
|
| 757 |
+
for name in unet.attn_processors.keys():
|
| 758 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
| 759 |
+
if name.startswith("mid_block"):
|
| 760 |
+
hidden_size = unet.config.block_out_channels[-1]
|
| 761 |
+
elif name.startswith("up_blocks"):
|
| 762 |
+
block_id = int(name[len("up_blocks.")])
|
| 763 |
+
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
| 764 |
+
elif name.startswith("down_blocks"):
|
| 765 |
+
block_id = int(name[len("down_blocks.")])
|
| 766 |
+
hidden_size = unet.config.block_out_channels[block_id]
|
| 767 |
+
|
| 768 |
+
oft_attn_procs[name] = HRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, r=args.hra_r, apply_GS=args.hra_apply_GS)
|
| 769 |
+
|
| 770 |
+
unet.set_attn_processor(oft_attn_procs)
|
| 771 |
+
print(f'Total parameters requiring grad: {sum([p.numel() for p in unet.parameters() if p.requires_grad == True])}')
|
| 772 |
+
|
| 773 |
+
oft_layers = AttnProcsLayers(unet.attn_processors)
|
| 774 |
+
|
| 775 |
+
accelerator.register_for_checkpointing(oft_layers)
|
| 776 |
+
|
| 777 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 778 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 779 |
+
if args.allow_tf32:
|
| 780 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 781 |
+
|
| 782 |
+
if args.scale_lr:
|
| 783 |
+
args.learning_rate = (
|
| 784 |
+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
| 785 |
+
)
|
| 786 |
+
|
| 787 |
+
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
| 788 |
+
if args.use_8bit_adam:
|
| 789 |
+
try:
|
| 790 |
+
import bitsandbytes as bnb
|
| 791 |
+
except ImportError:
|
| 792 |
+
raise ImportError(
|
| 793 |
+
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
optimizer_class = bnb.optim.AdamW8bit
|
| 797 |
+
else:
|
| 798 |
+
optimizer_class = torch.optim.AdamW
|
| 799 |
+
|
| 800 |
+
# Optimizer creation
|
| 801 |
+
optimizer = optimizer_class(
|
| 802 |
+
oft_layers.parameters(),
|
| 803 |
+
lr=args.learning_rate,
|
| 804 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 805 |
+
weight_decay=args.adam_weight_decay,
|
| 806 |
+
eps=args.adam_epsilon,
|
| 807 |
+
)
|
| 808 |
+
|
| 809 |
+
# Dataset and DataLoaders creation:
|
| 810 |
+
train_dataset = DreamBoothDataset(
|
| 811 |
+
instance_data_root=args.instance_data_dir,
|
| 812 |
+
instance_prompt=args.instance_prompt,
|
| 813 |
+
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
| 814 |
+
class_prompt=args.class_prompt,
|
| 815 |
+
class_num=args.num_class_images,
|
| 816 |
+
tokenizer=tokenizer,
|
| 817 |
+
size=args.resolution,
|
| 818 |
+
center_crop=args.center_crop,
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 822 |
+
train_dataset,
|
| 823 |
+
batch_size=args.train_batch_size,
|
| 824 |
+
shuffle=True,
|
| 825 |
+
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
|
| 826 |
+
num_workers=args.dataloader_num_workers,
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
# Scheduler and math around the number of training steps.
|
| 830 |
+
overrode_max_train_steps = False
|
| 831 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 832 |
+
if args.max_train_steps is None:
|
| 833 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 834 |
+
overrode_max_train_steps = True
|
| 835 |
+
|
| 836 |
+
lr_scheduler = get_scheduler(
|
| 837 |
+
args.lr_scheduler,
|
| 838 |
+
optimizer=optimizer,
|
| 839 |
+
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
| 840 |
+
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
| 841 |
+
num_cycles=args.lr_num_cycles,
|
| 842 |
+
power=args.lr_power,
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
# Prepare everything with our `accelerator`.
|
| 846 |
+
oft_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
| 847 |
+
oft_layers, optimizer, train_dataloader, lr_scheduler
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 851 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 852 |
+
if overrode_max_train_steps:
|
| 853 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 854 |
+
# Afterwards we recalculate our number of training epochs
|
| 855 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 856 |
+
|
| 857 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
| 858 |
+
# The trackers initializes automatically on the main process.
|
| 859 |
+
if accelerator.is_main_process:
|
| 860 |
+
accelerator.init_trackers("dreambooth-oft", config=vars(args), init_kwargs=wandb_init)
|
| 861 |
+
|
| 862 |
+
# Train!
|
| 863 |
+
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
| 864 |
+
|
| 865 |
+
logger.info("***** Running training *****")
|
| 866 |
+
logger.info(f" Num examples = {len(train_dataset)}")
|
| 867 |
+
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
| 868 |
+
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
| 869 |
+
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
| 870 |
+
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
| 871 |
+
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 872 |
+
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
| 873 |
+
global_step = 0
|
| 874 |
+
first_epoch = 0
|
| 875 |
+
|
| 876 |
+
# Potentially load in the weights and states from a previous save
|
| 877 |
+
if args.resume_from_checkpoint:
|
| 878 |
+
if args.resume_from_checkpoint != "latest":
|
| 879 |
+
path = os.path.basename(args.resume_from_checkpoint)
|
| 880 |
+
else:
|
| 881 |
+
# Get the mos recent checkpoint
|
| 882 |
+
dirs = os.listdir(args.output_dir)
|
| 883 |
+
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
| 884 |
+
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
| 885 |
+
path = dirs[-1] if len(dirs) > 0 else None
|
| 886 |
+
|
| 887 |
+
if path is None:
|
| 888 |
+
accelerator.print(
|
| 889 |
+
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
| 890 |
+
)
|
| 891 |
+
args.resume_from_checkpoint = None
|
| 892 |
+
else:
|
| 893 |
+
accelerator.print(f"Resuming from checkpoint {path}")
|
| 894 |
+
accelerator.load_state(os.path.join(args.output_dir, path))
|
| 895 |
+
global_step = int(path.split("-")[1])
|
| 896 |
+
|
| 897 |
+
resume_global_step = global_step * args.gradient_accumulation_steps
|
| 898 |
+
first_epoch = global_step // num_update_steps_per_epoch
|
| 899 |
+
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
| 900 |
+
|
| 901 |
+
# Only show the progress bar once on each machine.
|
| 902 |
+
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
| 903 |
+
progress_bar.set_description("Steps")
|
| 904 |
+
|
| 905 |
+
# calculate the hyperspherical energy fine-tuning
|
| 906 |
+
# mhe = MHE(unet, eps=args.eps, r=args.r)
|
| 907 |
+
# mhe_loss = mhe.calculate_mhe()
|
| 908 |
+
# accelerator.log({"mhe_loss": mhe_loss}, step=0)
|
| 909 |
+
accelerator.log({"hra_r": args.hra_r}, step=0)
|
| 910 |
+
accelerator.log({"hra_apply_GS": args.hra_apply_GS}, step=0)
|
| 911 |
+
# accelerator.log({"COFT": 1 if args.coft else 0}, step=0)
|
| 912 |
+
|
| 913 |
+
for epoch in range(first_epoch, args.num_train_epochs):
|
| 914 |
+
unet.train()
|
| 915 |
+
for step, batch in enumerate(train_dataloader):
|
| 916 |
+
# Skip steps until we reach the resumed step
|
| 917 |
+
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
| 918 |
+
if step % args.gradient_accumulation_steps == 0:
|
| 919 |
+
progress_bar.update(1)
|
| 920 |
+
continue
|
| 921 |
+
|
| 922 |
+
with accelerator.accumulate(unet):
|
| 923 |
+
# Convert images to latent space
|
| 924 |
+
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
| 925 |
+
latents = latents * vae.config.scaling_factor
|
| 926 |
+
|
| 927 |
+
# Sample noise that we'll add to the latents
|
| 928 |
+
noise = torch.randn_like(latents)
|
| 929 |
+
bsz = latents.shape[0]
|
| 930 |
+
# Sample a random timestep for each image
|
| 931 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
| 932 |
+
timesteps = timesteps.long()
|
| 933 |
+
|
| 934 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
| 935 |
+
# (this is the forward diffusion process)
|
| 936 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 937 |
+
|
| 938 |
+
# Get the text embedding for conditioning
|
| 939 |
+
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
| 940 |
+
|
| 941 |
+
# Predict the noise residual
|
| 942 |
+
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 943 |
+
|
| 944 |
+
# Get the target for loss depending on the prediction type
|
| 945 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 946 |
+
target = noise
|
| 947 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 948 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 949 |
+
else:
|
| 950 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| 951 |
+
|
| 952 |
+
if args.with_prior_preservation:
|
| 953 |
+
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
| 954 |
+
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
| 955 |
+
target, target_prior = torch.chunk(target, 2, dim=0)
|
| 956 |
+
|
| 957 |
+
# Compute instance loss
|
| 958 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
| 959 |
+
|
| 960 |
+
# Compute prior loss
|
| 961 |
+
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
| 962 |
+
|
| 963 |
+
# Add the prior loss to the instance loss.
|
| 964 |
+
loss = loss + args.prior_loss_weight * prior_loss
|
| 965 |
+
else:
|
| 966 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
| 967 |
+
|
| 968 |
+
# --------------------------------------------------------
|
| 969 |
+
# orthogonality regularizer
|
| 970 |
+
# for name, param in unet.named_parameters():
|
| 971 |
+
# if 'hra_u' in name:
|
| 972 |
+
# device = param.device
|
| 973 |
+
# hra_u_norm = param / (param.norm(dim=0))
|
| 974 |
+
# orth_loss = torch.norm(torch.eye(8, device=device) - hra_u_norm.t() @ hra_u_norm)
|
| 975 |
+
# loss = loss + 1e-5 * orth_loss
|
| 976 |
+
# --------------------------------------------------------
|
| 977 |
+
|
| 978 |
+
accelerator.backward(loss)
|
| 979 |
+
if accelerator.sync_gradients:
|
| 980 |
+
params_to_clip = oft_layers.parameters()
|
| 981 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 982 |
+
optimizer.step()
|
| 983 |
+
lr_scheduler.step()
|
| 984 |
+
optimizer.zero_grad()
|
| 985 |
+
|
| 986 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 987 |
+
if accelerator.sync_gradients:
|
| 988 |
+
progress_bar.update(1)
|
| 989 |
+
global_step += 1
|
| 990 |
+
|
| 991 |
+
if global_step % args.checkpointing_steps == 0:
|
| 992 |
+
if accelerator.is_main_process:
|
| 993 |
+
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
| 994 |
+
accelerator.save_state(save_path)
|
| 995 |
+
logger.info(f"Saved state to {save_path}")
|
| 996 |
+
|
| 997 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 998 |
+
progress_bar.set_postfix(**logs)
|
| 999 |
+
accelerator.log(logs, step=global_step)
|
| 1000 |
+
|
| 1001 |
+
if global_step >= args.max_train_steps:
|
| 1002 |
+
break
|
| 1003 |
+
|
| 1004 |
+
if accelerator.is_main_process:
|
| 1005 |
+
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # and epoch > 1:
|
| 1006 |
+
logger.info(
|
| 1007 |
+
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
| 1008 |
+
f" {args.validation_prompt}."
|
| 1009 |
+
)
|
| 1010 |
+
|
| 1011 |
+
# mhe = MHE(unet, eps=args.eps, r=args.r)
|
| 1012 |
+
# mhe_loss = mhe.calculate_mhe()
|
| 1013 |
+
# accelerator.log({"mhe_loss": mhe_loss}, step=global_step)
|
| 1014 |
+
|
| 1015 |
+
# create pipeline
|
| 1016 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
| 1017 |
+
args.pretrained_model_name_or_path,
|
| 1018 |
+
unet=accelerator.unwrap_model(unet),
|
| 1019 |
+
text_encoder=accelerator.unwrap_model(text_encoder),
|
| 1020 |
+
revision=args.revision,
|
| 1021 |
+
torch_dtype=weight_dtype,
|
| 1022 |
+
)
|
| 1023 |
+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
| 1024 |
+
pipeline = pipeline.to(accelerator.device)
|
| 1025 |
+
pipeline.set_progress_bar_config(disable=True)
|
| 1026 |
+
|
| 1027 |
+
# run inference
|
| 1028 |
+
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
| 1029 |
+
images = [
|
| 1030 |
+
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
| 1031 |
+
for _ in range(args.num_validation_images)
|
| 1032 |
+
]
|
| 1033 |
+
|
| 1034 |
+
for tracker in accelerator.trackers:
|
| 1035 |
+
if tracker.name == "tensorboard":
|
| 1036 |
+
np_images = np.stack([np.asarray(img) for img in images])
|
| 1037 |
+
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
| 1038 |
+
if tracker.name == "wandb":
|
| 1039 |
+
tracker.log(
|
| 1040 |
+
{
|
| 1041 |
+
"validation": [
|
| 1042 |
+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
| 1043 |
+
for i, image in enumerate(images)
|
| 1044 |
+
]
|
| 1045 |
+
}
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
# Create the output directory if it doesn't exist
|
| 1049 |
+
tmp_dir = os.path.join(args.output_dir, str(epoch))
|
| 1050 |
+
if not os.path.exists(tmp_dir):
|
| 1051 |
+
os.makedirs(tmp_dir)
|
| 1052 |
+
|
| 1053 |
+
for i, image in enumerate(images):
|
| 1054 |
+
np_image = np.array(image)
|
| 1055 |
+
pil_image = Image.fromarray(np_image)
|
| 1056 |
+
pil_image.save(os.path.join(args.output_dir, str(epoch), f"image_{i}.png"))
|
| 1057 |
+
|
| 1058 |
+
del pipeline
|
| 1059 |
+
torch.cuda.empty_cache()
|
| 1060 |
+
|
| 1061 |
+
|
| 1062 |
+
# Save the oft layers
|
| 1063 |
+
accelerator.wait_for_everyone()
|
| 1064 |
+
# if accelerator.is_main_process:
|
| 1065 |
+
# unet = unet.to(torch.float32)
|
| 1066 |
+
# unet.save_attn_procs(args.output_dir)
|
| 1067 |
+
|
| 1068 |
+
# # Final inference
|
| 1069 |
+
# # Load previous pipeline
|
| 1070 |
+
# pipeline = DiffusionPipeline.from_pretrained(
|
| 1071 |
+
# args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
|
| 1072 |
+
# )
|
| 1073 |
+
# pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
| 1074 |
+
# pipeline = pipeline.to(accelerator.device)
|
| 1075 |
+
|
| 1076 |
+
# # load attention processors
|
| 1077 |
+
# pipeline.unet.load_attn_procs(args.output_dir)
|
| 1078 |
+
|
| 1079 |
+
# # run inference
|
| 1080 |
+
# if args.validation_prompt and args.num_validation_images > 0:
|
| 1081 |
+
# generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
| 1082 |
+
# images = [
|
| 1083 |
+
# pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
|
| 1084 |
+
# for _ in range(args.num_validation_images)
|
| 1085 |
+
# ]
|
| 1086 |
+
|
| 1087 |
+
# for tracker in accelerator.trackers:
|
| 1088 |
+
# if tracker.name == "tensorboard":
|
| 1089 |
+
# np_images = np.stack([np.asarray(img) for img in images])
|
| 1090 |
+
# tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
|
| 1091 |
+
# if tracker.name == "wandb":
|
| 1092 |
+
# tracker.log(
|
| 1093 |
+
# {
|
| 1094 |
+
# "test": [
|
| 1095 |
+
# wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
| 1096 |
+
# for i, image in enumerate(images)
|
| 1097 |
+
# ]
|
| 1098 |
+
# }
|
| 1099 |
+
# )
|
| 1100 |
+
|
| 1101 |
+
# if args.push_to_hub:
|
| 1102 |
+
# save_model_card(
|
| 1103 |
+
# repo_id,
|
| 1104 |
+
# images=images,
|
| 1105 |
+
# base_model=args.pretrained_model_name_or_path,
|
| 1106 |
+
# prompt=args.instance_prompt,
|
| 1107 |
+
# repo_folder=args.output_dir,
|
| 1108 |
+
# )
|
| 1109 |
+
# upload_folder(
|
| 1110 |
+
# repo_id=repo_id,
|
| 1111 |
+
# folder_path=args.output_dir,
|
| 1112 |
+
# commit_message="End of training",
|
| 1113 |
+
# ignore_patterns=["step_*", "epoch_*"],
|
| 1114 |
+
# )
|
| 1115 |
+
|
| 1116 |
+
accelerator.end_training()
|
| 1117 |
+
|
| 1118 |
+
|
| 1119 |
+
if __name__ == "__main__":
|
| 1120 |
+
args = parse_args()
|
| 1121 |
+
for arg in vars(args):
|
| 1122 |
+
print(f'{arg}: {getattr(args, arg)}')
|
| 1123 |
+
main(args)
|
generation/subject/train_dreambooth_hra.sh
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
prompt_idx=$1
|
| 3 |
+
class_idx=$2
|
| 4 |
+
lr=1e-4
|
| 5 |
+
hra_r=8
|
| 6 |
+
|
| 7 |
+
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
| 8 |
+
|
| 9 |
+
# Define the unique_token, class_tokens, and subject_names
|
| 10 |
+
unique_token="qwe"
|
| 11 |
+
subject_names=(
|
| 12 |
+
"backpack" "backpack_dog" "bear_plushie" "berry_bowl" "can"
|
| 13 |
+
"candle" "cat" "cat2" "clock" "colorful_sneaker"
|
| 14 |
+
"dog" "dog2" "dog3" "dog5" "dog6"
|
| 15 |
+
"dog7" "dog8" "duck_toy" "fancy_boot" "grey_sloth_plushie"
|
| 16 |
+
"monster_toy" "pink_sunglasses" "poop_emoji" "rc_car" "red_cartoon"
|
| 17 |
+
"robot_toy" "shiny_sneaker" "teapot" "vase" "wolf_plushie"
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
class_tokens=(
|
| 21 |
+
"backpack" "backpack" "stuffed animal" "bowl" "can"
|
| 22 |
+
"candle" "cat" "cat" "clock" "sneaker"
|
| 23 |
+
"dog" "dog" "dog" "dog" "dog"
|
| 24 |
+
"dog" "dog" "toy" "boot" "stuffed animal"
|
| 25 |
+
"toy" "glasses" "toy" "toy" "cartoon"
|
| 26 |
+
"toy" "sneaker" "teapot" "vase" "stuffed animal"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
echo "prompt_idx: $prompt_idx, class_idx: $class_idx"
|
| 30 |
+
|
| 31 |
+
class_token=${class_tokens[$class_idx]}
|
| 32 |
+
selected_subject=${subject_names[$class_idx]}
|
| 33 |
+
|
| 34 |
+
if [[ $class_idx =~ ^(0|1|2|3|4|5|8|9|17|18|19|20|21|22|23|24|25|26|27|28|29)$ ]]; then
|
| 35 |
+
prompt_list=(
|
| 36 |
+
"a ${unique_token} ${class_token} in the jungle"
|
| 37 |
+
"a ${unique_token} ${class_token} in the snow"
|
| 38 |
+
"a ${unique_token} ${class_token} on the beach"
|
| 39 |
+
"a ${unique_token} ${class_token} on a cobblestone street"
|
| 40 |
+
"a ${unique_token} ${class_token} on top of pink fabric"
|
| 41 |
+
"a ${unique_token} ${class_token} on top of a wooden floor"
|
| 42 |
+
"a ${unique_token} ${class_token} with a city in the background"
|
| 43 |
+
"a ${unique_token} ${class_token} with a mountain in the background"
|
| 44 |
+
"a ${unique_token} ${class_token} with a blue house in the background"
|
| 45 |
+
"a ${unique_token} ${class_token} on top of a purple rug in a forest"
|
| 46 |
+
"a ${unique_token} ${class_token} with a wheat field in the background"
|
| 47 |
+
"a ${unique_token} ${class_token} with a tree and autumn leaves in the background"
|
| 48 |
+
"a ${unique_token} ${class_token} with the Eiffel Tower in the background"
|
| 49 |
+
"a ${unique_token} ${class_token} floating on top of water"
|
| 50 |
+
"a ${unique_token} ${class_token} floating in an ocean of milk"
|
| 51 |
+
"a ${unique_token} ${class_token} on top of green grass with sunflowers around it"
|
| 52 |
+
"a ${unique_token} ${class_token} on top of a mirror"
|
| 53 |
+
"a ${unique_token} ${class_token} on top of the sidewalk in a crowded street"
|
| 54 |
+
"a ${unique_token} ${class_token} on top of a dirt road"
|
| 55 |
+
"a ${unique_token} ${class_token} on top of a white rug"
|
| 56 |
+
"a red ${unique_token} ${class_token}"
|
| 57 |
+
"a purple ${unique_token} ${class_token}"
|
| 58 |
+
"a shiny ${unique_token} ${class_token}"
|
| 59 |
+
"a wet ${unique_token} ${class_token}"
|
| 60 |
+
"a cube shaped ${unique_token} ${class_token}"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
prompt_test_list=(
|
| 64 |
+
"a ${class_token} in the jungle"
|
| 65 |
+
"a ${class_token} in the snow"
|
| 66 |
+
"a ${class_token} on the beach"
|
| 67 |
+
"a ${class_token} on a cobblestone street"
|
| 68 |
+
"a ${class_token} on top of pink fabric"
|
| 69 |
+
"a ${class_token} on top of a wooden floor"
|
| 70 |
+
"a ${class_token} with a city in the background"
|
| 71 |
+
"a ${class_token} with a mountain in the background"
|
| 72 |
+
"a ${class_token} with a blue house in the background"
|
| 73 |
+
"a ${class_token} on top of a purple rug in a forest"
|
| 74 |
+
"a ${class_token} with a wheat field in the background"
|
| 75 |
+
"a ${class_token} with a tree and autumn leaves in the background"
|
| 76 |
+
"a ${class_token} with the Eiffel Tower in the background"
|
| 77 |
+
"a ${class_token} floating on top of water"
|
| 78 |
+
"a ${class_token} floating in an ocean of milk"
|
| 79 |
+
"a ${class_token} on top of green grass with sunflowers around it"
|
| 80 |
+
"a ${class_token} on top of a mirror"
|
| 81 |
+
"a ${class_token} on top of the sidewalk in a crowded street"
|
| 82 |
+
"a ${class_token} on top of a dirt road"
|
| 83 |
+
"a ${class_token} on top of a white rug"
|
| 84 |
+
"a red ${class_token}"
|
| 85 |
+
"a purple ${class_token}"
|
| 86 |
+
"a shiny ${class_token}"
|
| 87 |
+
"a wet ${class_token}"
|
| 88 |
+
"a cube shaped ${class_token}"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
else
|
| 92 |
+
prompt_list=(
|
| 93 |
+
"a ${unique_token} ${class_token} in the jungle"
|
| 94 |
+
"a ${unique_token} ${class_token} in the snow"
|
| 95 |
+
"a ${unique_token} ${class_token} on the beach"
|
| 96 |
+
"a ${unique_token} ${class_token} on a cobblestone street"
|
| 97 |
+
"a ${unique_token} ${class_token} on top of pink fabric"
|
| 98 |
+
"a ${unique_token} ${class_token} on top of a wooden floor"
|
| 99 |
+
"a ${unique_token} ${class_token} with a city in the background"
|
| 100 |
+
"a ${unique_token} ${class_token} with a mountain in the background"
|
| 101 |
+
"a ${unique_token} ${class_token} with a blue house in the background"
|
| 102 |
+
"a ${unique_token} ${class_token} on top of a purple rug in a forest"
|
| 103 |
+
"a ${unique_token} ${class_token} wearing a red hat"
|
| 104 |
+
"a ${unique_token} ${class_token} wearing a santa hat"
|
| 105 |
+
"a ${unique_token} ${class_token} wearing a rainbow scarf"
|
| 106 |
+
"a ${unique_token} ${class_token} wearing a black top hat and a monocle"
|
| 107 |
+
"a ${unique_token} ${class_token} in a chef outfit"
|
| 108 |
+
"a ${unique_token} ${class_token} in a firefighter outfit"
|
| 109 |
+
"a ${unique_token} ${class_token} in a police outfit"
|
| 110 |
+
"a ${unique_token} ${class_token} wearing pink glasses"
|
| 111 |
+
"a ${unique_token} ${class_token} wearing a yellow shirt"
|
| 112 |
+
"a ${unique_token} ${class_token} in a purple wizard outfit"
|
| 113 |
+
"a red ${unique_token} ${class_token}"
|
| 114 |
+
"a purple ${unique_token} ${class_token}"
|
| 115 |
+
"a shiny ${unique_token} ${class_token}"
|
| 116 |
+
"a wet ${unique_token} ${class_token}"
|
| 117 |
+
"a cube shaped ${unique_token} ${class_token}"
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
prompt_test_list=(
|
| 121 |
+
"a ${class_token} in the jungle"
|
| 122 |
+
"a ${class_token} in the snow"
|
| 123 |
+
"a ${class_token} on the beach"
|
| 124 |
+
"a ${class_token} on a cobblestone street"
|
| 125 |
+
"a ${class_token} on top of pink fabric"
|
| 126 |
+
"a ${class_token} on top of a wooden floor"
|
| 127 |
+
"a ${class_token} with a city in the background"
|
| 128 |
+
"a ${class_token} with a mountain in the background"
|
| 129 |
+
"a ${class_token} with a blue house in the background"
|
| 130 |
+
"a ${class_token} on top of a purple rug in a forest"
|
| 131 |
+
"a ${class_token} wearing a red hat"
|
| 132 |
+
"a ${class_token} wearing a santa hat"
|
| 133 |
+
"a ${class_token} wearing a rainbow scarf"
|
| 134 |
+
"a ${class_token} wearing a black top hat and a monocle"
|
| 135 |
+
"a ${class_token} in a chef outfit"
|
| 136 |
+
"a ${class_token} in a firefighter outfit"
|
| 137 |
+
"a ${class_token} in a police outfit"
|
| 138 |
+
"a ${class_token} wearing pink glasses"
|
| 139 |
+
"a ${class_token} wearing a yellow shirt"
|
| 140 |
+
"a ${class_token} in a purple wizard outfit"
|
| 141 |
+
"a red ${class_token}"
|
| 142 |
+
"a purple ${class_token}"
|
| 143 |
+
"a shiny ${class_token}"
|
| 144 |
+
"a wet ${class_token}"
|
| 145 |
+
"a cube shaped ${class_token}"
|
| 146 |
+
)
|
| 147 |
+
fi
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
validation_prompt=${prompt_list[$prompt_idx]}
|
| 151 |
+
test_prompt=${prompt_test_list[$prompt_idx]}
|
| 152 |
+
name="${selected_subject}-${prompt_idx}"
|
| 153 |
+
instance_prompt="a photo of ${unique_token} ${class_token}"
|
| 154 |
+
class_prompt="a photo of ${class_token}"
|
| 155 |
+
|
| 156 |
+
export OUTPUT_DIR="log_hra/lr_${lr}_r_${hra_r}/${name}"
|
| 157 |
+
export INSTANCE_DIR="dreambooth/dataset/${selected_subject}"
|
| 158 |
+
export CLASS_DIR="class_data/${class_token}"
|
| 159 |
+
|
| 160 |
+
if [ -d "$OUTPUT_DIR" ]; then
|
| 161 |
+
echo "该目录已存在:$OUTPUT_DIR"
|
| 162 |
+
fi
|
| 163 |
+
|
| 164 |
+
accelerate launch train_dreambooth_hra.py \
|
| 165 |
+
--pretrained_model_name_or_path=$MODEL_NAME \
|
| 166 |
+
--instance_data_dir=$INSTANCE_DIR \
|
| 167 |
+
--class_data_dir="$CLASS_DIR" \
|
| 168 |
+
--output_dir="$OUTPUT_DIR" \
|
| 169 |
+
--instance_prompt="$instance_prompt" \
|
| 170 |
+
--with_prior_preservation --prior_loss_weight=1.0 \
|
| 171 |
+
--class_prompt="$class_prompt" \
|
| 172 |
+
--resolution=512 \
|
| 173 |
+
--train_batch_size=1 \
|
| 174 |
+
--gradient_accumulation_steps=1 \
|
| 175 |
+
--checkpointing_steps=5000 \
|
| 176 |
+
--learning_rate=$lr \
|
| 177 |
+
--report_to="wandb" \
|
| 178 |
+
--lr_scheduler="constant" \
|
| 179 |
+
--lr_warmup_steps=0 \
|
| 180 |
+
--max_train_steps=2005 \
|
| 181 |
+
--validation_prompt="$validation_prompt" \
|
| 182 |
+
--validation_epochs=1 \
|
| 183 |
+
--seed="0" \
|
| 184 |
+
--name="$name" \
|
| 185 |
+
--num_class_images=200 \
|
| 186 |
+
--hra_r=$hra_r
|
llama/data/MATH_test.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
llama/data/gsm8k_test.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
llama/data/oft/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .config import OFTConfig
|
| 16 |
+
from .layer import Conv2d, Linear, OFTLayer
|
| 17 |
+
from .model import OFTModel
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
__all__ = ["OFTConfig", "OFTModel", "Conv2d", "Linear", "OFTLayer"]
|
llama/data/oft/config.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from dataclasses import dataclass, field
|
| 16 |
+
from typing import List, Optional, Union
|
| 17 |
+
|
| 18 |
+
from peft.tuners.lycoris_utils import LycorisConfig
|
| 19 |
+
from peft.utils import PeftType
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class OFTConfig(LycorisConfig):
|
| 24 |
+
"""
|
| 25 |
+
This is the configuration class to store the configuration of a [`OFTModel`].
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
r (`int`): OFT rank.
|
| 29 |
+
module_dropout (`int`): The dropout probability for disabling OFT modules during training.
|
| 30 |
+
target_modules (`Optional[Union[List[str], str]]`):
|
| 31 |
+
The names of the modules to apply the adapter to. If this is specified, only the modules with the specified
|
| 32 |
+
names will be replaced. When passing a string, a regex match will be performed. When passing a list of
|
| 33 |
+
strings, either an exact match will be performed or it is checked if the name of the module ends with any
|
| 34 |
+
of the passed strings. If this is specified as 'all-linear', then all linear modules are chosen, excluding
|
| 35 |
+
the output layer. If this is not specified, modules will be chosen according to the model architecture. If
|
| 36 |
+
the architecture is not known, an error will be raised -- in this case, you should specify the target
|
| 37 |
+
modules manually.
|
| 38 |
+
init_weights (`bool`):
|
| 39 |
+
Whether to perform initialization of OFT weights.
|
| 40 |
+
layers_to_transform (`Union[List[int], int]`):
|
| 41 |
+
The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices
|
| 42 |
+
that are specified in this list. If a single integer is passed, it will apply the transformations on the
|
| 43 |
+
layer at this index.
|
| 44 |
+
layers_pattern (`str`):
|
| 45 |
+
The layer pattern name, used only if `layers_to_transform` is different from `None`.
|
| 46 |
+
rank_pattern (`dict`):
|
| 47 |
+
The mapping from layer names or regexp expression to ranks which are different from the default rank
|
| 48 |
+
specified by `r`.
|
| 49 |
+
modules_to_save (`List[str]`):
|
| 50 |
+
List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
|
| 51 |
+
coft (`bool`):
|
| 52 |
+
Whether to use the constrained variant of OFT or not, off by default.
|
| 53 |
+
eps (`float`):
|
| 54 |
+
The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True.
|
| 55 |
+
block_share (`bool`):
|
| 56 |
+
Whether to share the OFT parameters between blocks or not. This is `False` by default.
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
r: int = field(default=8, metadata={"help": "OFT rank"})
|
| 60 |
+
module_dropout: float = field(
|
| 61 |
+
default=0.0, metadata={"help": "The dropout probability for disabling OFT modules during training"}
|
| 62 |
+
)
|
| 63 |
+
target_modules: Optional[Union[List[str], str]] = field(
|
| 64 |
+
default=None,
|
| 65 |
+
metadata={
|
| 66 |
+
"help": "List of module names or regex expression of the module names to replace with OFT."
|
| 67 |
+
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
|
| 68 |
+
"This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer."
|
| 69 |
+
},
|
| 70 |
+
)
|
| 71 |
+
init_weights: bool = field(
|
| 72 |
+
default=True,
|
| 73 |
+
metadata={
|
| 74 |
+
"help": (
|
| 75 |
+
"Whether to initialize the weights of the OFT layers with their default initialization. Don't change "
|
| 76 |
+
"this setting, except if you know exactly what you're doing."
|
| 77 |
+
),
|
| 78 |
+
},
|
| 79 |
+
)
|
| 80 |
+
layers_to_transform: Optional[Union[List[int], int]] = field(
|
| 81 |
+
default=None,
|
| 82 |
+
metadata={
|
| 83 |
+
"help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index."
|
| 84 |
+
},
|
| 85 |
+
)
|
| 86 |
+
layers_pattern: Optional[str] = field(
|
| 87 |
+
default=None,
|
| 88 |
+
metadata={
|
| 89 |
+
"help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern."
|
| 90 |
+
},
|
| 91 |
+
)
|
| 92 |
+
modules_to_save: Optional[List[str]] = field(
|
| 93 |
+
default=None,
|
| 94 |
+
metadata={
|
| 95 |
+
"help": "List of modules apart from OFT layers to be set as trainable and saved in the final checkpoint. "
|
| 96 |
+
"For example, in Sequence Classification or Token Classification tasks, "
|
| 97 |
+
"the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
|
| 98 |
+
},
|
| 99 |
+
)
|
| 100 |
+
coft: bool = field(
|
| 101 |
+
default=False,
|
| 102 |
+
metadata={"help": "Whether to use the constrained variant of OFT or not."},
|
| 103 |
+
)
|
| 104 |
+
eps: float = field(
|
| 105 |
+
default=6e-5,
|
| 106 |
+
metadata={
|
| 107 |
+
"help": "The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True."
|
| 108 |
+
},
|
| 109 |
+
)
|
| 110 |
+
block_share: bool = field(
|
| 111 |
+
default=False,
|
| 112 |
+
metadata={"help": "Whether to share the OFT parameters between blocks or not."},
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def __post_init__(self):
|
| 116 |
+
self.peft_type = PeftType.OFT
|
| 117 |
+
self.target_modules = (
|
| 118 |
+
set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
|
| 119 |
+
)
|
llama/data/oft/layer.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
import warnings
|
| 17 |
+
from typing import Any, List, Optional, Set, Tuple
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
from peft.tuners.lycoris_utils import LycorisLayer, check_adapters_to_merge
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class OFTLayer(nn.Module, LycorisLayer):
|
| 26 |
+
# All names of layers that may contain adapter weights
|
| 27 |
+
adapter_layer_names = ("oft_r",)
|
| 28 |
+
# other_param_names is defined on parent class
|
| 29 |
+
|
| 30 |
+
def __init__(self, base_layer: nn.Module):
|
| 31 |
+
super().__init__()
|
| 32 |
+
LycorisLayer.__init__(self, base_layer)
|
| 33 |
+
|
| 34 |
+
# OFT info
|
| 35 |
+
self.oft_r = nn.ParameterDict({})
|
| 36 |
+
self.coft = {}
|
| 37 |
+
self.eps = {}
|
| 38 |
+
self.block_share = {}
|
| 39 |
+
|
| 40 |
+
@property
|
| 41 |
+
def _available_adapters(self) -> Set[str]:
|
| 42 |
+
return {*self.oft_r}
|
| 43 |
+
|
| 44 |
+
def create_adapter_parameters(self, adapter_name: str, r: int, shape: Tuple[int, ...], block_share: bool):
|
| 45 |
+
if block_share:
|
| 46 |
+
self.oft_r[adapter_name] = nn.Parameter(torch.empty(1, math.ceil(shape[0] / r), math.ceil(shape[0] / r)))
|
| 47 |
+
else:
|
| 48 |
+
self.oft_r[adapter_name] = nn.Parameter(torch.empty(r, math.ceil(shape[0] / r), math.ceil(shape[0] / r)))
|
| 49 |
+
|
| 50 |
+
def reset_adapter_parameters(self, adapter_name: str):
|
| 51 |
+
nn.init.zeros_(self.oft_r[adapter_name])
|
| 52 |
+
|
| 53 |
+
def reset_adapter_parameters_random(self, adapter_name: str):
|
| 54 |
+
nn.init.kaiming_uniform_(self.oft_r[adapter_name], a=math.sqrt(5))
|
| 55 |
+
|
| 56 |
+
def update_layer(
|
| 57 |
+
self,
|
| 58 |
+
adapter_name: str,
|
| 59 |
+
r: int,
|
| 60 |
+
module_dropout: float,
|
| 61 |
+
init_weights: bool,
|
| 62 |
+
coft: bool = False,
|
| 63 |
+
eps: float = 6e-5,
|
| 64 |
+
block_share: bool = False,
|
| 65 |
+
**kwargs,
|
| 66 |
+
) -> None:
|
| 67 |
+
"""Internal function to create oft adapter
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
adapter_name (`str`): Name for the adapter to add.
|
| 71 |
+
r (`int`): Rank for the added adapter.
|
| 72 |
+
module_dropout (`float`): The dropout probability for disabling adapter during training.
|
| 73 |
+
init_weights (`bool`): Whether to initialize weights.
|
| 74 |
+
coft (`bool`): Whether to use the constrained variant of OFT or not.
|
| 75 |
+
eps (`float`):
|
| 76 |
+
The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True.
|
| 77 |
+
block_share (`bool`): Whether to share the OFT parameters between blocks or not.
|
| 78 |
+
"""
|
| 79 |
+
if r <= 0:
|
| 80 |
+
raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
|
| 81 |
+
|
| 82 |
+
self.r[adapter_name] = r
|
| 83 |
+
self.module_dropout[adapter_name] = module_dropout
|
| 84 |
+
self.coft[adapter_name] = coft
|
| 85 |
+
self.block_share[adapter_name] = block_share
|
| 86 |
+
|
| 87 |
+
# Determine shape of OFT weights
|
| 88 |
+
base_layer = self.get_base_layer()
|
| 89 |
+
if isinstance(base_layer, nn.Linear):
|
| 90 |
+
shape = tuple(base_layer.weight.shape)
|
| 91 |
+
elif isinstance(base_layer, nn.Conv2d):
|
| 92 |
+
shape = (
|
| 93 |
+
base_layer.out_channels,
|
| 94 |
+
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
|
| 95 |
+
)
|
| 96 |
+
else:
|
| 97 |
+
raise TypeError(f"OFT is not implemented for base layers of type {type(base_layer).__name__}")
|
| 98 |
+
|
| 99 |
+
self.eps[adapter_name] = eps * math.ceil(shape[0] / r) * math.ceil(shape[0] / r)
|
| 100 |
+
|
| 101 |
+
# Create weights with provided shape
|
| 102 |
+
self.create_adapter_parameters(adapter_name, r, shape, block_share)
|
| 103 |
+
|
| 104 |
+
# Initialize weights
|
| 105 |
+
if init_weights:
|
| 106 |
+
self.reset_adapter_parameters(adapter_name)
|
| 107 |
+
else:
|
| 108 |
+
self.reset_adapter_parameters_random(adapter_name)
|
| 109 |
+
|
| 110 |
+
# Move new weights to device
|
| 111 |
+
weight = getattr(self.get_base_layer(), "weight", None)
|
| 112 |
+
if weight is not None:
|
| 113 |
+
# the layer is already completely initialized, this is an update
|
| 114 |
+
if weight.dtype.is_floating_point or weight.dtype.is_complex:
|
| 115 |
+
self.to(weight.device, dtype=weight.dtype)
|
| 116 |
+
else:
|
| 117 |
+
self.to(weight.device)
|
| 118 |
+
self.set_adapter(self.active_adapters)
|
| 119 |
+
|
| 120 |
+
def unscale_layer(self, scale=None) -> None:
|
| 121 |
+
# scale is not used
|
| 122 |
+
pass
|
| 123 |
+
|
| 124 |
+
def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
|
| 125 |
+
"""
|
| 126 |
+
Merge the active adapter weights into the base weights
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
safe_merge (`bool`, *optional*):
|
| 130 |
+
If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs
|
| 131 |
+
before merging the weights. This is useful if you want to check if the merge operation will produce
|
| 132 |
+
NaNs. Defaults to `False`.
|
| 133 |
+
adapter_names (`List[str]`, *optional*):
|
| 134 |
+
The list of adapter names that should be merged. If `None`, all active adapters will be merged.
|
| 135 |
+
Defaults to `None`.
|
| 136 |
+
"""
|
| 137 |
+
adapter_names = check_adapters_to_merge(self, adapter_names)
|
| 138 |
+
if not adapter_names:
|
| 139 |
+
# no adapter to merge
|
| 140 |
+
return
|
| 141 |
+
|
| 142 |
+
for active_adapter in adapter_names:
|
| 143 |
+
if active_adapter in self._available_adapters:
|
| 144 |
+
base_layer = self.get_base_layer()
|
| 145 |
+
|
| 146 |
+
orig_weights = base_layer.weight.data
|
| 147 |
+
if isinstance(base_layer, nn.Linear):
|
| 148 |
+
orig_weights = torch.transpose(orig_weights, 0, 1)
|
| 149 |
+
elif isinstance(base_layer, nn.Conv2d):
|
| 150 |
+
orig_weights = orig_weights.view(
|
| 151 |
+
[
|
| 152 |
+
base_layer.out_channels,
|
| 153 |
+
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
|
| 154 |
+
]
|
| 155 |
+
)
|
| 156 |
+
orig_weights = torch.transpose(orig_weights, 0, 1)
|
| 157 |
+
delta_weight = self.get_delta_weight(active_adapter)
|
| 158 |
+
if orig_weights.shape[1] != delta_weight.shape[1]:
|
| 159 |
+
# when in channels is not divisible by r
|
| 160 |
+
delta_weight = delta_weight[: orig_weights.shape[1], : orig_weights.shape[1]]
|
| 161 |
+
new_weights = torch.mm(orig_weights, delta_weight)
|
| 162 |
+
if isinstance(base_layer, nn.Linear):
|
| 163 |
+
new_weights = torch.transpose(new_weights, 0, 1)
|
| 164 |
+
elif isinstance(base_layer, nn.Conv2d):
|
| 165 |
+
new_weights = torch.transpose(new_weights, 0, 1)
|
| 166 |
+
new_weights = new_weights.view(
|
| 167 |
+
[
|
| 168 |
+
base_layer.out_channels,
|
| 169 |
+
base_layer.in_channels,
|
| 170 |
+
base_layer.kernel_size[0],
|
| 171 |
+
base_layer.kernel_size[1],
|
| 172 |
+
]
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
if safe_merge and not torch.isfinite(new_weights).all():
|
| 176 |
+
raise ValueError(
|
| 177 |
+
f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
base_layer.weight.data = new_weights
|
| 181 |
+
self.merged_adapters.append(active_adapter)
|
| 182 |
+
|
| 183 |
+
def unmerge(self) -> None:
|
| 184 |
+
"""
|
| 185 |
+
This method unmerges all merged adapter layers from the base weights.
|
| 186 |
+
"""
|
| 187 |
+
if not self.merged:
|
| 188 |
+
warnings.warn("Already unmerged. Nothing to do.")
|
| 189 |
+
return
|
| 190 |
+
while len(self.merged_adapters) > 0:
|
| 191 |
+
active_adapter = self.merged_adapters.pop()
|
| 192 |
+
if active_adapter in self._available_adapters:
|
| 193 |
+
base_layer = self.get_base_layer()
|
| 194 |
+
new_weights = base_layer.weight.data
|
| 195 |
+
if isinstance(base_layer, nn.Linear):
|
| 196 |
+
new_weights = torch.transpose(new_weights, 0, 1)
|
| 197 |
+
elif isinstance(base_layer, nn.Conv2d):
|
| 198 |
+
new_weights = new_weights.view(
|
| 199 |
+
[
|
| 200 |
+
base_layer.out_channels,
|
| 201 |
+
base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
|
| 202 |
+
]
|
| 203 |
+
)
|
| 204 |
+
new_weights = torch.transpose(new_weights, 0, 1)
|
| 205 |
+
delta_weight = self.get_delta_weight(active_adapter)
|
| 206 |
+
if new_weights.shape[1] != delta_weight.shape[1]:
|
| 207 |
+
# when in channels is not divisible by r
|
| 208 |
+
delta_weight = delta_weight[: new_weights.shape[1], : new_weights.shape[1]]
|
| 209 |
+
delta_inv = torch.inverse(delta_weight)
|
| 210 |
+
orig_weights = torch.mm(new_weights, delta_inv)
|
| 211 |
+
|
| 212 |
+
if isinstance(base_layer, nn.Linear):
|
| 213 |
+
orig_weights = torch.transpose(orig_weights, 0, 1)
|
| 214 |
+
elif isinstance(base_layer, nn.Conv2d):
|
| 215 |
+
orig_weights = torch.transpose(orig_weights, 0, 1)
|
| 216 |
+
orig_weights = orig_weights.reshape(
|
| 217 |
+
[
|
| 218 |
+
base_layer.out_channels,
|
| 219 |
+
base_layer.in_channels,
|
| 220 |
+
base_layer.kernel_size[0],
|
| 221 |
+
base_layer.kernel_size[1],
|
| 222 |
+
]
|
| 223 |
+
)
|
| 224 |
+
base_layer.weight.data = orig_weights
|
| 225 |
+
|
| 226 |
+
def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
|
| 227 |
+
rank = self.r[adapter_name]
|
| 228 |
+
coft = self.coft[adapter_name]
|
| 229 |
+
eps = self.eps[adapter_name]
|
| 230 |
+
opt_r = self.oft_r[adapter_name]
|
| 231 |
+
|
| 232 |
+
if coft:
|
| 233 |
+
with torch.no_grad():
|
| 234 |
+
opt_r.copy_(self._project_batch(opt_r, eps=eps))
|
| 235 |
+
|
| 236 |
+
orth_rotate = self._cayley_batch(opt_r)
|
| 237 |
+
weight = self._block_diagonal(orth_rotate, rank)
|
| 238 |
+
|
| 239 |
+
return weight
|
| 240 |
+
|
| 241 |
+
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L144
|
| 242 |
+
def _cayley_batch(self, data: torch.Tensor) -> torch.Tensor:
|
| 243 |
+
b, r, c = data.shape
|
| 244 |
+
# Ensure the input matrix is skew-symmetric
|
| 245 |
+
skew = 0.5 * (data - data.transpose(1, 2))
|
| 246 |
+
I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) # noqa: E741
|
| 247 |
+
|
| 248 |
+
# Perform the Cayley parametrization
|
| 249 |
+
Q = torch.bmm(I - skew, torch.inverse(I + skew))
|
| 250 |
+
|
| 251 |
+
return Q
|
| 252 |
+
|
| 253 |
+
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L155
|
| 254 |
+
def _block_diagonal(self, oft_r: torch.Tensor, rank: int) -> torch.Tensor:
|
| 255 |
+
if oft_r.shape[0] == 1:
|
| 256 |
+
# block share
|
| 257 |
+
blocks = [oft_r[0, ...] for i in range(rank)]
|
| 258 |
+
else:
|
| 259 |
+
blocks = [oft_r[i, ...] for i in range(rank)]
|
| 260 |
+
|
| 261 |
+
# Use torch.block_diag to create the block diagonal matrix
|
| 262 |
+
A = torch.block_diag(*blocks)
|
| 263 |
+
|
| 264 |
+
return A
|
| 265 |
+
|
| 266 |
+
# Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L52
|
| 267 |
+
def _project_batch(self, oft_r, eps=1e-5):
|
| 268 |
+
# scaling factor for each of the smaller block matrix
|
| 269 |
+
eps = eps * 1 / torch.sqrt(torch.tensor(oft_r.shape[0]))
|
| 270 |
+
I = ( # noqa: E741
|
| 271 |
+
torch.zeros((oft_r.size(1), oft_r.size(1)), device=oft_r.device, dtype=oft_r.dtype)
|
| 272 |
+
.unsqueeze(0)
|
| 273 |
+
.expand_as(oft_r)
|
| 274 |
+
)
|
| 275 |
+
diff = oft_r - I
|
| 276 |
+
norm_diff = torch.norm(oft_r - I, dim=(1, 2), keepdim=True)
|
| 277 |
+
mask = (norm_diff <= eps).bool()
|
| 278 |
+
out = torch.where(mask, oft_r, I + eps * (diff / norm_diff))
|
| 279 |
+
return out
|
| 280 |
+
|
| 281 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 282 |
+
previous_dtype = x.dtype
|
| 283 |
+
|
| 284 |
+
if self.disable_adapters:
|
| 285 |
+
if self.merged:
|
| 286 |
+
self.unmerge()
|
| 287 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 288 |
+
elif self.merged:
|
| 289 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 290 |
+
else:
|
| 291 |
+
result = self.base_layer(x, *args, **kwargs)
|
| 292 |
+
if len(result.shape) == 4:
|
| 293 |
+
result = result.permute(0, 2, 3, 1)
|
| 294 |
+
|
| 295 |
+
base_layer = self.get_base_layer()
|
| 296 |
+
base_bias = base_layer.bias
|
| 297 |
+
if base_bias is not None:
|
| 298 |
+
# Bias should be added after OFT forward
|
| 299 |
+
result = result - base_bias.data
|
| 300 |
+
|
| 301 |
+
# Execute all the adapters
|
| 302 |
+
for active_adapter in self.active_adapters:
|
| 303 |
+
if active_adapter not in self._available_adapters:
|
| 304 |
+
continue
|
| 305 |
+
|
| 306 |
+
module_dropout = self.module_dropout[active_adapter]
|
| 307 |
+
|
| 308 |
+
# Modify current execution weights
|
| 309 |
+
if (not self.training) or (self.training and torch.rand(1) > module_dropout):
|
| 310 |
+
result = self._get_delta_activations(active_adapter, result, *args, **kwargs)
|
| 311 |
+
|
| 312 |
+
if base_bias is not None:
|
| 313 |
+
result = result + base_bias.data
|
| 314 |
+
if len(result.shape) == 4:
|
| 315 |
+
result = result.permute(0, 3, 1, 2)
|
| 316 |
+
|
| 317 |
+
result = result.to(previous_dtype)
|
| 318 |
+
return result
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class Linear(OFTLayer):
|
| 322 |
+
"""OFT implemented in Linear layer"""
|
| 323 |
+
|
| 324 |
+
def __init__(
|
| 325 |
+
self,
|
| 326 |
+
base_layer: nn.Module,
|
| 327 |
+
adapter_name: str = "default",
|
| 328 |
+
r: int = 0,
|
| 329 |
+
module_dropout: float = 0.0,
|
| 330 |
+
init_weights: bool = True,
|
| 331 |
+
**kwargs,
|
| 332 |
+
):
|
| 333 |
+
super().__init__(base_layer)
|
| 334 |
+
|
| 335 |
+
# Create adapter and set it active
|
| 336 |
+
self._active_adapter = adapter_name
|
| 337 |
+
self.update_layer(adapter_name, r, module_dropout, init_weights, **kwargs)
|
| 338 |
+
|
| 339 |
+
def _get_delta_activations(
|
| 340 |
+
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
|
| 341 |
+
) -> torch.Tensor:
|
| 342 |
+
delta_weight = self.get_delta_weight(adapter_name)
|
| 343 |
+
|
| 344 |
+
base_layer = self.get_base_layer()
|
| 345 |
+
base_weight = base_layer.weight.data
|
| 346 |
+
delta_weight = delta_weight[: base_weight.shape[0], : base_weight.shape[0]]
|
| 347 |
+
|
| 348 |
+
# don't add bias here, because the bias will be added after OFT forward
|
| 349 |
+
return torch.matmul(input, delta_weight)
|
| 350 |
+
|
| 351 |
+
def __repr__(self) -> str:
|
| 352 |
+
rep = super().__repr__()
|
| 353 |
+
return "oft." + rep
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class Conv2d(OFTLayer):
|
| 357 |
+
"""OFT implemented in Conv2d layer"""
|
| 358 |
+
|
| 359 |
+
def __init__(
|
| 360 |
+
self,
|
| 361 |
+
base_layer: nn.Module,
|
| 362 |
+
adapter_name: str = "default",
|
| 363 |
+
r: int = 0,
|
| 364 |
+
module_dropout: float = 0.0,
|
| 365 |
+
init_weights: bool = True,
|
| 366 |
+
**kwargs,
|
| 367 |
+
):
|
| 368 |
+
super().__init__(base_layer)
|
| 369 |
+
|
| 370 |
+
# Create adapter and set it active
|
| 371 |
+
self._active_adapter = adapter_name
|
| 372 |
+
self.update_layer(adapter_name, r, module_dropout, init_weights, **kwargs)
|
| 373 |
+
|
| 374 |
+
def _get_delta_activations(
|
| 375 |
+
self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
|
| 376 |
+
) -> torch.Tensor:
|
| 377 |
+
delta_weight = self.get_delta_weight(adapter_name)
|
| 378 |
+
|
| 379 |
+
base_layer = self.get_base_layer()
|
| 380 |
+
base_weight = base_layer.weight.data
|
| 381 |
+
delta_weight = delta_weight[: base_weight.shape[0], : base_weight.shape[0]]
|
| 382 |
+
|
| 383 |
+
# don't add bias here, because the bias will be added after OFT forward
|
| 384 |
+
return torch.matmul(input, delta_weight)
|
| 385 |
+
|
| 386 |
+
def __repr__(self) -> str:
|
| 387 |
+
rep = super().__repr__()
|
| 388 |
+
return "oft." + rep
|
llama/data/oft/model.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023-present the HuggingFace Inc. team.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import re
|
| 16 |
+
from typing import Dict, Type, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn
|
| 20 |
+
|
| 21 |
+
from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner
|
| 22 |
+
|
| 23 |
+
from .layer import Conv2d, Linear, OFTLayer
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class OFTModel(LycorisTuner):
|
| 27 |
+
"""
|
| 28 |
+
Creates Orthogonal Finetuning model from a pretrained model. The method is described in
|
| 29 |
+
https://arxiv.org/abs/2306.07280
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached.
|
| 33 |
+
config ([`OFTConfig`]): The configuration of the OFT model.
|
| 34 |
+
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
`torch.nn.Module`: The OFT model.
|
| 38 |
+
|
| 39 |
+
Example:
|
| 40 |
+
```py
|
| 41 |
+
>>> from diffusers import StableDiffusionPipeline
|
| 42 |
+
>>> from peft import OFTModel, OFTConfig
|
| 43 |
+
|
| 44 |
+
>>> config_te = OFTConfig(
|
| 45 |
+
... r=8,
|
| 46 |
+
... target_modules=["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"],
|
| 47 |
+
... module_dropout=0.0,
|
| 48 |
+
... init_weights=True,
|
| 49 |
+
... )
|
| 50 |
+
>>> config_unet = OFTConfig(
|
| 51 |
+
... r=8,
|
| 52 |
+
... target_modules=[
|
| 53 |
+
... "proj_in",
|
| 54 |
+
... "proj_out",
|
| 55 |
+
... "to_k",
|
| 56 |
+
... "to_q",
|
| 57 |
+
... "to_v",
|
| 58 |
+
... "to_out.0",
|
| 59 |
+
... "ff.net.0.proj",
|
| 60 |
+
... "ff.net.2",
|
| 61 |
+
... ],
|
| 62 |
+
... module_dropout=0.0,
|
| 63 |
+
... init_weights=True,
|
| 64 |
+
... )
|
| 65 |
+
|
| 66 |
+
>>> model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
| 67 |
+
>>> model.text_encoder = OFTModel(model.text_encoder, config_te, "default")
|
| 68 |
+
>>> model.unet = OFTModel(model.unet, config_unet, "default")
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
**Attributes**:
|
| 72 |
+
- **model** ([`~torch.nn.Module`]) -- The model to be adapted.
|
| 73 |
+
- **peft_config** ([`OFTConfig`]): The configuration of the OFT model.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
prefix: str = "oft_"
|
| 77 |
+
layers_mapping: Dict[Type[torch.nn.Module], Type[OFTLayer]] = {
|
| 78 |
+
torch.nn.Conv2d: Conv2d,
|
| 79 |
+
torch.nn.Linear: Linear,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
def _create_and_replace(
|
| 83 |
+
self,
|
| 84 |
+
config: LycorisConfig,
|
| 85 |
+
adapter_name: str,
|
| 86 |
+
target: Union[OFTLayer, nn.Module],
|
| 87 |
+
target_name: str,
|
| 88 |
+
parent: nn.Module,
|
| 89 |
+
current_key: str,
|
| 90 |
+
) -> None:
|
| 91 |
+
"""
|
| 92 |
+
A private method to create and replace the target module with the adapter module.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
# Regexp matching - Find key which matches current target_name in patterns provided
|
| 96 |
+
pattern_keys = list(config.rank_pattern.keys())
|
| 97 |
+
target_name_key = next(filter(lambda key: re.match(rf"(.*\.)?{key}$", current_key), pattern_keys), target_name)
|
| 98 |
+
|
| 99 |
+
kwargs = config.to_dict()
|
| 100 |
+
kwargs["r"] = config.rank_pattern.get(target_name_key, config.r)
|
| 101 |
+
|
| 102 |
+
if isinstance(target, OFTLayer):
|
| 103 |
+
target.update_layer(adapter_name, **kwargs)
|
| 104 |
+
else:
|
| 105 |
+
new_module = self._create_new_module(config, adapter_name, target, **kwargs)
|
| 106 |
+
self._replace_module(parent, target_name, new_module, target)
|
llama/finetune_32.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from dataclasses import field, dataclass
|
| 3 |
+
from typing import Sequence, Literal, List, Dict, Tuple
|
| 4 |
+
|
| 5 |
+
import transformers
|
| 6 |
+
from transformers import Trainer
|
| 7 |
+
from transformers.modeling_utils import *
|
| 8 |
+
from transformers.trainer import _is_peft_model
|
| 9 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 10 |
+
from transformers.data.data_collator import DataCollator
|
| 11 |
+
|
| 12 |
+
from transformers.training_args import TrainingArguments
|
| 13 |
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
| 14 |
+
from transformers.trainer_callback import TrainerCallback
|
| 15 |
+
from transformers.trainer_utils import EvalPrediction
|
| 16 |
+
from torch.utils.data import Dataset, IterableDataset
|
| 17 |
+
|
| 18 |
+
from datasets import load_dataset
|
| 19 |
+
from peft import LoraConfig, get_peft_model, PeftModel, OFTConfig
|
| 20 |
+
from datetime import datetime
|
| 21 |
+
|
| 22 |
+
IGNORE_INDEX = -100
|
| 23 |
+
DEFAULT_PAD_TOKEN = "[PAD]"
|
| 24 |
+
DEFAULT_EOS_TOKEN = "</s>"
|
| 25 |
+
DEFAULT_BOS_TOKEN = "</s>"
|
| 26 |
+
DEFAULT_UNK_TOKEN = "</s>"
|
| 27 |
+
PROMPT = (
|
| 28 |
+
"Below is an instruction that describes a task. "
|
| 29 |
+
"Write a response that appropriately completes the request.\n\n"
|
| 30 |
+
"### Instruction:\n{instruction}\n\n### Response:"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class MyTrainer(Trainer):
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
model: Union[PreTrainedModel, nn.Module] = None,
|
| 39 |
+
args: TrainingArguments = None,
|
| 40 |
+
data_collator: Optional[DataCollator] = None,
|
| 41 |
+
train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
|
| 42 |
+
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
|
| 43 |
+
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
| 44 |
+
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
| 45 |
+
compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
|
| 46 |
+
callbacks: Optional[List[TrainerCallback]] = None,
|
| 47 |
+
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
| 48 |
+
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
| 49 |
+
lamda: float = 1e-4
|
| 50 |
+
):
|
| 51 |
+
print('optimizers', optimizers)
|
| 52 |
+
if optimizers == None:
|
| 53 |
+
optimizers = (None, None)
|
| 54 |
+
super().__init__(model=model, args=args, data_collator=data_collator, train_dataset=train_dataset,
|
| 55 |
+
eval_dataset=eval_dataset, processing_class=tokenizer, model_init=model_init,
|
| 56 |
+
compute_metrics=compute_metrics, callbacks=callbacks,
|
| 57 |
+
optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics)
|
| 58 |
+
self.lamda = lamda
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
self.oft_params: List[torch.nn.Parameter] = [
|
| 62 |
+
p for n, p in self.model.named_parameters() if "oft_r" in n
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
def compute_loss(self, model, inputs, return_outputs=False,
|
| 66 |
+
num_items_in_batch: Optional[torch.Tensor] = None,):
|
| 67 |
+
"""
|
| 68 |
+
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
| 69 |
+
|
| 70 |
+
Subclass and override for custom behavior.
|
| 71 |
+
"""
|
| 72 |
+
# if self.label_smoother is not None and "labels" in inputs:
|
| 73 |
+
# labels = inputs.pop("labels")
|
| 74 |
+
# else:
|
| 75 |
+
# labels = None
|
| 76 |
+
# outputs = model(**inputs)
|
| 77 |
+
|
| 78 |
+
if self.label_smoother is not None and "labels" in inputs:
|
| 79 |
+
labels = inputs.pop("labels")
|
| 80 |
+
else:
|
| 81 |
+
labels = None
|
| 82 |
+
if self.model_accepts_loss_kwargs:
|
| 83 |
+
kwargs = {}
|
| 84 |
+
if num_items_in_batch is not None:
|
| 85 |
+
kwargs["num_items_in_batch"] = num_items_in_batch
|
| 86 |
+
inputs = {**inputs, **kwargs}
|
| 87 |
+
outputs = model(**inputs)
|
| 88 |
+
#
|
| 89 |
+
# Save past state if it exists
|
| 90 |
+
# TODO: this needs to be fixed and made cleaner later.
|
| 91 |
+
if self.args.past_index >= 0:
|
| 92 |
+
self._past = outputs[self.args.past_index]
|
| 93 |
+
|
| 94 |
+
if labels is not None:
|
| 95 |
+
unwrapped_model = unwrap_model(model)
|
| 96 |
+
if _is_peft_model(unwrapped_model):
|
| 97 |
+
model_name = unwrapped_model.base_model.model._get_name()
|
| 98 |
+
else:
|
| 99 |
+
model_name = unwrapped_model._get_name()
|
| 100 |
+
if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
| 101 |
+
loss = self.label_smoother(outputs, labels, shift_labels=True)
|
| 102 |
+
else:
|
| 103 |
+
loss = self.label_smoother(outputs, labels)
|
| 104 |
+
else:
|
| 105 |
+
if isinstance(outputs, dict) and "loss" not in outputs:
|
| 106 |
+
raise ValueError(
|
| 107 |
+
"The model did not return a loss from the inputs, only the following keys: "
|
| 108 |
+
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
|
| 109 |
+
)
|
| 110 |
+
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
| 111 |
+
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
| 112 |
+
# ------------------------------------------------------------------------------
|
| 113 |
+
# target_params = self.oft_params if hasattr(self, 'oft_params') and self.oft_params else model.named_parameters()
|
| 114 |
+
|
| 115 |
+
# for param_item in target_params:
|
| 116 |
+
# # Handle both cached list (param only) and named_parameters (name, param)
|
| 117 |
+
# if isinstance(param_item, tuple):
|
| 118 |
+
# name, param = param_item
|
| 119 |
+
# if 'oft_r' not in name: continue
|
| 120 |
+
# else:
|
| 121 |
+
# param = param_item
|
| 122 |
+
|
| 123 |
+
# device = param.device
|
| 124 |
+
# householder_U_norm = param / param.norm(dim=0)
|
| 125 |
+
# orth_loss = torch.norm(
|
| 126 |
+
# torch.eye(householder_U_norm.size(1), device=device) - householder_U_norm.t() @ householder_U_norm)
|
| 127 |
+
|
| 128 |
+
# loss = loss + self.lamda * orth_loss.to(loss.device)
|
| 129 |
+
|
| 130 |
+
# ------------------------------------------------------------------------------
|
| 131 |
+
|
| 132 |
+
return (loss, outputs) if return_outputs else loss
|
| 133 |
+
|
| 134 |
+
def _move_model_to_device(self, model, device):
|
| 135 |
+
pass
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclass
|
| 139 |
+
class TrainingArguments(transformers.TrainingArguments):
|
| 140 |
+
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
|
| 141 |
+
adapter_name_or_path: Optional[str] = field(default=None)
|
| 142 |
+
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
|
| 143 |
+
dataset_split: str = field(
|
| 144 |
+
default="train[:100000]", metadata={"help": "(`['train', 'test', 'eval']`):"}
|
| 145 |
+
)
|
| 146 |
+
dataset_field: List[str] = field(
|
| 147 |
+
default=None, metadata={"help": "Fields of dataset input and output."}
|
| 148 |
+
)
|
| 149 |
+
optim: str = field(default="adamw_torch")
|
| 150 |
+
model_max_length: int = field(default=512, metadata={
|
| 151 |
+
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, )
|
| 152 |
+
hrft_r: int = field(default=8, metadata={
|
| 153 |
+
"help": "The rank of the adapter. When passing `None` and `adapter_name_or_path` is also `None`, full fine-tuning is used."})
|
| 154 |
+
init_a: float = field(default=1e-4, metadata={"help": "The initial weights"})
|
| 155 |
+
eps: float = field(default=1e-4, metadata={"help": "The control strength of COFT. The freedom of rotation."})
|
| 156 |
+
lamda: float = field(default=1e-4, metadata={"help": "The control strength of regularity"})
|
| 157 |
+
add_orth: str = field(default='none', metadata={"help": ""})
|
| 158 |
+
init_weights: Literal[True, "pissa"] = field(
|
| 159 |
+
default=True,
|
| 160 |
+
metadata={
|
| 161 |
+
"help": (
|
| 162 |
+
"Passing True (default) results in the LoRA initialization."
|
| 163 |
+
"Passing `pissa` results in PiSSA initialization."
|
| 164 |
+
),
|
| 165 |
+
},
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
|
| 170 |
+
"""Collects the state dict and dump to disk."""
|
| 171 |
+
state_dict = trainer.model.state_dict()
|
| 172 |
+
if trainer.args.should_save:
|
| 173 |
+
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
|
| 174 |
+
del state_dict
|
| 175 |
+
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def smart_tokenizer_and_embedding_resize(
|
| 179 |
+
special_tokens_dict: Dict,
|
| 180 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 181 |
+
model: transformers.PreTrainedModel,
|
| 182 |
+
):
|
| 183 |
+
"""Resize tokenizer and embedding.
|
| 184 |
+
|
| 185 |
+
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
| 186 |
+
"""
|
| 187 |
+
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
| 188 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 189 |
+
|
| 190 |
+
if num_new_tokens > 0:
|
| 191 |
+
input_embeddings = model.get_input_embeddings().weight.data
|
| 192 |
+
output_embeddings = model.get_output_embeddings().weight.data
|
| 193 |
+
|
| 194 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 195 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
| 196 |
+
|
| 197 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
| 198 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
| 202 |
+
"""Tokenize a list of strings."""
|
| 203 |
+
tokenized_list = [
|
| 204 |
+
tokenizer(
|
| 205 |
+
text,
|
| 206 |
+
return_tensors="pt",
|
| 207 |
+
padding="longest",
|
| 208 |
+
max_length=tokenizer.model_max_length,
|
| 209 |
+
truncation=True,
|
| 210 |
+
)
|
| 211 |
+
for text in strings
|
| 212 |
+
]
|
| 213 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
| 214 |
+
input_ids_lens = labels_lens = [
|
| 215 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
| 216 |
+
]
|
| 217 |
+
return dict(
|
| 218 |
+
input_ids=input_ids,
|
| 219 |
+
labels=labels,
|
| 220 |
+
input_ids_lens=input_ids_lens,
|
| 221 |
+
labels_lens=labels_lens,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def preprocess(
|
| 226 |
+
sources: Sequence[str],
|
| 227 |
+
targets: Sequence[str],
|
| 228 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 229 |
+
) -> Dict:
|
| 230 |
+
"""Preprocess the data by tokenizing."""
|
| 231 |
+
examples = [s + t for s, t in zip(sources, targets)]
|
| 232 |
+
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
| 233 |
+
input_ids = examples_tokenized["input_ids"]
|
| 234 |
+
labels = copy.deepcopy(input_ids)
|
| 235 |
+
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
| 236 |
+
label[:source_len] = IGNORE_INDEX
|
| 237 |
+
return dict(input_ids=input_ids, labels=labels)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@dataclass
|
| 241 |
+
class DataCollatorForSupervisedDataset(object):
|
| 242 |
+
"""Collate examples for supervised fine-tuning."""
|
| 243 |
+
|
| 244 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 245 |
+
|
| 246 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 247 |
+
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
| 248 |
+
input_ids = [torch.tensor(x) for x in input_ids]
|
| 249 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 250 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
| 251 |
+
)
|
| 252 |
+
labels = [torch.tensor(x) for x in labels]
|
| 253 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
| 254 |
+
return dict(
|
| 255 |
+
input_ids=input_ids,
|
| 256 |
+
labels=labels,
|
| 257 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def train_tokenize_function(examples, tokenizer, query, response):
|
| 262 |
+
sources = [PROMPT.format_map(dict(instruction=instruction)) for instruction in examples[query]]
|
| 263 |
+
targets = [f"{output}{tokenizer.eos_token}" for output in examples[response]]
|
| 264 |
+
data_dict = preprocess(sources, targets, tokenizer)
|
| 265 |
+
return data_dict
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def train():
|
| 269 |
+
parser = transformers.HfArgumentParser(TrainingArguments)
|
| 270 |
+
script_args = parser.parse_args_into_dataclasses()[0]
|
| 271 |
+
# print(script_args)
|
| 272 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 273 |
+
script_args.model_name_or_path,
|
| 274 |
+
device_map={"": 0}, #device_map="auto",
|
| 275 |
+
)
|
| 276 |
+
if script_args.adapter_name_or_path is not None:
|
| 277 |
+
print(f"Load {script_args.init_weights} from {script_args.adapter_name_or_path}: ", )
|
| 278 |
+
model = PeftModel.from_pretrained(model, script_args.model_name_or_path,
|
| 279 |
+
subfolder=script_args.adapter_name_or_path, is_trainable=True)
|
| 280 |
+
elif script_args.hrft_r is not None:
|
| 281 |
+
print(f"Initilized {script_args.init_weights} layers")
|
| 282 |
+
|
| 283 |
+
hra_config = OFTConfig(
|
| 284 |
+
r= script_args.hrft_r,
|
| 285 |
+
eps=script_args.eps,
|
| 286 |
+
init_weights=script_args.init_weights,
|
| 287 |
+
target_modules=["q_proj", "v_proj"],
|
| 288 |
+
task_type="CAUSAL_LM",
|
| 289 |
+
)
|
| 290 |
+
model = get_peft_model(model, hra_config)
|
| 291 |
+
else:
|
| 292 |
+
print("Full Parameter Fine-Tuning")
|
| 293 |
+
|
| 294 |
+
print(model)
|
| 295 |
+
model.print_trainable_parameters()
|
| 296 |
+
# import time
|
| 297 |
+
# print("Program starts")
|
| 298 |
+
# time.sleep(300)
|
| 299 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 300 |
+
script_args.model_name_or_path,
|
| 301 |
+
model_max_length=script_args.model_max_length,
|
| 302 |
+
padding_side="right",
|
| 303 |
+
use_fast=True,
|
| 304 |
+
)
|
| 305 |
+
if tokenizer.pad_token is None:
|
| 306 |
+
smart_tokenizer_and_embedding_resize(
|
| 307 |
+
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
| 308 |
+
tokenizer=tokenizer,
|
| 309 |
+
model=model,
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
if "llama" in script_args.model_name_or_path:
|
| 313 |
+
tokenizer.add_special_tokens(
|
| 314 |
+
{
|
| 315 |
+
"eos_token": DEFAULT_EOS_TOKEN,
|
| 316 |
+
"bos_token": DEFAULT_BOS_TOKEN,
|
| 317 |
+
"unk_token": DEFAULT_UNK_TOKEN,
|
| 318 |
+
}
|
| 319 |
+
)
|
| 320 |
+
# if tokenizer.pad_token is None:
|
| 321 |
+
# if tokenizer.unk_token_id is not None:
|
| 322 |
+
# tokenizer.pad_token_id = tokenizer.unk_token_id
|
| 323 |
+
# tokenizer.pad_token = tokenizer.unk_token
|
| 324 |
+
# print("Set PAD token to UNK token.")
|
| 325 |
+
# elif tokenizer.eos_token_id is not None:
|
| 326 |
+
# tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 327 |
+
# tokenizer.pad_token = tokenizer.eos_token
|
| 328 |
+
# print("Set PAD token to EOS token.")
|
| 329 |
+
|
| 330 |
+
# if model is not None:
|
| 331 |
+
# model.config.pad_token_id = tokenizer.pad_token_id
|
| 332 |
+
# if model.config.pad_token_id != tokenizer.pad_token_id:
|
| 333 |
+
# raise ValueError("Failed to sync pad_token_id between tokenizer and model config")
|
| 334 |
+
|
| 335 |
+
raw_train_datasets = load_dataset("json", data_files=script_args.data_path, split=script_args.dataset_split, name="metamath_qa",)
|
| 336 |
+
train_dataset = raw_train_datasets.map(
|
| 337 |
+
train_tokenize_function,
|
| 338 |
+
batched=True,
|
| 339 |
+
batch_size=3000,
|
| 340 |
+
num_proc=32,
|
| 341 |
+
remove_columns=raw_train_datasets.column_names,
|
| 342 |
+
load_from_cache_file=True,
|
| 343 |
+
desc="Running tokenizer on train dataset",
|
| 344 |
+
fn_kwargs={"tokenizer": tokenizer, "query": script_args.dataset_field[0],
|
| 345 |
+
"response": script_args.dataset_field[1]}
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
| 349 |
+
data_module = dict(train_dataset=train_dataset, data_collator=data_collator)
|
| 350 |
+
trainer = MyTrainer(model=model, tokenizer=tokenizer, lamda=script_args.lamda, args=script_args, **data_module)
|
| 351 |
+
model.config.use_cache = False
|
| 352 |
+
|
| 353 |
+
start_time = datetime.now()
|
| 354 |
+
print('start time: ', start_time.strftime("%Y-%m-%d %H:%M:%S"))
|
| 355 |
+
|
| 356 |
+
trainer.train()
|
| 357 |
+
|
| 358 |
+
end_time = datetime.now()
|
| 359 |
+
print('end time: ', end_time.strftime("%Y-%m-%d %H:%M:%S"), '| duration: ', end_time - start_time)
|
| 360 |
+
|
| 361 |
+
# trainer.save_state()
|
| 362 |
+
tokenizer.save_pretrained(os.path.join(script_args.output_dir, 'ft'))
|
| 363 |
+
model.save_pretrained(os.path.join(script_args.output_dir, 'ft'))
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
if __name__ == "__main__":
|
| 367 |
+
|
| 368 |
+
train()
|
llama/inference/MATH_inference.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import pdb
|
| 4 |
+
import jsonlines
|
| 5 |
+
|
| 6 |
+
import util
|
| 7 |
+
from vllm import LLM, SamplingParams
|
| 8 |
+
import sys
|
| 9 |
+
MAX_INT = sys.maxsize
|
| 10 |
+
INVALID_ANS = "[invalid]"
|
| 11 |
+
|
| 12 |
+
invalid_outputs = []
|
| 13 |
+
def remove_boxed(s):
|
| 14 |
+
left = "\\boxed{"
|
| 15 |
+
try:
|
| 16 |
+
assert s[:len(left)] == left
|
| 17 |
+
assert s[-1] == "}"
|
| 18 |
+
return s[len(left):-1]
|
| 19 |
+
except:
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
def process_results(doc, completion, answer):
|
| 23 |
+
split_ans = completion.split('The answer is: ')
|
| 24 |
+
if len(split_ans) > 1:
|
| 25 |
+
ans = split_ans[-1]
|
| 26 |
+
extract_ans_temp = ans.split('.\n')[0]
|
| 27 |
+
extract_ans_temp = extract_ans_temp.strip()
|
| 28 |
+
if len(extract_ans_temp)>0 and extract_ans_temp[-1] == '.':
|
| 29 |
+
extract_ans = extract_ans_temp[0:-1]
|
| 30 |
+
else:
|
| 31 |
+
extract_ans = extract_ans_temp
|
| 32 |
+
extract_ans = extract_ans.strip()
|
| 33 |
+
if util.is_equiv(extract_ans, answer):
|
| 34 |
+
return True
|
| 35 |
+
else:
|
| 36 |
+
return False
|
| 37 |
+
else:
|
| 38 |
+
temp = {'question': doc, 'output': completion, 'answer': answer}
|
| 39 |
+
invalid_outputs.append(temp)
|
| 40 |
+
return False
|
| 41 |
+
def batch_data(data_list, batch_size=1):
|
| 42 |
+
n = len(data_list) // batch_size
|
| 43 |
+
batch_data = []
|
| 44 |
+
for i in range(n-1):
|
| 45 |
+
start = i * batch_size
|
| 46 |
+
end = (i+1)*batch_size
|
| 47 |
+
batch_data.append(data_list[start:end])
|
| 48 |
+
|
| 49 |
+
last_start = (n-1) * batch_size
|
| 50 |
+
last_end = MAX_INT
|
| 51 |
+
batch_data.append(data_list[last_start:last_end])
|
| 52 |
+
return batch_data
|
| 53 |
+
|
| 54 |
+
def test_hendrycks_math(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1):
|
| 55 |
+
hendrycks_math_ins = []
|
| 56 |
+
hendrycks_math_answers = []
|
| 57 |
+
problem_prompt = (
|
| 58 |
+
"Below is an instruction that describes a task. "
|
| 59 |
+
"Write a response that appropriately completes the request.\n\n"
|
| 60 |
+
"### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
|
| 61 |
+
)
|
| 62 |
+
print('promt =====', problem_prompt)
|
| 63 |
+
with open(data_path, "r+", encoding="utf8") as f:
|
| 64 |
+
for idx, item in enumerate(jsonlines.Reader(f)):
|
| 65 |
+
temp_instr = problem_prompt.format(instruction=item["instruction"])
|
| 66 |
+
hendrycks_math_ins.append(temp_instr)
|
| 67 |
+
solution = item['output']
|
| 68 |
+
temp_ans = remove_boxed(util.last_boxed_only_string(solution))
|
| 69 |
+
hendrycks_math_answers.append(temp_ans)
|
| 70 |
+
|
| 71 |
+
print('total length ===', len(hendrycks_math_ins))
|
| 72 |
+
hendrycks_math_ins = hendrycks_math_ins[start:end]
|
| 73 |
+
hendrycks_math_answers = hendrycks_math_answers[start:end]
|
| 74 |
+
print('lenght ====', len(hendrycks_math_ins))
|
| 75 |
+
batch_hendrycks_math_ins = batch_data(hendrycks_math_ins, batch_size=batch_size)
|
| 76 |
+
|
| 77 |
+
stop_tokens = ["Instruction:", "Instruction", "Response:", "Response"]
|
| 78 |
+
sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=2048, stop=stop_tokens)
|
| 79 |
+
print('sampleing =====', sampling_params)
|
| 80 |
+
llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
|
| 81 |
+
outputs = llm.generate(hendrycks_math_ins, sampling_params)
|
| 82 |
+
res_completions = [output.outputs[0].text for output in outputs]
|
| 83 |
+
|
| 84 |
+
results = []
|
| 85 |
+
for idx, (prompt, completion, prompt_answer) in enumerate(zip(hendrycks_math_ins, res_completions, hendrycks_math_answers)):
|
| 86 |
+
res = process_results(prompt, completion, prompt_answer)
|
| 87 |
+
results.append(res)
|
| 88 |
+
|
| 89 |
+
acc = sum(results) / len(results)
|
| 90 |
+
print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', len(invalid_outputs))
|
| 91 |
+
# print('start===', start, ', end====',end)
|
| 92 |
+
print('length====', len(results), ', acc====', acc)
|
| 93 |
+
|
| 94 |
+
def parse_args():
|
| 95 |
+
parser = argparse.ArgumentParser()
|
| 96 |
+
parser.add_argument("--model", type=str, default=0) # model path
|
| 97 |
+
parser.add_argument("--data_file", type=str, default='data/MATH_test.jsonl') # data path
|
| 98 |
+
parser.add_argument("--start", type=int, default=0) #start index
|
| 99 |
+
parser.add_argument("--end", type=int, default=MAX_INT) # end index
|
| 100 |
+
parser.add_argument("--batch_size", type=int, default=50) # batch_size
|
| 101 |
+
parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size
|
| 102 |
+
return parser.parse_args()
|
| 103 |
+
|
| 104 |
+
if __name__ == "__main__":
|
| 105 |
+
args = parse_args()
|
| 106 |
+
test_hendrycks_math(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size)
|
| 107 |
+
|
| 108 |
+
|
llama/inference/grader.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
|
| 3 |
+
- https://github.com/microsoft/ProphetNet/tree/master/CRITIC
|
| 4 |
+
"""
|
| 5 |
+
import multiprocessing
|
| 6 |
+
from math import isclose
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
from sympy import simplify, N
|
| 10 |
+
from sympy.parsing.sympy_parser import parse_expr
|
| 11 |
+
from sympy.parsing.latex import parse_latex
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def is_digit(s):
|
| 15 |
+
try:
|
| 16 |
+
float(str(s).replace(",", ""))
|
| 17 |
+
return True
|
| 18 |
+
except ValueError:
|
| 19 |
+
return False
|
| 20 |
+
|
| 21 |
+
def math_equal(prediction: Union[bool, float, str],
|
| 22 |
+
reference: Union[float, str],
|
| 23 |
+
include_percentage: bool = True,
|
| 24 |
+
is_close: bool = True,
|
| 25 |
+
timeout: bool = False,
|
| 26 |
+
) -> bool:
|
| 27 |
+
"""
|
| 28 |
+
Exact match of math if and only if:
|
| 29 |
+
1. numerical equal: both can convert to float and are equal
|
| 30 |
+
2. symbolic equal: both can convert to sympy expression and are equal
|
| 31 |
+
"""
|
| 32 |
+
try: # 1. numerical equal
|
| 33 |
+
if is_digit(prediction) and is_digit(reference):
|
| 34 |
+
prediction = float(str(prediction).replace(",", ""))
|
| 35 |
+
reference = float(str(reference).replace(",", ""))
|
| 36 |
+
# number questions
|
| 37 |
+
if include_percentage:
|
| 38 |
+
gt_result = [reference / 100, reference, reference * 100]
|
| 39 |
+
else:
|
| 40 |
+
gt_result = [reference]
|
| 41 |
+
for item in gt_result:
|
| 42 |
+
try:
|
| 43 |
+
if is_close:
|
| 44 |
+
if isclose(item, prediction, rel_tol=1e-4):
|
| 45 |
+
return True
|
| 46 |
+
else:
|
| 47 |
+
if item == prediction:
|
| 48 |
+
return True
|
| 49 |
+
except Exception:
|
| 50 |
+
continue
|
| 51 |
+
return False
|
| 52 |
+
except:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
if not prediction and prediction not in [0, False]:
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
# 2. symbolic equal
|
| 59 |
+
reference = str(reference).strip()
|
| 60 |
+
prediction = str(prediction).strip()
|
| 61 |
+
|
| 62 |
+
## deal with [], (), {}
|
| 63 |
+
pred_str, ref_str = prediction, reference
|
| 64 |
+
if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \
|
| 65 |
+
(prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")):
|
| 66 |
+
pred_str = pred_str.strip("[]()")
|
| 67 |
+
ref_str = ref_str.strip("[]()")
|
| 68 |
+
for s in ['{', "}", "(", ")"]:
|
| 69 |
+
ref_str = ref_str.replace(s, "")
|
| 70 |
+
pred_str = pred_str.replace(s, "")
|
| 71 |
+
if pred_str == ref_str:
|
| 72 |
+
return True
|
| 73 |
+
|
| 74 |
+
## [a, b] vs. [c, d], return a==c and b==d
|
| 75 |
+
if (prediction.startswith("[") and prediction.endswith("]")) and (reference.startswith("[") and reference.endswith("]")) or \
|
| 76 |
+
(prediction.startswith("(") and prediction.endswith(")")) and (reference.startswith("(") and reference.endswith(")")):
|
| 77 |
+
pred_parts = prediction[1:-1].split(",")
|
| 78 |
+
ref_parts = reference[1:-1].split(",")
|
| 79 |
+
if len(pred_parts) == len(ref_parts):
|
| 80 |
+
if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
|
| 81 |
+
return True
|
| 82 |
+
|
| 83 |
+
# symbolic equal with sympy
|
| 84 |
+
if timeout:
|
| 85 |
+
if call_with_timeout(symbolic_equal_process, prediction, reference):
|
| 86 |
+
return True
|
| 87 |
+
else:
|
| 88 |
+
if symbolic_equal(prediction, reference):
|
| 89 |
+
return True
|
| 90 |
+
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def math_equal_process(param):
|
| 95 |
+
return math_equal(param[-2], param[-1])
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def symbolic_equal(a, b):
|
| 99 |
+
def _parse(s):
|
| 100 |
+
for f in [parse_latex, parse_expr]:
|
| 101 |
+
try:
|
| 102 |
+
return f(s)
|
| 103 |
+
except:
|
| 104 |
+
pass
|
| 105 |
+
return s
|
| 106 |
+
a = _parse(a)
|
| 107 |
+
b = _parse(b)
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
if simplify(a-b) == 0:
|
| 111 |
+
return True
|
| 112 |
+
except:
|
| 113 |
+
pass
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
if isclose(N(a), N(b), rel_tol=1e-3):
|
| 117 |
+
return True
|
| 118 |
+
except:
|
| 119 |
+
pass
|
| 120 |
+
return False
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def symbolic_equal_process(a, b, output_queue):
|
| 124 |
+
result = symbolic_equal(a, b)
|
| 125 |
+
output_queue.put(result)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def call_with_timeout(func, *args, timeout=1, **kwargs):
|
| 129 |
+
output_queue = multiprocessing.Queue()
|
| 130 |
+
process_args = args + (output_queue,)
|
| 131 |
+
process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
|
| 132 |
+
process.start()
|
| 133 |
+
process.join(timeout)
|
| 134 |
+
|
| 135 |
+
if process.is_alive():
|
| 136 |
+
process.terminate()
|
| 137 |
+
process.join()
|
| 138 |
+
return False
|
| 139 |
+
|
| 140 |
+
return output_queue.get()
|
| 141 |
+
|
llama/inference/gsm8k_inference.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import re
|
| 4 |
+
import jsonlines
|
| 5 |
+
from fraction import Fraction
|
| 6 |
+
from vllm import LLM, SamplingParams
|
| 7 |
+
import sys
|
| 8 |
+
from grader import math_equal
|
| 9 |
+
MAX_INT = sys.maxsize
|
| 10 |
+
|
| 11 |
+
def is_number(s):
|
| 12 |
+
try:
|
| 13 |
+
float(s)
|
| 14 |
+
return True
|
| 15 |
+
except ValueError:
|
| 16 |
+
pass
|
| 17 |
+
try:
|
| 18 |
+
import unicodedata
|
| 19 |
+
unicodedata.numeric(s)
|
| 20 |
+
return True
|
| 21 |
+
except (TypeError, ValueError):
|
| 22 |
+
pass
|
| 23 |
+
return False
|
| 24 |
+
|
| 25 |
+
def extract_answer_number(completion):
|
| 26 |
+
text = completion.split('The answer is: ')
|
| 27 |
+
if len(text) > 1:
|
| 28 |
+
extract_ans = text[-1].strip()
|
| 29 |
+
match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans)
|
| 30 |
+
if match:
|
| 31 |
+
if '/' in match.group():
|
| 32 |
+
denominator = match.group().split('/')[1]
|
| 33 |
+
numerator = match.group().split('/')[0]
|
| 34 |
+
if is_number(denominator) == True and is_number(numerator) == True:
|
| 35 |
+
if denominator == '0':
|
| 36 |
+
return round(float(numerator.replace(',', '')))
|
| 37 |
+
else:
|
| 38 |
+
frac = Fraction(match.group().replace(',', ''))
|
| 39 |
+
num_numerator = frac.numerator
|
| 40 |
+
num_denominator = frac.denominator
|
| 41 |
+
return round(float(num_numerator / num_denominator))
|
| 42 |
+
else:
|
| 43 |
+
return None
|
| 44 |
+
else:
|
| 45 |
+
if float(match.group().replace(',', '')) == float('inf'):
|
| 46 |
+
return None
|
| 47 |
+
return round(float(match.group().replace(',', '')))
|
| 48 |
+
else:
|
| 49 |
+
return None
|
| 50 |
+
else:
|
| 51 |
+
return None
|
| 52 |
+
|
| 53 |
+
def batch_data(data_list, batch_size=1):
|
| 54 |
+
n = len(data_list) // batch_size
|
| 55 |
+
batch_data = []
|
| 56 |
+
for i in range(n-1):
|
| 57 |
+
start = i * batch_size
|
| 58 |
+
end = (i+1)*batch_size
|
| 59 |
+
batch_data.append(data_list[start:end])
|
| 60 |
+
|
| 61 |
+
last_start = (n-1) * batch_size
|
| 62 |
+
last_end = MAX_INT
|
| 63 |
+
batch_data.append(data_list[last_start:last_end])
|
| 64 |
+
return batch_data
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def gsm8k_test(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1):
|
| 68 |
+
INVALID_ANS = "[invalid]"
|
| 69 |
+
gsm8k_ins = []
|
| 70 |
+
gsm8k_answers = []
|
| 71 |
+
problem_prompt = (
|
| 72 |
+
"Below is an instruction that describes a task. "
|
| 73 |
+
"Write a response that appropriately completes the request.\n\n"
|
| 74 |
+
"### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
|
| 75 |
+
)
|
| 76 |
+
print('promt =====', problem_prompt)
|
| 77 |
+
with open(data_path,"r+", encoding="utf8") as f:
|
| 78 |
+
for idx, item in enumerate(jsonlines.Reader(f)):
|
| 79 |
+
temp_instr = problem_prompt.format(instruction=item["question"])
|
| 80 |
+
gsm8k_ins.append(temp_instr)
|
| 81 |
+
temp_ans = item['answer'].split('#### ')[1]
|
| 82 |
+
temp_ans = int(temp_ans.replace(',', ''))
|
| 83 |
+
gsm8k_answers.append(temp_ans)
|
| 84 |
+
|
| 85 |
+
gsm8k_ins = gsm8k_ins[start:end]
|
| 86 |
+
gsm8k_answers = gsm8k_answers[start:end]
|
| 87 |
+
print('lenght ====', len(gsm8k_ins))
|
| 88 |
+
batch_gsm8k_ins = batch_data(gsm8k_ins, batch_size=batch_size)
|
| 89 |
+
|
| 90 |
+
stop_tokens = ["Instruction:", "Instruction", "Response:", "Response"]
|
| 91 |
+
sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=1024, stop=stop_tokens)
|
| 92 |
+
print('sampleing =====', sampling_params)
|
| 93 |
+
llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
|
| 94 |
+
result = []
|
| 95 |
+
|
| 96 |
+
outputs = llm.generate(gsm8k_ins, sampling_params)
|
| 97 |
+
res_completions = [output.outputs[0].text for output in outputs]
|
| 98 |
+
|
| 99 |
+
invalid_outputs = []
|
| 100 |
+
for idx, (prompt, completion, prompt_answer) in enumerate(zip(gsm8k_ins, res_completions, gsm8k_answers)):
|
| 101 |
+
doc = {'question': prompt}
|
| 102 |
+
y_pred = extract_answer_number(completion)
|
| 103 |
+
if y_pred != None:
|
| 104 |
+
result.append(float(y_pred) == float(prompt_answer) or math_equal(y_pred, prompt_answer))
|
| 105 |
+
else:
|
| 106 |
+
result.append(False)
|
| 107 |
+
temp = {'question': prompt, 'output': completion, 'answer': prompt_answer}
|
| 108 |
+
invalid_outputs.append(temp)
|
| 109 |
+
acc = sum(result) / len(result)
|
| 110 |
+
print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', len(invalid_outputs))
|
| 111 |
+
# print('start===', start, ', end====', end)
|
| 112 |
+
print('gsm8k length====', len(result), ', gsm8k acc====', acc)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def parse_args():
|
| 116 |
+
parser = argparse.ArgumentParser()
|
| 117 |
+
parser.add_argument("--model", type=str) # model path
|
| 118 |
+
parser.add_argument("--data_file", type=str, default='data/gsm8k_test.jsonl') # data path
|
| 119 |
+
parser.add_argument("--start", type=int, default=0) #start index
|
| 120 |
+
parser.add_argument("--end", type=int, default=MAX_INT) # end index
|
| 121 |
+
parser.add_argument("--batch_size", type=int, default=60) # batch_size
|
| 122 |
+
parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size
|
| 123 |
+
return parser.parse_args()
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
args = parse_args()
|
| 127 |
+
gsm8k_test(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size)
|
llama/inference/util.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pprint
|
| 2 |
+
from grader import math_equal
|
| 3 |
+
|
| 4 |
+
def last_boxed_only(sample):
|
| 5 |
+
q, a = sample
|
| 6 |
+
a = last_boxed_only_string(a)
|
| 7 |
+
if a == None:
|
| 8 |
+
return None
|
| 9 |
+
return (q, a)
|
| 10 |
+
|
| 11 |
+
def last_boxed_only_string(string):
|
| 12 |
+
idx = string.rfind("\\boxed")
|
| 13 |
+
if idx < 0:
|
| 14 |
+
idx = string.rfind("\\fbox")
|
| 15 |
+
if idx < 0:
|
| 16 |
+
return None
|
| 17 |
+
|
| 18 |
+
i = idx
|
| 19 |
+
right_brace_idx = None
|
| 20 |
+
num_left_braces_open = 0
|
| 21 |
+
while i < len(string):
|
| 22 |
+
if string[i] == "{":
|
| 23 |
+
num_left_braces_open += 1
|
| 24 |
+
if string[i] == "}":
|
| 25 |
+
num_left_braces_open -= 1
|
| 26 |
+
if num_left_braces_open == 0:
|
| 27 |
+
right_brace_idx = i
|
| 28 |
+
break
|
| 29 |
+
i += 1
|
| 30 |
+
|
| 31 |
+
if right_brace_idx == None:
|
| 32 |
+
retval = None
|
| 33 |
+
else:
|
| 34 |
+
retval = string[idx:right_brace_idx + 1]
|
| 35 |
+
|
| 36 |
+
return retval
|
| 37 |
+
|
| 38 |
+
def only_until_first_boxed_from_tokens(string, tokens):
|
| 39 |
+
idx = string.find("\\boxed")
|
| 40 |
+
if idx < 0:
|
| 41 |
+
idx = string.find("\\fbox")
|
| 42 |
+
if idx < 0:
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
cum_length = 0
|
| 46 |
+
for i, t in enumerate(tokens):
|
| 47 |
+
cum_length += len(t)
|
| 48 |
+
if cum_length >= idx:
|
| 49 |
+
break
|
| 50 |
+
|
| 51 |
+
return tokens[:i]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def clean_numbers(sample):
|
| 56 |
+
if not sample:
|
| 57 |
+
return None
|
| 58 |
+
new_sample = list()
|
| 59 |
+
for s in sample:
|
| 60 |
+
new_sample.append(_clean_numbers(s))
|
| 61 |
+
|
| 62 |
+
return tuple(new_sample)
|
| 63 |
+
|
| 64 |
+
def _clean_numbers(string):
|
| 65 |
+
"""
|
| 66 |
+
Clean Numbers in the given string
|
| 67 |
+
|
| 68 |
+
>>> _clean_numbers(None, "Hello 123")
|
| 69 |
+
'Hello 123'
|
| 70 |
+
>>> _clean_numbers(None, "Hello 1234")
|
| 71 |
+
'Hello 1,234'
|
| 72 |
+
>>> _clean_numbers(None, "Hello 1234324asdasd")
|
| 73 |
+
'Hello 1,234,324asdasd'
|
| 74 |
+
"""
|
| 75 |
+
num_prev_digits = 0
|
| 76 |
+
new_string = ""
|
| 77 |
+
for i, c in enumerate(string):
|
| 78 |
+
# isdigit() doesnt work here because of weird unicode chars.
|
| 79 |
+
if c in {'1', '2', '3', '4', '5', '6', '7', '8', '9', '0'}:
|
| 80 |
+
num_prev_digits += 1
|
| 81 |
+
else:
|
| 82 |
+
if num_prev_digits > 3:
|
| 83 |
+
# Some fixing
|
| 84 |
+
string_number = new_string[-num_prev_digits:]
|
| 85 |
+
new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number))
|
| 86 |
+
num_prev_digits = 0
|
| 87 |
+
new_string += c
|
| 88 |
+
|
| 89 |
+
if num_prev_digits > 3:
|
| 90 |
+
# Some fixing
|
| 91 |
+
string_number = new_string[-num_prev_digits:]
|
| 92 |
+
new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number))
|
| 93 |
+
|
| 94 |
+
return new_string
|
| 95 |
+
|
| 96 |
+
def fix_fracs(string):
|
| 97 |
+
substrs = string.split("\\frac")
|
| 98 |
+
new_str = substrs[0]
|
| 99 |
+
if len(substrs) > 1:
|
| 100 |
+
substrs = substrs[1:]
|
| 101 |
+
for substr in substrs:
|
| 102 |
+
new_str += "\\frac"
|
| 103 |
+
if substr[0] == "{":
|
| 104 |
+
new_str += substr
|
| 105 |
+
else:
|
| 106 |
+
try:
|
| 107 |
+
assert len(substr) >= 2
|
| 108 |
+
except AssertionError:
|
| 109 |
+
return string
|
| 110 |
+
a = substr[0]
|
| 111 |
+
b = substr[1]
|
| 112 |
+
if b != "{":
|
| 113 |
+
if len(substr) > 2:
|
| 114 |
+
post_substr = substr[2:]
|
| 115 |
+
new_str += "{" + a + "}{" + b + "}" + post_substr
|
| 116 |
+
else:
|
| 117 |
+
new_str += "{" + a + "}{" + b + "}"
|
| 118 |
+
else:
|
| 119 |
+
if len(substr) > 2:
|
| 120 |
+
post_substr = substr[2:]
|
| 121 |
+
new_str += "{" + a + "}" + b + post_substr
|
| 122 |
+
else:
|
| 123 |
+
new_str += "{" + a + "}" + b
|
| 124 |
+
string = new_str
|
| 125 |
+
return string
|
| 126 |
+
|
| 127 |
+
def fix_a_slash_b(string):
|
| 128 |
+
if len(string.split("/")) != 2:
|
| 129 |
+
return string
|
| 130 |
+
a = string.split("/")[0]
|
| 131 |
+
b = string.split("/")[1]
|
| 132 |
+
try:
|
| 133 |
+
a = int(a)
|
| 134 |
+
b = int(b)
|
| 135 |
+
assert string == "{}/{}".format(a, b)
|
| 136 |
+
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
| 137 |
+
return new_string
|
| 138 |
+
except AssertionError:
|
| 139 |
+
return string
|
| 140 |
+
|
| 141 |
+
def remove_right_units(string):
|
| 142 |
+
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
| 143 |
+
if "\\text{ " in string:
|
| 144 |
+
splits = string.split("\\text{ ")
|
| 145 |
+
assert len(splits) == 2
|
| 146 |
+
return splits[0]
|
| 147 |
+
else:
|
| 148 |
+
return string
|
| 149 |
+
|
| 150 |
+
def fix_sqrt(string):
|
| 151 |
+
if "\\sqrt" not in string:
|
| 152 |
+
return string
|
| 153 |
+
splits = string.split("\\sqrt")
|
| 154 |
+
new_string = splits[0]
|
| 155 |
+
for split in splits[1:]:
|
| 156 |
+
if split[0] != "{":
|
| 157 |
+
a = split[0]
|
| 158 |
+
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
| 159 |
+
else:
|
| 160 |
+
new_substr = "\\sqrt" + split
|
| 161 |
+
new_string += new_substr
|
| 162 |
+
return new_string
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def strip_string(string):
|
| 166 |
+
# linebreaks
|
| 167 |
+
string = string.replace("\n", "")
|
| 168 |
+
|
| 169 |
+
# remove inverse spaces
|
| 170 |
+
string = string.replace("\\!", "")
|
| 171 |
+
|
| 172 |
+
# replace \\ with \
|
| 173 |
+
string = string.replace("\\\\", "\\")
|
| 174 |
+
|
| 175 |
+
# replace tfrac and dfrac with frac
|
| 176 |
+
string = string.replace("tfrac", "frac")
|
| 177 |
+
string = string.replace("dfrac", "frac")
|
| 178 |
+
|
| 179 |
+
# remove \left and \right
|
| 180 |
+
string = string.replace("\\left", "")
|
| 181 |
+
string = string.replace("\\right", "")
|
| 182 |
+
|
| 183 |
+
# Remove circ (degrees)
|
| 184 |
+
string = string.replace("^{\\circ}", "")
|
| 185 |
+
string = string.replace("^\\circ", "")
|
| 186 |
+
|
| 187 |
+
# remove dollar signs
|
| 188 |
+
string = string.replace("\\$", "")
|
| 189 |
+
|
| 190 |
+
# remove units (on the right)
|
| 191 |
+
string = remove_right_units(string)
|
| 192 |
+
|
| 193 |
+
# remove percentage
|
| 194 |
+
string = string.replace("\\%", "")
|
| 195 |
+
string = string.replace("\%", "") # noqa: W605
|
| 196 |
+
|
| 197 |
+
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
| 198 |
+
string = string.replace(" .", " 0.")
|
| 199 |
+
string = string.replace("{.", "{0.")
|
| 200 |
+
# if empty, return empty string
|
| 201 |
+
if len(string) == 0:
|
| 202 |
+
return string
|
| 203 |
+
if string[0] == ".":
|
| 204 |
+
string = "0" + string
|
| 205 |
+
|
| 206 |
+
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
| 207 |
+
if len(string.split("=")) == 2:
|
| 208 |
+
if len(string.split("=")[0]) <= 2:
|
| 209 |
+
string = string.split("=")[1]
|
| 210 |
+
|
| 211 |
+
# fix sqrt3 --> sqrt{3}
|
| 212 |
+
string = fix_sqrt(string)
|
| 213 |
+
|
| 214 |
+
# remove spaces
|
| 215 |
+
string = string.replace(" ", "")
|
| 216 |
+
|
| 217 |
+
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
| 218 |
+
string = fix_fracs(string)
|
| 219 |
+
|
| 220 |
+
# manually change 0.5 --> \frac{1}{2}
|
| 221 |
+
if string == "0.5":
|
| 222 |
+
string = "\\frac{1}{2}"
|
| 223 |
+
|
| 224 |
+
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
| 225 |
+
string = fix_a_slash_b(string)
|
| 226 |
+
|
| 227 |
+
return string
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def is_equiv(str1, str2, verbose=False):
|
| 231 |
+
if str1 is None and str2 is None:
|
| 232 |
+
print("WARNING: Both None")
|
| 233 |
+
return True
|
| 234 |
+
if str1 is None or str2 is None:
|
| 235 |
+
return False
|
| 236 |
+
|
| 237 |
+
try:
|
| 238 |
+
ss1 = strip_string(str1)
|
| 239 |
+
ss2 = strip_string(str2)
|
| 240 |
+
#pdb.set_trace()
|
| 241 |
+
if verbose:
|
| 242 |
+
print(ss1, ss2)
|
| 243 |
+
#return ss1 == ss2
|
| 244 |
+
res = math_equal(ss1,ss2) or ss1 == ss2
|
| 245 |
+
return res
|
| 246 |
+
except Exception:
|
| 247 |
+
#return str1 == str2
|
| 248 |
+
res = math_equal(str1,str1) or str1 == str2
|
| 249 |
+
return res
|
| 250 |
+
|
| 251 |
+
class NotEqual:
|
| 252 |
+
def __eq__(self, other):
|
| 253 |
+
return False
|
llama/merge_adapter_to_base_model.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 2 |
+
from peft import PeftModel, PeftConfig
|
| 3 |
+
import argparse
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
parser = argparse.ArgumentParser(description='Merge Adapter to Base Model')
|
| 7 |
+
parser.add_argument('--base_mode', type=str)
|
| 8 |
+
parser.add_argument('--adapter', type=str)
|
| 9 |
+
parser.add_argument('--output_path', type=str)
|
| 10 |
+
args = parser.parse_args()
|
| 11 |
+
|
| 12 |
+
model = AutoModelForCausalLM.from_pretrained(args.base_mode, torch_dtype=torch.bfloat16, device_map="cpu")
|
| 13 |
+
tokenizer = AutoTokenizer.from_pretrained(args.base_mode, device_map='auto')
|
| 14 |
+
#
|
| 15 |
+
# tokenizer = AutoTokenizer.from_pretrained(args.adapter)
|
| 16 |
+
model.resize_token_embeddings(32001)
|
| 17 |
+
print('len', len(tokenizer))
|
| 18 |
+
print(f"Base model vocab size after resize: {model.get_input_embeddings().weight.shape[0]}")
|
| 19 |
+
#
|
| 20 |
+
lora_config = PeftConfig.from_pretrained(args.adapter)
|
| 21 |
+
lora_config.init_oft_weights=True
|
| 22 |
+
model = PeftModel.from_pretrained(model, args.adapter, config=lora_config)
|
| 23 |
+
model = model.merge_and_unload()
|
| 24 |
+
|
| 25 |
+
model.save_pretrained(args.output_path, safe_serialization=False)
|
| 26 |
+
tokenizer.save_pretrained(args.output_path)
|
| 27 |
+
|
llama/output/cp1e4/ft/README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: meta-llama/Llama-2-7b-hf
|
| 3 |
+
library_name: peft
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
### Framework versions
|
| 201 |
+
|
| 202 |
+
- PEFT 0.10.0
|
llama/output/cp1e4/ft/adapter_config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
| 5 |
+
"block_share": false,
|
| 6 |
+
"coft": false,
|
| 7 |
+
"eps": 0.0001,
|
| 8 |
+
"inference_mode": true,
|
| 9 |
+
"init_weights": true,
|
| 10 |
+
"layers_pattern": null,
|
| 11 |
+
"layers_to_transform": null,
|
| 12 |
+
"module_dropout": 0.0,
|
| 13 |
+
"modules_to_save": null,
|
| 14 |
+
"peft_type": "OFT",
|
| 15 |
+
"r": 32,
|
| 16 |
+
"rank_pattern": {},
|
| 17 |
+
"revision": null,
|
| 18 |
+
"target_modules": [
|
| 19 |
+
"v_proj",
|
| 20 |
+
"q_proj"
|
| 21 |
+
],
|
| 22 |
+
"task_type": "CAUSAL_LM"
|
| 23 |
+
}
|
llama/output/cp1e4/ft/added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"[PAD]": 32000
|
| 3 |
+
}
|
llama/output/cp1e4/ft/special_tokens_map.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "</s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "[PAD]",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"unk_token": {
|
| 24 |
+
"content": "</s>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
}
|
| 30 |
+
}
|
llama/output/cp1e4/ft/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
llama/output/cp1e4/ft/tokenizer_config.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": null,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<unk>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "</s>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
},
|
| 30 |
+
"32000": {
|
| 31 |
+
"content": "[PAD]",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false,
|
| 36 |
+
"special": true
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
"bos_token": "</s>",
|
| 40 |
+
"clean_up_tokenization_spaces": false,
|
| 41 |
+
"eos_token": "</s>",
|
| 42 |
+
"extra_special_tokens": {},
|
| 43 |
+
"legacy": false,
|
| 44 |
+
"model_max_length": 512,
|
| 45 |
+
"pad_token": "[PAD]",
|
| 46 |
+
"padding_side": "right",
|
| 47 |
+
"sp_model_kwargs": {},
|
| 48 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 49 |
+
"unk_token": "</s>",
|
| 50 |
+
"use_default_system_prompt": false
|
| 51 |
+
}
|
llama/output/cp1e5/ft/README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: facebook/opt-125m
|
| 3 |
+
library_name: peft
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
### Framework versions
|
| 201 |
+
|
| 202 |
+
- PEFT 0.10.0
|
llama/output/cp1e5/ft/adapter_config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "facebook/opt-125m",
|
| 5 |
+
"block_share": false,
|
| 6 |
+
"coft": false,
|
| 7 |
+
"eps": 0.0001,
|
| 8 |
+
"inference_mode": true,
|
| 9 |
+
"init_weights": true,
|
| 10 |
+
"layers_pattern": null,
|
| 11 |
+
"layers_to_transform": null,
|
| 12 |
+
"module_dropout": 0.0,
|
| 13 |
+
"modules_to_save": null,
|
| 14 |
+
"peft_type": "OFT",
|
| 15 |
+
"r": 8,
|
| 16 |
+
"rank_pattern": {},
|
| 17 |
+
"revision": null,
|
| 18 |
+
"target_modules": [
|
| 19 |
+
"v_proj",
|
| 20 |
+
"q_proj"
|
| 21 |
+
],
|
| 22 |
+
"task_type": "CAUSAL_LM"
|
| 23 |
+
}
|
llama/output/cp1e5/trainer_state.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_metric": null,
|
| 3 |
+
"best_model_checkpoint": null,
|
| 4 |
+
"epoch": 2e-05,
|
| 5 |
+
"eval_steps": 500,
|
| 6 |
+
"global_step": 2,
|
| 7 |
+
"is_hyper_param_search": false,
|
| 8 |
+
"is_local_process_zero": true,
|
| 9 |
+
"is_world_process_zero": true,
|
| 10 |
+
"log_history": [
|
| 11 |
+
{
|
| 12 |
+
"epoch": 2e-05,
|
| 13 |
+
"step": 2,
|
| 14 |
+
"total_flos": 368078929920.0,
|
| 15 |
+
"train_loss": 2.170874834060669,
|
| 16 |
+
"train_runtime": 0.7734,
|
| 17 |
+
"train_samples_per_second": 2.586,
|
| 18 |
+
"train_steps_per_second": 2.586
|
| 19 |
+
}
|
| 20 |
+
],
|
| 21 |
+
"logging_steps": 1000,
|
| 22 |
+
"max_steps": 2,
|
| 23 |
+
"num_input_tokens_seen": 0,
|
| 24 |
+
"num_train_epochs": 1,
|
| 25 |
+
"save_steps": 0,
|
| 26 |
+
"total_flos": 368078929920.0,
|
| 27 |
+
"train_batch_size": 1,
|
| 28 |
+
"trial_name": null,
|
| 29 |
+
"trial_params": null
|
| 30 |
+
}
|
llama/output/cp1e5N/ft/README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: meta-llama/Llama-2-7b-hf
|
| 3 |
+
library_name: peft
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
### Framework versions
|
| 201 |
+
|
| 202 |
+
- PEFT 0.10.0
|
llama/output/cp1e5N/ft/adapter_config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
| 5 |
+
"block_share": false,
|
| 6 |
+
"coft": false,
|
| 7 |
+
"eps": 0.0001,
|
| 8 |
+
"inference_mode": true,
|
| 9 |
+
"init_weights": true,
|
| 10 |
+
"layers_pattern": null,
|
| 11 |
+
"layers_to_transform": null,
|
| 12 |
+
"module_dropout": 0.0,
|
| 13 |
+
"modules_to_save": null,
|
| 14 |
+
"peft_type": "OFT",
|
| 15 |
+
"r": 32,
|
| 16 |
+
"rank_pattern": {},
|
| 17 |
+
"revision": null,
|
| 18 |
+
"target_modules": [
|
| 19 |
+
"q_proj",
|
| 20 |
+
"v_proj"
|
| 21 |
+
],
|
| 22 |
+
"task_type": "CAUSAL_LM"
|
| 23 |
+
}
|
llama/output/cp1e5N/ft/added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"[PAD]": 32000
|
| 3 |
+
}
|
llama/output/cp1e5N/ft/special_tokens_map.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "</s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "[PAD]",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"unk_token": {
|
| 24 |
+
"content": "</s>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
}
|
| 30 |
+
}
|
llama/output/cp1e5N/ft/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
llama/output/cp1e5N/ft/tokenizer_config.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": null,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<unk>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "</s>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
},
|
| 30 |
+
"32000": {
|
| 31 |
+
"content": "[PAD]",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false,
|
| 36 |
+
"special": true
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
"bos_token": "</s>",
|
| 40 |
+
"clean_up_tokenization_spaces": false,
|
| 41 |
+
"eos_token": "</s>",
|
| 42 |
+
"extra_special_tokens": {},
|
| 43 |
+
"legacy": false,
|
| 44 |
+
"model_max_length": 512,
|
| 45 |
+
"pad_token": "[PAD]",
|
| 46 |
+
"padding_side": "right",
|
| 47 |
+
"sp_model_kwargs": {},
|
| 48 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 49 |
+
"unk_token": "</s>",
|
| 50 |
+
"use_default_system_prompt": false
|
| 51 |
+
}
|
llama/output/cp3e5/ft/README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: meta-llama/Llama-2-7b-hf
|
| 3 |
+
library_name: peft
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
### Framework versions
|
| 201 |
+
|
| 202 |
+
- PEFT 0.10.0
|
llama/output/cp3e5/ft/adapter_config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
| 5 |
+
"block_share": false,
|
| 6 |
+
"coft": false,
|
| 7 |
+
"eps": 0.0001,
|
| 8 |
+
"inference_mode": true,
|
| 9 |
+
"init_weights": true,
|
| 10 |
+
"layers_pattern": null,
|
| 11 |
+
"layers_to_transform": null,
|
| 12 |
+
"module_dropout": 0.0,
|
| 13 |
+
"modules_to_save": null,
|
| 14 |
+
"peft_type": "OFT",
|
| 15 |
+
"r": 32,
|
| 16 |
+
"rank_pattern": {},
|
| 17 |
+
"revision": null,
|
| 18 |
+
"target_modules": [
|
| 19 |
+
"q_proj",
|
| 20 |
+
"v_proj"
|
| 21 |
+
],
|
| 22 |
+
"task_type": "CAUSAL_LM"
|
| 23 |
+
}
|
llama/output/cp3e5/trainer_state.json
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"best_metric": null,
|
| 3 |
+
"best_model_checkpoint": null,
|
| 4 |
+
"epoch": 2.0,
|
| 5 |
+
"eval_steps": 500,
|
| 6 |
+
"global_step": 6250,
|
| 7 |
+
"is_hyper_param_search": false,
|
| 8 |
+
"is_local_process_zero": true,
|
| 9 |
+
"is_world_process_zero": true,
|
| 10 |
+
"log_history": [
|
| 11 |
+
{
|
| 12 |
+
"epoch": 0.32,
|
| 13 |
+
"grad_norm": 3.8161606788635254,
|
| 14 |
+
"learning_rate": 2.8241524699899885e-05,
|
| 15 |
+
"loss": 0.3071,
|
| 16 |
+
"step": 1000
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"epoch": 0.64,
|
| 20 |
+
"grad_norm": 3.0841195583343506,
|
| 21 |
+
"learning_rate": 2.317615224686078e-05,
|
| 22 |
+
"loss": 0.2366,
|
| 23 |
+
"step": 2000
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"epoch": 0.96,
|
| 27 |
+
"grad_norm": 2.72049880027771,
|
| 28 |
+
"learning_rate": 1.6067682497434074e-05,
|
| 29 |
+
"loss": 0.215,
|
| 30 |
+
"step": 3000
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"epoch": 1.28,
|
| 34 |
+
"grad_norm": 2.600498914718628,
|
| 35 |
+
"learning_rate": 8.692414973449614e-06,
|
| 36 |
+
"loss": 0.186,
|
| 37 |
+
"step": 4000
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"epoch": 1.6,
|
| 41 |
+
"grad_norm": 2.350414752960205,
|
| 42 |
+
"learning_rate": 2.893317942061826e-06,
|
| 43 |
+
"loss": 0.1763,
|
| 44 |
+
"step": 5000
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"epoch": 1.92,
|
| 48 |
+
"grad_norm": 2.4736688137054443,
|
| 49 |
+
"learning_rate": 1.194984047782738e-07,
|
| 50 |
+
"loss": 0.1733,
|
| 51 |
+
"step": 6000
|
| 52 |
+
},
|
| 53 |
+
{
|
| 54 |
+
"epoch": 2.0,
|
| 55 |
+
"step": 6250,
|
| 56 |
+
"total_flos": 3.546740381344727e+18,
|
| 57 |
+
"train_loss": 0.21396501647949218,
|
| 58 |
+
"train_runtime": 35607.6251,
|
| 59 |
+
"train_samples_per_second": 5.617,
|
| 60 |
+
"train_steps_per_second": 0.176
|
| 61 |
+
}
|
| 62 |
+
],
|
| 63 |
+
"logging_steps": 1000,
|
| 64 |
+
"max_steps": 6250,
|
| 65 |
+
"num_input_tokens_seen": 0,
|
| 66 |
+
"num_train_epochs": 2,
|
| 67 |
+
"save_steps": 0,
|
| 68 |
+
"total_flos": 3.546740381344727e+18,
|
| 69 |
+
"train_batch_size": 8,
|
| 70 |
+
"trial_name": null,
|
| 71 |
+
"trial_params": null
|
| 72 |
+
}
|
llama/output/cp3e5N/ft/README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: meta-llama/Llama-2-7b-hf
|
| 3 |
+
library_name: peft
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
### Framework versions
|
| 201 |
+
|
| 202 |
+
- PEFT 0.10.0
|
llama/output/cp3e5N/ft/adapter_config.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
|
| 5 |
+
"block_share": false,
|
| 6 |
+
"coft": false,
|
| 7 |
+
"eps": 0.0001,
|
| 8 |
+
"inference_mode": true,
|
| 9 |
+
"init_weights": true,
|
| 10 |
+
"layers_pattern": null,
|
| 11 |
+
"layers_to_transform": null,
|
| 12 |
+
"module_dropout": 0.0,
|
| 13 |
+
"modules_to_save": null,
|
| 14 |
+
"peft_type": "OFT",
|
| 15 |
+
"r": 32,
|
| 16 |
+
"rank_pattern": {},
|
| 17 |
+
"revision": null,
|
| 18 |
+
"target_modules": [
|
| 19 |
+
"q_proj",
|
| 20 |
+
"v_proj"
|
| 21 |
+
],
|
| 22 |
+
"task_type": "CAUSAL_LM"
|
| 23 |
+
}
|
llama/output/cp3e5N/ft/added_tokens.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"[PAD]": 32000
|
| 3 |
+
}
|
llama/output/cp3e5N/ft/special_tokens_map.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "</s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": {
|
| 17 |
+
"content": "[PAD]",
|
| 18 |
+
"lstrip": false,
|
| 19 |
+
"normalized": false,
|
| 20 |
+
"rstrip": false,
|
| 21 |
+
"single_word": false
|
| 22 |
+
},
|
| 23 |
+
"unk_token": {
|
| 24 |
+
"content": "</s>",
|
| 25 |
+
"lstrip": false,
|
| 26 |
+
"normalized": false,
|
| 27 |
+
"rstrip": false,
|
| 28 |
+
"single_word": false
|
| 29 |
+
}
|
| 30 |
+
}
|
llama/output/cp3e5N/ft/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
llama/output/cp3e5N/ft/tokenizer_config.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"add_prefix_space": null,
|
| 5 |
+
"added_tokens_decoder": {
|
| 6 |
+
"0": {
|
| 7 |
+
"content": "<unk>",
|
| 8 |
+
"lstrip": false,
|
| 9 |
+
"normalized": false,
|
| 10 |
+
"rstrip": false,
|
| 11 |
+
"single_word": false,
|
| 12 |
+
"special": true
|
| 13 |
+
},
|
| 14 |
+
"1": {
|
| 15 |
+
"content": "<s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false,
|
| 20 |
+
"special": true
|
| 21 |
+
},
|
| 22 |
+
"2": {
|
| 23 |
+
"content": "</s>",
|
| 24 |
+
"lstrip": false,
|
| 25 |
+
"normalized": false,
|
| 26 |
+
"rstrip": false,
|
| 27 |
+
"single_word": false,
|
| 28 |
+
"special": true
|
| 29 |
+
},
|
| 30 |
+
"32000": {
|
| 31 |
+
"content": "[PAD]",
|
| 32 |
+
"lstrip": false,
|
| 33 |
+
"normalized": false,
|
| 34 |
+
"rstrip": false,
|
| 35 |
+
"single_word": false,
|
| 36 |
+
"special": true
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
"bos_token": "</s>",
|
| 40 |
+
"clean_up_tokenization_spaces": false,
|
| 41 |
+
"eos_token": "</s>",
|
| 42 |
+
"extra_special_tokens": {},
|
| 43 |
+
"legacy": false,
|
| 44 |
+
"model_max_length": 512,
|
| 45 |
+
"pad_token": "[PAD]",
|
| 46 |
+
"padding_side": "right",
|
| 47 |
+
"sp_model_kwargs": {},
|
| 48 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 49 |
+
"unk_token": "</s>",
|
| 50 |
+
"use_default_system_prompt": false
|
| 51 |
+
}
|
llama/output/cpr1/ft/README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: meta-llama/Llama-2-7b-hf
|
| 3 |
+
library_name: peft
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Model Card for Model ID
|
| 7 |
+
|
| 8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
## Model Details
|
| 13 |
+
|
| 14 |
+
### Model Description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
- **Developed by:** [More Information Needed]
|
| 21 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 22 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 23 |
+
- **Model type:** [More Information Needed]
|
| 24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 25 |
+
- **License:** [More Information Needed]
|
| 26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 27 |
+
|
| 28 |
+
### Model Sources [optional]
|
| 29 |
+
|
| 30 |
+
<!-- Provide the basic links for the model. -->
|
| 31 |
+
|
| 32 |
+
- **Repository:** [More Information Needed]
|
| 33 |
+
- **Paper [optional]:** [More Information Needed]
|
| 34 |
+
- **Demo [optional]:** [More Information Needed]
|
| 35 |
+
|
| 36 |
+
## Uses
|
| 37 |
+
|
| 38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 39 |
+
|
| 40 |
+
### Direct Use
|
| 41 |
+
|
| 42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 43 |
+
|
| 44 |
+
[More Information Needed]
|
| 45 |
+
|
| 46 |
+
### Downstream Use [optional]
|
| 47 |
+
|
| 48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 49 |
+
|
| 50 |
+
[More Information Needed]
|
| 51 |
+
|
| 52 |
+
### Out-of-Scope Use
|
| 53 |
+
|
| 54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 55 |
+
|
| 56 |
+
[More Information Needed]
|
| 57 |
+
|
| 58 |
+
## Bias, Risks, and Limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
[More Information Needed]
|
| 63 |
+
|
| 64 |
+
### Recommendations
|
| 65 |
+
|
| 66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 67 |
+
|
| 68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 69 |
+
|
| 70 |
+
## How to Get Started with the Model
|
| 71 |
+
|
| 72 |
+
Use the code below to get started with the model.
|
| 73 |
+
|
| 74 |
+
[More Information Needed]
|
| 75 |
+
|
| 76 |
+
## Training Details
|
| 77 |
+
|
| 78 |
+
### Training Data
|
| 79 |
+
|
| 80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 81 |
+
|
| 82 |
+
[More Information Needed]
|
| 83 |
+
|
| 84 |
+
### Training Procedure
|
| 85 |
+
|
| 86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 87 |
+
|
| 88 |
+
#### Preprocessing [optional]
|
| 89 |
+
|
| 90 |
+
[More Information Needed]
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
#### Training Hyperparameters
|
| 94 |
+
|
| 95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 96 |
+
|
| 97 |
+
#### Speeds, Sizes, Times [optional]
|
| 98 |
+
|
| 99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 100 |
+
|
| 101 |
+
[More Information Needed]
|
| 102 |
+
|
| 103 |
+
## Evaluation
|
| 104 |
+
|
| 105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 106 |
+
|
| 107 |
+
### Testing Data, Factors & Metrics
|
| 108 |
+
|
| 109 |
+
#### Testing Data
|
| 110 |
+
|
| 111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 112 |
+
|
| 113 |
+
[More Information Needed]
|
| 114 |
+
|
| 115 |
+
#### Factors
|
| 116 |
+
|
| 117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 118 |
+
|
| 119 |
+
[More Information Needed]
|
| 120 |
+
|
| 121 |
+
#### Metrics
|
| 122 |
+
|
| 123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 124 |
+
|
| 125 |
+
[More Information Needed]
|
| 126 |
+
|
| 127 |
+
### Results
|
| 128 |
+
|
| 129 |
+
[More Information Needed]
|
| 130 |
+
|
| 131 |
+
#### Summary
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
## Model Examination [optional]
|
| 136 |
+
|
| 137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 138 |
+
|
| 139 |
+
[More Information Needed]
|
| 140 |
+
|
| 141 |
+
## Environmental Impact
|
| 142 |
+
|
| 143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 144 |
+
|
| 145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 146 |
+
|
| 147 |
+
- **Hardware Type:** [More Information Needed]
|
| 148 |
+
- **Hours used:** [More Information Needed]
|
| 149 |
+
- **Cloud Provider:** [More Information Needed]
|
| 150 |
+
- **Compute Region:** [More Information Needed]
|
| 151 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 152 |
+
|
| 153 |
+
## Technical Specifications [optional]
|
| 154 |
+
|
| 155 |
+
### Model Architecture and Objective
|
| 156 |
+
|
| 157 |
+
[More Information Needed]
|
| 158 |
+
|
| 159 |
+
### Compute Infrastructure
|
| 160 |
+
|
| 161 |
+
[More Information Needed]
|
| 162 |
+
|
| 163 |
+
#### Hardware
|
| 164 |
+
|
| 165 |
+
[More Information Needed]
|
| 166 |
+
|
| 167 |
+
#### Software
|
| 168 |
+
|
| 169 |
+
[More Information Needed]
|
| 170 |
+
|
| 171 |
+
## Citation [optional]
|
| 172 |
+
|
| 173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 174 |
+
|
| 175 |
+
**BibTeX:**
|
| 176 |
+
|
| 177 |
+
[More Information Needed]
|
| 178 |
+
|
| 179 |
+
**APA:**
|
| 180 |
+
|
| 181 |
+
[More Information Needed]
|
| 182 |
+
|
| 183 |
+
## Glossary [optional]
|
| 184 |
+
|
| 185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 186 |
+
|
| 187 |
+
[More Information Needed]
|
| 188 |
+
|
| 189 |
+
## More Information [optional]
|
| 190 |
+
|
| 191 |
+
[More Information Needed]
|
| 192 |
+
|
| 193 |
+
## Model Card Authors [optional]
|
| 194 |
+
|
| 195 |
+
[More Information Needed]
|
| 196 |
+
|
| 197 |
+
## Model Card Contact
|
| 198 |
+
|
| 199 |
+
[More Information Needed]
|
| 200 |
+
### Framework versions
|
| 201 |
+
|
| 202 |
+
- PEFT 0.10.0
|