nvan13 commited on
Commit
f4dcc30
·
verified ·
1 Parent(s): 51cbdf4

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. generation/control/oldm/hack.py +111 -0
  2. generation/control/oldm/lora.py +1119 -0
  3. generation/control/oldm/lora_ldm.py +343 -0
  4. generation/control/oldm/model.py +28 -0
  5. generation/control/oldm/oft_ldm.py +353 -0
  6. generation/subject/download_dreambooth.sh +4 -0
  7. generation/subject/evaluate.py +462 -0
  8. generation/subject/get_result.py +62 -0
  9. generation/subject/oft_utils/__init__.py +2 -0
  10. generation/subject/oft_utils/attention_processor.py +1036 -0
  11. generation/subject/oft_utils/mhe.py +360 -0
  12. generation/subject/train_dreambooth_hra.py +1123 -0
  13. generation/subject/train_dreambooth_hra.sh +186 -0
  14. llama/data/MATH_test.jsonl +0 -0
  15. llama/data/gsm8k_test.jsonl +0 -0
  16. llama/data/oft/__init__.py +20 -0
  17. llama/data/oft/config.py +119 -0
  18. llama/data/oft/layer.py +388 -0
  19. llama/data/oft/model.py +106 -0
  20. llama/finetune_32.py +368 -0
  21. llama/inference/MATH_inference.py +108 -0
  22. llama/inference/grader.py +141 -0
  23. llama/inference/gsm8k_inference.py +127 -0
  24. llama/inference/util.py +253 -0
  25. llama/merge_adapter_to_base_model.py +27 -0
  26. llama/output/cp1e4/ft/README.md +202 -0
  27. llama/output/cp1e4/ft/adapter_config.json +23 -0
  28. llama/output/cp1e4/ft/added_tokens.json +3 -0
  29. llama/output/cp1e4/ft/special_tokens_map.json +30 -0
  30. llama/output/cp1e4/ft/tokenizer.json +0 -0
  31. llama/output/cp1e4/ft/tokenizer_config.json +51 -0
  32. llama/output/cp1e5/ft/README.md +202 -0
  33. llama/output/cp1e5/ft/adapter_config.json +23 -0
  34. llama/output/cp1e5/trainer_state.json +30 -0
  35. llama/output/cp1e5N/ft/README.md +202 -0
  36. llama/output/cp1e5N/ft/adapter_config.json +23 -0
  37. llama/output/cp1e5N/ft/added_tokens.json +3 -0
  38. llama/output/cp1e5N/ft/special_tokens_map.json +30 -0
  39. llama/output/cp1e5N/ft/tokenizer.json +0 -0
  40. llama/output/cp1e5N/ft/tokenizer_config.json +51 -0
  41. llama/output/cp3e5/ft/README.md +202 -0
  42. llama/output/cp3e5/ft/adapter_config.json +23 -0
  43. llama/output/cp3e5/trainer_state.json +72 -0
  44. llama/output/cp3e5N/ft/README.md +202 -0
  45. llama/output/cp3e5N/ft/adapter_config.json +23 -0
  46. llama/output/cp3e5N/ft/added_tokens.json +3 -0
  47. llama/output/cp3e5N/ft/special_tokens_map.json +30 -0
  48. llama/output/cp3e5N/ft/tokenizer.json +0 -0
  49. llama/output/cp3e5N/ft/tokenizer_config.json +51 -0
  50. llama/output/cpr1/ft/README.md +202 -0
generation/control/oldm/hack.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import einops
3
+
4
+ import ldm.modules.encoders.modules
5
+ import ldm.modules.attention
6
+
7
+ from transformers import logging
8
+ from ldm.modules.attention import default
9
+
10
+
11
+ def disable_verbosity():
12
+ logging.set_verbosity_error()
13
+ print('logging improved.')
14
+ return
15
+
16
+
17
+ def enable_sliced_attention():
18
+ ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
19
+ print('Enabled sliced_attention.')
20
+ return
21
+
22
+
23
+ def hack_everything(clip_skip=0):
24
+ disable_verbosity()
25
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
26
+ ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
27
+ print('Enabled clip hacks.')
28
+ return
29
+
30
+
31
+ # Written by Lvmin
32
+ def _hacked_clip_forward(self, text):
33
+ PAD = self.tokenizer.pad_token_id
34
+ EOS = self.tokenizer.eos_token_id
35
+ BOS = self.tokenizer.bos_token_id
36
+
37
+ def tokenize(t):
38
+ return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
39
+
40
+ def transformer_encode(t):
41
+ if self.clip_skip > 1:
42
+ rt = self.transformer(input_ids=t, output_hidden_states=True)
43
+ return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
44
+ else:
45
+ return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
46
+
47
+ def split(x):
48
+ return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
49
+
50
+ def pad(x, p, i):
51
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
52
+
53
+ raw_tokens_list = tokenize(text)
54
+ tokens_list = []
55
+
56
+ for raw_tokens in raw_tokens_list:
57
+ raw_tokens_123 = split(raw_tokens)
58
+ raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
59
+ raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
60
+ tokens_list.append(raw_tokens_123)
61
+
62
+ tokens_list = torch.IntTensor(tokens_list).to(self.device)
63
+
64
+ feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
65
+ y = transformer_encode(feed)
66
+ z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
67
+
68
+ return z
69
+
70
+
71
+ # Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
72
+ def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
73
+ h = self.heads
74
+
75
+ q = self.to_q(x)
76
+ context = default(context, x)
77
+ k = self.to_k(context)
78
+ v = self.to_v(context)
79
+ del context, x
80
+
81
+ q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
82
+
83
+ limit = k.shape[0]
84
+ att_step = 1
85
+ q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
86
+ k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
87
+ v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
88
+
89
+ q_chunks.reverse()
90
+ k_chunks.reverse()
91
+ v_chunks.reverse()
92
+ sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
93
+ del k, q, v
94
+ for i in range(0, limit, att_step):
95
+ q_buffer = q_chunks.pop()
96
+ k_buffer = k_chunks.pop()
97
+ v_buffer = v_chunks.pop()
98
+ sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
99
+
100
+ del k_buffer, q_buffer
101
+ # attention, what we cannot get enough of, by chunks
102
+
103
+ sim_buffer = sim_buffer.softmax(dim=-1)
104
+
105
+ sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
106
+ del v_buffer
107
+ sim[i:i + att_step, :, :] = sim_buffer
108
+
109
+ del sim_buffer
110
+ sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
111
+ return self.to_out(sim)
generation/control/oldm/lora.py ADDED
@@ -0,0 +1,1119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This script is retrived from lora available at:
3
+ https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
4
+
5
+ Original Author: Simo Ryu
6
+ License: Apache License 2.0
7
+ """
8
+
9
+ import json
10
+ import math
11
+ from itertools import groupby
12
+ from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
13
+
14
+ import pickle
15
+
16
+ import numpy as np
17
+ import PIL
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ try:
23
+ from safetensors.torch import safe_open
24
+ from safetensors.torch import save_file as safe_save
25
+
26
+ safetensors_available = True
27
+ except ImportError:
28
+ from .safe_open import safe_open
29
+
30
+ def safe_save(
31
+ tensors: Dict[str, torch.Tensor],
32
+ filename: str,
33
+ metadata: Optional[Dict[str, str]] = None,
34
+ ) -> None:
35
+ raise EnvironmentError(
36
+ "Saving safetensors requires the safetensors library. Please install with pip or similar."
37
+ )
38
+
39
+ safetensors_available = False
40
+
41
+
42
+ class LoraInjectedLinear(nn.Module):
43
+ def __init__(
44
+ self, in_features, out_features, bias=False, r=4, dropout_p=0.1, scale=1.0
45
+ ):
46
+ super().__init__()
47
+
48
+ if r > min(in_features, out_features):
49
+ raise ValueError(
50
+ f"LoRA rank {r} must be less or equal than {min(in_features, out_features)}"
51
+ )
52
+ self.r = r
53
+ self.linear = nn.Linear(in_features, out_features, bias)
54
+ self.lora_down = nn.Linear(in_features, r, bias=False)
55
+ self.dropout = nn.Dropout(dropout_p)
56
+ self.lora_up = nn.Linear(r, out_features, bias=False)
57
+ self.scale = scale
58
+ self.selector = nn.Identity()
59
+
60
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
61
+ nn.init.zeros_(self.lora_up.weight)
62
+
63
+ def forward(self, input):
64
+ return (
65
+ self.linear(input) + self.lora_up(self.selector(self.lora_down(input))) * self.scale
66
+ # + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
67
+ # * self.scale
68
+ )
69
+
70
+ def realize_as_lora(self):
71
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
72
+
73
+ def set_selector_from_diag(self, diag: torch.Tensor):
74
+ # diag is a 1D tensor of size (r,)
75
+ assert diag.shape == (self.r,)
76
+ self.selector = nn.Linear(self.r, self.r, bias=False)
77
+ self.selector.weight.data = torch.diag(diag)
78
+ self.selector.weight.data = self.selector.weight.data.to(
79
+ self.lora_up.weight.device
80
+ ).to(self.lora_up.weight.dtype)
81
+
82
+
83
+ class LoraInjectedConv2d(nn.Module):
84
+ def __init__(
85
+ self,
86
+ in_channels: int,
87
+ out_channels: int,
88
+ kernel_size,
89
+ stride=1,
90
+ padding=0,
91
+ dilation=1,
92
+ groups: int = 1,
93
+ bias: bool = True,
94
+ r: int = 4,
95
+ dropout_p: float = 0.1,
96
+ scale: float = 1.0,
97
+ ):
98
+ super().__init__()
99
+ if r > min(in_channels, out_channels):
100
+ raise ValueError(
101
+ f"LoRA rank {r} must be less or equal than {min(in_channels, out_channels)}"
102
+ )
103
+ self.r = r
104
+ self.conv = nn.Conv2d(
105
+ in_channels=in_channels,
106
+ out_channels=out_channels,
107
+ kernel_size=kernel_size,
108
+ stride=stride,
109
+ padding=padding,
110
+ dilation=dilation,
111
+ groups=groups,
112
+ bias=bias,
113
+ )
114
+
115
+ self.lora_down = nn.Conv2d(
116
+ in_channels=in_channels,
117
+ out_channels=r,
118
+ kernel_size=kernel_size,
119
+ stride=stride,
120
+ padding=padding,
121
+ dilation=dilation,
122
+ groups=groups,
123
+ bias=False,
124
+ )
125
+ self.dropout = nn.Dropout(dropout_p)
126
+ self.lora_up = nn.Conv2d(
127
+ in_channels=r,
128
+ out_channels=out_channels,
129
+ kernel_size=1,
130
+ stride=1,
131
+ padding=0,
132
+ bias=False,
133
+ )
134
+ self.selector = nn.Identity()
135
+ self.scale = scale
136
+
137
+ nn.init.normal_(self.lora_down.weight, std=1 / r)
138
+ nn.init.zeros_(self.lora_up.weight)
139
+
140
+ def forward(self, input):
141
+ return (
142
+ self.linear(input) + self.lora_up(self.selector(self.lora_down(input))) * self.scale
143
+ # + self.dropout(self.lora_up(self.selector(self.lora_down(input))))
144
+ # * self.scale
145
+ )
146
+
147
+ def realize_as_lora(self):
148
+ return self.lora_up.weight.data * self.scale, self.lora_down.weight.data
149
+
150
+ def set_selector_from_diag(self, diag: torch.Tensor):
151
+ # diag is a 1D tensor of size (r,)
152
+ assert diag.shape == (self.r,)
153
+ self.selector = nn.Conv2d(
154
+ in_channels=self.r,
155
+ out_channels=self.r,
156
+ kernel_size=1,
157
+ stride=1,
158
+ padding=0,
159
+ bias=False,
160
+ )
161
+ self.selector.weight.data = torch.diag(diag)
162
+
163
+ # same device + dtype as lora_up
164
+ self.selector.weight.data = self.selector.weight.data.to(
165
+ self.lora_up.weight.device
166
+ ).to(self.lora_up.weight.dtype)
167
+
168
+
169
+ UNET_DEFAULT_TARGET_REPLACE = {"CrossAttention", "Attention", "GEGLU"}
170
+
171
+ UNET_EXTENDED_TARGET_REPLACE = {"ResBlock", "CrossAttention", "Attention", "GEGLU"}
172
+
173
+ TEXT_ENCODER_DEFAULT_TARGET_REPLACE = {"CLIPAttention"}
174
+
175
+ TEXT_ENCODER_EXTENDED_TARGET_REPLACE = {"CLIPAttention"}
176
+
177
+ DEFAULT_TARGET_REPLACE = UNET_DEFAULT_TARGET_REPLACE
178
+
179
+ EMBED_FLAG = "<embed>"
180
+
181
+
182
+ def _find_children(
183
+ model,
184
+ search_class: List[Type[nn.Module]] = [nn.Linear],
185
+ ):
186
+ """
187
+ Find all modules of a certain class (or union of classes).
188
+ Returns all matching modules, along with the parent of those moduless and the
189
+ names they are referenced by.
190
+ """
191
+ result = []
192
+ for parent in model.modules():
193
+ for name, module in parent.named_children():
194
+ if any([isinstance(module, _class) for _class in search_class]):
195
+ result.append((parent, name, module)) # Append the result to the list
196
+
197
+ return result # Return the list instead of using 'yield'
198
+
199
+
200
+ def _find_modules_v2(
201
+ model,
202
+ ancestor_class: Optional[Set[str]] = None,
203
+ search_class: List[Type[nn.Module]] = [nn.Linear],
204
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [
205
+ LoraInjectedLinear,
206
+ LoraInjectedConv2d,
207
+ ],
208
+ ):
209
+ """
210
+ Find all modules of a certain class (or union of classes) that are direct or
211
+ indirect descendants of other modules of a certain class (or union of classes).
212
+ Returns all matching modules, along with the parent of those moduless and the
213
+ names they are referenced by.
214
+ """
215
+
216
+ # Get the targets we should replace all linears under
217
+ if ancestor_class is not None:
218
+ ancestors = (
219
+ module
220
+ for module in model.modules()
221
+ if module.__class__.__name__ in ancestor_class
222
+ )
223
+ else:
224
+ # this, incase you want to naively iterate over all modules.
225
+ ancestors = [module for module in model.modules()]
226
+
227
+ results = []
228
+ # For each target find every linear_class module that isn't a child of a LoraInjectedLinear
229
+ for ancestor in ancestors:
230
+ for fullname, module in ancestor.named_modules():
231
+ if any([isinstance(module, _class) for _class in search_class]):
232
+ # Find the direct parent if this is a descendant, not a child, of target
233
+ *path, name = fullname.split(".")
234
+ parent = ancestor
235
+ while path:
236
+ parent = parent.get_submodule(path.pop(0))
237
+ # Skip this linear if it's a child of a LoraInjectedLinear
238
+ if exclude_children_of and any(
239
+ [isinstance(parent, _class) for _class in exclude_children_of]
240
+ ):
241
+ continue
242
+ results.append((parent, name, module)) # Append the result to the list
243
+
244
+ return results # Return the list instead of using 'yield'
245
+
246
+ def _find_modules_old(
247
+ model,
248
+ ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
249
+ search_class: List[Type[nn.Module]] = [nn.Linear],
250
+ exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
251
+ ):
252
+ ret = []
253
+ for _module in model.modules():
254
+ if _module.__class__.__name__ in ancestor_class:
255
+
256
+ for name, _child_module in _module.named_modules():
257
+ if _child_module.__class__ in search_class:
258
+ ret.append((_module, name, _child_module))
259
+ # print(ret)
260
+ return ret
261
+
262
+
263
+ _find_modules = _find_modules_v2
264
+ # _find_modules = _find_modules_old
265
+
266
+ def inject_trainable_lora(
267
+ model: nn.Module,
268
+ target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
269
+ r: int = 4,
270
+ loras=None, # path to lora .pt
271
+ verbose: bool = False,
272
+ dropout_p: float = 0.0,
273
+ scale: float = 1.0,
274
+ ):
275
+ """
276
+ inject lora into model, and returns lora parameter groups.
277
+ """
278
+
279
+ require_grad_params = []
280
+ names = []
281
+
282
+ if loras != None:
283
+ loras = torch.load(loras)
284
+
285
+ for _module, name, _child_module in _find_modules(
286
+ model, target_replace_module, search_class=[nn.Linear]
287
+ ):
288
+
289
+ weight = _child_module.weight
290
+ bias = _child_module.bias
291
+ if verbose:
292
+ print("LoRA Injection : injecting lora into ", name)
293
+ print("LoRA Injection : weight shape", weight.shape)
294
+ _tmp = LoraInjectedLinear(
295
+ _child_module.in_features,
296
+ _child_module.out_features,
297
+ _child_module.bias is not None,
298
+ r=r,
299
+ dropout_p=dropout_p,
300
+ scale=scale,
301
+ )
302
+ _tmp.linear.weight = weight
303
+ if bias is not None:
304
+ _tmp.linear.bias = bias
305
+
306
+ # switch the module
307
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
308
+ _module._modules[name] = _tmp
309
+
310
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
311
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
312
+
313
+ if loras != None:
314
+ _module._modules[name].lora_up.weight = loras.pop(0)
315
+ _module._modules[name].lora_down.weight = loras.pop(0)
316
+
317
+ _module._modules[name].lora_up.weight.requires_grad = True
318
+ _module._modules[name].lora_down.weight.requires_grad = True
319
+ names.append(name)
320
+
321
+ return require_grad_params, names
322
+
323
+
324
+ def inject_trainable_lora_extended(
325
+ model: nn.Module,
326
+ target_replace_module: Set[str] = UNET_EXTENDED_TARGET_REPLACE,
327
+ r: int = 4,
328
+ loras=None, # path to lora .pt
329
+ ):
330
+ """
331
+ inject lora into model, and returns lora parameter groups.
332
+ """
333
+
334
+ require_grad_params = []
335
+ names = []
336
+
337
+ if loras != None:
338
+ loras = torch.load(loras)
339
+
340
+ for _module, name, _child_module in _find_modules(
341
+ model, target_replace_module, search_class=[nn.Linear, nn.Conv2d]
342
+ ):
343
+ if _child_module.__class__ == nn.Linear:
344
+ weight = _child_module.weight
345
+ bias = _child_module.bias
346
+ _tmp = LoraInjectedLinear(
347
+ _child_module.in_features,
348
+ _child_module.out_features,
349
+ _child_module.bias is not None,
350
+ r=r,
351
+ )
352
+ _tmp.linear.weight = weight
353
+ if bias is not None:
354
+ _tmp.linear.bias = bias
355
+ elif _child_module.__class__ == nn.Conv2d:
356
+ weight = _child_module.weight
357
+ bias = _child_module.bias
358
+ _tmp = LoraInjectedConv2d(
359
+ _child_module.in_channels,
360
+ _child_module.out_channels,
361
+ _child_module.kernel_size,
362
+ _child_module.stride,
363
+ _child_module.padding,
364
+ _child_module.dilation,
365
+ _child_module.groups,
366
+ _child_module.bias is not None,
367
+ r=r,
368
+ )
369
+
370
+ _tmp.conv.weight = weight
371
+ if bias is not None:
372
+ _tmp.conv.bias = bias
373
+
374
+ # switch the module
375
+ _tmp.to(_child_module.weight.device).to(_child_module.weight.dtype)
376
+ if bias is not None:
377
+ _tmp.to(_child_module.bias.device).to(_child_module.bias.dtype)
378
+
379
+ _module._modules[name] = _tmp
380
+
381
+ require_grad_params.append(_module._modules[name].lora_up.parameters())
382
+ require_grad_params.append(_module._modules[name].lora_down.parameters())
383
+
384
+ if loras != None:
385
+ _module._modules[name].lora_up.weight = loras.pop(0)
386
+ _module._modules[name].lora_down.weight = loras.pop(0)
387
+
388
+ _module._modules[name].lora_up.weight.requires_grad = True
389
+ _module._modules[name].lora_down.weight.requires_grad = True
390
+ names.append(name)
391
+
392
+ return require_grad_params, names
393
+
394
+
395
+ def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
396
+
397
+ loras = []
398
+
399
+ for _m, _n, _child_module in _find_modules(
400
+ model,
401
+ target_replace_module,
402
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d],
403
+ ):
404
+ loras.append((_child_module.lora_up, _child_module.lora_down))
405
+
406
+ if len(loras) == 0:
407
+ raise ValueError("No lora injected.")
408
+
409
+ return loras
410
+
411
+
412
+ def extract_lora_as_tensor(
413
+ model, target_replace_module=DEFAULT_TARGET_REPLACE, as_fp16=True
414
+ ):
415
+
416
+ loras = []
417
+
418
+ for _m, _n, _child_module in _find_modules(
419
+ model,
420
+ target_replace_module,
421
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d],
422
+ ):
423
+ up, down = _child_module.realize_as_lora()
424
+ if as_fp16:
425
+ up = up.to(torch.float16)
426
+ down = down.to(torch.float16)
427
+
428
+ loras.append((up, down))
429
+
430
+ if len(loras) == 0:
431
+ raise ValueError("No lora injected.")
432
+
433
+ return loras
434
+
435
+
436
+ def save_lora_weight(
437
+ model,
438
+ path="./lora.pt",
439
+ target_replace_module=DEFAULT_TARGET_REPLACE,
440
+ ):
441
+ weights = []
442
+ for _up, _down in extract_lora_ups_down(
443
+ model, target_replace_module=target_replace_module
444
+ ):
445
+ weights.append(_up.weight.to("cpu").to(torch.float16))
446
+ weights.append(_down.weight.to("cpu").to(torch.float16))
447
+
448
+ torch.save(weights, path)
449
+
450
+
451
+ def save_lora_as_json(model, path="./lora.json"):
452
+ weights = []
453
+ for _up, _down in extract_lora_ups_down(model):
454
+ weights.append(_up.weight.detach().cpu().numpy().tolist())
455
+ weights.append(_down.weight.detach().cpu().numpy().tolist())
456
+
457
+ import json
458
+
459
+ with open(path, "w") as f:
460
+ json.dump(weights, f)
461
+
462
+
463
+ def save_safeloras_with_embeds(
464
+ modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
465
+ embeds: Dict[str, torch.Tensor] = {},
466
+ outpath="./lora.safetensors",
467
+ ):
468
+ """
469
+ Saves the Lora from multiple modules in a single safetensor file.
470
+ modelmap is a dictionary of {
471
+ "module name": (module, target_replace_module)
472
+ }
473
+ """
474
+ weights = {}
475
+ metadata = {}
476
+
477
+ for name, (model, target_replace_module) in modelmap.items():
478
+ metadata[name] = json.dumps(list(target_replace_module))
479
+
480
+ for i, (_up, _down) in enumerate(
481
+ extract_lora_as_tensor(model, target_replace_module)
482
+ ):
483
+ rank = _down.shape[0]
484
+
485
+ metadata[f"{name}:{i}:rank"] = str(rank)
486
+ weights[f"{name}:{i}:up"] = _up
487
+ weights[f"{name}:{i}:down"] = _down
488
+
489
+ for token, tensor in embeds.items():
490
+ metadata[token] = EMBED_FLAG
491
+ weights[token] = tensor
492
+
493
+ print(f"Saving weights to {outpath}")
494
+ safe_save(weights, outpath, metadata)
495
+
496
+
497
+ def save_safeloras(
498
+ modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
499
+ outpath="./lora.safetensors",
500
+ ):
501
+ return save_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
502
+
503
+
504
+ def convert_loras_to_safeloras_with_embeds(
505
+ modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
506
+ embeds: Dict[str, torch.Tensor] = {},
507
+ outpath="./lora.safetensors",
508
+ ):
509
+ """
510
+ Converts the Lora from multiple pytorch .pt files into a single safetensor file.
511
+ modelmap is a dictionary of {
512
+ "module name": (pytorch_model_path, target_replace_module, rank)
513
+ }
514
+ """
515
+
516
+ weights = {}
517
+ metadata = {}
518
+
519
+ for name, (path, target_replace_module, r) in modelmap.items():
520
+ metadata[name] = json.dumps(list(target_replace_module))
521
+
522
+ lora = torch.load(path)
523
+ for i, weight in enumerate(lora):
524
+ is_up = i % 2 == 0
525
+ i = i // 2
526
+
527
+ if is_up:
528
+ metadata[f"{name}:{i}:rank"] = str(r)
529
+ weights[f"{name}:{i}:up"] = weight
530
+ else:
531
+ weights[f"{name}:{i}:down"] = weight
532
+
533
+ for token, tensor in embeds.items():
534
+ metadata[token] = EMBED_FLAG
535
+ weights[token] = tensor
536
+
537
+ print(f"Saving weights to {outpath}")
538
+ safe_save(weights, outpath, metadata)
539
+
540
+
541
+ def convert_loras_to_safeloras(
542
+ modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
543
+ outpath="./lora.safetensors",
544
+ ):
545
+ convert_loras_to_safeloras_with_embeds(modelmap=modelmap, outpath=outpath)
546
+
547
+
548
+ def parse_safeloras(
549
+ safeloras,
550
+ ) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
551
+ """
552
+ Converts a loaded safetensor file that contains a set of module Loras
553
+ into Parameters and other information
554
+ Output is a dictionary of {
555
+ "module name": (
556
+ [list of weights],
557
+ [list of ranks],
558
+ target_replacement_modules
559
+ )
560
+ }
561
+ """
562
+ loras = {}
563
+ metadata = safeloras.metadata()
564
+
565
+ get_name = lambda k: k.split(":")[0]
566
+
567
+ keys = list(safeloras.keys())
568
+ keys.sort(key=get_name)
569
+
570
+ for name, module_keys in groupby(keys, get_name):
571
+ info = metadata.get(name)
572
+
573
+ if not info:
574
+ raise ValueError(
575
+ f"Tensor {name} has no metadata - is this a Lora safetensor?"
576
+ )
577
+
578
+ # Skip Textual Inversion embeds
579
+ if info == EMBED_FLAG:
580
+ continue
581
+
582
+ # Handle Loras
583
+ # Extract the targets
584
+ target = json.loads(info)
585
+
586
+ # Build the result lists - Python needs us to preallocate lists to insert into them
587
+ module_keys = list(module_keys)
588
+ ranks = [4] * (len(module_keys) // 2)
589
+ weights = [None] * len(module_keys)
590
+
591
+ for key in module_keys:
592
+ # Split the model name and index out of the key
593
+ _, idx, direction = key.split(":")
594
+ idx = int(idx)
595
+
596
+ # Add the rank
597
+ ranks[idx] = int(metadata[f"{name}:{idx}:rank"])
598
+
599
+ # Insert the weight into the list
600
+ idx = idx * 2 + (1 if direction == "down" else 0)
601
+ weights[idx] = nn.parameter.Parameter(safeloras.get_tensor(key))
602
+
603
+ loras[name] = (weights, ranks, target)
604
+
605
+ return loras
606
+
607
+
608
+ def parse_safeloras_embeds(
609
+ safeloras,
610
+ ) -> Dict[str, torch.Tensor]:
611
+ """
612
+ Converts a loaded safetensor file that contains Textual Inversion embeds into
613
+ a dictionary of embed_token: Tensor
614
+ """
615
+ embeds = {}
616
+ metadata = safeloras.metadata()
617
+
618
+ for key in safeloras.keys():
619
+ # Only handle Textual Inversion embeds
620
+ meta = metadata.get(key)
621
+ if not meta or meta != EMBED_FLAG:
622
+ continue
623
+
624
+ embeds[key] = safeloras.get_tensor(key)
625
+
626
+ return embeds
627
+
628
+
629
+ def load_safeloras(path, device="cpu"):
630
+ safeloras = safe_open(path, framework="pt", device=device)
631
+ return parse_safeloras(safeloras)
632
+
633
+
634
+ def load_safeloras_embeds(path, device="cpu"):
635
+ safeloras = safe_open(path, framework="pt", device=device)
636
+ return parse_safeloras_embeds(safeloras)
637
+
638
+
639
+ def load_safeloras_both(path, device="cpu"):
640
+ safeloras = safe_open(path, framework="pt", device=device)
641
+ return parse_safeloras(safeloras), parse_safeloras_embeds(safeloras)
642
+
643
+
644
+ def collapse_lora(model, alpha=1.0):
645
+
646
+ for _module, name, _child_module in _find_modules(
647
+ model,
648
+ UNET_EXTENDED_TARGET_REPLACE | TEXT_ENCODER_EXTENDED_TARGET_REPLACE,
649
+ search_class=[LoraInjectedLinear, LoraInjectedConv2d],
650
+ ):
651
+
652
+ if isinstance(_child_module, LoraInjectedLinear):
653
+ print("Collapsing Lin Lora in", name)
654
+
655
+ _child_module.linear.weight = nn.Parameter(
656
+ _child_module.linear.weight.data
657
+ + alpha
658
+ * (
659
+ _child_module.lora_up.weight.data
660
+ @ _child_module.lora_down.weight.data
661
+ )
662
+ .type(_child_module.linear.weight.dtype)
663
+ .to(_child_module.linear.weight.device)
664
+ )
665
+
666
+ else:
667
+ print("Collapsing Conv Lora in", name)
668
+ _child_module.conv.weight = nn.Parameter(
669
+ _child_module.conv.weight.data
670
+ + alpha
671
+ * (
672
+ _child_module.lora_up.weight.data.flatten(start_dim=1)
673
+ @ _child_module.lora_down.weight.data.flatten(start_dim=1)
674
+ )
675
+ .reshape(_child_module.conv.weight.data.shape)
676
+ .type(_child_module.conv.weight.dtype)
677
+ .to(_child_module.conv.weight.device)
678
+ )
679
+
680
+
681
+ def monkeypatch_or_replace_lora(
682
+ model,
683
+ loras,
684
+ target_replace_module=DEFAULT_TARGET_REPLACE,
685
+ r: Union[int, List[int]] = 4,
686
+ ):
687
+ for _module, name, _child_module in _find_modules(
688
+ model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
689
+ ):
690
+ _source = (
691
+ _child_module.linear
692
+ if isinstance(_child_module, LoraInjectedLinear)
693
+ else _child_module
694
+ )
695
+
696
+ weight = _source.weight
697
+ bias = _source.bias
698
+ _tmp = LoraInjectedLinear(
699
+ _source.in_features,
700
+ _source.out_features,
701
+ _source.bias is not None,
702
+ r=r.pop(0) if isinstance(r, list) else r,
703
+ )
704
+ _tmp.linear.weight = weight
705
+
706
+ if bias is not None:
707
+ _tmp.linear.bias = bias
708
+
709
+ # switch the module
710
+ _module._modules[name] = _tmp
711
+
712
+ up_weight = loras.pop(0)
713
+ down_weight = loras.pop(0)
714
+
715
+ _module._modules[name].lora_up.weight = nn.Parameter(
716
+ up_weight.type(weight.dtype)
717
+ )
718
+ _module._modules[name].lora_down.weight = nn.Parameter(
719
+ down_weight.type(weight.dtype)
720
+ )
721
+
722
+ _module._modules[name].to(weight.device)
723
+
724
+
725
+ def monkeypatch_or_replace_lora_extended(
726
+ model,
727
+ loras,
728
+ target_replace_module=DEFAULT_TARGET_REPLACE,
729
+ r: Union[int, List[int]] = 4,
730
+ ):
731
+ for _module, name, _child_module in _find_modules(
732
+ model,
733
+ target_replace_module,
734
+ search_class=[nn.Linear, LoraInjectedLinear, nn.Conv2d, LoraInjectedConv2d],
735
+ ):
736
+
737
+ if (_child_module.__class__ == nn.Linear) or (
738
+ _child_module.__class__ == LoraInjectedLinear
739
+ ):
740
+ if len(loras[0].shape) != 2:
741
+ continue
742
+
743
+ _source = (
744
+ _child_module.linear
745
+ if isinstance(_child_module, LoraInjectedLinear)
746
+ else _child_module
747
+ )
748
+
749
+ weight = _source.weight
750
+ bias = _source.bias
751
+ _tmp = LoraInjectedLinear(
752
+ _source.in_features,
753
+ _source.out_features,
754
+ _source.bias is not None,
755
+ r=r.pop(0) if isinstance(r, list) else r,
756
+ )
757
+ _tmp.linear.weight = weight
758
+
759
+ if bias is not None:
760
+ _tmp.linear.bias = bias
761
+
762
+ elif (_child_module.__class__ == nn.Conv2d) or (
763
+ _child_module.__class__ == LoraInjectedConv2d
764
+ ):
765
+ if len(loras[0].shape) != 4:
766
+ continue
767
+ _source = (
768
+ _child_module.conv
769
+ if isinstance(_child_module, LoraInjectedConv2d)
770
+ else _child_module
771
+ )
772
+
773
+ weight = _source.weight
774
+ bias = _source.bias
775
+ _tmp = LoraInjectedConv2d(
776
+ _source.in_channels,
777
+ _source.out_channels,
778
+ _source.kernel_size,
779
+ _source.stride,
780
+ _source.padding,
781
+ _source.dilation,
782
+ _source.groups,
783
+ _source.bias is not None,
784
+ r=r.pop(0) if isinstance(r, list) else r,
785
+ )
786
+
787
+ _tmp.conv.weight = weight
788
+
789
+ if bias is not None:
790
+ _tmp.conv.bias = bias
791
+
792
+ # switch the module
793
+ _module._modules[name] = _tmp
794
+
795
+ up_weight = loras.pop(0)
796
+ down_weight = loras.pop(0)
797
+
798
+ _module._modules[name].lora_up.weight = nn.Parameter(
799
+ up_weight.type(weight.dtype)
800
+ )
801
+ _module._modules[name].lora_down.weight = nn.Parameter(
802
+ down_weight.type(weight.dtype)
803
+ )
804
+
805
+ _module._modules[name].to(weight.device)
806
+
807
+
808
+ def monkeypatch_or_replace_safeloras(models, safeloras):
809
+ loras = parse_safeloras(safeloras)
810
+
811
+ for name, (lora, ranks, target) in loras.items():
812
+ model = getattr(models, name, None)
813
+
814
+ if not model:
815
+ print(f"No model provided for {name}, contained in Lora")
816
+ continue
817
+
818
+ monkeypatch_or_replace_lora_extended(model, lora, target, ranks)
819
+
820
+
821
+ def monkeypatch_remove_lora(model):
822
+ for _module, name, _child_module in _find_modules(
823
+ model, search_class=[LoraInjectedLinear, LoraInjectedConv2d]
824
+ ):
825
+ if isinstance(_child_module, LoraInjectedLinear):
826
+ _source = _child_module.linear
827
+ weight, bias = _source.weight, _source.bias
828
+
829
+ _tmp = nn.Linear(
830
+ _source.in_features, _source.out_features, bias is not None
831
+ )
832
+
833
+ _tmp.weight = weight
834
+ if bias is not None:
835
+ _tmp.bias = bias
836
+
837
+ else:
838
+ _source = _child_module.conv
839
+ weight, bias = _source.weight, _source.bias
840
+
841
+ _tmp = nn.Conv2d(
842
+ in_channels=_source.in_channels,
843
+ out_channels=_source.out_channels,
844
+ kernel_size=_source.kernel_size,
845
+ stride=_source.stride,
846
+ padding=_source.padding,
847
+ dilation=_source.dilation,
848
+ groups=_source.groups,
849
+ bias=bias is not None,
850
+ )
851
+
852
+ _tmp.weight = weight
853
+ if bias is not None:
854
+ _tmp.bias = bias
855
+
856
+ _module._modules[name] = _tmp
857
+
858
+
859
+ def monkeypatch_add_lora(
860
+ model,
861
+ loras,
862
+ target_replace_module=DEFAULT_TARGET_REPLACE,
863
+ alpha: float = 1.0,
864
+ beta: float = 1.0,
865
+ ):
866
+ for _module, name, _child_module in _find_modules(
867
+ model, target_replace_module, search_class=[LoraInjectedLinear]
868
+ ):
869
+ weight = _child_module.linear.weight
870
+
871
+ up_weight = loras.pop(0)
872
+ down_weight = loras.pop(0)
873
+
874
+ _module._modules[name].lora_up.weight = nn.Parameter(
875
+ up_weight.type(weight.dtype).to(weight.device) * alpha
876
+ + _module._modules[name].lora_up.weight.to(weight.device) * beta
877
+ )
878
+ _module._modules[name].lora_down.weight = nn.Parameter(
879
+ down_weight.type(weight.dtype).to(weight.device) * alpha
880
+ + _module._modules[name].lora_down.weight.to(weight.device) * beta
881
+ )
882
+
883
+ _module._modules[name].to(weight.device)
884
+
885
+
886
+ def tune_lora_scale(model, alpha: float = 1.0):
887
+ for _module in model.modules():
888
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
889
+ _module.scale = alpha
890
+
891
+
892
+ def set_lora_diag(model, diag: torch.Tensor):
893
+ for _module in model.modules():
894
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
895
+ _module.set_selector_from_diag(diag)
896
+
897
+
898
+ def _text_lora_path(path: str) -> str:
899
+ assert path.endswith(".pt"), "Only .pt files are supported"
900
+ return ".".join(path.split(".")[:-1] + ["text_encoder", "pt"])
901
+
902
+
903
+ def _ti_lora_path(path: str) -> str:
904
+ assert path.endswith(".pt"), "Only .pt files are supported"
905
+ return ".".join(path.split(".")[:-1] + ["ti", "pt"])
906
+
907
+
908
+ def apply_learned_embed_in_clip(
909
+ learned_embeds,
910
+ text_encoder,
911
+ tokenizer,
912
+ token: Optional[Union[str, List[str]]] = None,
913
+ idempotent=False,
914
+ ):
915
+ if isinstance(token, str):
916
+ trained_tokens = [token]
917
+ elif isinstance(token, list):
918
+ assert len(learned_embeds.keys()) == len(
919
+ token
920
+ ), "The number of tokens and the number of embeds should be the same"
921
+ trained_tokens = token
922
+ else:
923
+ trained_tokens = list(learned_embeds.keys())
924
+
925
+ for token in trained_tokens:
926
+ print(token)
927
+ embeds = learned_embeds[token]
928
+
929
+ # cast to dtype of text_encoder
930
+ dtype = text_encoder.get_input_embeddings().weight.dtype
931
+ num_added_tokens = tokenizer.add_tokens(token)
932
+
933
+ i = 1
934
+ if not idempotent:
935
+ while num_added_tokens == 0:
936
+ print(f"The tokenizer already contains the token {token}.")
937
+ token = f"{token[:-1]}-{i}>"
938
+ print(f"Attempting to add the token {token}.")
939
+ num_added_tokens = tokenizer.add_tokens(token)
940
+ i += 1
941
+ elif num_added_tokens == 0 and idempotent:
942
+ print(f"The tokenizer already contains the token {token}.")
943
+ print(f"Replacing {token} embedding.")
944
+
945
+ # resize the token embeddings
946
+ text_encoder.resize_token_embeddings(len(tokenizer))
947
+
948
+ # get the id for the token and assign the embeds
949
+ token_id = tokenizer.convert_tokens_to_ids(token)
950
+ text_encoder.get_input_embeddings().weight.data[token_id] = embeds
951
+ return token
952
+
953
+
954
+ def load_learned_embed_in_clip(
955
+ learned_embeds_path,
956
+ text_encoder,
957
+ tokenizer,
958
+ token: Optional[Union[str, List[str]]] = None,
959
+ idempotent=False,
960
+ ):
961
+ learned_embeds = torch.load(learned_embeds_path)
962
+ apply_learned_embed_in_clip(
963
+ learned_embeds, text_encoder, tokenizer, token, idempotent
964
+ )
965
+
966
+
967
+ def patch_pipe(
968
+ pipe,
969
+ maybe_unet_path,
970
+ token: Optional[str] = None,
971
+ r: int = 4,
972
+ patch_unet=True,
973
+ patch_text=True,
974
+ patch_ti=True,
975
+ idempotent_token=True,
976
+ unet_target_replace_module=DEFAULT_TARGET_REPLACE,
977
+ text_target_replace_module=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
978
+ ):
979
+ if maybe_unet_path.endswith(".pt"):
980
+ # torch format
981
+
982
+ if maybe_unet_path.endswith(".ti.pt"):
983
+ unet_path = maybe_unet_path[:-6] + ".pt"
984
+ elif maybe_unet_path.endswith(".text_encoder.pt"):
985
+ unet_path = maybe_unet_path[:-16] + ".pt"
986
+ else:
987
+ unet_path = maybe_unet_path
988
+
989
+ ti_path = _ti_lora_path(unet_path)
990
+ text_path = _text_lora_path(unet_path)
991
+
992
+ if patch_unet:
993
+ print("LoRA : Patching Unet")
994
+ monkeypatch_or_replace_lora(
995
+ pipe.unet,
996
+ torch.load(unet_path),
997
+ r=r,
998
+ target_replace_module=unet_target_replace_module,
999
+ )
1000
+
1001
+ if patch_text:
1002
+ print("LoRA : Patching text encoder")
1003
+ monkeypatch_or_replace_lora(
1004
+ pipe.text_encoder,
1005
+ torch.load(text_path),
1006
+ target_replace_module=text_target_replace_module,
1007
+ r=r,
1008
+ )
1009
+ if patch_ti:
1010
+ print("LoRA : Patching token input")
1011
+ token = load_learned_embed_in_clip(
1012
+ ti_path,
1013
+ pipe.text_encoder,
1014
+ pipe.tokenizer,
1015
+ token=token,
1016
+ idempotent=idempotent_token,
1017
+ )
1018
+
1019
+ elif maybe_unet_path.endswith(".safetensors"):
1020
+ safeloras = safe_open(maybe_unet_path, framework="pt", device="cpu")
1021
+ monkeypatch_or_replace_safeloras(pipe, safeloras)
1022
+ tok_dict = parse_safeloras_embeds(safeloras)
1023
+ if patch_ti:
1024
+ apply_learned_embed_in_clip(
1025
+ tok_dict,
1026
+ pipe.text_encoder,
1027
+ pipe.tokenizer,
1028
+ token=token,
1029
+ idempotent=idempotent_token,
1030
+ )
1031
+ return tok_dict
1032
+
1033
+
1034
+ @torch.no_grad()
1035
+ def inspect_lora(model):
1036
+ moved = {}
1037
+
1038
+ for name, _module in model.named_modules():
1039
+ if _module.__class__.__name__ in ["LoraInjectedLinear", "LoraInjectedConv2d"]:
1040
+ ups = _module.lora_up.weight.data.clone()
1041
+ downs = _module.lora_down.weight.data.clone()
1042
+
1043
+ wght: torch.Tensor = ups.flatten(1) @ downs.flatten(1)
1044
+
1045
+ dist = wght.flatten().abs().mean().item()
1046
+ if name in moved:
1047
+ moved[name].append(dist)
1048
+ else:
1049
+ moved[name] = [dist]
1050
+
1051
+ return moved
1052
+
1053
+
1054
+ def save_all(
1055
+ unet,
1056
+ text_encoder,
1057
+ save_path,
1058
+ placeholder_token_ids=None,
1059
+ placeholder_tokens=None,
1060
+ save_lora=True,
1061
+ save_ti=True,
1062
+ target_replace_module_text=TEXT_ENCODER_DEFAULT_TARGET_REPLACE,
1063
+ target_replace_module_unet=DEFAULT_TARGET_REPLACE,
1064
+ safe_form=True,
1065
+ ):
1066
+ if not safe_form:
1067
+ # save ti
1068
+ if save_ti:
1069
+ ti_path = _ti_lora_path(save_path)
1070
+ learned_embeds_dict = {}
1071
+ for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
1072
+ learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
1073
+ print(
1074
+ f"Current Learned Embeddings for {tok}:, id {tok_id} ",
1075
+ learned_embeds[:4],
1076
+ )
1077
+ learned_embeds_dict[tok] = learned_embeds.detach().cpu()
1078
+
1079
+ torch.save(learned_embeds_dict, ti_path)
1080
+ print("Ti saved to ", ti_path)
1081
+
1082
+ # save text encoder
1083
+ if save_lora:
1084
+
1085
+ save_lora_weight(
1086
+ unet, save_path, target_replace_module=target_replace_module_unet
1087
+ )
1088
+ print("Unet saved to ", save_path)
1089
+
1090
+ save_lora_weight(
1091
+ text_encoder,
1092
+ _text_lora_path(save_path),
1093
+ target_replace_module=target_replace_module_text,
1094
+ )
1095
+ print("Text Encoder saved to ", _text_lora_path(save_path))
1096
+
1097
+ else:
1098
+ assert save_path.endswith(
1099
+ ".safetensors"
1100
+ ), f"Save path : {save_path} should end with .safetensors"
1101
+
1102
+ loras = {}
1103
+ embeds = {}
1104
+
1105
+ if save_lora:
1106
+
1107
+ loras["unet"] = (unet, target_replace_module_unet)
1108
+ loras["text_encoder"] = (text_encoder, target_replace_module_text)
1109
+
1110
+ if save_ti:
1111
+ for tok, tok_id in zip(placeholder_tokens, placeholder_token_ids):
1112
+ learned_embeds = text_encoder.get_input_embeddings().weight[tok_id]
1113
+ print(
1114
+ f"Current Learned Embeddings for {tok}:, id {tok_id} ",
1115
+ learned_embeds[:4],
1116
+ )
1117
+ embeds[tok] = learned_embeds.detach().cpu()
1118
+
1119
+ save_safeloras_with_embeds(loras, embeds, save_path)
generation/control/oldm/lora_ldm.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import torch
3
+ import torch as th
4
+ import torch.nn as nn
5
+
6
+ import os
7
+ import sys
8
+
9
+ from ldm.modules.diffusionmodules.util import (
10
+ conv_nd,
11
+ linear,
12
+ zero_module,
13
+ timestep_embedding,
14
+ )
15
+
16
+ from einops import rearrange, repeat
17
+ from torchvision.utils import make_grid
18
+ from ldm.modules.attention import SpatialTransformer
19
+ from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Upsample, Downsample, AttentionBlock, normalization
20
+ from ldm.models.diffusion.ddpm import LatentDiffusion
21
+ from ldm.util import log_txt_as_img, exists, instantiate_from_config
22
+ from ldm.models.diffusion.ddim import DDIMSampler
23
+
24
+ from cldm.lora import inject_trainable_lora, extract_lora_ups_down, inject_trainable_lora_extended
25
+
26
+ def count_parameters(params):
27
+ num_params = sum(p.numel() for p in params)
28
+ return round(num_params / 1e6, 1)
29
+
30
+ def set_requires_grad(model, requires_grad=True):
31
+ for param in model.parameters():
32
+ param.requires_grad = requires_grad
33
+
34
+ class ControlledUnetModel(UNetModel):
35
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
36
+ hs = []
37
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
38
+ emb = self.time_embed(t_emb)
39
+
40
+ h = x.type(self.dtype)
41
+ for module in self.input_blocks:
42
+ if control is not None:
43
+ h = module(h, emb, context)
44
+ h += control
45
+ control = None
46
+ else:
47
+ h = module(h, emb, context)
48
+ hs.append(h)
49
+ h = self.middle_block(h, emb, context)
50
+ for module in self.output_blocks:
51
+ h = th.cat([h, hs.pop()], dim=1)
52
+ h = module(h, emb, context)
53
+ h = h.type(x.dtype)
54
+
55
+ return self.out(h)
56
+
57
+
58
+ class ControlNet(nn.Module):
59
+ def __init__(
60
+ self,
61
+ image_size,
62
+ in_channels,
63
+ model_channels,
64
+ out_channels,
65
+ hint_channels,
66
+ num_res_blocks,
67
+ attention_resolutions,
68
+ dropout=0,
69
+ channel_mult=(1, 2, 4, 8),
70
+ conv_resample=True,
71
+ dims=2,
72
+ use_checkpoint=False,
73
+ use_fp16=False,
74
+ num_heads=-1,
75
+ num_head_channels=-1,
76
+ num_heads_upsample=-1,
77
+ use_scale_shift_norm=False,
78
+ resblock_updown=False,
79
+ use_new_attention_order=False,
80
+ use_spatial_transformer=False, # custom transformer support
81
+ transformer_depth=1, # custom transformer support
82
+ context_dim=None, # custom transformer support
83
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
84
+ legacy=True,
85
+ disable_self_attentions=None,
86
+ num_attention_blocks=None,
87
+ disable_middle_self_attn=False,
88
+ use_linear_in_transformer=False,
89
+ ):
90
+ super().__init__()
91
+ if use_spatial_transformer:
92
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
93
+
94
+ if context_dim is not None:
95
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
96
+ from omegaconf.listconfig import ListConfig
97
+ if type(context_dim) == ListConfig:
98
+ context_dim = list(context_dim)
99
+
100
+ if num_heads_upsample == -1:
101
+ num_heads_upsample = num_heads
102
+
103
+ if num_heads == -1:
104
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
105
+
106
+ if num_head_channels == -1:
107
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
108
+
109
+ self.dims = dims
110
+ self.image_size = image_size
111
+ self.in_channels = in_channels
112
+ self.model_channels = model_channels
113
+ if isinstance(num_res_blocks, int):
114
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
115
+ else:
116
+ if len(num_res_blocks) != len(channel_mult):
117
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
118
+ "as a list/tuple (per-level) with the same length as channel_mult")
119
+ self.num_res_blocks = num_res_blocks
120
+ if disable_self_attentions is not None:
121
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
122
+ assert len(disable_self_attentions) == len(channel_mult)
123
+ if num_attention_blocks is not None:
124
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
125
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
126
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
127
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
128
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
129
+ f"attention will still not be set.")
130
+
131
+ self.attention_resolutions = attention_resolutions
132
+ self.dropout = dropout
133
+ self.channel_mult = channel_mult
134
+ self.conv_resample = conv_resample
135
+ self.use_checkpoint = use_checkpoint
136
+ self.dtype = th.float16 if use_fp16 else th.float32
137
+ self.num_heads = num_heads
138
+ self.num_head_channels = num_head_channels
139
+ self.num_heads_upsample = num_heads_upsample
140
+ self.predict_codebook_ids = n_embed is not None
141
+
142
+ time_embed_dim = model_channels * 4
143
+ self.time_embed = nn.Sequential(
144
+ linear(model_channels, time_embed_dim),
145
+ nn.SiLU(),
146
+ linear(time_embed_dim, time_embed_dim),
147
+ )
148
+
149
+ self.input_hint_block = TimestepEmbedSequential(
150
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
151
+ nn.SiLU(),
152
+ conv_nd(dims, 16, 16, 3, padding=1),
153
+ nn.SiLU(),
154
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
155
+ nn.SiLU(),
156
+ conv_nd(dims, 32, 32, 3, padding=1),
157
+ nn.SiLU(),
158
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
159
+ nn.SiLU(),
160
+ conv_nd(dims, 96, 96, 3, padding=1),
161
+ nn.SiLU(),
162
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
163
+ nn.SiLU(),
164
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
165
+ )
166
+
167
+
168
+ def forward(self, x, hint, timesteps, context, **kwargs):
169
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
170
+ emb = self.time_embed(t_emb)
171
+
172
+ guided_hint = self.input_hint_block(hint, emb, context)
173
+
174
+ # print('guided_hint', len(guided_hint), guided_hint[0].shape)
175
+ # sys.exit()
176
+
177
+ return guided_hint
178
+
179
+
180
+ class ControlLDM(LatentDiffusion):
181
+ def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
182
+ super().__init__(*args, **kwargs)
183
+ self.control_model = instantiate_from_config(control_stage_config)
184
+ self.control_key = control_key
185
+ self.only_mid_control = only_mid_control
186
+ self.control_scales = [1.0] * 13
187
+
188
+ @torch.no_grad()
189
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
190
+ x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
191
+ control = batch[self.control_key]
192
+ if bs is not None:
193
+ control = control[:bs]
194
+ control = control.to(self.device)
195
+ control = einops.rearrange(control, 'b h w c -> b c h w')
196
+ control = control.to(memory_format=torch.contiguous_format).float()
197
+ return x, dict(c_crossattn=[c], c_concat=[control])
198
+
199
+ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
200
+ assert isinstance(cond, dict)
201
+ diffusion_model = self.model.diffusion_model
202
+
203
+ cond_txt = torch.cat(cond['c_crossattn'], 1)
204
+
205
+ if cond['c_concat'] is None:
206
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
207
+ else:
208
+ control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
209
+ # control = [c * scale for c, scale in zip(control, self.control_scales)]
210
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
211
+
212
+ return eps
213
+
214
+ @torch.no_grad()
215
+ def get_unconditional_conditioning(self, N):
216
+ return self.get_learned_conditioning([""] * N)
217
+
218
+ @torch.no_grad()
219
+ def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
220
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
221
+ plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
222
+ use_ema_scope=True, num_samples=1,
223
+ **kwargs):
224
+ use_ddim = ddim_steps is not None
225
+
226
+ log = dict()
227
+ z, c = self.get_input(batch, self.first_stage_key, bs=N)
228
+ c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
229
+ N = min(z.shape[0], N)
230
+ n_row = min(z.shape[0], n_row)
231
+ log["reconstruction"] = self.decode_first_stage(z)
232
+ log["control"] = c_cat * 2.0 - 1.0
233
+ log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
234
+
235
+ if plot_diffusion_rows:
236
+ # get diffusion row
237
+ diffusion_row = list()
238
+ z_start = z[:n_row]
239
+ for t in range(self.num_timesteps):
240
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
241
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
242
+ t = t.to(self.device).long()
243
+ noise = torch.randn_like(z_start)
244
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
245
+ diffusion_row.append(self.decode_first_stage(z_noisy))
246
+
247
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
248
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
249
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
250
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
251
+ log["diffusion_row"] = diffusion_grid
252
+
253
+ if sample:
254
+ # get denoise row
255
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
256
+ batch_size=N, ddim=use_ddim,
257
+ ddim_steps=ddim_steps, eta=ddim_eta)
258
+ x_samples = self.decode_first_stage(samples)
259
+ log["samples"] = x_samples
260
+ if plot_denoise_rows:
261
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
262
+ log["denoise_row"] = denoise_grid
263
+
264
+ if kwargs['split'] == 'train':
265
+ if unconditional_guidance_scale > 1.0:
266
+ uc_cross = self.get_unconditional_conditioning(N)
267
+ uc_cat = c_cat # torch.zeros_like(c_cat)
268
+
269
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
270
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
271
+ batch_size=N, ddim=use_ddim,
272
+ ddim_steps=ddim_steps, eta=ddim_eta,
273
+ unconditional_guidance_scale=unconditional_guidance_scale,
274
+ unconditional_conditioning=uc_full,
275
+ )
276
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
277
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
278
+
279
+ else:
280
+ if unconditional_guidance_scale > 1.0:
281
+ # uc_cross = self.get_unconditional_conditioning(N)
282
+ # uc_cat = c_cat # torch.zeros_like(c_cat)
283
+
284
+ c_cat = torch.stack([c_cat[0] for _ in range(num_samples)], dim=0).clone()
285
+
286
+ cond = {"c_concat": [c_cat], "c_crossattn": [self.get_learned_conditioning([batch['txt'][0]] * num_samples)]}
287
+ uc_full = {"c_concat": [c_cat], "c_crossattn": [self.get_learned_conditioning([''] * num_samples)]}
288
+
289
+ samples_cfg, _ = self.sample_log(cond=cond, # cond={"c_concat": [c_cat], "c_crossattn": [c]},
290
+ batch_size=num_samples, ddim=use_ddim,
291
+ ddim_steps=ddim_steps, eta=ddim_eta,
292
+ unconditional_guidance_scale=unconditional_guidance_scale,
293
+ unconditional_conditioning=uc_full,
294
+ )
295
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
296
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
297
+
298
+ return log
299
+
300
+ @torch.no_grad()
301
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
302
+ ddim_sampler = DDIMSampler(self)
303
+ b, c, h, w = cond["c_concat"][0].shape
304
+ shape = (self.channels, h // 8, w // 8)
305
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
306
+ return samples, intermediates
307
+
308
+ def configure_optimizers(self):
309
+ lr = self.learning_rate
310
+ params = list(self.control_model.parameters())
311
+ names = []
312
+ for name, param in self.model.diffusion_model.named_parameters():
313
+ if param.requires_grad:
314
+ params.append(param)
315
+ names.append(name)
316
+
317
+ # params += self.unet_lora_params
318
+ if not self.sd_locked:
319
+ params += list(self.model.diffusion_model.output_blocks.parameters())
320
+ params += list(self.model.diffusion_model.out.parameters())
321
+ opt = torch.optim.AdamW(params, lr=lr)
322
+
323
+ set_requires_grad(self.model.diffusion_model, True)
324
+
325
+ num_params = count_parameters(params)
326
+ print()
327
+ print()
328
+ print(f"Total number of trainable parameters: {num_params} M!")
329
+ print()
330
+ print()
331
+ return opt
332
+
333
+ def low_vram_shift(self, is_diffusing):
334
+ if is_diffusing:
335
+ self.model = self.model.cuda()
336
+ self.control_model = self.control_model.cuda()
337
+ self.first_stage_model = self.first_stage_model.cpu()
338
+ self.cond_stage_model = self.cond_stage_model.cpu()
339
+ else:
340
+ self.model = self.model.cpu()
341
+ self.control_model = self.control_model.cpu()
342
+ self.first_stage_model = self.first_stage_model.cuda()
343
+ self.cond_stage_model = self.cond_stage_model.cuda()
generation/control/oldm/model.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ from omegaconf import OmegaConf
5
+ from ldm.util import instantiate_from_config
6
+
7
+
8
+ def get_state_dict(d):
9
+ return d.get('state_dict', d)
10
+
11
+
12
+ def load_state_dict(ckpt_path, location='cpu'):
13
+ _, extension = os.path.splitext(ckpt_path)
14
+ if extension.lower() == ".safetensors":
15
+ import safetensors.torch
16
+ state_dict = safetensors.torch.load_file(ckpt_path, device=location)
17
+ else:
18
+ state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
19
+ state_dict = get_state_dict(state_dict)
20
+ print(f'Loaded state_dict from [{ckpt_path}]')
21
+ return state_dict
22
+
23
+
24
+ def create_model(config_path):
25
+ config = OmegaConf.load(config_path)
26
+ model = instantiate_from_config(config.model).cpu()
27
+ print(f'Loaded model config from [{config_path}]')
28
+ return model
generation/control/oldm/oft_ldm.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import einops
2
+ import torch
3
+ import torch as th
4
+ import torch.nn as nn
5
+
6
+ import os
7
+ import sys
8
+
9
+ from ldm.modules.diffusionmodules.util import (
10
+ conv_nd,
11
+ linear,
12
+ zero_module,
13
+ timestep_embedding,
14
+ )
15
+
16
+ from einops import rearrange, repeat
17
+ from torchvision.utils import make_grid
18
+ from ldm.modules.attention import SpatialTransformer
19
+ from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Upsample, Downsample, AttentionBlock, normalization
20
+ from ldm.models.diffusion.ddpm import LatentDiffusion
21
+ from ldm.util import log_txt_as_img, exists, instantiate_from_config
22
+ from ldm.models.diffusion.ddim import DDIMSampler
23
+
24
+
25
+ def count_parameters(params):
26
+ num_params = 0
27
+ for p in params:
28
+ shape = p.shape
29
+ if len(shape) == 3 and shape[1] == shape[2]:
30
+ N, D, _ = shape
31
+ num_params += N * D * (D - 1) // 2
32
+ else:
33
+ num_params += p.numel()
34
+ # num_params = sum(p.numel() for p in params)
35
+ return round(num_params / 1e6, 1)
36
+
37
+ def set_requires_grad(model, requires_grad=True):
38
+ for param in model.parameters():
39
+ param.requires_grad = requires_grad
40
+
41
+ class ControlledUnetModel(UNetModel):
42
+ def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
43
+ hs = []
44
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
45
+ emb = self.time_embed(t_emb)
46
+
47
+ h = x.type(self.dtype)
48
+ for module in self.input_blocks:
49
+ if control is not None:
50
+ h = module(h, emb, context)
51
+ h += control
52
+ control = None
53
+ else:
54
+ h = module(h, emb, context)
55
+ hs.append(h)
56
+ h = self.middle_block(h, emb, context)
57
+ for module in self.output_blocks:
58
+ h = th.cat([h, hs.pop()], dim=1)
59
+ h = module(h, emb, context)
60
+ h = h.type(x.dtype)
61
+
62
+ return self.out(h)
63
+
64
+
65
+ class ControlNet(nn.Module):
66
+ def __init__(
67
+ self,
68
+ image_size,
69
+ in_channels,
70
+ model_channels,
71
+ out_channels,
72
+ hint_channels,
73
+ num_res_blocks,
74
+ attention_resolutions,
75
+ dropout=0,
76
+ channel_mult=(1, 2, 4, 8),
77
+ conv_resample=True,
78
+ dims=2,
79
+ use_checkpoint=False,
80
+ use_fp16=False,
81
+ num_heads=-1,
82
+ num_head_channels=-1,
83
+ num_heads_upsample=-1,
84
+ use_scale_shift_norm=False,
85
+ resblock_updown=False,
86
+ use_new_attention_order=False,
87
+ use_spatial_transformer=False, # custom transformer support
88
+ transformer_depth=1, # custom transformer support
89
+ context_dim=None, # custom transformer support
90
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
91
+ legacy=True,
92
+ disable_self_attentions=None,
93
+ num_attention_blocks=None,
94
+ disable_middle_self_attn=False,
95
+ use_linear_in_transformer=False,
96
+ ):
97
+ super().__init__()
98
+ if use_spatial_transformer:
99
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
100
+
101
+ if context_dim is not None:
102
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
103
+ from omegaconf.listconfig import ListConfig
104
+ if type(context_dim) == ListConfig:
105
+ context_dim = list(context_dim)
106
+
107
+ if num_heads_upsample == -1:
108
+ num_heads_upsample = num_heads
109
+
110
+ if num_heads == -1:
111
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
112
+
113
+ if num_head_channels == -1:
114
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
115
+
116
+ self.dims = dims
117
+ self.image_size = image_size
118
+ self.in_channels = in_channels
119
+ self.model_channels = model_channels
120
+ if isinstance(num_res_blocks, int):
121
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
122
+ else:
123
+ if len(num_res_blocks) != len(channel_mult):
124
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
125
+ "as a list/tuple (per-level) with the same length as channel_mult")
126
+ self.num_res_blocks = num_res_blocks
127
+ if disable_self_attentions is not None:
128
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
129
+ assert len(disable_self_attentions) == len(channel_mult)
130
+ if num_attention_blocks is not None:
131
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
132
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
133
+ print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
134
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
135
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
136
+ f"attention will still not be set.")
137
+
138
+ self.attention_resolutions = attention_resolutions
139
+ self.dropout = dropout
140
+ self.channel_mult = channel_mult
141
+ self.conv_resample = conv_resample
142
+ self.use_checkpoint = use_checkpoint
143
+ self.dtype = th.float16 if use_fp16 else th.float32
144
+ self.num_heads = num_heads
145
+ self.num_head_channels = num_head_channels
146
+ self.num_heads_upsample = num_heads_upsample
147
+ self.predict_codebook_ids = n_embed is not None
148
+
149
+ time_embed_dim = model_channels * 4
150
+ self.time_embed = nn.Sequential(
151
+ linear(model_channels, time_embed_dim),
152
+ nn.SiLU(),
153
+ linear(time_embed_dim, time_embed_dim),
154
+ )
155
+
156
+ self.input_hint_block = TimestepEmbedSequential(
157
+ conv_nd(dims, hint_channels, 16, 3, padding=1),
158
+ nn.SiLU(),
159
+ conv_nd(dims, 16, 16, 3, padding=1),
160
+ nn.SiLU(),
161
+ conv_nd(dims, 16, 32, 3, padding=1, stride=2),
162
+ nn.SiLU(),
163
+ conv_nd(dims, 32, 32, 3, padding=1),
164
+ nn.SiLU(),
165
+ conv_nd(dims, 32, 96, 3, padding=1, stride=2),
166
+ nn.SiLU(),
167
+ conv_nd(dims, 96, 96, 3, padding=1),
168
+ nn.SiLU(),
169
+ conv_nd(dims, 96, 256, 3, padding=1, stride=2),
170
+ nn.SiLU(),
171
+ zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
172
+ )
173
+
174
+
175
+ def forward(self, x, hint, timesteps, context, **kwargs):
176
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
177
+ emb = self.time_embed(t_emb)
178
+
179
+ guided_hint = self.input_hint_block(hint, emb, context)
180
+
181
+ # print('guided_hint', len(guided_hint), guided_hint[0].shape, guided_hint.max(), guided_hint.min())
182
+ # sys.exit()
183
+
184
+ return guided_hint
185
+
186
+
187
+ class ControlLDM(LatentDiffusion):
188
+ def __init__(self, control_stage_config, control_key, only_mid_control, *args, **kwargs):
189
+ super().__init__(*args, **kwargs)
190
+ self.control_model = instantiate_from_config(control_stage_config)
191
+ self.control_key = control_key
192
+ self.only_mid_control = only_mid_control
193
+ self.control_scales = [1.0] * 13
194
+
195
+ @torch.no_grad()
196
+ def get_input(self, batch, k, bs=None, *args, **kwargs):
197
+ x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
198
+ control = batch[self.control_key]
199
+ if bs is not None:
200
+ control = control[:bs]
201
+ control = control.to(self.device)
202
+ control = einops.rearrange(control, 'b h w c -> b c h w')
203
+ control = control.to(memory_format=torch.contiguous_format).float()
204
+ return x, dict(c_crossattn=[c], c_concat=[control])
205
+
206
+ def apply_model(self, x_noisy, t, cond, *args, **kwargs):
207
+ assert isinstance(cond, dict)
208
+ diffusion_model = self.model.diffusion_model
209
+
210
+ cond_txt = torch.cat(cond['c_crossattn'], 1)
211
+
212
+ if cond['c_concat'] is None:
213
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=None, only_mid_control=self.only_mid_control)
214
+ else:
215
+ control = self.control_model(x=x_noisy, hint=torch.cat(cond['c_concat'], 1), timesteps=t, context=cond_txt)
216
+ # control = [c * scale for c, scale in zip(control, self.control_scales)]
217
+ eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
218
+
219
+ return eps
220
+
221
+ @torch.no_grad()
222
+ def get_unconditional_conditioning(self, N):
223
+ return self.get_learned_conditioning([""] * N)
224
+
225
+ @torch.no_grad()
226
+ def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
227
+ quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
228
+ plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
229
+ use_ema_scope=True, num_samples=1,
230
+ **kwargs):
231
+ use_ddim = ddim_steps is not None
232
+
233
+ log = dict()
234
+ z, c = self.get_input(batch, self.first_stage_key, bs=N)
235
+ c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
236
+ N = min(z.shape[0], N)
237
+ n_row = min(z.shape[0], n_row)
238
+ log["reconstruction"] = self.decode_first_stage(z)
239
+ log["control"] = c_cat * 2.0 - 1.0
240
+ log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
241
+
242
+ if plot_diffusion_rows:
243
+ # get diffusion row
244
+ diffusion_row = list()
245
+ z_start = z[:n_row]
246
+ for t in range(self.num_timesteps):
247
+ if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
248
+ t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
249
+ t = t.to(self.device).long()
250
+ noise = torch.randn_like(z_start)
251
+ z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
252
+ diffusion_row.append(self.decode_first_stage(z_noisy))
253
+
254
+ diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
255
+ diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
256
+ diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
257
+ diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
258
+ log["diffusion_row"] = diffusion_grid
259
+
260
+ if sample:
261
+ # get denoise row
262
+ samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
263
+ batch_size=N, ddim=use_ddim,
264
+ ddim_steps=ddim_steps, eta=ddim_eta)
265
+ x_samples = self.decode_first_stage(samples)
266
+ log["samples"] = x_samples
267
+ if plot_denoise_rows:
268
+ denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
269
+ log["denoise_row"] = denoise_grid
270
+
271
+ if kwargs['split'] == 'train':
272
+ if unconditional_guidance_scale > 1.0:
273
+ uc_cross = self.get_unconditional_conditioning(N)
274
+ uc_cat = c_cat # torch.zeros_like(c_cat)
275
+
276
+ uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
277
+ samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
278
+ batch_size=N, ddim=use_ddim,
279
+ ddim_steps=ddim_steps, eta=ddim_eta,
280
+ unconditional_guidance_scale=unconditional_guidance_scale,
281
+ unconditional_conditioning=uc_full,
282
+ )
283
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
284
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
285
+
286
+ else:
287
+ if unconditional_guidance_scale > 1.0:
288
+ # uc_cross = self.get_unconditional_conditioning(N)
289
+ # uc_cat = c_cat # torch.zeros_like(c_cat)
290
+
291
+ c_cat = torch.stack([c_cat[0] for _ in range(num_samples)], dim=0).clone()
292
+
293
+ cond = {"c_concat": [c_cat], "c_crossattn": [self.get_learned_conditioning([batch['txt'][0]] * num_samples)]}
294
+ uc_full = {"c_concat": [c_cat], "c_crossattn": [self.get_learned_conditioning([''] * num_samples)]}
295
+
296
+ samples_cfg, _ = self.sample_log(cond=cond, # cond={"c_concat": [c_cat], "c_crossattn": [c]},
297
+ batch_size=num_samples, ddim=use_ddim,
298
+ ddim_steps=ddim_steps, eta=ddim_eta,
299
+ unconditional_guidance_scale=unconditional_guidance_scale,
300
+ unconditional_conditioning=uc_full,
301
+ )
302
+ x_samples_cfg = self.decode_first_stage(samples_cfg)
303
+ log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
304
+
305
+ return log
306
+
307
+ @torch.no_grad()
308
+ def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
309
+ ddim_sampler = DDIMSampler(self)
310
+ b, c, h, w = cond["c_concat"][0].shape
311
+ shape = (self.channels, h // 8, w // 8)
312
+ samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
313
+ return samples, intermediates
314
+
315
+ def configure_optimizers(self):
316
+ lr = self.learning_rate
317
+ params = list(self.control_model.parameters())
318
+
319
+ names = []
320
+ for name, param in self.model.diffusion_model.named_parameters():
321
+ if param.requires_grad:
322
+ params.append(param)
323
+ names.append(name)
324
+ # print(name, param.shape)
325
+
326
+ # params += self.unet_lora_params
327
+ if not self.sd_locked:
328
+ params += list(self.model.diffusion_model.output_blocks.parameters())
329
+ params += list(self.model.diffusion_model.out.parameters())
330
+ opt = torch.optim.AdamW(params, lr=lr)
331
+
332
+ set_requires_grad(self.model.diffusion_model, True)
333
+
334
+ num_params = count_parameters(params)
335
+ print()
336
+ print()
337
+ print(f"Total number of trainable parameters: {num_params} M!")
338
+ print()
339
+ print()
340
+
341
+ return opt
342
+
343
+ def low_vram_shift(self, is_diffusing):
344
+ if is_diffusing:
345
+ self.model = self.model.cuda()
346
+ self.control_model = self.control_model.cuda()
347
+ self.first_stage_model = self.first_stage_model.cpu()
348
+ self.cond_stage_model = self.cond_stage_model.cpu()
349
+ else:
350
+ self.model = self.model.cpu()
351
+ self.control_model = self.control_model.cpu()
352
+ self.first_stage_model = self.first_stage_model.cuda()
353
+ self.cond_stage_model = self.cond_stage_model.cuda()
generation/subject/download_dreambooth.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ echo -e "\nDownloading dreambooth dataset..."
4
+ git clone https://github.com/google/dreambooth.git
generation/subject/evaluate.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import hashlib
18
+ import logging
19
+ import math
20
+ import os
21
+ import warnings
22
+ from pathlib import Path
23
+
24
+ from functools import reduce
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn.functional as F
28
+ import torch.utils.checkpoint
29
+ import transformers
30
+ from packaging import version
31
+ from PIL import Image
32
+ from torch.utils.data import Dataset, DataLoader
33
+ from torchvision import transforms
34
+ from tqdm.auto import tqdm
35
+ from transformers import AutoTokenizer, PretrainedConfig, ViTFeatureExtractor, ViTModel
36
+
37
+ import lpips
38
+ import json
39
+ from PIL import Image
40
+ import requests
41
+ from transformers import AutoProcessor, AutoTokenizer, CLIPModel
42
+ import torchvision.transforms.functional as TF
43
+ from torch.nn.functional import cosine_similarity
44
+ from torchvision.transforms import Compose, ToTensor, Normalize, Resize, ToPILImage
45
+ import re
46
+
47
+ def get_prompt(subject_name, prompt_idx):
48
+
49
+ subject_names = [
50
+ "backpack", "backpack_dog", "bear_plushie", "berry_bowl", "can",
51
+ "candle", "cat", "cat2", "clock", "colorful_sneaker",
52
+ "dog", "dog2", "dog3", "dog5", "dog6",
53
+ "dog7", "dog8", "duck_toy", "fancy_boot", "grey_sloth_plushie",
54
+ "monster_toy", "pink_sunglasses", "poop_emoji", "rc_car", "red_cartoon",
55
+ "robot_toy", "shiny_sneaker", "teapot", "vase", "wolf_plushie",
56
+ ]
57
+
58
+ class_tokens = [
59
+ "backpack", "backpack", "stuffed animal", "bowl", "can",
60
+ "candle", "cat", "cat", "clock", "sneaker",
61
+ "dog", "dog", "dog", "dog", "dog",
62
+ "dog", "dog", "toy", "boot", "stuffed animal",
63
+ "toy", "glasses", "toy", "toy", "cartoon",
64
+ "toy", "sneaker", "teapot", "vase", "stuffed animal",
65
+ ]
66
+
67
+ class_token = class_tokens[subject_names.index(subject_name)]
68
+
69
+ prompt_list = [
70
+ f"a qwe {class_token} in the jungle",
71
+ f"a qwe {class_token} in the snow",
72
+ f"a qwe {class_token} on the beach",
73
+ f"a qwe {class_token} on a cobblestone street",
74
+ f"a qwe {class_token} on top of pink fabric",
75
+ f"a qwe {class_token} on top of a wooden floor",
76
+ f"a qwe {class_token} with a city in the background",
77
+ f"a qwe {class_token} with a mountain in the background",
78
+ f"a qwe {class_token} with a blue house in the background",
79
+ f"a qwe {class_token} on top of a purple rug in a forest",
80
+ f"a qwe {class_token} wearing a red hat",
81
+ f"a qwe {class_token} wearing a santa hat",
82
+ f"a qwe {class_token} wearing a rainbow scarf",
83
+ f"a qwe {class_token} wearing a black top hat and a monocle",
84
+ f"a qwe {class_token} in a chef outfit",
85
+ f"a qwe {class_token} in a firefighter outfit",
86
+ f"a qwe {class_token} in a police outfit",
87
+ f"a qwe {class_token} wearing pink glasses",
88
+ f"a qwe {class_token} wearing a yellow shirt",
89
+ f"a qwe {class_token} in a purple wizard outfit",
90
+ f"a red qwe {class_token}",
91
+ f"a purple qwe {class_token}",
92
+ f"a shiny qwe {class_token}",
93
+ f"a wet qwe {class_token}",
94
+ f"a cube shaped qwe {class_token}",
95
+ ]
96
+
97
+ return prompt_list[int(prompt_idx)]
98
+
99
+
100
+
101
+ class PromptDatasetCLIP(Dataset):
102
+ def __init__(self, subject_name, data_dir_B, tokenizer, processor, epoch=None):
103
+ self.data_dir_B = data_dir_B
104
+
105
+ subject_name, prompt_idx = subject_name.split('-')
106
+
107
+ data_dir_B = os.path.join(self.data_dir_B, str(epoch))
108
+ self.image_lst = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")]
109
+ self.prompt_lst = [get_prompt(subject_name, prompt_idx)] * len(self.image_lst)
110
+
111
+ self.tokenizer = tokenizer
112
+ self.processor = processor
113
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
114
+
115
+ def __len__(self):
116
+ return len(self.image_lst)
117
+
118
+ def __getitem__(self, idx):
119
+ image_path = self.image_lst[idx]
120
+ image = Image.open(image_path)
121
+ prompt = self.prompt_lst[idx]
122
+
123
+ extrema = image.getextrema()
124
+ if all(min_val == max_val == 0 for min_val, max_val in extrema):
125
+ return None, None
126
+ else:
127
+ prompt_inputs = self.tokenizer([prompt], padding=True, return_tensors="pt")
128
+ image_inputs = self.processor(images=image, return_tensors="pt")
129
+
130
+ return image_inputs, prompt_inputs
131
+
132
+
133
+ class PairwiseImageDatasetCLIP(Dataset):
134
+ def __init__(self, subject_name, data_dir_A, data_dir_B, processor, epoch):
135
+ self.data_dir_A = data_dir_A
136
+ self.data_dir_B = data_dir_B
137
+
138
+ subject_name, prompt_idx = subject_name.split('-')
139
+
140
+ self.data_dir_A = os.path.join(self.data_dir_A, subject_name)
141
+ self.image_files_A = [os.path.join(self.data_dir_A, f) for f in os.listdir(self.data_dir_A) if f.endswith(".jpg")]
142
+
143
+ data_dir_B = os.path.join(self.data_dir_B, str(epoch))
144
+ self.image_files_B = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")]
145
+
146
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
147
+ self.processor = processor
148
+
149
+ def __len__(self):
150
+ return len(self.image_files_A) * len(self.image_files_B)
151
+
152
+ def __getitem__(self, index):
153
+ index_A = index // len(self.image_files_B)
154
+ index_B = index % len(self.image_files_B)
155
+
156
+ image_A = Image.open(self.image_files_A[index_A]) # .convert("RGB")
157
+ image_B = Image.open(self.image_files_B[index_B]) # .convert("RGB")
158
+
159
+ extrema_A = image_A.getextrema()
160
+ extrema_B = image_B.getextrema()
161
+ if all(min_val == max_val == 0 for min_val, max_val in extrema_A) or all(min_val == max_val == 0 for min_val, max_val in extrema_B):
162
+ return None, None
163
+ else:
164
+ inputs_A = self.processor(images=image_A, return_tensors="pt")
165
+ inputs_B = self.processor(images=image_B, return_tensors="pt")
166
+
167
+ return inputs_A, inputs_B
168
+
169
+
170
+ class PairwiseImageDatasetDINO(Dataset):
171
+ def __init__(self, subject_name, data_dir_A, data_dir_B, feature_extractor, epoch):
172
+ self.data_dir_A = data_dir_A
173
+ self.data_dir_B = data_dir_B
174
+
175
+ subject_name, prompt_idx = subject_name.split('-')
176
+
177
+ self.data_dir_A = os.path.join(self.data_dir_A, subject_name)
178
+ self.image_files_A = [os.path.join(self.data_dir_A, f) for f in os.listdir(self.data_dir_A) if f.endswith(".jpg")]
179
+
180
+ data_dir_B = os.path.join(self.data_dir_B, str(epoch))
181
+ self.image_files_B = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")]
182
+
183
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
184
+ self.feature_extractor = feature_extractor
185
+
186
+ def __len__(self):
187
+ return len(self.image_files_A) * len(self.image_files_B)
188
+
189
+ def __getitem__(self, index):
190
+ index_A = index // len(self.image_files_B)
191
+ index_B = index % len(self.image_files_B)
192
+
193
+ image_A = Image.open(self.image_files_A[index_A]) # .convert("RGB")
194
+ image_B = Image.open(self.image_files_B[index_B]) # .convert("RGB")
195
+
196
+ extrema_A = image_A.getextrema()
197
+ extrema_B = image_B.getextrema()
198
+ if all(min_val == max_val == 0 for min_val, max_val in extrema_A) or all(min_val == max_val == 0 for min_val, max_val in extrema_B):
199
+ return None, None
200
+ else:
201
+ inputs_A = self.feature_extractor(images=image_A, return_tensors="pt")
202
+ inputs_B = self.feature_extractor(images=image_B, return_tensors="pt")
203
+
204
+ return inputs_A, inputs_B
205
+
206
+ class PairwiseImageDatasetLPIPS(Dataset):
207
+ def __init__(self, subject_name, data_dir_A, data_dir_B, epoch):
208
+ self.data_dir_A = data_dir_A
209
+ self.data_dir_B = data_dir_B
210
+
211
+ subject_name, prompt_idx = subject_name.split('-')
212
+
213
+ self.data_dir_A = os.path.join(self.data_dir_A, subject_name)
214
+ self.image_files_A = [os.path.join(self.data_dir_A, f) for f in os.listdir(self.data_dir_A) if f.endswith(".jpg")]
215
+
216
+ data_dir_B = os.path.join(self.data_dir_B, str(epoch))
217
+ self.image_files_B = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")]
218
+
219
+ self.transform = Compose([
220
+ Resize((512, 512)),
221
+ ToTensor(),
222
+ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
223
+ ])
224
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
225
+
226
+ def __len__(self):
227
+ return len(self.image_files_A) * len(self.image_files_B)
228
+
229
+ def __getitem__(self, index):
230
+ index_A = index // len(self.image_files_B)
231
+ index_B = index % len(self.image_files_B)
232
+
233
+ image_A = Image.open(self.image_files_A[index_A]) # .convert("RGB")
234
+ image_B = Image.open(self.image_files_B[index_B]) # .convert("RGB")
235
+
236
+ extrema_A = image_A.getextrema()
237
+ extrema_B = image_B.getextrema()
238
+ if all(min_val == max_val == 0 for min_val, max_val in extrema_A) or all(min_val == max_val == 0 for min_val, max_val in extrema_B):
239
+ return None, None
240
+ else:
241
+ if self.transform:
242
+ image_A = self.transform(image_A)
243
+ image_B = self.transform(image_B)
244
+
245
+ return image_A, image_B
246
+
247
+
248
+ def clip_text(subject_name, image_dir):
249
+ criterion = 'clip_text'
250
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
251
+ model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
252
+ # Get the text features
253
+ tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-large-patch14")
254
+ # Get the image features
255
+ processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
256
+
257
+ epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)])
258
+ best_mean_similarity = 0
259
+ mean_similarity_list = []
260
+ for epoch in epochs:
261
+ similarity = []
262
+ dataset = PromptDatasetCLIP(subject_name, image_dir, tokenizer, processor, epoch)
263
+ dataloader = DataLoader(dataset, batch_size=32)
264
+ for i in range(len(dataset)):
265
+ image_inputs, prompt_inputs = dataset[i]
266
+ if image_inputs is not None and prompt_inputs is not None:
267
+ image_inputs['pixel_values'] = image_inputs['pixel_values'].to(device)
268
+ prompt_inputs['input_ids'] = prompt_inputs['input_ids'].to(device)
269
+ prompt_inputs['attention_mask'] = prompt_inputs['attention_mask'].to(device)
270
+ # print(prompt_inputs)
271
+ image_features = model.get_image_features(**image_inputs)
272
+ text_features = model.get_text_features(**prompt_inputs)
273
+
274
+ sim = cosine_similarity(image_features, text_features)
275
+
276
+ #image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
277
+ #text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
278
+ #logit_scale = model.logit_scale.exp()
279
+ #sim = torch.matmul(text_features, image_features.t()) * logit_scale
280
+ similarity.append(sim.item())
281
+
282
+ if similarity:
283
+ mean_similarity = torch.tensor(similarity).mean().item()
284
+ mean_similarity_list.append(mean_similarity)
285
+ best_mean_similarity = max(best_mean_similarity, mean_similarity)
286
+ print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {mean_similarity}({best_mean_similarity})')
287
+ else:
288
+ mean_similarity_list.append(0)
289
+ print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {0}({best_mean_similarity})')
290
+
291
+ return mean_similarity_list
292
+
293
+
294
+ def clip_image(subject_name, image_dir, dreambooth_dir='dreambooth/dataset'):
295
+ criterion = 'clip_image'
296
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
297
+ model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
298
+ # Get the image features
299
+ processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
300
+
301
+ epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)])
302
+ best_mean_similarity = 0
303
+ mean_similarity_list = []
304
+ for epoch in epochs:
305
+ similarity = []
306
+ dataset = PairwiseImageDatasetCLIP(subject_name, dreambooth_dir, image_dir, processor, epoch)
307
+ # dataset = SelfPairwiseImageDatasetCLIP(subject, './data', processor)
308
+
309
+ for i in range(len(dataset)):
310
+ inputs_A, inputs_B = dataset[i]
311
+ if inputs_A is not None and inputs_B is not None:
312
+ inputs_A['pixel_values'] = inputs_A['pixel_values'].to(device)
313
+ inputs_B['pixel_values'] = inputs_B['pixel_values'].to(device)
314
+
315
+ image_A_features = model.get_image_features(**inputs_A)
316
+ image_B_features = model.get_image_features(**inputs_B)
317
+
318
+ image_A_features = image_A_features / image_A_features.norm(p=2, dim=-1, keepdim=True)
319
+ image_B_features = image_B_features / image_B_features.norm(p=2, dim=-1, keepdim=True)
320
+
321
+ logit_scale = model.logit_scale.exp()
322
+ sim = torch.matmul(image_A_features, image_B_features.t()) # * logit_scale
323
+ similarity.append(sim.item())
324
+
325
+ if similarity:
326
+ mean_similarity = torch.tensor(similarity).mean().item()
327
+ best_mean_similarity = max(best_mean_similarity, mean_similarity)
328
+ mean_similarity_list.append(mean_similarity)
329
+ print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {mean_similarity}({best_mean_similarity})')
330
+ else:
331
+ mean_similarity_list.append(0)
332
+ print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {0}({best_mean_similarity})')
333
+
334
+ return mean_similarity_list
335
+
336
+
337
+ def dino(subject_name, image_dir, dreambooth_dir='dreambooth/dataset'):
338
+ criterion = 'dino'
339
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
340
+ model = ViTModel.from_pretrained('facebook/dino-vits16').to(device)
341
+ feature_extractor = ViTFeatureExtractor.from_pretrained('facebook/dino-vits16')
342
+
343
+ epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)])
344
+ best_mean_similarity = 0
345
+ mean_similarity_list = []
346
+ for epoch in epochs:
347
+ similarity = []
348
+ # dataset = PairwiseImageDatasetDINO(subject, './data', image_dir, feature_extractor, epoch)
349
+ dataset = PairwiseImageDatasetDINO(subject_name, dreambooth_dir, image_dir, feature_extractor, epoch)
350
+ # dataset = SelfPairwiseImageDatasetDINO(subject, './data', feature_extractor)
351
+
352
+ for i in range(len(dataset)):
353
+ inputs_A, inputs_B = dataset[i]
354
+ if inputs_A is not None and inputs_B is not None:
355
+ inputs_A['pixel_values'] = inputs_A['pixel_values'].to(device)
356
+ inputs_B['pixel_values'] = inputs_B['pixel_values'].to(device)
357
+
358
+ outputs_A = model(**inputs_A)
359
+ image_A_features = outputs_A.last_hidden_state[:, 0, :]
360
+
361
+ outputs_B = model(**inputs_B)
362
+ image_B_features = outputs_B.last_hidden_state[:, 0, :]
363
+
364
+ image_A_features = image_A_features / image_A_features.norm(p=2, dim=-1, keepdim=True)
365
+ image_B_features = image_B_features / image_B_features.norm(p=2, dim=-1, keepdim=True)
366
+
367
+ sim = torch.matmul(image_A_features, image_B_features.t()) # * logit_scale
368
+ similarity.append(sim.item())
369
+
370
+ mean_similarity = torch.tensor(similarity).mean().item()
371
+ best_mean_similarity = max(best_mean_similarity, mean_similarity)
372
+ mean_similarity_list.append(mean_similarity)
373
+ print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {mean_similarity}({best_mean_similarity})')
374
+
375
+ return mean_similarity_list
376
+
377
+
378
+ def lpips_image(subject_name, image_dir, dreambooth_dir='dreambooth/dataset'):
379
+ criterion = 'lpips_image'
380
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
381
+ # Set up the LPIPS model (vgg=True uses the VGG-based model from the paper)
382
+ loss_fn = lpips.LPIPS(net='vgg').to(device)
383
+
384
+ # 有可能有些epoch没跑全
385
+ epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)])
386
+ mean_similarity_list = []
387
+ best_mean_similarity = 0
388
+ for epoch in epochs:
389
+ similarity = []
390
+ dataset = PairwiseImageDatasetLPIPS(subject_name, dreambooth_dir, image_dir, epoch)
391
+ # dataset = SelfPairwiseImageDatasetLPIPS(subject, './data')
392
+
393
+ for i in range(len(dataset)):
394
+ image_A, image_B = dataset[i]
395
+ if image_A is not None and image_B is not None:
396
+ image_A = image_A.to(device)
397
+ image_B = image_B.to(device)
398
+
399
+ # Calculate LPIPS between the two images
400
+ distance = loss_fn(image_A, image_B)
401
+
402
+ similarity.append(distance.item())
403
+
404
+ mean_similarity = torch.tensor(similarity).mean().item()
405
+ best_mean_similarity = max(best_mean_similarity, mean_similarity)
406
+ mean_similarity_list.append(mean_similarity)
407
+ print(f'epoch: {epoch}, criterion: LPIPS distance, mean_similarity: {mean_similarity}({best_mean_similarity})')
408
+
409
+ return mean_similarity_list
410
+
411
+ if __name__ == "__main__":
412
+ image_dir = 'log_hra/lr_1e-4_r_8/'
413
+
414
+ subject_dirs, subject_names = [], []
415
+ for name in os.listdir(image_dir):
416
+ if os.path.isdir(os.path.join(image_dir, name)):
417
+ subject_dirs.append(os.path.join(image_dir, name))
418
+ subject_names.append(name)
419
+
420
+ results_path = os.path.join(image_dir, 'true_results.json')
421
+ # {'backpack-0':{'DINO':[x, ...], 'CLIP-I':[x, ...], 'CLIP-T':[x, ...], 'LPIPS':[x, ...],}}
422
+
423
+ results_dict = dict()
424
+ if os.path.exists(results_path):
425
+ with open(results_path, 'r') as f:
426
+ results = f.__iter__()
427
+ while True:
428
+ try:
429
+ result_json = json.loads(next(results))
430
+ results_dict.update(result_json)
431
+
432
+ except StopIteration:
433
+ print("finish extraction.")
434
+ break
435
+
436
+ for idx in range(len(subject_names)):
437
+ subject_name = subject_names[idx]
438
+ subject_dir = subject_dirs[idx]
439
+
440
+ if subject_name in results_dict:
441
+ continue
442
+
443
+ print(f'evaluating {subject_dir}')
444
+ dino_sim = dino(subject_name, subject_dir)
445
+ clip_i_sim = clip_image(subject_name, subject_dir)
446
+ clip_t_sim = clip_text(subject_name, subject_dir)
447
+ lpips_sim = lpips_image(subject_name, subject_dir)
448
+
449
+ subject_result = {'DINO': dino_sim, 'CLIP-I': clip_i_sim, 'CLIP-T': clip_t_sim, 'LPIPS': lpips_sim}
450
+ print(subject_result)
451
+
452
+ with open(results_path,'a') as f:
453
+ json_string = json.dumps({subject_name: subject_result})
454
+ f.write(json_string + "\n")
455
+
456
+
457
+
458
+
459
+
460
+
461
+
462
+
generation/subject/get_result.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import math
17
+ import os
18
+
19
+ from functools import reduce
20
+ import numpy as np
21
+
22
+ import json
23
+
24
+
25
+ if __name__ == "__main__":
26
+ image_dir = 'log_hra/lr_1e-4_r_8/'
27
+
28
+ results_path = os.path.join(image_dir, 'true_results.json')
29
+ # {'backpack-0':{'DINO':[x, ...], 'CLIP-I':[x, ...], 'CLIP-T':[x, ...], 'LPIPS':[x, ...],}}
30
+
31
+ results_dict = dict()
32
+ if os.path.exists(results_path):
33
+ with open(results_path, 'r') as f:
34
+ results = f.__iter__()
35
+ while True:
36
+ try:
37
+ result_json = json.loads(next(results))
38
+ results_dict.update(result_json)
39
+
40
+ except StopIteration:
41
+ print("finish extraction.")
42
+ break
43
+
44
+ total_result = np.zeros(4)
45
+ metric_name_list = ['DINO', 'CLIP-I', 'CLIP-T', 'LPIPS']
46
+ for subject_name, subject_results in results_dict.items():
47
+
48
+ metric_results_percent = None
49
+ for metric_name, metric_results in subject_results.items():
50
+ metric_results = [0 if np.isnan(r) else r for r in metric_results]
51
+ metric_results_norm = np.array(metric_results) / (max(metric_results) - min(metric_results))
52
+ if metric_results_percent is None:
53
+ metric_results_percent = metric_results_norm
54
+ else:
55
+ metric_results_percent += metric_results_norm
56
+
57
+ subject_results_max_idx = np.argmax(metric_results_percent)
58
+ for idx, metric_name in enumerate(metric_name_list):
59
+ total_result[idx] += subject_results[metric_name][subject_results_max_idx]
60
+ total_result /= len(results_dict)
61
+ print(f'DINO: {total_result[0]}, CLIP-I: {total_result[1]}, CLIP-T: {total_result[2]}, LPIPS: {total_result[3]}')
62
+
generation/subject/oft_utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .mhe import MHE_db, MHE_OFT, MHE_LoRA
2
+
generation/subject/oft_utils/attention_processor.py ADDED
@@ -0,0 +1,1036 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Callable, Optional, Union
15
+ import math
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+ from torch.autograd import Function
20
+
21
+ from diffusers.utils import deprecate, logging
22
+ from diffusers.utils.import_utils import is_xformers_available
23
+
24
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+
27
+ if is_xformers_available():
28
+ import xformers
29
+ import xformers.ops
30
+ else:
31
+ xformers = None
32
+
33
+ class Attention(nn.Module):
34
+ r"""
35
+ A cross attention layer.
36
+ Parameters:
37
+ query_dim (`int`): The number of channels in the query.
38
+ cross_attention_dim (`int`, *optional*):
39
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
40
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
41
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
42
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
43
+ bias (`bool`, *optional*, defaults to False):
44
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ query_dim: int,
50
+ cross_attention_dim: Optional[int] = None,
51
+ heads: int = 8,
52
+ dim_head: int = 64,
53
+ dropout: float = 0.0,
54
+ bias=False,
55
+ upcast_attention: bool = False,
56
+ upcast_softmax: bool = False,
57
+ cross_attention_norm: Optional[str] = None,
58
+ cross_attention_norm_num_groups: int = 32,
59
+ added_kv_proj_dim: Optional[int] = None,
60
+ norm_num_groups: Optional[int] = None,
61
+ out_bias: bool = True,
62
+ scale_qk: bool = True,
63
+ only_cross_attention: bool = False,
64
+ processor: Optional["AttnProcessor"] = None,
65
+ ):
66
+ super().__init__()
67
+ inner_dim = dim_head * heads
68
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
69
+ self.upcast_attention = upcast_attention
70
+ self.upcast_softmax = upcast_softmax
71
+
72
+ self.scale = dim_head**-0.5 if scale_qk else 1.0
73
+
74
+ self.heads = heads
75
+ # for slice_size > 0 the attention score computation
76
+ # is split across the batch axis to save memory
77
+ # You can set slice_size with `set_attention_slice`
78
+ self.sliceable_head_dim = heads
79
+
80
+ self.added_kv_proj_dim = added_kv_proj_dim
81
+ self.only_cross_attention = only_cross_attention
82
+
83
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
84
+ raise ValueError(
85
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
86
+ )
87
+
88
+ if norm_num_groups is not None:
89
+ self.group_norm = nn.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
90
+ else:
91
+ self.group_norm = None
92
+
93
+ if cross_attention_norm is None:
94
+ self.norm_cross = None
95
+ elif cross_attention_norm == "layer_norm":
96
+ self.norm_cross = nn.LayerNorm(cross_attention_dim)
97
+ elif cross_attention_norm == "group_norm":
98
+ if self.added_kv_proj_dim is not None:
99
+ # The given `encoder_hidden_states` are initially of shape
100
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
101
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
102
+ # before the projection, so we need to use `added_kv_proj_dim` as
103
+ # the number of channels for the group norm.
104
+ norm_cross_num_channels = added_kv_proj_dim
105
+ else:
106
+ norm_cross_num_channels = cross_attention_dim
107
+
108
+ self.norm_cross = nn.GroupNorm(
109
+ num_channels=norm_cross_num_channels, num_groups=cross_attention_norm_num_groups, eps=1e-5, affine=True
110
+ )
111
+ else:
112
+ raise ValueError(
113
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
114
+ )
115
+
116
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
117
+
118
+ if not self.only_cross_attention:
119
+ # only relevant for the `AddedKVProcessor` classes
120
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
121
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
122
+ else:
123
+ self.to_k = None
124
+ self.to_v = None
125
+
126
+ if self.added_kv_proj_dim is not None:
127
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, inner_dim)
128
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, inner_dim)
129
+
130
+ self.to_out = nn.ModuleList([])
131
+ self.to_out.append(nn.Linear(inner_dim, query_dim, bias=out_bias))
132
+ self.to_out.append(nn.Dropout(dropout))
133
+
134
+ # set attention processor
135
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
136
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
137
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
138
+ if processor is None:
139
+ processor = (
140
+ AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and scale_qk else AttnProcessor()
141
+ )
142
+ self.set_processor(processor)
143
+
144
+ def set_use_memory_efficient_attention_xformers(
145
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
146
+ ):
147
+ is_lora = hasattr(self, "processor") and isinstance(
148
+ self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
149
+ )
150
+
151
+ if use_memory_efficient_attention_xformers:
152
+ if self.added_kv_proj_dim is not None:
153
+ # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
154
+ # which uses this type of cross attention ONLY because the attention mask of format
155
+ # [0, ..., -10.000, ..., 0, ...,] is not supported
156
+ raise NotImplementedError(
157
+ "Memory efficient attention with `xformers` is currently not supported when"
158
+ " `self.added_kv_proj_dim` is defined."
159
+ )
160
+ elif not is_xformers_available():
161
+ raise ModuleNotFoundError(
162
+ (
163
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
164
+ " xformers"
165
+ ),
166
+ name="xformers",
167
+ )
168
+ elif not torch.cuda.is_available():
169
+ raise ValueError(
170
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
171
+ " only available for GPU "
172
+ )
173
+ else:
174
+ try:
175
+ # Make sure we can run the memory efficient attention
176
+ _ = xformers.ops.memory_efficient_attention(
177
+ torch.randn((1, 2, 40), device="cuda"),
178
+ torch.randn((1, 2, 40), device="cuda"),
179
+ torch.randn((1, 2, 40), device="cuda"),
180
+ )
181
+ except Exception as e:
182
+ raise e
183
+
184
+ if is_lora:
185
+ processor = LoRAXFormersAttnProcessor(
186
+ hidden_size=self.processor.hidden_size,
187
+ cross_attention_dim=self.processor.cross_attention_dim,
188
+ rank=self.processor.rank,
189
+ attention_op=attention_op,
190
+ )
191
+ processor.load_state_dict(self.processor.state_dict())
192
+ processor.to(self.processor.to_q_lora.up.weight.device)
193
+ else:
194
+ processor = XFormersAttnProcessor(attention_op=attention_op)
195
+ else:
196
+ if is_lora:
197
+ processor = LoRAAttnProcessor(
198
+ hidden_size=self.processor.hidden_size,
199
+ cross_attention_dim=self.processor.cross_attention_dim,
200
+ rank=self.processor.rank,
201
+ )
202
+ processor.load_state_dict(self.processor.state_dict())
203
+ processor.to(self.processor.to_q_lora.up.weight.device)
204
+ else:
205
+ processor = AttnProcessor()
206
+
207
+ self.set_processor(processor)
208
+
209
+ def set_attention_slice(self, slice_size):
210
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
211
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
212
+
213
+ if slice_size is not None and self.added_kv_proj_dim is not None:
214
+ processor = SlicedAttnAddedKVProcessor(slice_size)
215
+ elif slice_size is not None:
216
+ processor = SlicedAttnProcessor(slice_size)
217
+ elif self.added_kv_proj_dim is not None:
218
+ processor = AttnAddedKVProcessor()
219
+ else:
220
+ processor = AttnProcessor()
221
+
222
+ self.set_processor(processor)
223
+
224
+ def set_processor(self, processor: "AttnProcessor"):
225
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
226
+ # pop `processor` from `self._modules`
227
+ if (
228
+ hasattr(self, "processor")
229
+ and isinstance(self.processor, torch.nn.Module)
230
+ and not isinstance(processor, torch.nn.Module)
231
+ ):
232
+ logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
233
+ self._modules.pop("processor")
234
+
235
+ self.processor = processor
236
+
237
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
238
+ # The `Attention` class can call different attention processors / attention functions
239
+ # here we simply pass along all tensors to the selected processor class
240
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
241
+ return self.processor(
242
+ self,
243
+ hidden_states,
244
+ encoder_hidden_states=encoder_hidden_states,
245
+ attention_mask=attention_mask,
246
+ **cross_attention_kwargs,
247
+ )
248
+
249
+ def batch_to_head_dim(self, tensor):
250
+ head_size = self.heads
251
+ batch_size, seq_len, dim = tensor.shape
252
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
253
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
254
+ return tensor
255
+
256
+ def head_to_batch_dim(self, tensor, out_dim=3):
257
+ head_size = self.heads
258
+ batch_size, seq_len, dim = tensor.shape
259
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
260
+ tensor = tensor.permute(0, 2, 1, 3)
261
+
262
+ if out_dim == 3:
263
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
264
+
265
+ return tensor
266
+
267
+ def get_attention_scores(self, query, key, attention_mask=None):
268
+ dtype = query.dtype
269
+ if self.upcast_attention:
270
+ query = query.float()
271
+ key = key.float()
272
+
273
+ if attention_mask is None:
274
+ baddbmm_input = torch.empty(
275
+ query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
276
+ )
277
+ beta = 0
278
+ else:
279
+ baddbmm_input = attention_mask
280
+ beta = 1
281
+
282
+ attention_scores = torch.baddbmm(
283
+ baddbmm_input,
284
+ query,
285
+ key.transpose(-1, -2),
286
+ beta=beta,
287
+ alpha=self.scale,
288
+ )
289
+
290
+ if self.upcast_softmax:
291
+ attention_scores = attention_scores.float()
292
+
293
+ attention_probs = attention_scores.softmax(dim=-1)
294
+ attention_probs = attention_probs.to(dtype)
295
+
296
+ return attention_probs
297
+
298
+ def prepare_attention_mask(self, attention_mask, target_length, batch_size=None, out_dim=3):
299
+ if batch_size is None:
300
+ deprecate(
301
+ "batch_size=None",
302
+ "0.0.15",
303
+ (
304
+ "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
305
+ " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
306
+ " `prepare_attention_mask` when preparing the attention_mask."
307
+ ),
308
+ )
309
+ batch_size = 1
310
+
311
+ head_size = self.heads
312
+ if attention_mask is None:
313
+ return attention_mask
314
+
315
+ if attention_mask.shape[-1] != target_length:
316
+ if attention_mask.device.type == "mps":
317
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
318
+ # Instead, we can manually construct the padding tensor.
319
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
320
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
321
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
322
+ else:
323
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
324
+
325
+ if out_dim == 3:
326
+ if attention_mask.shape[0] < batch_size * head_size:
327
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
328
+ elif out_dim == 4:
329
+ attention_mask = attention_mask.unsqueeze(1)
330
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
331
+
332
+ return attention_mask
333
+
334
+ def norm_encoder_hidden_states(self, encoder_hidden_states):
335
+ assert self.norm_cross is not None, "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
336
+
337
+ if isinstance(self.norm_cross, nn.LayerNorm):
338
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
339
+ elif isinstance(self.norm_cross, nn.GroupNorm):
340
+ # Group norm norms along the channels dimension and expects
341
+ # input to be in the shape of (N, C, *). In this case, we want
342
+ # to norm along the hidden dimension, so we need to move
343
+ # (batch_size, sequence_length, hidden_size) ->
344
+ # (batch_size, hidden_size, sequence_length)
345
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
346
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
347
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
348
+ else:
349
+ assert False
350
+
351
+ return encoder_hidden_states
352
+
353
+
354
+ class AttnProcessor:
355
+ def __call__(
356
+ self,
357
+ attn: Attention,
358
+ hidden_states,
359
+ encoder_hidden_states=None,
360
+ attention_mask=None,
361
+ ):
362
+ batch_size, sequence_length, _ = (
363
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
364
+ )
365
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
366
+ query = attn.to_q(hidden_states)
367
+
368
+ if encoder_hidden_states is None:
369
+ encoder_hidden_states = hidden_states
370
+ elif attn.norm_cross:
371
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
372
+
373
+ key = attn.to_k(encoder_hidden_states)
374
+ value = attn.to_v(encoder_hidden_states)
375
+
376
+ query = attn.head_to_batch_dim(query)
377
+ key = attn.head_to_batch_dim(key)
378
+ value = attn.head_to_batch_dim(value)
379
+
380
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
381
+ hidden_states = torch.bmm(attention_probs, value)
382
+ hidden_states = attn.batch_to_head_dim(hidden_states)
383
+
384
+ # linear proj
385
+ hidden_states = attn.to_out[0](hidden_states)
386
+ # dropout
387
+ hidden_states = attn.to_out[1](hidden_states)
388
+
389
+ return hidden_states
390
+
391
+
392
+ class HRALinearLayer(nn.Module):
393
+ def __init__(self, in_features, out_features, bias=False, r=8, apply_GS=False):
394
+ super(HRALinearLayer, self).__init__()
395
+
396
+ self.in_features=in_features
397
+ self.out_features=out_features
398
+
399
+ self.register_buffer('cross_attention_dim', torch.tensor(in_features))
400
+ self.register_buffer('hidden_size', torch.tensor(out_features))
401
+
402
+ self.r = r
403
+ self.apply_GS = apply_GS
404
+
405
+ half_u = torch.zeros(in_features, r // 2)
406
+ nn.init.kaiming_uniform_(half_u, a=math.sqrt(5))
407
+ self.hra_u = nn.Parameter(torch.repeat_interleave(half_u, 2, dim=1), requires_grad=True)
408
+
409
+ def forward(self, attn, x):
410
+ # orig_dtype = x.dtype
411
+ # dtype = self.v_list[0].dtype
412
+
413
+ # unit_v_list = [v / (torch.sqrt(torch.sum(v ** 2) + self.eps)) for v in self.v_list]
414
+
415
+ # filt = attn.weight.data.to(dtype)
416
+ # for unit_v in unit_v_list:
417
+ # filt = torch.mm(filt, torch.eye(self.in_features, device=x.device) - 2 * unit_v @ unit_v.t())
418
+ # # filt = torch.mm(filt, torch.eye(self.in_features, device=x.device) + self.v_square)
419
+
420
+ # bias_term = attn.bias.data if attn.bias is not None else None
421
+ # if bias_term is not None:
422
+ # bias_term = bias_term.to(orig_dtype)
423
+
424
+ # out = nn.functional.linear(input=x.to(orig_dtype), weight=filt.to(orig_dtype), bias=bias_term)
425
+
426
+ # return out
427
+ orig_weight = attn.weight.data
428
+ if self.apply_GS:
429
+ weight = [(self.hra_u[:, 0] / self.hra_u[:, 0].norm()).view(-1, 1)]
430
+ for i in range(1, self.r):
431
+ ui = self.hra_u[:, i].view(-1, 1)
432
+ for j in range(i):
433
+ ui = ui - (weight[j].t() @ ui) * weight[j]
434
+ weight.append((ui / ui.norm()).view(-1, 1))
435
+ weight = torch.cat(weight, dim=1)
436
+ new_weight = orig_weight @ (torch.eye(self.in_features, device=x.device) - 2 * weight @ weight.t())
437
+
438
+ else:
439
+ new_weight = orig_weight
440
+ hra_u_norm = self.hra_u / self.hra_u.norm(dim=0)
441
+ for i in range(self.r):
442
+ ui = hra_u_norm[:, i].view(-1, 1)
443
+ new_weight = torch.mm(new_weight, torch.eye(self.in_features, device=x.device) - 2 * ui @ ui.t())
444
+
445
+ out = nn.functional.linear(input=x, weight=new_weight, bias=attn.bias)
446
+ return out
447
+
448
+ class HRAAttnProcessor(nn.Module):
449
+ def __init__(self, hidden_size, cross_attention_dim=None, r=8, apply_GS=False):
450
+ super().__init__()
451
+
452
+ self.hidden_size = hidden_size
453
+ self.cross_attention_dim = cross_attention_dim
454
+ self.r = r
455
+
456
+ self.to_q_hra = HRALinearLayer(hidden_size, hidden_size, r=r, apply_GS=apply_GS)
457
+ self.to_k_hra = HRALinearLayer(cross_attention_dim or hidden_size, hidden_size, r=r, apply_GS=apply_GS)
458
+ self.to_v_hra = HRALinearLayer(cross_attention_dim or hidden_size, hidden_size, r=r, apply_GS=apply_GS)
459
+ self.to_out_hra = HRALinearLayer(hidden_size, hidden_size, r=r, apply_GS=apply_GS)
460
+
461
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
462
+ batch_size, sequence_length, _ = (
463
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
464
+ )
465
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
466
+
467
+ # query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
468
+
469
+ query = self.to_q_hra(attn.to_q, hidden_states)
470
+ query = attn.head_to_batch_dim(query)
471
+
472
+ if encoder_hidden_states is None:
473
+ encoder_hidden_states = hidden_states
474
+ elif attn.norm_cross:
475
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
476
+
477
+ # key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
478
+ key = self.to_k_hra(attn.to_k, encoder_hidden_states)
479
+ # value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
480
+ value = self.to_v_hra(attn.to_v, encoder_hidden_states)
481
+
482
+ key = attn.head_to_batch_dim(key)
483
+ value = attn.head_to_batch_dim(value)
484
+
485
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
486
+ hidden_states = torch.bmm(attention_probs, value)
487
+ hidden_states = attn.batch_to_head_dim(hidden_states)
488
+
489
+ # linear proj
490
+ # hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
491
+ hidden_states = self.to_out_hra(attn.to_out[0], hidden_states)
492
+ # dropout
493
+ hidden_states = attn.to_out[1](hidden_states)
494
+
495
+ return hidden_states
496
+
497
+
498
+ def project(R, eps):
499
+ I = torch.zeros((R.size(0), R.size(0)), dtype=R.dtype, device=R.device)
500
+ diff = R - I
501
+ norm_diff = torch.norm(diff)
502
+ if norm_diff <= eps:
503
+ return R
504
+ else:
505
+ return I + eps * (diff / norm_diff)
506
+
507
+ def project_batch(R, eps=1e-5):
508
+ # scaling factor for each of the smaller block matrix
509
+ eps = eps * 1 / torch.sqrt(torch.tensor(R.shape[0]))
510
+ I = torch.zeros((R.size(1), R.size(1)), device=R.device, dtype=R.dtype).unsqueeze(0).expand_as(R)
511
+ diff = R - I
512
+ norm_diff = torch.norm(R - I, dim=(1, 2), keepdim=True)
513
+ mask = (norm_diff <= eps).bool()
514
+ out = torch.where(mask, R, I + eps * (diff / norm_diff))
515
+ return out
516
+
517
+
518
+ class OFTLinearLayer(nn.Module):
519
+ def __init__(self, in_features, out_features, bias=False, block_share=False, eps=6e-5, r=4, is_coft=False):
520
+ super(OFTLinearLayer, self).__init__()
521
+
522
+ # Define the reduction rate:
523
+ self.r = r
524
+
525
+ # Check whether to use the constrained variant COFT
526
+ self.is_coft = is_coft
527
+
528
+ assert in_features % self.r == 0, "in_features must be divisible by r"
529
+
530
+ # Get the number of available GPUs
531
+ # self.num_gpus = torch.cuda.device_count()
532
+ # Set the device IDs for distributed training
533
+ # self.device_ids = list(range(self.num_gpus))
534
+
535
+ self.in_features=in_features
536
+ self.out_features=out_features
537
+
538
+ self.register_buffer('cross_attention_dim', torch.tensor(in_features))
539
+ self.register_buffer('hidden_size', torch.tensor(out_features))
540
+
541
+ # Define the fixed Linear layer: v
542
+ # self.OFT = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
543
+
544
+ #self.filt_shape = [in_features, in_features]
545
+ self.fix_filt_shape = [in_features, out_features]
546
+
547
+ self.block_share = block_share
548
+ # Define the trainable matrix parameter: R
549
+ if self.block_share:
550
+ # Initialized as an identity matrix
551
+ self.R_shape = [in_features // self.r, in_features // self.r]
552
+ self.R = nn.Parameter(torch.zeros(self.R_shape[0], self.R_shape[0]), requires_grad=True)
553
+
554
+ self.eps = eps * self.R_shape[0] * self.R_shape[0]
555
+ else:
556
+ # Initialized as an identity matrix
557
+ self.R_shape = [self.r, in_features // self.r, in_features // self.r]
558
+ R = torch.zeros(self.R_shape[1], self.R_shape[1])
559
+ R = torch.stack([R] * self.r)
560
+ self.R = nn.Parameter(R, requires_grad=True)
561
+ self.eps = eps * self.R_shape[1] * self.R_shape[1]
562
+
563
+ self.tmp = None
564
+
565
+ def forward(self, attn, x):
566
+ orig_dtype = x.dtype
567
+ dtype = self.R.dtype
568
+
569
+ if self.block_share:
570
+ if self.is_coft:
571
+ with torch.no_grad():
572
+ self.R.copy_(project(self.R, eps=self.eps))
573
+ orth_rotate = self.cayley(self.R)
574
+ else:
575
+ if self.is_coft:
576
+ with torch.no_grad():
577
+ self.R.copy_(project_batch(self.R, eps=self.eps))
578
+ # 如果没有cayley_batch这一步,那么self.R也不会更新
579
+ orth_rotate = self.cayley_batch(self.R)
580
+
581
+ # print('self.tmp[:5, :5]')
582
+ # print(self.tmp[:5, :5])
583
+ # if self.tmp is not None:
584
+ # print('self.R[0, :5, :5] - self.tmp[0, :5, :5]')
585
+ # print(self.R[0, :5, :5] - self.tmp[0, :5, :5])
586
+ # self.tmp = self.R.clone()
587
+
588
+ # Block-diagonal parametrization
589
+ block_diagonal_matrix = self.block_diagonal(orth_rotate)
590
+
591
+ # fix filter
592
+ fix_filt = attn.weight.data
593
+ fix_filt = torch.transpose(fix_filt, 0, 1)
594
+ filt = torch.mm(block_diagonal_matrix, fix_filt.to(dtype))
595
+ filt = torch.transpose(filt, 0, 1)
596
+
597
+ # Apply the trainable identity matrix
598
+ bias_term = attn.bias.data if attn.bias is not None else None
599
+ if bias_term is not None:
600
+ bias_term = bias_term.to(orig_dtype)
601
+
602
+ out = nn.functional.linear(input=x.to(orig_dtype), weight=filt.to(orig_dtype), bias=bias_term)
603
+ # out = nn.functional.linear(input=x, weight=fix_filt.transpose(0, 1), bias=bias_term)
604
+
605
+ return out
606
+
607
+ def cayley(self, data):
608
+ r, c = list(data.shape)
609
+ # Ensure the input matrix is skew-symmetric
610
+ skew = 0.5 * (data - data.t())
611
+ I = torch.eye(r, device=data.device)
612
+ # Perform the Cayley parametrization
613
+ Q = torch.mm(I - skew, torch.inverse(I + skew))
614
+
615
+ return Q
616
+
617
+ def cayley_batch(self, data):
618
+ b, r, c = data.shape
619
+ # Ensure the input matrix is skew-symmetric
620
+ skew = 0.5 * (data - data.transpose(1, 2))
621
+ # I = torch.eye(r, device=data.device).unsqueeze(0).repeat(b, 1, 1)
622
+ I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c)
623
+
624
+ # Perform the Cayley parametrization
625
+ Q = torch.bmm(I - skew, torch.inverse(I + skew))
626
+
627
+ return Q
628
+
629
+ def block_diagonal(self, R):
630
+ if len(R.shape) == 2:
631
+ # Create a list of R repeated block_count times
632
+ blocks = [R] * self.r
633
+ else:
634
+ # Create a list of R slices along the third dimension
635
+ blocks = [R[i, ...] for i in range(self.r)]
636
+
637
+ # Use torch.block_diag to create the block diagonal matrix
638
+ A = torch.block_diag(*blocks)
639
+
640
+ return A
641
+
642
+ def is_orthogonal(self, R, eps=1e-5):
643
+ with torch.no_grad():
644
+ RtR = torch.matmul(R.t(), R)
645
+ diff = torch.abs(RtR - torch.eye(R.shape[1], dtype=R.dtype, device=R.device))
646
+ return torch.all(diff < eps)
647
+
648
+ def is_identity_matrix(self, tensor):
649
+ if not torch.is_tensor(tensor):
650
+ raise TypeError("Input must be a PyTorch tensor.")
651
+ if tensor.ndim != 2 or tensor.shape[0] != tensor.shape[1]:
652
+ return False
653
+ identity = torch.eye(tensor.shape[0], device=tensor.device)
654
+ return torch.all(torch.eq(tensor, identity))
655
+
656
+
657
+ class OFTAttnProcessor(nn.Module):
658
+ def __init__(self, hidden_size, cross_attention_dim=None, eps=2e-5, r=4, is_coft=False):
659
+ super().__init__()
660
+
661
+ self.hidden_size = hidden_size
662
+ self.cross_attention_dim = cross_attention_dim
663
+ self.r = r
664
+ self.is_coft = is_coft
665
+
666
+ self.to_q_oft = OFTLinearLayer(hidden_size, hidden_size, eps=eps, r=r, is_coft=is_coft)
667
+ self.to_k_oft = OFTLinearLayer(cross_attention_dim or hidden_size, hidden_size, eps=eps, r=r, is_coft=is_coft)
668
+ self.to_v_oft = OFTLinearLayer(cross_attention_dim or hidden_size, hidden_size, eps=eps, r=r, is_coft=is_coft)
669
+ self.to_out_oft = OFTLinearLayer(hidden_size, hidden_size, eps=eps, r=r, is_coft=is_coft)
670
+
671
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
672
+ batch_size, sequence_length, _ = (
673
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
674
+ )
675
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
676
+
677
+ # query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
678
+
679
+ query = self.to_q_oft(attn.to_q, hidden_states)
680
+ query = attn.head_to_batch_dim(query)
681
+
682
+ if encoder_hidden_states is None:
683
+ encoder_hidden_states = hidden_states
684
+ elif attn.norm_cross:
685
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
686
+
687
+ # key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
688
+ key = self.to_k_oft(attn.to_k, encoder_hidden_states)
689
+ # value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
690
+ value = self.to_v_oft(attn.to_v, encoder_hidden_states)
691
+
692
+ key = attn.head_to_batch_dim(key)
693
+ value = attn.head_to_batch_dim(value)
694
+
695
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
696
+ hidden_states = torch.bmm(attention_probs, value)
697
+ hidden_states = attn.batch_to_head_dim(hidden_states)
698
+
699
+ # linear proj
700
+ # hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
701
+ hidden_states = self.to_out_oft(attn.to_out[0], hidden_states)
702
+ # dropout
703
+ hidden_states = attn.to_out[1](hidden_states)
704
+
705
+ return hidden_states
706
+
707
+
708
+ class AttnAddedKVProcessor:
709
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
710
+ residual = hidden_states
711
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
712
+ batch_size, sequence_length, _ = hidden_states.shape
713
+
714
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
715
+
716
+ if encoder_hidden_states is None:
717
+ encoder_hidden_states = hidden_states
718
+ elif attn.norm_cross:
719
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
720
+
721
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
722
+
723
+ query = attn.to_q(hidden_states)
724
+ query = attn.head_to_batch_dim(query)
725
+
726
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
727
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
728
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
729
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
730
+
731
+ if not attn.only_cross_attention:
732
+ key = attn.to_k(hidden_states)
733
+ value = attn.to_v(hidden_states)
734
+ key = attn.head_to_batch_dim(key)
735
+ value = attn.head_to_batch_dim(value)
736
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
737
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
738
+ else:
739
+ key = encoder_hidden_states_key_proj
740
+ value = encoder_hidden_states_value_proj
741
+
742
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
743
+ hidden_states = torch.bmm(attention_probs, value)
744
+ hidden_states = attn.batch_to_head_dim(hidden_states)
745
+
746
+ # linear proj
747
+ hidden_states = attn.to_out[0](hidden_states)
748
+ # dropout
749
+ hidden_states = attn.to_out[1](hidden_states)
750
+
751
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
752
+ hidden_states = hidden_states + residual
753
+
754
+ return hidden_states
755
+
756
+
757
+ class AttnAddedKVProcessor2_0:
758
+ def __init__(self):
759
+ if not hasattr(F, "scaled_dot_product_attention"):
760
+ raise ImportError(
761
+ "AttnAddedKVProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
762
+ )
763
+
764
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
765
+ residual = hidden_states
766
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
767
+ batch_size, sequence_length, _ = hidden_states.shape
768
+
769
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size, out_dim=4)
770
+
771
+ if encoder_hidden_states is None:
772
+ encoder_hidden_states = hidden_states
773
+ elif attn.norm_cross:
774
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
775
+
776
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
777
+
778
+ query = attn.to_q(hidden_states)
779
+ query = attn.head_to_batch_dim(query, out_dim=4)
780
+
781
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
782
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
783
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj, out_dim=4)
784
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
785
+
786
+ if not attn.only_cross_attention:
787
+ key = attn.to_k(hidden_states)
788
+ value = attn.to_v(hidden_states)
789
+ key = attn.head_to_batch_dim(key, out_dim=4)
790
+ value = attn.head_to_batch_dim(value, out_dim=4)
791
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
792
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
793
+ else:
794
+ key = encoder_hidden_states_key_proj
795
+ value = encoder_hidden_states_value_proj
796
+
797
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
798
+ # TODO: add support for attn.scale when we move to Torch 2.1
799
+ hidden_states = F.scaled_dot_product_attention(
800
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
801
+ )
802
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
803
+
804
+ # linear proj
805
+ hidden_states = attn.to_out[0](hidden_states)
806
+ # dropout
807
+ hidden_states = attn.to_out[1](hidden_states)
808
+
809
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
810
+ hidden_states = hidden_states + residual
811
+
812
+ return hidden_states
813
+
814
+
815
+ class XFormersAttnProcessor:
816
+ def __init__(self, attention_op: Optional[Callable] = None):
817
+ self.attention_op = attention_op
818
+
819
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
820
+ batch_size, sequence_length, _ = (
821
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
822
+ )
823
+
824
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
825
+
826
+ query = attn.to_q(hidden_states)
827
+
828
+ if encoder_hidden_states is None:
829
+ encoder_hidden_states = hidden_states
830
+ elif attn.norm_cross:
831
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
832
+
833
+ key = attn.to_k(encoder_hidden_states)
834
+ value = attn.to_v(encoder_hidden_states)
835
+
836
+ query = attn.head_to_batch_dim(query).contiguous()
837
+ key = attn.head_to_batch_dim(key).contiguous()
838
+ value = attn.head_to_batch_dim(value).contiguous()
839
+
840
+ hidden_states = xformers.ops.memory_efficient_attention(
841
+ query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
842
+ )
843
+ hidden_states = hidden_states.to(query.dtype)
844
+ hidden_states = attn.batch_to_head_dim(hidden_states)
845
+
846
+ # linear proj
847
+ hidden_states = attn.to_out[0](hidden_states)
848
+ # dropout
849
+ hidden_states = attn.to_out[1](hidden_states)
850
+ return hidden_states
851
+
852
+
853
+ class AttnProcessor2_0:
854
+ def __init__(self):
855
+ if not hasattr(F, "scaled_dot_product_attention"):
856
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
857
+
858
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
859
+ batch_size, sequence_length, _ = (
860
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
861
+ )
862
+ inner_dim = hidden_states.shape[-1]
863
+
864
+ if attention_mask is not None:
865
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
866
+ # scaled_dot_product_attention expects attention_mask shape to be
867
+ # (batch, heads, source_length, target_length)
868
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
869
+
870
+ query = attn.to_q(hidden_states)
871
+
872
+ if encoder_hidden_states is None:
873
+ encoder_hidden_states = hidden_states
874
+ elif attn.norm_cross:
875
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
876
+
877
+ key = attn.to_k(encoder_hidden_states)
878
+ value = attn.to_v(encoder_hidden_states)
879
+
880
+ head_dim = inner_dim // attn.heads
881
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
882
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
883
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
884
+
885
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
886
+ # TODO: add support for attn.scale when we move to Torch 2.1
887
+ hidden_states = F.scaled_dot_product_attention(
888
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
889
+ )
890
+
891
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
892
+ hidden_states = hidden_states.to(query.dtype)
893
+
894
+ # linear proj
895
+ hidden_states = attn.to_out[0](hidden_states)
896
+ # dropout
897
+ hidden_states = attn.to_out[1](hidden_states)
898
+ return hidden_states
899
+
900
+
901
+ class SlicedAttnProcessor:
902
+ def __init__(self, slice_size):
903
+ self.slice_size = slice_size
904
+
905
+ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
906
+ batch_size, sequence_length, _ = (
907
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
908
+ )
909
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
910
+
911
+ query = attn.to_q(hidden_states)
912
+ dim = query.shape[-1]
913
+ query = attn.head_to_batch_dim(query)
914
+
915
+ if encoder_hidden_states is None:
916
+ encoder_hidden_states = hidden_states
917
+ elif attn.norm_cross:
918
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
919
+
920
+ key = attn.to_k(encoder_hidden_states)
921
+ value = attn.to_v(encoder_hidden_states)
922
+ key = attn.head_to_batch_dim(key)
923
+ value = attn.head_to_batch_dim(value)
924
+
925
+ batch_size_attention, query_tokens, _ = query.shape
926
+ hidden_states = torch.zeros(
927
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
928
+ )
929
+
930
+ for i in range(batch_size_attention // self.slice_size):
931
+ start_idx = i * self.slice_size
932
+ end_idx = (i + 1) * self.slice_size
933
+
934
+ query_slice = query[start_idx:end_idx]
935
+ key_slice = key[start_idx:end_idx]
936
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
937
+
938
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
939
+
940
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
941
+
942
+ hidden_states[start_idx:end_idx] = attn_slice
943
+
944
+ hidden_states = attn.batch_to_head_dim(hidden_states)
945
+
946
+ # linear proj
947
+ hidden_states = attn.to_out[0](hidden_states)
948
+ # dropout
949
+ hidden_states = attn.to_out[1](hidden_states)
950
+
951
+ return hidden_states
952
+
953
+
954
+ class SlicedAttnAddedKVProcessor:
955
+ def __init__(self, slice_size):
956
+ self.slice_size = slice_size
957
+
958
+ def __call__(self, attn: "Attention", hidden_states, encoder_hidden_states=None, attention_mask=None):
959
+ residual = hidden_states
960
+ hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
961
+
962
+ batch_size, sequence_length, _ = hidden_states.shape
963
+
964
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
965
+
966
+ if encoder_hidden_states is None:
967
+ encoder_hidden_states = hidden_states
968
+ elif attn.norm_cross:
969
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
970
+
971
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
972
+
973
+ query = attn.to_q(hidden_states)
974
+ dim = query.shape[-1]
975
+ query = attn.head_to_batch_dim(query)
976
+
977
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
978
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
979
+
980
+ encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
981
+ encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
982
+
983
+ if not attn.only_cross_attention:
984
+ key = attn.to_k(hidden_states)
985
+ value = attn.to_v(hidden_states)
986
+ key = attn.head_to_batch_dim(key)
987
+ value = attn.head_to_batch_dim(value)
988
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
989
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
990
+ else:
991
+ key = encoder_hidden_states_key_proj
992
+ value = encoder_hidden_states_value_proj
993
+
994
+ batch_size_attention, query_tokens, _ = query.shape
995
+ hidden_states = torch.zeros(
996
+ (batch_size_attention, query_tokens, dim // attn.heads), device=query.device, dtype=query.dtype
997
+ )
998
+
999
+ for i in range(batch_size_attention // self.slice_size):
1000
+ start_idx = i * self.slice_size
1001
+ end_idx = (i + 1) * self.slice_size
1002
+
1003
+ query_slice = query[start_idx:end_idx]
1004
+ key_slice = key[start_idx:end_idx]
1005
+ attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
1006
+
1007
+ attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
1008
+
1009
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
1010
+
1011
+ hidden_states[start_idx:end_idx] = attn_slice
1012
+
1013
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1014
+
1015
+ # linear proj
1016
+ hidden_states = attn.to_out[0](hidden_states)
1017
+ # dropout
1018
+ hidden_states = attn.to_out[1](hidden_states)
1019
+
1020
+ hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
1021
+ hidden_states = hidden_states + residual
1022
+
1023
+ return hidden_states
1024
+
1025
+
1026
+ AttentionProcessor = Union[
1027
+ AttnProcessor,
1028
+ AttnProcessor2_0,
1029
+ XFormersAttnProcessor,
1030
+ SlicedAttnProcessor,
1031
+ AttnAddedKVProcessor,
1032
+ SlicedAttnAddedKVProcessor,
1033
+ AttnAddedKVProcessor2_0,
1034
+ OFTAttnProcessor,
1035
+ HRAAttnProcessor
1036
+ ]
generation/subject/oft_utils/mhe.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import math
6
+
7
+
8
+ import copy
9
+ import numpy as np
10
+
11
+ class MHE_LoRA(nn.Module):
12
+ def __init__(self, model):
13
+ super(MHE_LoRA, self).__init__()
14
+ # self.model = copy.deepcopy(model)
15
+ self.model = self.copy_without_grad(model)
16
+
17
+ self.extracted_params = {}
18
+ keys_to_delete = []
19
+ # for name, param in self.model.named_parameters():
20
+ # self.extracted_params[name] = param
21
+
22
+ for name, tensor in model.state_dict().items():
23
+ self.extracted_params[name] = tensor.detach().clone()
24
+
25
+ for name in self.extracted_params:
26
+ if 'attn' in name and 'processor' not in name:
27
+ if 'weight' in name:
28
+ if 'to_q' in name:
29
+ lora_down = name.replace('to_q', 'processor.to_q_lora.down')
30
+ lora_up = name.replace('to_q', 'processor.to_q_lora.up')
31
+ elif 'to_k' in name:
32
+ lora_down = name.replace('to_k', 'processor.to_k_lora.down')
33
+ lora_up = name.replace('to_k', 'processor.to_k_lora.up')
34
+ elif 'to_v' in name:
35
+ lora_down = name.replace('to_v', 'processor.to_v_lora.down')
36
+ lora_up = name.replace('to_v', 'processor.to_v_lora.up')
37
+ elif 'to_out' in name:
38
+ lora_down = name.replace('to_out.0', 'processor.to_out_lora.down')
39
+ lora_up = name.replace('to_out.0', 'processor.to_out_lora.up')
40
+ else:
41
+ pass
42
+ with torch.no_grad():
43
+ self.extracted_params[name] += self.extracted_params[lora_up].cuda() @ self.extracted_params[lora_down].cuda()
44
+ keys_to_delete.append(lora_up)
45
+ keys_to_delete.append(lora_down)
46
+
47
+ for key in keys_to_delete:
48
+ del self.extracted_params[key]
49
+
50
+ def copy_without_grad(self, model):
51
+ copied_model = copy.deepcopy(model)
52
+ for param in copied_model.parameters():
53
+ param.requires_grad = False
54
+ param.detach_()
55
+ return copied_model
56
+
57
+ @staticmethod
58
+ def mhe_loss(filt):
59
+ if len(filt.shape) == 2:
60
+ n_filt, _ = filt.shape
61
+ filt = torch.transpose(filt, 0, 1)
62
+ filt_neg = filt * (-1)
63
+ filt = torch.cat((filt, filt_neg), dim=1)
64
+ n_filt *= 2
65
+
66
+ filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
67
+ norm_mat = torch.matmul(filt_norm.t(), filt_norm)
68
+ inner_pro = torch.matmul(filt.t(), filt)
69
+ inner_pro /= norm_mat
70
+
71
+ cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
72
+ final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
73
+ final -= torch.tril(final)
74
+ cnt = n_filt * (n_filt - 1) / 2.0
75
+ MHE_loss = 1 * torch.sum(final) / cnt
76
+
77
+ else:
78
+ n_filt, _, _, _ = filt.shape
79
+ filt = filt.reshape(n_filt, -1)
80
+ filt = torch.transpose(filt, 0, 1)
81
+ filt_neg = filt * -1
82
+ filt = torch.cat((filt, filt_neg), dim=1)
83
+ n_filt *= 2
84
+
85
+ filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
86
+ norm_mat = torch.matmul(filt_norm.t(), filt_norm)
87
+ inner_pro = torch.matmul(filt.t(), filt)
88
+ inner_pro /= norm_mat
89
+
90
+ cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
91
+ final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
92
+ final -= torch.tril(final)
93
+ cnt = n_filt * (n_filt - 1) / 2.0
94
+ MHE_loss = 1 * torch.sum(final) / cnt
95
+
96
+ return MHE_loss
97
+
98
+ def calculate_mhe(self):
99
+ mhe_loss = []
100
+ with torch.no_grad():
101
+ for name in self.extracted_params:
102
+ weight = self.extracted_params[name]
103
+ # linear layer or conv layer
104
+ if len(weight.shape) == 2 or len(weight.shape) == 4:
105
+ loss = self.mhe_loss(weight)
106
+ mhe_loss.append(loss.cpu().detach().item())
107
+ mhe_loss = np.array(mhe_loss)
108
+ return mhe_loss.sum()
109
+
110
+
111
+ def project(R, eps):
112
+ I = torch.zeros((R.size(0), R.size(0)), dtype=R.dtype, device=R.device)
113
+ diff = R - I
114
+ norm_diff = torch.norm(diff)
115
+ if norm_diff <= eps:
116
+ return R
117
+ else:
118
+ return I + eps * (diff / norm_diff)
119
+
120
+ def project_batch(R, eps=1e-5):
121
+ # scaling factor for each of the smaller block matrix
122
+ eps = eps * 1 / torch.sqrt(torch.tensor(R.shape[0]))
123
+ I = torch.zeros((R.size(1), R.size(1)), device=R.device, dtype=R.dtype).unsqueeze(0).expand_as(R)
124
+ diff = R - I
125
+ norm_diff = torch.norm(R - I, dim=(1, 2), keepdim=True)
126
+ mask = (norm_diff <= eps).bool()
127
+ out = torch.where(mask, R, I + eps * (diff / norm_diff))
128
+ return out
129
+
130
+
131
+ class MHE_OFT(nn.Module):
132
+ def __init__(self, model, eps=6e-5, r=4):
133
+ super(MHE_OFT, self).__init__()
134
+ # self.model = copy.deepcopy(model)
135
+ # self.model = self.copy_without_grad(model)
136
+
137
+ self.r = r
138
+
139
+ self.extracted_params = {}
140
+ keys_to_delete = []
141
+ # for name, param in self.model.named_parameters():
142
+ # self.extracted_params[name] = param
143
+
144
+ for name, tensor in model.state_dict().items():
145
+ self.extracted_params[name] = tensor.detach().clone()
146
+
147
+ for name in self.extracted_params:
148
+ if 'attn' in name and 'processor' not in name:
149
+ if 'weight' in name:
150
+ if 'to_q' in name:
151
+ oft_R = name.replace('to_q.weight', 'processor.to_q_oft.R')
152
+ elif 'to_k' in name:
153
+ oft_R = name.replace('to_k.weight', 'processor.to_k_oft.R')
154
+ elif 'to_v' in name:
155
+ oft_R = name.replace('to_v.weight', 'processor.to_v_oft.R')
156
+ elif 'to_out' in name:
157
+ oft_R = name.replace('to_out.0.weight', 'processor.to_out_oft.R')
158
+ else:
159
+ pass
160
+
161
+ R = self.extracted_params[oft_R].cuda()
162
+
163
+ with torch.no_grad():
164
+ if len(R.shape) == 2:
165
+ self.eps = eps * R.shape[0] * R.shape[0]
166
+ R.copy_(project(R, eps=self.eps))
167
+ orth_rotate = self.cayley(R)
168
+ else:
169
+ self.eps = eps * R.shape[1] * R.shape[0]
170
+ R.copy_(project_batch(R, eps=self.eps))
171
+ orth_rotate = self.cayley_batch(R)
172
+
173
+ self.extracted_params[name] = self.extracted_params[name] @ self.block_diagonal(orth_rotate)
174
+ keys_to_delete.append(oft_R)
175
+
176
+ for key in keys_to_delete:
177
+ del self.extracted_params[key]
178
+
179
+ def is_orthogonal(self, R, eps=1e-5):
180
+ with torch.no_grad():
181
+ RtR = torch.matmul(R.t(), R)
182
+ diff = torch.abs(RtR - torch.eye(R.shape[1], dtype=R.dtype, device=R.device))
183
+ return torch.all(diff < eps)
184
+
185
+ def block_diagonal(self, R):
186
+ if len(R.shape) == 2:
187
+ # Create a list of R repeated block_count times
188
+ blocks = [R] * self.r
189
+ else:
190
+ # Create a list of R slices along the third dimension
191
+ blocks = [R[i, ...] for i in range(R.shape[0])]
192
+
193
+ # Use torch.block_diag to create the block diagonal matrix
194
+ A = torch.block_diag(*blocks)
195
+
196
+ return A
197
+
198
+ def copy_without_grad(self, model):
199
+ copied_model = copy.deepcopy(model)
200
+ for param in copied_model.parameters():
201
+ param.requires_grad = False
202
+ param.detach_()
203
+ return copied_model
204
+
205
+ def cayley(self, data):
206
+ r, c = list(data.shape)
207
+ # Ensure the input matrix is skew-symmetric
208
+ skew = 0.5 * (data - data.t())
209
+ I = torch.eye(r, device=data.device)
210
+ # Perform the Cayley parametrization
211
+ Q = torch.mm(I + skew, torch.inverse(I - skew))
212
+ return Q
213
+
214
+ def cayley_batch(self, data):
215
+ b, r, c = data.shape
216
+ # Ensure the input matrix is skew-symmetric
217
+ skew = 0.5 * (data - data.transpose(1, 2))
218
+ # I = torch.eye(r, device=data.device).unsqueeze(0).repeat(b, 1, 1)
219
+ I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c)
220
+
221
+ # Perform the Cayley parametrization
222
+ Q = torch.bmm(I + skew, torch.inverse(I - skew))
223
+
224
+ return Q
225
+
226
+ @staticmethod
227
+ def mhe_loss(filt):
228
+ if len(filt.shape) == 2:
229
+ n_filt, _ = filt.shape
230
+ filt = torch.transpose(filt, 0, 1)
231
+ filt_neg = filt * (-1)
232
+ filt = torch.cat((filt, filt_neg), dim=1)
233
+ n_filt *= 2
234
+
235
+ filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
236
+ norm_mat = torch.matmul(filt_norm.t(), filt_norm)
237
+ inner_pro = torch.matmul(filt.t(), filt)
238
+ inner_pro /= norm_mat
239
+
240
+ cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
241
+ final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
242
+ final -= torch.tril(final)
243
+ cnt = n_filt * (n_filt - 1) / 2.0
244
+ MHE_loss = 1 * torch.sum(final) / cnt
245
+
246
+ else:
247
+ n_filt, _, _, _ = filt.shape
248
+ filt = filt.reshape(n_filt, -1)
249
+ filt = torch.transpose(filt, 0, 1)
250
+ filt_neg = filt * -1
251
+ filt = torch.cat((filt, filt_neg), dim=1)
252
+ n_filt *= 2
253
+
254
+ filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
255
+ norm_mat = torch.matmul(filt_norm.t(), filt_norm)
256
+ inner_pro = torch.matmul(filt.t(), filt)
257
+ inner_pro /= norm_mat
258
+
259
+ cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
260
+ final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
261
+ final -= torch.tril(final)
262
+ cnt = n_filt * (n_filt - 1) / 2.0
263
+ MHE_loss = 1 * torch.sum(final) / cnt
264
+
265
+ return MHE_loss
266
+
267
+ def calculate_mhe(self):
268
+ mhe_loss = []
269
+ with torch.no_grad():
270
+ for name in self.extracted_params:
271
+ weight = self.extracted_params[name]
272
+ # linear layer or conv layer
273
+ if len(weight.shape) == 2 or len(weight.shape) == 4:
274
+ loss = self.mhe_loss(weight)
275
+ mhe_loss.append(loss.cpu().detach().item())
276
+ mhe_loss = np.array(mhe_loss)
277
+ return mhe_loss.sum()
278
+
279
+ def is_orthogonal(self, R, eps=1e-5):
280
+ with torch.no_grad():
281
+ RtR = torch.matmul(R.t(), R)
282
+ diff = torch.abs(RtR - torch.eye(R.shape[1], dtype=R.dtype, device=R.device))
283
+ return torch.all(diff < eps)
284
+
285
+ def is_identity_matrix(self, tensor):
286
+ if not torch.is_tensor(tensor):
287
+ raise TypeError("Input must be a PyTorch tensor.")
288
+ if tensor.ndim != 2 or tensor.shape[0] != tensor.shape[1]:
289
+ return False
290
+ identity = torch.eye(tensor.shape[0], device=tensor.device)
291
+ return torch.all(torch.eq(tensor, identity))
292
+
293
+
294
+
295
+ class MHE_db:
296
+ def __init__(self, model):
297
+ # self.model = copy.deepcopy(model)
298
+ # self.model.load_state_dict(model.state_dict())
299
+ # self.model = self.copy_without_grad(model)
300
+
301
+ #self.extracted_params = {}
302
+ #for name, param in model.named_parameters():
303
+ # self.extracted_params[name] = param
304
+
305
+ self.extracted_params = {}
306
+ for name, tensor in model.state_dict().items():
307
+ self.extracted_params[name] = tensor.detach().clone()
308
+
309
+ @staticmethod
310
+ def mhe_loss(filt):
311
+ if len(filt.shape) == 2:
312
+ n_filt, _ = filt.shape
313
+ filt = torch.transpose(filt, 0, 1)
314
+ filt_neg = filt * (-1)
315
+ filt = torch.cat((filt, filt_neg), dim=1)
316
+ n_filt *= 2
317
+
318
+ filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
319
+ norm_mat = torch.matmul(filt_norm.t(), filt_norm)
320
+ inner_pro = torch.matmul(filt.t(), filt)
321
+ inner_pro /= norm_mat
322
+
323
+ cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
324
+ final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
325
+ final -= torch.tril(final)
326
+ cnt = n_filt * (n_filt - 1) / 2.0
327
+ MHE_loss = 1 * torch.sum(final) / cnt
328
+
329
+ else:
330
+ n_filt, _, _, _ = filt.shape
331
+ filt = filt.reshape(n_filt, -1)
332
+ filt = torch.transpose(filt, 0, 1)
333
+ filt_neg = filt * -1
334
+ filt = torch.cat((filt, filt_neg), dim=1)
335
+ n_filt *= 2
336
+
337
+ filt_norm = torch.sqrt(torch.sum(filt * filt, dim=0, keepdim=True) + 1e-4)
338
+ norm_mat = torch.matmul(filt_norm.t(), filt_norm)
339
+ inner_pro = torch.matmul(filt.t(), filt)
340
+ inner_pro /= norm_mat
341
+
342
+ cross_terms = (2.0 - 2.0 * inner_pro + torch.diag(torch.tensor([1.0] * n_filt)).cuda())
343
+ final = torch.pow(cross_terms, torch.ones_like(cross_terms) * (-0.5))
344
+ final -= torch.tril(final)
345
+ cnt = n_filt * (n_filt - 1) / 2.0
346
+ MHE_loss = 1 * torch.sum(final) / cnt
347
+
348
+ return MHE_loss
349
+
350
+ def calculate_mhe(self):
351
+ mhe_loss = []
352
+ with torch.no_grad():
353
+ for name in self.extracted_params:
354
+ weight = self.extracted_params[name]
355
+ # linear layer or conv layer
356
+ if len(weight.shape) == 2 or len(weight.shape) == 4:
357
+ loss = self.mhe_loss(weight)
358
+ mhe_loss.append(loss.cpu().detach().item())
359
+ mhe_loss = np.array(mhe_loss)
360
+ return mhe_loss.sum()
generation/subject/train_dreambooth_hra.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+
16
+ import argparse
17
+ import hashlib
18
+ import logging
19
+ import math
20
+ import os
21
+ import warnings
22
+ from pathlib import Path
23
+
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn.functional as F
27
+ import torch.utils.checkpoint
28
+ import transformers
29
+ from accelerate import Accelerator
30
+ from accelerate.logging import get_logger
31
+ from accelerate.utils import ProjectConfiguration, set_seed
32
+ from huggingface_hub import create_repo, upload_folder
33
+ from packaging import version
34
+ from PIL import Image
35
+ from torch.utils.data import Dataset
36
+ from torchvision import transforms
37
+ from tqdm.auto import tqdm
38
+ from transformers import AutoTokenizer, PretrainedConfig
39
+
40
+ import diffusers
41
+ from diffusers import (
42
+ AutoencoderKL,
43
+ DDPMScheduler,
44
+ DiffusionPipeline,
45
+ DPMSolverMultistepScheduler,
46
+ UNet2DConditionModel,
47
+ )
48
+ from diffusers.loaders import AttnProcsLayers
49
+ from oft_utils.attention_processor import HRAAttnProcessor
50
+ from diffusers.optimization import get_scheduler
51
+ from diffusers.utils import check_min_version, is_wandb_available
52
+ from diffusers.utils.import_utils import is_xformers_available
53
+ from oft_utils.mhe import MHE_OFT as MHE
54
+
55
+
56
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
57
+ check_min_version("0.16.0.dev0")
58
+
59
+ logger = get_logger(__name__)
60
+
61
+
62
+ def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
63
+ img_str = ""
64
+ for i, image in enumerate(images):
65
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
66
+ img_str += f"![img_{i}](./image_{i}.png)\n"
67
+
68
+ yaml = f"""
69
+ ---
70
+ license: creativeml-openrail-m
71
+ base_model: {base_model}
72
+ instance_prompt: {prompt}
73
+ tags:
74
+ - stable-diffusion
75
+ - stable-diffusion-diffusers
76
+ - text-to-image
77
+ - diffusers
78
+ - oft
79
+ inference: true
80
+ ---
81
+ """
82
+ model_card = f"""
83
+ # OFT DreamBooth - {repo_id}
84
+
85
+ These are OFT adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
86
+ {img_str}
87
+ """
88
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
89
+ f.write(yaml + model_card)
90
+
91
+
92
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
93
+ text_encoder_config = PretrainedConfig.from_pretrained(
94
+ pretrained_model_name_or_path,
95
+ subfolder="text_encoder",
96
+ revision=revision,
97
+ )
98
+ model_class = text_encoder_config.architectures[0]
99
+
100
+ if model_class == "CLIPTextModel":
101
+ from transformers import CLIPTextModel
102
+
103
+ return CLIPTextModel
104
+ elif model_class == "RobertaSeriesModelWithTransformation":
105
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
106
+
107
+ return RobertaSeriesModelWithTransformation
108
+ else:
109
+ raise ValueError(f"{model_class} is not supported.")
110
+
111
+
112
+ def parse_args(input_args=None):
113
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
114
+ parser.add_argument(
115
+ "--pretrained_model_name_or_path",
116
+ type=str,
117
+ default='runwayml/stable-diffusion-v1-5',
118
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
119
+ )
120
+ parser.add_argument(
121
+ "--revision",
122
+ type=str,
123
+ default=None,
124
+ help="Revision of pretrained model identifier from huggingface.co/models.",
125
+ )
126
+ parser.add_argument(
127
+ "--tokenizer_name",
128
+ type=str,
129
+ default=None,
130
+ help="Pretrained tokenizer name or path if not the same as model_name",
131
+ )
132
+ parser.add_argument(
133
+ "--instance_data_dir",
134
+ type=str,
135
+ default='../data/dreambooth/backpack',
136
+ help="A folder containing the training data of instance images.",
137
+ )
138
+ parser.add_argument(
139
+ "--class_data_dir",
140
+ type=str,
141
+ default='data/class_data/backpack',
142
+ help="A folder containing the training data of class images.",
143
+ )
144
+ parser.add_argument(
145
+ "--instance_prompt",
146
+ type=str,
147
+ default='a photo of qwe backpack',
148
+ help="The prompt with identifier specifying the instance",
149
+ )
150
+ parser.add_argument(
151
+ "--class_prompt",
152
+ type=str,
153
+ default='a photo of backpack',
154
+ help="The prompt to specify images in the same class as provided instance images.",
155
+ )
156
+ parser.add_argument(
157
+ "--validation_prompt",
158
+ type=str,
159
+ default='a qwe backpack in the jungle',
160
+ help="A prompt that is used during validation to verify that the model is learning.",
161
+ )
162
+ parser.add_argument(
163
+ "--test_prompt",
164
+ type=str,
165
+ default=None,
166
+ help="A prompt that is used during validation to verify that the model is keeps class prior.",
167
+ )
168
+ parser.add_argument(
169
+ "--num_validation_images",
170
+ type=int,
171
+ default=4,
172
+ help="Number of images that should be generated during validation with `validation_prompt`.",
173
+ )
174
+ parser.add_argument(
175
+ "--validation_epochs",
176
+ type=int,
177
+ default=1,
178
+ help=(
179
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
180
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
181
+ ),
182
+ )
183
+ parser.add_argument(
184
+ "--with_prior_preservation",
185
+ default=True,
186
+ action="store_true",
187
+ help="Flag to add prior preservation loss.",
188
+ )
189
+ parser.add_argument(
190
+ "--prior_loss_weight",
191
+ type=float,
192
+ default=1.0,
193
+ help="The weight of prior preservation loss."
194
+ )
195
+ parser.add_argument(
196
+ "--num_class_images",
197
+ type=int,
198
+ default=200,
199
+ help=(
200
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
201
+ " class_data_dir, additional images will be sampled with class_prompt."
202
+ ),
203
+ )
204
+ parser.add_argument(
205
+ "--output_dir",
206
+ type=str,
207
+ default="log_hra/backpack-0",
208
+ help="The output directory where the model predictions and checkpoints will be written.",
209
+ )
210
+ parser.add_argument(
211
+ "--seed",
212
+ type=int,
213
+ default=0,
214
+ help="A seed for reproducible training."
215
+ )
216
+ parser.add_argument(
217
+ "--resolution",
218
+ type=int,
219
+ default=512,
220
+ help=(
221
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
222
+ " resolution"
223
+ ),
224
+ )
225
+ parser.add_argument(
226
+ "--center_crop",
227
+ default=False,
228
+ action="store_true",
229
+ help=(
230
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
231
+ " cropped. The images will be resized to the resolution first before cropping."
232
+ ),
233
+ )
234
+ parser.add_argument(
235
+ "--train_batch_size",
236
+ type=int,
237
+ default=1,
238
+ help="Batch size (per device) for the training dataloader."
239
+ )
240
+ parser.add_argument(
241
+ "--sample_batch_size",
242
+ type=int,
243
+ default=4,
244
+ help="Batch size (per device) for sampling images.",
245
+ )
246
+ parser.add_argument(
247
+ "--num_train_epochs",
248
+ type=int,
249
+ default=1,
250
+ )
251
+ parser.add_argument(
252
+ "--max_train_steps",
253
+ type=int,
254
+ default=2005,
255
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
256
+ )
257
+ parser.add_argument(
258
+ "--checkpointing_steps",
259
+ type=int,
260
+ default=5000,
261
+ help=(
262
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
263
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
264
+ " training using `--resume_from_checkpoint`."
265
+ ),
266
+ )
267
+ parser.add_argument(
268
+ "--checkpoints_total_limit",
269
+ type=int,
270
+ default=None,
271
+ help=(
272
+ "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
273
+ " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
274
+ " for more docs"
275
+ ),
276
+ )
277
+ parser.add_argument(
278
+ "--resume_from_checkpoint",
279
+ type=str,
280
+ default=None,
281
+ help=(
282
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
283
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
284
+ ),
285
+ )
286
+ parser.add_argument(
287
+ "--gradient_accumulation_steps",
288
+ type=int,
289
+ default=1,
290
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
291
+ )
292
+ parser.add_argument(
293
+ "--gradient_checkpointing",
294
+ action="store_true",
295
+ default=False,
296
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
297
+ )
298
+ parser.add_argument(
299
+ "--learning_rate",
300
+ type=float,
301
+ default=6e-05,
302
+ help="Initial learning rate (after the potential warmup period) to use.",
303
+ )
304
+ parser.add_argument(
305
+ "--scale_lr",
306
+ action="store_true",
307
+ default=False,
308
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
309
+ )
310
+ parser.add_argument(
311
+ "--lr_scheduler",
312
+ type=str,
313
+ default="constant",
314
+ help=(
315
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
316
+ ' "constant", "constant_with_warmup"]'
317
+ ),
318
+ )
319
+ parser.add_argument(
320
+ "--lr_warmup_steps",
321
+ type=int,
322
+ default=0,
323
+ help="Number of steps for the warmup in the lr scheduler."
324
+ )
325
+ parser.add_argument(
326
+ "--lr_num_cycles",
327
+ type=int,
328
+ default=1,
329
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
330
+ )
331
+ parser.add_argument(
332
+ "--lr_power",
333
+ type=float,
334
+ default=1.0,
335
+ help="Power factor of the polynomial scheduler.",
336
+ )
337
+ parser.add_argument(
338
+ "--dataloader_num_workers",
339
+ type=int,
340
+ default=0,
341
+ help=(
342
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
343
+ ),
344
+ )
345
+ parser.add_argument(
346
+ "--use_8bit_adam",
347
+ action="store_true",
348
+ help="Whether or not to use 8-bit Adam from bitsandbytes."
349
+ )
350
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
351
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
352
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
353
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
354
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
355
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
356
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
357
+ parser.add_argument(
358
+ "--hub_model_id",
359
+ type=str,
360
+ default=None,
361
+ help="The name of the repository to keep in sync with the local `output_dir`.",
362
+ )
363
+ parser.add_argument(
364
+ "--logging_dir",
365
+ type=str,
366
+ default="logs",
367
+ help=(
368
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
369
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
370
+ ),
371
+ )
372
+ parser.add_argument(
373
+ "--allow_tf32",
374
+ action="store_true",
375
+ help=(
376
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
377
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
378
+ ),
379
+ )
380
+ parser.add_argument(
381
+ "--report_to",
382
+ type=str,
383
+ default="wandb",
384
+ help=(
385
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
386
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
387
+ ),
388
+ )
389
+ parser.add_argument(
390
+ "--mixed_precision",
391
+ type=str,
392
+ default=None,
393
+ choices=["no", "fp16", "bf16"],
394
+ help=(
395
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
396
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
397
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
398
+ ),
399
+ )
400
+ parser.add_argument(
401
+ "--prior_generation_precision",
402
+ type=str,
403
+ default=None,
404
+ choices=["no", "fp32", "fp16", "bf16"],
405
+ help=(
406
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
407
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
408
+ ),
409
+ )
410
+ parser.add_argument("--local_rank", type=int, default=6, help="For distributed training: local_rank")
411
+ parser.add_argument(
412
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
413
+ )
414
+ parser.add_argument(
415
+ "--name",
416
+ type=str,
417
+ default='backpack-0',
418
+ help=(
419
+ "The name of the current experiment run, consists of [data]-[prompt]"
420
+ ),
421
+ )
422
+ parser.add_argument(
423
+ "--hra_r",
424
+ type=int,
425
+ default=8,
426
+ help=(
427
+ "The rank of HRA across different layers. It is best to set 'r' to an even number; otherwise, the default initialization method will not work."
428
+ ),
429
+ )
430
+ parser.add_argument(
431
+ "--hra_apply_GS",
432
+ action='store_true',
433
+ default=False,
434
+ help=(
435
+ "Whether to apply Gram-Schmidt orthogonalization."
436
+ ),
437
+ )
438
+ if input_args is not None:
439
+ args = parser.parse_args(input_args)
440
+ else:
441
+ args = parser.parse_args()
442
+
443
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
444
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
445
+ args.local_rank = env_local_rank
446
+
447
+ if args.with_prior_preservation:
448
+ if args.class_data_dir is None:
449
+ raise ValueError("You must specify a data directory for class images.")
450
+ if args.class_prompt is None:
451
+ raise ValueError("You must specify prompt for class images.")
452
+ else:
453
+ # logger is not available yet
454
+ if args.class_data_dir is not None:
455
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
456
+ if args.class_prompt is not None:
457
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
458
+
459
+ return args
460
+
461
+
462
+ class DreamBoothDataset(Dataset):
463
+ """
464
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
465
+ It pre-processes the images and the tokenizes prompts.
466
+ """
467
+
468
+ def __init__(
469
+ self,
470
+ instance_data_root,
471
+ instance_prompt,
472
+ tokenizer,
473
+ class_data_root=None,
474
+ class_prompt=None,
475
+ class_num=None,
476
+ size=512,
477
+ center_crop=False,
478
+ ):
479
+ self.size = size
480
+ self.center_crop = center_crop
481
+ self.tokenizer = tokenizer
482
+
483
+ self.instance_data_root = Path(instance_data_root)
484
+ if not self.instance_data_root.exists():
485
+ raise ValueError("Instance images root doesn't exists.")
486
+
487
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
488
+ self.num_instance_images = len(self.instance_images_path)
489
+ self.instance_prompt = instance_prompt
490
+ self._length = self.num_instance_images
491
+
492
+ if class_data_root is not None:
493
+ self.class_data_root = Path(class_data_root)
494
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
495
+ self.class_images_path = list(self.class_data_root.iterdir())
496
+ if class_num is not None:
497
+ self.num_class_images = min(len(self.class_images_path), class_num)
498
+ else:
499
+ self.num_class_images = len(self.class_images_path)
500
+ self._length = max(self.num_class_images, self.num_instance_images)
501
+ self.class_prompt = class_prompt
502
+ else:
503
+ self.class_data_root = None
504
+
505
+ self.image_transforms = transforms.Compose(
506
+ [
507
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
508
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
509
+ transforms.ToTensor(),
510
+ transforms.Normalize([0.5], [0.5]),
511
+ ]
512
+ )
513
+
514
+ def __len__(self):
515
+ return self._length
516
+
517
+ def __getitem__(self, index):
518
+ example = {}
519
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
520
+ if not instance_image.mode == "RGB":
521
+ instance_image = instance_image.convert("RGB")
522
+ example["instance_images"] = self.image_transforms(instance_image)
523
+ example["instance_prompt_ids"] = self.tokenizer(
524
+ self.instance_prompt,
525
+ truncation=True,
526
+ padding="max_length",
527
+ max_length=self.tokenizer.model_max_length,
528
+ return_tensors="pt",
529
+ ).input_ids
530
+
531
+ if self.class_data_root:
532
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
533
+ if not class_image.mode == "RGB":
534
+ class_image = class_image.convert("RGB")
535
+ example["class_images"] = self.image_transforms(class_image)
536
+ example["class_prompt_ids"] = self.tokenizer(
537
+ self.class_prompt,
538
+ truncation=True,
539
+ padding="max_length",
540
+ max_length=self.tokenizer.model_max_length,
541
+ return_tensors="pt",
542
+ ).input_ids
543
+
544
+ return example
545
+
546
+
547
+ def collate_fn(examples, with_prior_preservation=False):
548
+ input_ids = [example["instance_prompt_ids"] for example in examples]
549
+ pixel_values = [example["instance_images"] for example in examples]
550
+
551
+ # Concat class and instance examples for prior preservation.
552
+ # We do this to avoid doing two forward passes.
553
+ if with_prior_preservation:
554
+ input_ids += [example["class_prompt_ids"] for example in examples]
555
+ pixel_values += [example["class_images"] for example in examples]
556
+
557
+ pixel_values = torch.stack(pixel_values)
558
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
559
+
560
+ input_ids = torch.cat(input_ids, dim=0)
561
+
562
+ batch = {
563
+ "input_ids": input_ids,
564
+ "pixel_values": pixel_values,
565
+ }
566
+ return batch
567
+
568
+
569
+ class PromptDataset(Dataset):
570
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
571
+
572
+ def __init__(self, prompt, num_samples):
573
+ self.prompt = prompt
574
+ self.num_samples = num_samples
575
+
576
+ def __len__(self):
577
+ return self.num_samples
578
+
579
+ def __getitem__(self, index):
580
+ example = {}
581
+ example["prompt"] = self.prompt
582
+ example["index"] = index
583
+ return example
584
+
585
+
586
+ def main(args):
587
+ logging_dir = Path(args.output_dir, args.logging_dir)
588
+
589
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) # total_limit=args.checkpoints_total_limit)
590
+
591
+ wandb_init = {
592
+ "wandb": {
593
+ "name": args.name,
594
+ # "project": args.project,
595
+ }
596
+ }
597
+
598
+ accelerator = Accelerator(
599
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
600
+ mixed_precision=args.mixed_precision,
601
+ log_with=args.report_to,
602
+ project_config=accelerator_project_config,
603
+ )
604
+
605
+ if args.report_to == "wandb":
606
+ if not is_wandb_available():
607
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
608
+ import wandb
609
+
610
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
611
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
612
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
613
+ # Make one log on every process with the configuration for debugging.
614
+ logging.basicConfig(
615
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
616
+ datefmt="%m/%d/%Y %H:%M:%S",
617
+ level=logging.INFO,
618
+ )
619
+ logger.info(accelerator.state, main_process_only=False)
620
+ if accelerator.is_local_main_process:
621
+ transformers.utils.logging.set_verbosity_warning()
622
+ diffusers.utils.logging.set_verbosity_info()
623
+ else:
624
+ transformers.utils.logging.set_verbosity_error()
625
+ diffusers.utils.logging.set_verbosity_error()
626
+
627
+ # If passed along, set the training seed now.
628
+ if args.seed is not None:
629
+ set_seed(args.seed)
630
+
631
+ # Generate class images if prior preservation is enabled.
632
+ if args.with_prior_preservation:
633
+ class_images_dir = Path(args.class_data_dir)
634
+ if not class_images_dir.exists():
635
+ class_images_dir.mkdir(parents=True)
636
+ cur_class_images = len(list(class_images_dir.iterdir()))
637
+
638
+ if cur_class_images < args.num_class_images:
639
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
640
+ if args.prior_generation_precision == "fp32":
641
+ torch_dtype = torch.float32
642
+ elif args.prior_generation_precision == "fp16":
643
+ torch_dtype = torch.float16
644
+ elif args.prior_generation_precision == "bf16":
645
+ torch_dtype = torch.bfloat16
646
+ pipeline = DiffusionPipeline.from_pretrained(
647
+ args.pretrained_model_name_or_path,
648
+ torch_dtype=torch_dtype,
649
+ safety_checker=None,
650
+ revision=args.revision,
651
+ )
652
+ pipeline.set_progress_bar_config(disable=True)
653
+
654
+ num_new_images = args.num_class_images - cur_class_images
655
+ logger.info(f"Number of class images to sample: {num_new_images}.")
656
+
657
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
658
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
659
+
660
+ sample_dataloader = accelerator.prepare(sample_dataloader)
661
+ pipeline.to(accelerator.device)
662
+
663
+ for example in tqdm(
664
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
665
+ ):
666
+ images = pipeline(example["prompt"]).images
667
+
668
+ for i, image in enumerate(images):
669
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
670
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
671
+ image.save(image_filename)
672
+
673
+ del pipeline
674
+ if torch.cuda.is_available():
675
+ torch.cuda.empty_cache()
676
+
677
+ # Handle the repository creation
678
+ if accelerator.is_main_process:
679
+ if args.output_dir is not None:
680
+ os.makedirs(args.output_dir, exist_ok=True)
681
+
682
+ if args.push_to_hub:
683
+ repo_id = create_repo(
684
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
685
+ ).repo_id
686
+
687
+ # Load the tokenizer
688
+ if args.tokenizer_name:
689
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
690
+ elif args.pretrained_model_name_or_path:
691
+ tokenizer = AutoTokenizer.from_pretrained(
692
+ args.pretrained_model_name_or_path,
693
+ subfolder="tokenizer",
694
+ revision=args.revision,
695
+ use_fast=False,
696
+ )
697
+
698
+ # import correct text encoder class
699
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
700
+
701
+ # Load scheduler and models
702
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
703
+ text_encoder = text_encoder_cls.from_pretrained(
704
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
705
+ )
706
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
707
+ unet = UNet2DConditionModel.from_pretrained(
708
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
709
+ )
710
+
711
+ # We only train the additional adapter OFT layers
712
+ vae.requires_grad_(False)
713
+ text_encoder.requires_grad_(False)
714
+ unet.requires_grad_(False)
715
+
716
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
717
+ # as these models are only used for inference, keeping weights in full precision is not required.
718
+ weight_dtype = torch.float32
719
+ if accelerator.mixed_precision == "fp16":
720
+ weight_dtype = torch.float16
721
+ elif accelerator.mixed_precision == "bf16":
722
+ weight_dtype = torch.bfloat16
723
+
724
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
725
+ unet.to(accelerator.device, dtype=weight_dtype)
726
+ vae.to(accelerator.device, dtype=weight_dtype)
727
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
728
+
729
+ if args.enable_xformers_memory_efficient_attention:
730
+ if is_xformers_available():
731
+ import xformers
732
+
733
+ xformers_version = version.parse(xformers.__version__)
734
+ if xformers_version == version.parse("0.0.16"):
735
+ logger.warn(
736
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
737
+ )
738
+ unet.enable_xformers_memory_efficient_attention()
739
+ else:
740
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
741
+
742
+ # now we will add new COT weights to the attention layers
743
+ # It's important to realize here how many attention weights will be added and of which sizes
744
+ # The sizes of the attention layers consist only of two different variables:
745
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
746
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
747
+
748
+ # Let's first see how many attention processors we will have to set.
749
+ # For Stable Diffusion, it should be equal to:
750
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
751
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
752
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
753
+ # => 32 layers
754
+
755
+ # Set correct oft layers
756
+ oft_attn_procs = {}
757
+ for name in unet.attn_processors.keys():
758
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
759
+ if name.startswith("mid_block"):
760
+ hidden_size = unet.config.block_out_channels[-1]
761
+ elif name.startswith("up_blocks"):
762
+ block_id = int(name[len("up_blocks.")])
763
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
764
+ elif name.startswith("down_blocks"):
765
+ block_id = int(name[len("down_blocks.")])
766
+ hidden_size = unet.config.block_out_channels[block_id]
767
+
768
+ oft_attn_procs[name] = HRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, r=args.hra_r, apply_GS=args.hra_apply_GS)
769
+
770
+ unet.set_attn_processor(oft_attn_procs)
771
+ print(f'Total parameters requiring grad: {sum([p.numel() for p in unet.parameters() if p.requires_grad == True])}')
772
+
773
+ oft_layers = AttnProcsLayers(unet.attn_processors)
774
+
775
+ accelerator.register_for_checkpointing(oft_layers)
776
+
777
+ # Enable TF32 for faster training on Ampere GPUs,
778
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
779
+ if args.allow_tf32:
780
+ torch.backends.cuda.matmul.allow_tf32 = True
781
+
782
+ if args.scale_lr:
783
+ args.learning_rate = (
784
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
785
+ )
786
+
787
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
788
+ if args.use_8bit_adam:
789
+ try:
790
+ import bitsandbytes as bnb
791
+ except ImportError:
792
+ raise ImportError(
793
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
794
+ )
795
+
796
+ optimizer_class = bnb.optim.AdamW8bit
797
+ else:
798
+ optimizer_class = torch.optim.AdamW
799
+
800
+ # Optimizer creation
801
+ optimizer = optimizer_class(
802
+ oft_layers.parameters(),
803
+ lr=args.learning_rate,
804
+ betas=(args.adam_beta1, args.adam_beta2),
805
+ weight_decay=args.adam_weight_decay,
806
+ eps=args.adam_epsilon,
807
+ )
808
+
809
+ # Dataset and DataLoaders creation:
810
+ train_dataset = DreamBoothDataset(
811
+ instance_data_root=args.instance_data_dir,
812
+ instance_prompt=args.instance_prompt,
813
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
814
+ class_prompt=args.class_prompt,
815
+ class_num=args.num_class_images,
816
+ tokenizer=tokenizer,
817
+ size=args.resolution,
818
+ center_crop=args.center_crop,
819
+ )
820
+
821
+ train_dataloader = torch.utils.data.DataLoader(
822
+ train_dataset,
823
+ batch_size=args.train_batch_size,
824
+ shuffle=True,
825
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
826
+ num_workers=args.dataloader_num_workers,
827
+ )
828
+
829
+ # Scheduler and math around the number of training steps.
830
+ overrode_max_train_steps = False
831
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
832
+ if args.max_train_steps is None:
833
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
834
+ overrode_max_train_steps = True
835
+
836
+ lr_scheduler = get_scheduler(
837
+ args.lr_scheduler,
838
+ optimizer=optimizer,
839
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
840
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
841
+ num_cycles=args.lr_num_cycles,
842
+ power=args.lr_power,
843
+ )
844
+
845
+ # Prepare everything with our `accelerator`.
846
+ oft_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
847
+ oft_layers, optimizer, train_dataloader, lr_scheduler
848
+ )
849
+
850
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
851
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
852
+ if overrode_max_train_steps:
853
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
854
+ # Afterwards we recalculate our number of training epochs
855
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
856
+
857
+ # We need to initialize the trackers we use, and also store our configuration.
858
+ # The trackers initializes automatically on the main process.
859
+ if accelerator.is_main_process:
860
+ accelerator.init_trackers("dreambooth-oft", config=vars(args), init_kwargs=wandb_init)
861
+
862
+ # Train!
863
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
864
+
865
+ logger.info("***** Running training *****")
866
+ logger.info(f" Num examples = {len(train_dataset)}")
867
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
868
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
869
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
870
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
871
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
872
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
873
+ global_step = 0
874
+ first_epoch = 0
875
+
876
+ # Potentially load in the weights and states from a previous save
877
+ if args.resume_from_checkpoint:
878
+ if args.resume_from_checkpoint != "latest":
879
+ path = os.path.basename(args.resume_from_checkpoint)
880
+ else:
881
+ # Get the mos recent checkpoint
882
+ dirs = os.listdir(args.output_dir)
883
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
884
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
885
+ path = dirs[-1] if len(dirs) > 0 else None
886
+
887
+ if path is None:
888
+ accelerator.print(
889
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
890
+ )
891
+ args.resume_from_checkpoint = None
892
+ else:
893
+ accelerator.print(f"Resuming from checkpoint {path}")
894
+ accelerator.load_state(os.path.join(args.output_dir, path))
895
+ global_step = int(path.split("-")[1])
896
+
897
+ resume_global_step = global_step * args.gradient_accumulation_steps
898
+ first_epoch = global_step // num_update_steps_per_epoch
899
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
900
+
901
+ # Only show the progress bar once on each machine.
902
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
903
+ progress_bar.set_description("Steps")
904
+
905
+ # calculate the hyperspherical energy fine-tuning
906
+ # mhe = MHE(unet, eps=args.eps, r=args.r)
907
+ # mhe_loss = mhe.calculate_mhe()
908
+ # accelerator.log({"mhe_loss": mhe_loss}, step=0)
909
+ accelerator.log({"hra_r": args.hra_r}, step=0)
910
+ accelerator.log({"hra_apply_GS": args.hra_apply_GS}, step=0)
911
+ # accelerator.log({"COFT": 1 if args.coft else 0}, step=0)
912
+
913
+ for epoch in range(first_epoch, args.num_train_epochs):
914
+ unet.train()
915
+ for step, batch in enumerate(train_dataloader):
916
+ # Skip steps until we reach the resumed step
917
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
918
+ if step % args.gradient_accumulation_steps == 0:
919
+ progress_bar.update(1)
920
+ continue
921
+
922
+ with accelerator.accumulate(unet):
923
+ # Convert images to latent space
924
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
925
+ latents = latents * vae.config.scaling_factor
926
+
927
+ # Sample noise that we'll add to the latents
928
+ noise = torch.randn_like(latents)
929
+ bsz = latents.shape[0]
930
+ # Sample a random timestep for each image
931
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
932
+ timesteps = timesteps.long()
933
+
934
+ # Add noise to the latents according to the noise magnitude at each timestep
935
+ # (this is the forward diffusion process)
936
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
937
+
938
+ # Get the text embedding for conditioning
939
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
940
+
941
+ # Predict the noise residual
942
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
943
+
944
+ # Get the target for loss depending on the prediction type
945
+ if noise_scheduler.config.prediction_type == "epsilon":
946
+ target = noise
947
+ elif noise_scheduler.config.prediction_type == "v_prediction":
948
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
949
+ else:
950
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
951
+
952
+ if args.with_prior_preservation:
953
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
954
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
955
+ target, target_prior = torch.chunk(target, 2, dim=0)
956
+
957
+ # Compute instance loss
958
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
959
+
960
+ # Compute prior loss
961
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
962
+
963
+ # Add the prior loss to the instance loss.
964
+ loss = loss + args.prior_loss_weight * prior_loss
965
+ else:
966
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
967
+
968
+ # --------------------------------------------------------
969
+ # orthogonality regularizer
970
+ # for name, param in unet.named_parameters():
971
+ # if 'hra_u' in name:
972
+ # device = param.device
973
+ # hra_u_norm = param / (param.norm(dim=0))
974
+ # orth_loss = torch.norm(torch.eye(8, device=device) - hra_u_norm.t() @ hra_u_norm)
975
+ # loss = loss + 1e-5 * orth_loss
976
+ # --------------------------------------------------------
977
+
978
+ accelerator.backward(loss)
979
+ if accelerator.sync_gradients:
980
+ params_to_clip = oft_layers.parameters()
981
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
982
+ optimizer.step()
983
+ lr_scheduler.step()
984
+ optimizer.zero_grad()
985
+
986
+ # Checks if the accelerator has performed an optimization step behind the scenes
987
+ if accelerator.sync_gradients:
988
+ progress_bar.update(1)
989
+ global_step += 1
990
+
991
+ if global_step % args.checkpointing_steps == 0:
992
+ if accelerator.is_main_process:
993
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
994
+ accelerator.save_state(save_path)
995
+ logger.info(f"Saved state to {save_path}")
996
+
997
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
998
+ progress_bar.set_postfix(**logs)
999
+ accelerator.log(logs, step=global_step)
1000
+
1001
+ if global_step >= args.max_train_steps:
1002
+ break
1003
+
1004
+ if accelerator.is_main_process:
1005
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # and epoch > 1:
1006
+ logger.info(
1007
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1008
+ f" {args.validation_prompt}."
1009
+ )
1010
+
1011
+ # mhe = MHE(unet, eps=args.eps, r=args.r)
1012
+ # mhe_loss = mhe.calculate_mhe()
1013
+ # accelerator.log({"mhe_loss": mhe_loss}, step=global_step)
1014
+
1015
+ # create pipeline
1016
+ pipeline = DiffusionPipeline.from_pretrained(
1017
+ args.pretrained_model_name_or_path,
1018
+ unet=accelerator.unwrap_model(unet),
1019
+ text_encoder=accelerator.unwrap_model(text_encoder),
1020
+ revision=args.revision,
1021
+ torch_dtype=weight_dtype,
1022
+ )
1023
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
1024
+ pipeline = pipeline.to(accelerator.device)
1025
+ pipeline.set_progress_bar_config(disable=True)
1026
+
1027
+ # run inference
1028
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
1029
+ images = [
1030
+ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1031
+ for _ in range(args.num_validation_images)
1032
+ ]
1033
+
1034
+ for tracker in accelerator.trackers:
1035
+ if tracker.name == "tensorboard":
1036
+ np_images = np.stack([np.asarray(img) for img in images])
1037
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1038
+ if tracker.name == "wandb":
1039
+ tracker.log(
1040
+ {
1041
+ "validation": [
1042
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1043
+ for i, image in enumerate(images)
1044
+ ]
1045
+ }
1046
+ )
1047
+
1048
+ # Create the output directory if it doesn't exist
1049
+ tmp_dir = os.path.join(args.output_dir, str(epoch))
1050
+ if not os.path.exists(tmp_dir):
1051
+ os.makedirs(tmp_dir)
1052
+
1053
+ for i, image in enumerate(images):
1054
+ np_image = np.array(image)
1055
+ pil_image = Image.fromarray(np_image)
1056
+ pil_image.save(os.path.join(args.output_dir, str(epoch), f"image_{i}.png"))
1057
+
1058
+ del pipeline
1059
+ torch.cuda.empty_cache()
1060
+
1061
+
1062
+ # Save the oft layers
1063
+ accelerator.wait_for_everyone()
1064
+ # if accelerator.is_main_process:
1065
+ # unet = unet.to(torch.float32)
1066
+ # unet.save_attn_procs(args.output_dir)
1067
+
1068
+ # # Final inference
1069
+ # # Load previous pipeline
1070
+ # pipeline = DiffusionPipeline.from_pretrained(
1071
+ # args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
1072
+ # )
1073
+ # pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
1074
+ # pipeline = pipeline.to(accelerator.device)
1075
+
1076
+ # # load attention processors
1077
+ # pipeline.unet.load_attn_procs(args.output_dir)
1078
+
1079
+ # # run inference
1080
+ # if args.validation_prompt and args.num_validation_images > 0:
1081
+ # generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1082
+ # images = [
1083
+ # pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1084
+ # for _ in range(args.num_validation_images)
1085
+ # ]
1086
+
1087
+ # for tracker in accelerator.trackers:
1088
+ # if tracker.name == "tensorboard":
1089
+ # np_images = np.stack([np.asarray(img) for img in images])
1090
+ # tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1091
+ # if tracker.name == "wandb":
1092
+ # tracker.log(
1093
+ # {
1094
+ # "test": [
1095
+ # wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1096
+ # for i, image in enumerate(images)
1097
+ # ]
1098
+ # }
1099
+ # )
1100
+
1101
+ # if args.push_to_hub:
1102
+ # save_model_card(
1103
+ # repo_id,
1104
+ # images=images,
1105
+ # base_model=args.pretrained_model_name_or_path,
1106
+ # prompt=args.instance_prompt,
1107
+ # repo_folder=args.output_dir,
1108
+ # )
1109
+ # upload_folder(
1110
+ # repo_id=repo_id,
1111
+ # folder_path=args.output_dir,
1112
+ # commit_message="End of training",
1113
+ # ignore_patterns=["step_*", "epoch_*"],
1114
+ # )
1115
+
1116
+ accelerator.end_training()
1117
+
1118
+
1119
+ if __name__ == "__main__":
1120
+ args = parse_args()
1121
+ for arg in vars(args):
1122
+ print(f'{arg}: {getattr(args, arg)}')
1123
+ main(args)
generation/subject/train_dreambooth_hra.sh ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ prompt_idx=$1
3
+ class_idx=$2
4
+ lr=1e-4
5
+ hra_r=8
6
+
7
+ export MODEL_NAME="runwayml/stable-diffusion-v1-5"
8
+
9
+ # Define the unique_token, class_tokens, and subject_names
10
+ unique_token="qwe"
11
+ subject_names=(
12
+ "backpack" "backpack_dog" "bear_plushie" "berry_bowl" "can"
13
+ "candle" "cat" "cat2" "clock" "colorful_sneaker"
14
+ "dog" "dog2" "dog3" "dog5" "dog6"
15
+ "dog7" "dog8" "duck_toy" "fancy_boot" "grey_sloth_plushie"
16
+ "monster_toy" "pink_sunglasses" "poop_emoji" "rc_car" "red_cartoon"
17
+ "robot_toy" "shiny_sneaker" "teapot" "vase" "wolf_plushie"
18
+ )
19
+
20
+ class_tokens=(
21
+ "backpack" "backpack" "stuffed animal" "bowl" "can"
22
+ "candle" "cat" "cat" "clock" "sneaker"
23
+ "dog" "dog" "dog" "dog" "dog"
24
+ "dog" "dog" "toy" "boot" "stuffed animal"
25
+ "toy" "glasses" "toy" "toy" "cartoon"
26
+ "toy" "sneaker" "teapot" "vase" "stuffed animal"
27
+ )
28
+
29
+ echo "prompt_idx: $prompt_idx, class_idx: $class_idx"
30
+
31
+ class_token=${class_tokens[$class_idx]}
32
+ selected_subject=${subject_names[$class_idx]}
33
+
34
+ if [[ $class_idx =~ ^(0|1|2|3|4|5|8|9|17|18|19|20|21|22|23|24|25|26|27|28|29)$ ]]; then
35
+ prompt_list=(
36
+ "a ${unique_token} ${class_token} in the jungle"
37
+ "a ${unique_token} ${class_token} in the snow"
38
+ "a ${unique_token} ${class_token} on the beach"
39
+ "a ${unique_token} ${class_token} on a cobblestone street"
40
+ "a ${unique_token} ${class_token} on top of pink fabric"
41
+ "a ${unique_token} ${class_token} on top of a wooden floor"
42
+ "a ${unique_token} ${class_token} with a city in the background"
43
+ "a ${unique_token} ${class_token} with a mountain in the background"
44
+ "a ${unique_token} ${class_token} with a blue house in the background"
45
+ "a ${unique_token} ${class_token} on top of a purple rug in a forest"
46
+ "a ${unique_token} ${class_token} with a wheat field in the background"
47
+ "a ${unique_token} ${class_token} with a tree and autumn leaves in the background"
48
+ "a ${unique_token} ${class_token} with the Eiffel Tower in the background"
49
+ "a ${unique_token} ${class_token} floating on top of water"
50
+ "a ${unique_token} ${class_token} floating in an ocean of milk"
51
+ "a ${unique_token} ${class_token} on top of green grass with sunflowers around it"
52
+ "a ${unique_token} ${class_token} on top of a mirror"
53
+ "a ${unique_token} ${class_token} on top of the sidewalk in a crowded street"
54
+ "a ${unique_token} ${class_token} on top of a dirt road"
55
+ "a ${unique_token} ${class_token} on top of a white rug"
56
+ "a red ${unique_token} ${class_token}"
57
+ "a purple ${unique_token} ${class_token}"
58
+ "a shiny ${unique_token} ${class_token}"
59
+ "a wet ${unique_token} ${class_token}"
60
+ "a cube shaped ${unique_token} ${class_token}"
61
+ )
62
+
63
+ prompt_test_list=(
64
+ "a ${class_token} in the jungle"
65
+ "a ${class_token} in the snow"
66
+ "a ${class_token} on the beach"
67
+ "a ${class_token} on a cobblestone street"
68
+ "a ${class_token} on top of pink fabric"
69
+ "a ${class_token} on top of a wooden floor"
70
+ "a ${class_token} with a city in the background"
71
+ "a ${class_token} with a mountain in the background"
72
+ "a ${class_token} with a blue house in the background"
73
+ "a ${class_token} on top of a purple rug in a forest"
74
+ "a ${class_token} with a wheat field in the background"
75
+ "a ${class_token} with a tree and autumn leaves in the background"
76
+ "a ${class_token} with the Eiffel Tower in the background"
77
+ "a ${class_token} floating on top of water"
78
+ "a ${class_token} floating in an ocean of milk"
79
+ "a ${class_token} on top of green grass with sunflowers around it"
80
+ "a ${class_token} on top of a mirror"
81
+ "a ${class_token} on top of the sidewalk in a crowded street"
82
+ "a ${class_token} on top of a dirt road"
83
+ "a ${class_token} on top of a white rug"
84
+ "a red ${class_token}"
85
+ "a purple ${class_token}"
86
+ "a shiny ${class_token}"
87
+ "a wet ${class_token}"
88
+ "a cube shaped ${class_token}"
89
+ )
90
+
91
+ else
92
+ prompt_list=(
93
+ "a ${unique_token} ${class_token} in the jungle"
94
+ "a ${unique_token} ${class_token} in the snow"
95
+ "a ${unique_token} ${class_token} on the beach"
96
+ "a ${unique_token} ${class_token} on a cobblestone street"
97
+ "a ${unique_token} ${class_token} on top of pink fabric"
98
+ "a ${unique_token} ${class_token} on top of a wooden floor"
99
+ "a ${unique_token} ${class_token} with a city in the background"
100
+ "a ${unique_token} ${class_token} with a mountain in the background"
101
+ "a ${unique_token} ${class_token} with a blue house in the background"
102
+ "a ${unique_token} ${class_token} on top of a purple rug in a forest"
103
+ "a ${unique_token} ${class_token} wearing a red hat"
104
+ "a ${unique_token} ${class_token} wearing a santa hat"
105
+ "a ${unique_token} ${class_token} wearing a rainbow scarf"
106
+ "a ${unique_token} ${class_token} wearing a black top hat and a monocle"
107
+ "a ${unique_token} ${class_token} in a chef outfit"
108
+ "a ${unique_token} ${class_token} in a firefighter outfit"
109
+ "a ${unique_token} ${class_token} in a police outfit"
110
+ "a ${unique_token} ${class_token} wearing pink glasses"
111
+ "a ${unique_token} ${class_token} wearing a yellow shirt"
112
+ "a ${unique_token} ${class_token} in a purple wizard outfit"
113
+ "a red ${unique_token} ${class_token}"
114
+ "a purple ${unique_token} ${class_token}"
115
+ "a shiny ${unique_token} ${class_token}"
116
+ "a wet ${unique_token} ${class_token}"
117
+ "a cube shaped ${unique_token} ${class_token}"
118
+ )
119
+
120
+ prompt_test_list=(
121
+ "a ${class_token} in the jungle"
122
+ "a ${class_token} in the snow"
123
+ "a ${class_token} on the beach"
124
+ "a ${class_token} on a cobblestone street"
125
+ "a ${class_token} on top of pink fabric"
126
+ "a ${class_token} on top of a wooden floor"
127
+ "a ${class_token} with a city in the background"
128
+ "a ${class_token} with a mountain in the background"
129
+ "a ${class_token} with a blue house in the background"
130
+ "a ${class_token} on top of a purple rug in a forest"
131
+ "a ${class_token} wearing a red hat"
132
+ "a ${class_token} wearing a santa hat"
133
+ "a ${class_token} wearing a rainbow scarf"
134
+ "a ${class_token} wearing a black top hat and a monocle"
135
+ "a ${class_token} in a chef outfit"
136
+ "a ${class_token} in a firefighter outfit"
137
+ "a ${class_token} in a police outfit"
138
+ "a ${class_token} wearing pink glasses"
139
+ "a ${class_token} wearing a yellow shirt"
140
+ "a ${class_token} in a purple wizard outfit"
141
+ "a red ${class_token}"
142
+ "a purple ${class_token}"
143
+ "a shiny ${class_token}"
144
+ "a wet ${class_token}"
145
+ "a cube shaped ${class_token}"
146
+ )
147
+ fi
148
+
149
+
150
+ validation_prompt=${prompt_list[$prompt_idx]}
151
+ test_prompt=${prompt_test_list[$prompt_idx]}
152
+ name="${selected_subject}-${prompt_idx}"
153
+ instance_prompt="a photo of ${unique_token} ${class_token}"
154
+ class_prompt="a photo of ${class_token}"
155
+
156
+ export OUTPUT_DIR="log_hra/lr_${lr}_r_${hra_r}/${name}"
157
+ export INSTANCE_DIR="dreambooth/dataset/${selected_subject}"
158
+ export CLASS_DIR="class_data/${class_token}"
159
+
160
+ if [ -d "$OUTPUT_DIR" ]; then
161
+ echo "该目录已存在:$OUTPUT_DIR"
162
+ fi
163
+
164
+ accelerate launch train_dreambooth_hra.py \
165
+ --pretrained_model_name_or_path=$MODEL_NAME \
166
+ --instance_data_dir=$INSTANCE_DIR \
167
+ --class_data_dir="$CLASS_DIR" \
168
+ --output_dir="$OUTPUT_DIR" \
169
+ --instance_prompt="$instance_prompt" \
170
+ --with_prior_preservation --prior_loss_weight=1.0 \
171
+ --class_prompt="$class_prompt" \
172
+ --resolution=512 \
173
+ --train_batch_size=1 \
174
+ --gradient_accumulation_steps=1 \
175
+ --checkpointing_steps=5000 \
176
+ --learning_rate=$lr \
177
+ --report_to="wandb" \
178
+ --lr_scheduler="constant" \
179
+ --lr_warmup_steps=0 \
180
+ --max_train_steps=2005 \
181
+ --validation_prompt="$validation_prompt" \
182
+ --validation_epochs=1 \
183
+ --seed="0" \
184
+ --name="$name" \
185
+ --num_class_images=200 \
186
+ --hra_r=$hra_r
llama/data/MATH_test.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
llama/data/gsm8k_test.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
llama/data/oft/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .config import OFTConfig
16
+ from .layer import Conv2d, Linear, OFTLayer
17
+ from .model import OFTModel
18
+
19
+
20
+ __all__ = ["OFTConfig", "OFTModel", "Conv2d", "Linear", "OFTLayer"]
llama/data/oft/config.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import List, Optional, Union
17
+
18
+ from peft.tuners.lycoris_utils import LycorisConfig
19
+ from peft.utils import PeftType
20
+
21
+
22
+ @dataclass
23
+ class OFTConfig(LycorisConfig):
24
+ """
25
+ This is the configuration class to store the configuration of a [`OFTModel`].
26
+
27
+ Args:
28
+ r (`int`): OFT rank.
29
+ module_dropout (`int`): The dropout probability for disabling OFT modules during training.
30
+ target_modules (`Optional[Union[List[str], str]]`):
31
+ The names of the modules to apply the adapter to. If this is specified, only the modules with the specified
32
+ names will be replaced. When passing a string, a regex match will be performed. When passing a list of
33
+ strings, either an exact match will be performed or it is checked if the name of the module ends with any
34
+ of the passed strings. If this is specified as 'all-linear', then all linear modules are chosen, excluding
35
+ the output layer. If this is not specified, modules will be chosen according to the model architecture. If
36
+ the architecture is not known, an error will be raised -- in this case, you should specify the target
37
+ modules manually.
38
+ init_weights (`bool`):
39
+ Whether to perform initialization of OFT weights.
40
+ layers_to_transform (`Union[List[int], int]`):
41
+ The layer indices to transform. If a list of ints is passed, it will apply the adapter to the layer indices
42
+ that are specified in this list. If a single integer is passed, it will apply the transformations on the
43
+ layer at this index.
44
+ layers_pattern (`str`):
45
+ The layer pattern name, used only if `layers_to_transform` is different from `None`.
46
+ rank_pattern (`dict`):
47
+ The mapping from layer names or regexp expression to ranks which are different from the default rank
48
+ specified by `r`.
49
+ modules_to_save (`List[str]`):
50
+ List of modules apart from adapter layers to be set as trainable and saved in the final checkpoint.
51
+ coft (`bool`):
52
+ Whether to use the constrained variant of OFT or not, off by default.
53
+ eps (`float`):
54
+ The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True.
55
+ block_share (`bool`):
56
+ Whether to share the OFT parameters between blocks or not. This is `False` by default.
57
+ """
58
+
59
+ r: int = field(default=8, metadata={"help": "OFT rank"})
60
+ module_dropout: float = field(
61
+ default=0.0, metadata={"help": "The dropout probability for disabling OFT modules during training"}
62
+ )
63
+ target_modules: Optional[Union[List[str], str]] = field(
64
+ default=None,
65
+ metadata={
66
+ "help": "List of module names or regex expression of the module names to replace with OFT."
67
+ "For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$' "
68
+ "This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer."
69
+ },
70
+ )
71
+ init_weights: bool = field(
72
+ default=True,
73
+ metadata={
74
+ "help": (
75
+ "Whether to initialize the weights of the OFT layers with their default initialization. Don't change "
76
+ "this setting, except if you know exactly what you're doing."
77
+ ),
78
+ },
79
+ )
80
+ layers_to_transform: Optional[Union[List[int], int]] = field(
81
+ default=None,
82
+ metadata={
83
+ "help": "The layer indexes to transform, is this argument is specified, PEFT will transform only the layers indexes that are specified inside this list. If a single integer is passed, PEFT will transform only the layer at this index."
84
+ },
85
+ )
86
+ layers_pattern: Optional[str] = field(
87
+ default=None,
88
+ metadata={
89
+ "help": "The layer pattern name, used only if `layers_to_transform` is different to None and if the layer pattern is not in the common layers pattern."
90
+ },
91
+ )
92
+ modules_to_save: Optional[List[str]] = field(
93
+ default=None,
94
+ metadata={
95
+ "help": "List of modules apart from OFT layers to be set as trainable and saved in the final checkpoint. "
96
+ "For example, in Sequence Classification or Token Classification tasks, "
97
+ "the final layer `classifier/score` are randomly initialized and as such need to be trainable and saved."
98
+ },
99
+ )
100
+ coft: bool = field(
101
+ default=False,
102
+ metadata={"help": "Whether to use the constrained variant of OFT or not."},
103
+ )
104
+ eps: float = field(
105
+ default=6e-5,
106
+ metadata={
107
+ "help": "The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True."
108
+ },
109
+ )
110
+ block_share: bool = field(
111
+ default=False,
112
+ metadata={"help": "Whether to share the OFT parameters between blocks or not."},
113
+ )
114
+
115
+ def __post_init__(self):
116
+ self.peft_type = PeftType.OFT
117
+ self.target_modules = (
118
+ set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules
119
+ )
llama/data/oft/layer.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ import warnings
17
+ from typing import Any, List, Optional, Set, Tuple
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from peft.tuners.lycoris_utils import LycorisLayer, check_adapters_to_merge
23
+
24
+
25
+ class OFTLayer(nn.Module, LycorisLayer):
26
+ # All names of layers that may contain adapter weights
27
+ adapter_layer_names = ("oft_r",)
28
+ # other_param_names is defined on parent class
29
+
30
+ def __init__(self, base_layer: nn.Module):
31
+ super().__init__()
32
+ LycorisLayer.__init__(self, base_layer)
33
+
34
+ # OFT info
35
+ self.oft_r = nn.ParameterDict({})
36
+ self.coft = {}
37
+ self.eps = {}
38
+ self.block_share = {}
39
+
40
+ @property
41
+ def _available_adapters(self) -> Set[str]:
42
+ return {*self.oft_r}
43
+
44
+ def create_adapter_parameters(self, adapter_name: str, r: int, shape: Tuple[int, ...], block_share: bool):
45
+ if block_share:
46
+ self.oft_r[adapter_name] = nn.Parameter(torch.empty(1, math.ceil(shape[0] / r), math.ceil(shape[0] / r)))
47
+ else:
48
+ self.oft_r[adapter_name] = nn.Parameter(torch.empty(r, math.ceil(shape[0] / r), math.ceil(shape[0] / r)))
49
+
50
+ def reset_adapter_parameters(self, adapter_name: str):
51
+ nn.init.zeros_(self.oft_r[adapter_name])
52
+
53
+ def reset_adapter_parameters_random(self, adapter_name: str):
54
+ nn.init.kaiming_uniform_(self.oft_r[adapter_name], a=math.sqrt(5))
55
+
56
+ def update_layer(
57
+ self,
58
+ adapter_name: str,
59
+ r: int,
60
+ module_dropout: float,
61
+ init_weights: bool,
62
+ coft: bool = False,
63
+ eps: float = 6e-5,
64
+ block_share: bool = False,
65
+ **kwargs,
66
+ ) -> None:
67
+ """Internal function to create oft adapter
68
+
69
+ Args:
70
+ adapter_name (`str`): Name for the adapter to add.
71
+ r (`int`): Rank for the added adapter.
72
+ module_dropout (`float`): The dropout probability for disabling adapter during training.
73
+ init_weights (`bool`): Whether to initialize weights.
74
+ coft (`bool`): Whether to use the constrained variant of OFT or not.
75
+ eps (`float`):
76
+ The control strength of COFT. The freedom of rotation. Only has an effect if `coft` is set to True.
77
+ block_share (`bool`): Whether to share the OFT parameters between blocks or not.
78
+ """
79
+ if r <= 0:
80
+ raise ValueError(f"`r` should be a positive integer value but the value passed is {r}")
81
+
82
+ self.r[adapter_name] = r
83
+ self.module_dropout[adapter_name] = module_dropout
84
+ self.coft[adapter_name] = coft
85
+ self.block_share[adapter_name] = block_share
86
+
87
+ # Determine shape of OFT weights
88
+ base_layer = self.get_base_layer()
89
+ if isinstance(base_layer, nn.Linear):
90
+ shape = tuple(base_layer.weight.shape)
91
+ elif isinstance(base_layer, nn.Conv2d):
92
+ shape = (
93
+ base_layer.out_channels,
94
+ base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
95
+ )
96
+ else:
97
+ raise TypeError(f"OFT is not implemented for base layers of type {type(base_layer).__name__}")
98
+
99
+ self.eps[adapter_name] = eps * math.ceil(shape[0] / r) * math.ceil(shape[0] / r)
100
+
101
+ # Create weights with provided shape
102
+ self.create_adapter_parameters(adapter_name, r, shape, block_share)
103
+
104
+ # Initialize weights
105
+ if init_weights:
106
+ self.reset_adapter_parameters(adapter_name)
107
+ else:
108
+ self.reset_adapter_parameters_random(adapter_name)
109
+
110
+ # Move new weights to device
111
+ weight = getattr(self.get_base_layer(), "weight", None)
112
+ if weight is not None:
113
+ # the layer is already completely initialized, this is an update
114
+ if weight.dtype.is_floating_point or weight.dtype.is_complex:
115
+ self.to(weight.device, dtype=weight.dtype)
116
+ else:
117
+ self.to(weight.device)
118
+ self.set_adapter(self.active_adapters)
119
+
120
+ def unscale_layer(self, scale=None) -> None:
121
+ # scale is not used
122
+ pass
123
+
124
+ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None:
125
+ """
126
+ Merge the active adapter weights into the base weights
127
+
128
+ Args:
129
+ safe_merge (`bool`, *optional*):
130
+ If `True`, the merge operation will be performed in a copy of the original weights and check for NaNs
131
+ before merging the weights. This is useful if you want to check if the merge operation will produce
132
+ NaNs. Defaults to `False`.
133
+ adapter_names (`List[str]`, *optional*):
134
+ The list of adapter names that should be merged. If `None`, all active adapters will be merged.
135
+ Defaults to `None`.
136
+ """
137
+ adapter_names = check_adapters_to_merge(self, adapter_names)
138
+ if not adapter_names:
139
+ # no adapter to merge
140
+ return
141
+
142
+ for active_adapter in adapter_names:
143
+ if active_adapter in self._available_adapters:
144
+ base_layer = self.get_base_layer()
145
+
146
+ orig_weights = base_layer.weight.data
147
+ if isinstance(base_layer, nn.Linear):
148
+ orig_weights = torch.transpose(orig_weights, 0, 1)
149
+ elif isinstance(base_layer, nn.Conv2d):
150
+ orig_weights = orig_weights.view(
151
+ [
152
+ base_layer.out_channels,
153
+ base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
154
+ ]
155
+ )
156
+ orig_weights = torch.transpose(orig_weights, 0, 1)
157
+ delta_weight = self.get_delta_weight(active_adapter)
158
+ if orig_weights.shape[1] != delta_weight.shape[1]:
159
+ # when in channels is not divisible by r
160
+ delta_weight = delta_weight[: orig_weights.shape[1], : orig_weights.shape[1]]
161
+ new_weights = torch.mm(orig_weights, delta_weight)
162
+ if isinstance(base_layer, nn.Linear):
163
+ new_weights = torch.transpose(new_weights, 0, 1)
164
+ elif isinstance(base_layer, nn.Conv2d):
165
+ new_weights = torch.transpose(new_weights, 0, 1)
166
+ new_weights = new_weights.view(
167
+ [
168
+ base_layer.out_channels,
169
+ base_layer.in_channels,
170
+ base_layer.kernel_size[0],
171
+ base_layer.kernel_size[1],
172
+ ]
173
+ )
174
+
175
+ if safe_merge and not torch.isfinite(new_weights).all():
176
+ raise ValueError(
177
+ f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken"
178
+ )
179
+
180
+ base_layer.weight.data = new_weights
181
+ self.merged_adapters.append(active_adapter)
182
+
183
+ def unmerge(self) -> None:
184
+ """
185
+ This method unmerges all merged adapter layers from the base weights.
186
+ """
187
+ if not self.merged:
188
+ warnings.warn("Already unmerged. Nothing to do.")
189
+ return
190
+ while len(self.merged_adapters) > 0:
191
+ active_adapter = self.merged_adapters.pop()
192
+ if active_adapter in self._available_adapters:
193
+ base_layer = self.get_base_layer()
194
+ new_weights = base_layer.weight.data
195
+ if isinstance(base_layer, nn.Linear):
196
+ new_weights = torch.transpose(new_weights, 0, 1)
197
+ elif isinstance(base_layer, nn.Conv2d):
198
+ new_weights = new_weights.view(
199
+ [
200
+ base_layer.out_channels,
201
+ base_layer.in_channels * base_layer.kernel_size[0] * base_layer.kernel_size[1],
202
+ ]
203
+ )
204
+ new_weights = torch.transpose(new_weights, 0, 1)
205
+ delta_weight = self.get_delta_weight(active_adapter)
206
+ if new_weights.shape[1] != delta_weight.shape[1]:
207
+ # when in channels is not divisible by r
208
+ delta_weight = delta_weight[: new_weights.shape[1], : new_weights.shape[1]]
209
+ delta_inv = torch.inverse(delta_weight)
210
+ orig_weights = torch.mm(new_weights, delta_inv)
211
+
212
+ if isinstance(base_layer, nn.Linear):
213
+ orig_weights = torch.transpose(orig_weights, 0, 1)
214
+ elif isinstance(base_layer, nn.Conv2d):
215
+ orig_weights = torch.transpose(orig_weights, 0, 1)
216
+ orig_weights = orig_weights.reshape(
217
+ [
218
+ base_layer.out_channels,
219
+ base_layer.in_channels,
220
+ base_layer.kernel_size[0],
221
+ base_layer.kernel_size[1],
222
+ ]
223
+ )
224
+ base_layer.weight.data = orig_weights
225
+
226
+ def get_delta_weight(self, adapter_name: str) -> torch.Tensor:
227
+ rank = self.r[adapter_name]
228
+ coft = self.coft[adapter_name]
229
+ eps = self.eps[adapter_name]
230
+ opt_r = self.oft_r[adapter_name]
231
+
232
+ if coft:
233
+ with torch.no_grad():
234
+ opt_r.copy_(self._project_batch(opt_r, eps=eps))
235
+
236
+ orth_rotate = self._cayley_batch(opt_r)
237
+ weight = self._block_diagonal(orth_rotate, rank)
238
+
239
+ return weight
240
+
241
+ # Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L144
242
+ def _cayley_batch(self, data: torch.Tensor) -> torch.Tensor:
243
+ b, r, c = data.shape
244
+ # Ensure the input matrix is skew-symmetric
245
+ skew = 0.5 * (data - data.transpose(1, 2))
246
+ I = torch.eye(r, device=data.device).unsqueeze(0).expand(b, r, c) # noqa: E741
247
+
248
+ # Perform the Cayley parametrization
249
+ Q = torch.bmm(I - skew, torch.inverse(I + skew))
250
+
251
+ return Q
252
+
253
+ # Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L155
254
+ def _block_diagonal(self, oft_r: torch.Tensor, rank: int) -> torch.Tensor:
255
+ if oft_r.shape[0] == 1:
256
+ # block share
257
+ blocks = [oft_r[0, ...] for i in range(rank)]
258
+ else:
259
+ blocks = [oft_r[i, ...] for i in range(rank)]
260
+
261
+ # Use torch.block_diag to create the block diagonal matrix
262
+ A = torch.block_diag(*blocks)
263
+
264
+ return A
265
+
266
+ # Copied from https://github.com/Zeju1997/oft/blob/84cebb965df69781e3d9c3c875f5980b421eaf24/oft-control/oft.py#L52
267
+ def _project_batch(self, oft_r, eps=1e-5):
268
+ # scaling factor for each of the smaller block matrix
269
+ eps = eps * 1 / torch.sqrt(torch.tensor(oft_r.shape[0]))
270
+ I = ( # noqa: E741
271
+ torch.zeros((oft_r.size(1), oft_r.size(1)), device=oft_r.device, dtype=oft_r.dtype)
272
+ .unsqueeze(0)
273
+ .expand_as(oft_r)
274
+ )
275
+ diff = oft_r - I
276
+ norm_diff = torch.norm(oft_r - I, dim=(1, 2), keepdim=True)
277
+ mask = (norm_diff <= eps).bool()
278
+ out = torch.where(mask, oft_r, I + eps * (diff / norm_diff))
279
+ return out
280
+
281
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
282
+ previous_dtype = x.dtype
283
+
284
+ if self.disable_adapters:
285
+ if self.merged:
286
+ self.unmerge()
287
+ result = self.base_layer(x, *args, **kwargs)
288
+ elif self.merged:
289
+ result = self.base_layer(x, *args, **kwargs)
290
+ else:
291
+ result = self.base_layer(x, *args, **kwargs)
292
+ if len(result.shape) == 4:
293
+ result = result.permute(0, 2, 3, 1)
294
+
295
+ base_layer = self.get_base_layer()
296
+ base_bias = base_layer.bias
297
+ if base_bias is not None:
298
+ # Bias should be added after OFT forward
299
+ result = result - base_bias.data
300
+
301
+ # Execute all the adapters
302
+ for active_adapter in self.active_adapters:
303
+ if active_adapter not in self._available_adapters:
304
+ continue
305
+
306
+ module_dropout = self.module_dropout[active_adapter]
307
+
308
+ # Modify current execution weights
309
+ if (not self.training) or (self.training and torch.rand(1) > module_dropout):
310
+ result = self._get_delta_activations(active_adapter, result, *args, **kwargs)
311
+
312
+ if base_bias is not None:
313
+ result = result + base_bias.data
314
+ if len(result.shape) == 4:
315
+ result = result.permute(0, 3, 1, 2)
316
+
317
+ result = result.to(previous_dtype)
318
+ return result
319
+
320
+
321
+ class Linear(OFTLayer):
322
+ """OFT implemented in Linear layer"""
323
+
324
+ def __init__(
325
+ self,
326
+ base_layer: nn.Module,
327
+ adapter_name: str = "default",
328
+ r: int = 0,
329
+ module_dropout: float = 0.0,
330
+ init_weights: bool = True,
331
+ **kwargs,
332
+ ):
333
+ super().__init__(base_layer)
334
+
335
+ # Create adapter and set it active
336
+ self._active_adapter = adapter_name
337
+ self.update_layer(adapter_name, r, module_dropout, init_weights, **kwargs)
338
+
339
+ def _get_delta_activations(
340
+ self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
341
+ ) -> torch.Tensor:
342
+ delta_weight = self.get_delta_weight(adapter_name)
343
+
344
+ base_layer = self.get_base_layer()
345
+ base_weight = base_layer.weight.data
346
+ delta_weight = delta_weight[: base_weight.shape[0], : base_weight.shape[0]]
347
+
348
+ # don't add bias here, because the bias will be added after OFT forward
349
+ return torch.matmul(input, delta_weight)
350
+
351
+ def __repr__(self) -> str:
352
+ rep = super().__repr__()
353
+ return "oft." + rep
354
+
355
+
356
+ class Conv2d(OFTLayer):
357
+ """OFT implemented in Conv2d layer"""
358
+
359
+ def __init__(
360
+ self,
361
+ base_layer: nn.Module,
362
+ adapter_name: str = "default",
363
+ r: int = 0,
364
+ module_dropout: float = 0.0,
365
+ init_weights: bool = True,
366
+ **kwargs,
367
+ ):
368
+ super().__init__(base_layer)
369
+
370
+ # Create adapter and set it active
371
+ self._active_adapter = adapter_name
372
+ self.update_layer(adapter_name, r, module_dropout, init_weights, **kwargs)
373
+
374
+ def _get_delta_activations(
375
+ self, adapter_name: str, input: torch.Tensor, *args: Any, **kwargs: Any
376
+ ) -> torch.Tensor:
377
+ delta_weight = self.get_delta_weight(adapter_name)
378
+
379
+ base_layer = self.get_base_layer()
380
+ base_weight = base_layer.weight.data
381
+ delta_weight = delta_weight[: base_weight.shape[0], : base_weight.shape[0]]
382
+
383
+ # don't add bias here, because the bias will be added after OFT forward
384
+ return torch.matmul(input, delta_weight)
385
+
386
+ def __repr__(self) -> str:
387
+ rep = super().__repr__()
388
+ return "oft." + rep
llama/data/oft/model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023-present the HuggingFace Inc. team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import Dict, Type, Union
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from peft.tuners.lycoris_utils import LycorisConfig, LycorisTuner
22
+
23
+ from .layer import Conv2d, Linear, OFTLayer
24
+
25
+
26
+ class OFTModel(LycorisTuner):
27
+ """
28
+ Creates Orthogonal Finetuning model from a pretrained model. The method is described in
29
+ https://arxiv.org/abs/2306.07280
30
+
31
+ Args:
32
+ model (`torch.nn.Module`): The model to which the adapter tuner layers will be attached.
33
+ config ([`OFTConfig`]): The configuration of the OFT model.
34
+ adapter_name (`str`): The name of the adapter, defaults to `"default"`.
35
+
36
+ Returns:
37
+ `torch.nn.Module`: The OFT model.
38
+
39
+ Example:
40
+ ```py
41
+ >>> from diffusers import StableDiffusionPipeline
42
+ >>> from peft import OFTModel, OFTConfig
43
+
44
+ >>> config_te = OFTConfig(
45
+ ... r=8,
46
+ ... target_modules=["k_proj", "q_proj", "v_proj", "out_proj", "fc1", "fc2"],
47
+ ... module_dropout=0.0,
48
+ ... init_weights=True,
49
+ ... )
50
+ >>> config_unet = OFTConfig(
51
+ ... r=8,
52
+ ... target_modules=[
53
+ ... "proj_in",
54
+ ... "proj_out",
55
+ ... "to_k",
56
+ ... "to_q",
57
+ ... "to_v",
58
+ ... "to_out.0",
59
+ ... "ff.net.0.proj",
60
+ ... "ff.net.2",
61
+ ... ],
62
+ ... module_dropout=0.0,
63
+ ... init_weights=True,
64
+ ... )
65
+
66
+ >>> model = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
67
+ >>> model.text_encoder = OFTModel(model.text_encoder, config_te, "default")
68
+ >>> model.unet = OFTModel(model.unet, config_unet, "default")
69
+ ```
70
+
71
+ **Attributes**:
72
+ - **model** ([`~torch.nn.Module`]) -- The model to be adapted.
73
+ - **peft_config** ([`OFTConfig`]): The configuration of the OFT model.
74
+ """
75
+
76
+ prefix: str = "oft_"
77
+ layers_mapping: Dict[Type[torch.nn.Module], Type[OFTLayer]] = {
78
+ torch.nn.Conv2d: Conv2d,
79
+ torch.nn.Linear: Linear,
80
+ }
81
+
82
+ def _create_and_replace(
83
+ self,
84
+ config: LycorisConfig,
85
+ adapter_name: str,
86
+ target: Union[OFTLayer, nn.Module],
87
+ target_name: str,
88
+ parent: nn.Module,
89
+ current_key: str,
90
+ ) -> None:
91
+ """
92
+ A private method to create and replace the target module with the adapter module.
93
+ """
94
+
95
+ # Regexp matching - Find key which matches current target_name in patterns provided
96
+ pattern_keys = list(config.rank_pattern.keys())
97
+ target_name_key = next(filter(lambda key: re.match(rf"(.*\.)?{key}$", current_key), pattern_keys), target_name)
98
+
99
+ kwargs = config.to_dict()
100
+ kwargs["r"] = config.rank_pattern.get(target_name_key, config.r)
101
+
102
+ if isinstance(target, OFTLayer):
103
+ target.update_layer(adapter_name, **kwargs)
104
+ else:
105
+ new_module = self._create_new_module(config, adapter_name, target, **kwargs)
106
+ self._replace_module(parent, target_name, new_module, target)
llama/finetune_32.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import field, dataclass
3
+ from typing import Sequence, Literal, List, Dict, Tuple
4
+
5
+ import transformers
6
+ from transformers import Trainer
7
+ from transformers.modeling_utils import *
8
+ from transformers.trainer import _is_peft_model
9
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
10
+ from transformers.data.data_collator import DataCollator
11
+
12
+ from transformers.training_args import TrainingArguments
13
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
14
+ from transformers.trainer_callback import TrainerCallback
15
+ from transformers.trainer_utils import EvalPrediction
16
+ from torch.utils.data import Dataset, IterableDataset
17
+
18
+ from datasets import load_dataset
19
+ from peft import LoraConfig, get_peft_model, PeftModel, OFTConfig
20
+ from datetime import datetime
21
+
22
+ IGNORE_INDEX = -100
23
+ DEFAULT_PAD_TOKEN = "[PAD]"
24
+ DEFAULT_EOS_TOKEN = "</s>"
25
+ DEFAULT_BOS_TOKEN = "</s>"
26
+ DEFAULT_UNK_TOKEN = "</s>"
27
+ PROMPT = (
28
+ "Below is an instruction that describes a task. "
29
+ "Write a response that appropriately completes the request.\n\n"
30
+ "### Instruction:\n{instruction}\n\n### Response:"
31
+ )
32
+
33
+
34
+ class MyTrainer(Trainer):
35
+
36
+ def __init__(
37
+ self,
38
+ model: Union[PreTrainedModel, nn.Module] = None,
39
+ args: TrainingArguments = None,
40
+ data_collator: Optional[DataCollator] = None,
41
+ train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
42
+ eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
43
+ tokenizer: Optional[PreTrainedTokenizerBase] = None,
44
+ model_init: Optional[Callable[[], PreTrainedModel]] = None,
45
+ compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
46
+ callbacks: Optional[List[TrainerCallback]] = None,
47
+ optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
48
+ preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
49
+ lamda: float = 1e-4
50
+ ):
51
+ print('optimizers', optimizers)
52
+ if optimizers == None:
53
+ optimizers = (None, None)
54
+ super().__init__(model=model, args=args, data_collator=data_collator, train_dataset=train_dataset,
55
+ eval_dataset=eval_dataset, processing_class=tokenizer, model_init=model_init,
56
+ compute_metrics=compute_metrics, callbacks=callbacks,
57
+ optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics)
58
+ self.lamda = lamda
59
+
60
+
61
+ self.oft_params: List[torch.nn.Parameter] = [
62
+ p for n, p in self.model.named_parameters() if "oft_r" in n
63
+ ]
64
+
65
+ def compute_loss(self, model, inputs, return_outputs=False,
66
+ num_items_in_batch: Optional[torch.Tensor] = None,):
67
+ """
68
+ How the loss is computed by Trainer. By default, all models return the loss in the first element.
69
+
70
+ Subclass and override for custom behavior.
71
+ """
72
+ # if self.label_smoother is not None and "labels" in inputs:
73
+ # labels = inputs.pop("labels")
74
+ # else:
75
+ # labels = None
76
+ # outputs = model(**inputs)
77
+
78
+ if self.label_smoother is not None and "labels" in inputs:
79
+ labels = inputs.pop("labels")
80
+ else:
81
+ labels = None
82
+ if self.model_accepts_loss_kwargs:
83
+ kwargs = {}
84
+ if num_items_in_batch is not None:
85
+ kwargs["num_items_in_batch"] = num_items_in_batch
86
+ inputs = {**inputs, **kwargs}
87
+ outputs = model(**inputs)
88
+ #
89
+ # Save past state if it exists
90
+ # TODO: this needs to be fixed and made cleaner later.
91
+ if self.args.past_index >= 0:
92
+ self._past = outputs[self.args.past_index]
93
+
94
+ if labels is not None:
95
+ unwrapped_model = unwrap_model(model)
96
+ if _is_peft_model(unwrapped_model):
97
+ model_name = unwrapped_model.base_model.model._get_name()
98
+ else:
99
+ model_name = unwrapped_model._get_name()
100
+ if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
101
+ loss = self.label_smoother(outputs, labels, shift_labels=True)
102
+ else:
103
+ loss = self.label_smoother(outputs, labels)
104
+ else:
105
+ if isinstance(outputs, dict) and "loss" not in outputs:
106
+ raise ValueError(
107
+ "The model did not return a loss from the inputs, only the following keys: "
108
+ f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
109
+ )
110
+ # We don't use .loss here since the model may return tuples instead of ModelOutput.
111
+ loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
112
+ # ------------------------------------------------------------------------------
113
+ # target_params = self.oft_params if hasattr(self, 'oft_params') and self.oft_params else model.named_parameters()
114
+
115
+ # for param_item in target_params:
116
+ # # Handle both cached list (param only) and named_parameters (name, param)
117
+ # if isinstance(param_item, tuple):
118
+ # name, param = param_item
119
+ # if 'oft_r' not in name: continue
120
+ # else:
121
+ # param = param_item
122
+
123
+ # device = param.device
124
+ # householder_U_norm = param / param.norm(dim=0)
125
+ # orth_loss = torch.norm(
126
+ # torch.eye(householder_U_norm.size(1), device=device) - householder_U_norm.t() @ householder_U_norm)
127
+
128
+ # loss = loss + self.lamda * orth_loss.to(loss.device)
129
+
130
+ # ------------------------------------------------------------------------------
131
+
132
+ return (loss, outputs) if return_outputs else loss
133
+
134
+ def _move_model_to_device(self, model, device):
135
+ pass
136
+
137
+
138
+ @dataclass
139
+ class TrainingArguments(transformers.TrainingArguments):
140
+ model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
141
+ adapter_name_or_path: Optional[str] = field(default=None)
142
+ data_path: str = field(default=None, metadata={"help": "Path to the training data."})
143
+ dataset_split: str = field(
144
+ default="train[:100000]", metadata={"help": "(`['train', 'test', 'eval']`):"}
145
+ )
146
+ dataset_field: List[str] = field(
147
+ default=None, metadata={"help": "Fields of dataset input and output."}
148
+ )
149
+ optim: str = field(default="adamw_torch")
150
+ model_max_length: int = field(default=512, metadata={
151
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, )
152
+ hrft_r: int = field(default=8, metadata={
153
+ "help": "The rank of the adapter. When passing `None` and `adapter_name_or_path` is also `None`, full fine-tuning is used."})
154
+ init_a: float = field(default=1e-4, metadata={"help": "The initial weights"})
155
+ eps: float = field(default=1e-4, metadata={"help": "The control strength of COFT. The freedom of rotation."})
156
+ lamda: float = field(default=1e-4, metadata={"help": "The control strength of regularity"})
157
+ add_orth: str = field(default='none', metadata={"help": ""})
158
+ init_weights: Literal[True, "pissa"] = field(
159
+ default=True,
160
+ metadata={
161
+ "help": (
162
+ "Passing True (default) results in the LoRA initialization."
163
+ "Passing `pissa` results in PiSSA initialization."
164
+ ),
165
+ },
166
+ )
167
+
168
+
169
+ def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
170
+ """Collects the state dict and dump to disk."""
171
+ state_dict = trainer.model.state_dict()
172
+ if trainer.args.should_save:
173
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
174
+ del state_dict
175
+ trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
176
+
177
+
178
+ def smart_tokenizer_and_embedding_resize(
179
+ special_tokens_dict: Dict,
180
+ tokenizer: transformers.PreTrainedTokenizer,
181
+ model: transformers.PreTrainedModel,
182
+ ):
183
+ """Resize tokenizer and embedding.
184
+
185
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
186
+ """
187
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
188
+ model.resize_token_embeddings(len(tokenizer))
189
+
190
+ if num_new_tokens > 0:
191
+ input_embeddings = model.get_input_embeddings().weight.data
192
+ output_embeddings = model.get_output_embeddings().weight.data
193
+
194
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
195
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
196
+
197
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
198
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
199
+
200
+
201
+ def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
202
+ """Tokenize a list of strings."""
203
+ tokenized_list = [
204
+ tokenizer(
205
+ text,
206
+ return_tensors="pt",
207
+ padding="longest",
208
+ max_length=tokenizer.model_max_length,
209
+ truncation=True,
210
+ )
211
+ for text in strings
212
+ ]
213
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
214
+ input_ids_lens = labels_lens = [
215
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
216
+ ]
217
+ return dict(
218
+ input_ids=input_ids,
219
+ labels=labels,
220
+ input_ids_lens=input_ids_lens,
221
+ labels_lens=labels_lens,
222
+ )
223
+
224
+
225
+ def preprocess(
226
+ sources: Sequence[str],
227
+ targets: Sequence[str],
228
+ tokenizer: transformers.PreTrainedTokenizer,
229
+ ) -> Dict:
230
+ """Preprocess the data by tokenizing."""
231
+ examples = [s + t for s, t in zip(sources, targets)]
232
+ examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
233
+ input_ids = examples_tokenized["input_ids"]
234
+ labels = copy.deepcopy(input_ids)
235
+ for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
236
+ label[:source_len] = IGNORE_INDEX
237
+ return dict(input_ids=input_ids, labels=labels)
238
+
239
+
240
+ @dataclass
241
+ class DataCollatorForSupervisedDataset(object):
242
+ """Collate examples for supervised fine-tuning."""
243
+
244
+ tokenizer: transformers.PreTrainedTokenizer
245
+
246
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
247
+ input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
248
+ input_ids = [torch.tensor(x) for x in input_ids]
249
+ input_ids = torch.nn.utils.rnn.pad_sequence(
250
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
251
+ )
252
+ labels = [torch.tensor(x) for x in labels]
253
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
254
+ return dict(
255
+ input_ids=input_ids,
256
+ labels=labels,
257
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
258
+ )
259
+
260
+
261
+ def train_tokenize_function(examples, tokenizer, query, response):
262
+ sources = [PROMPT.format_map(dict(instruction=instruction)) for instruction in examples[query]]
263
+ targets = [f"{output}{tokenizer.eos_token}" for output in examples[response]]
264
+ data_dict = preprocess(sources, targets, tokenizer)
265
+ return data_dict
266
+
267
+
268
+ def train():
269
+ parser = transformers.HfArgumentParser(TrainingArguments)
270
+ script_args = parser.parse_args_into_dataclasses()[0]
271
+ # print(script_args)
272
+ model = transformers.AutoModelForCausalLM.from_pretrained(
273
+ script_args.model_name_or_path,
274
+ device_map={"": 0}, #device_map="auto",
275
+ )
276
+ if script_args.adapter_name_or_path is not None:
277
+ print(f"Load {script_args.init_weights} from {script_args.adapter_name_or_path}: ", )
278
+ model = PeftModel.from_pretrained(model, script_args.model_name_or_path,
279
+ subfolder=script_args.adapter_name_or_path, is_trainable=True)
280
+ elif script_args.hrft_r is not None:
281
+ print(f"Initilized {script_args.init_weights} layers")
282
+
283
+ hra_config = OFTConfig(
284
+ r= script_args.hrft_r,
285
+ eps=script_args.eps,
286
+ init_weights=script_args.init_weights,
287
+ target_modules=["q_proj", "v_proj"],
288
+ task_type="CAUSAL_LM",
289
+ )
290
+ model = get_peft_model(model, hra_config)
291
+ else:
292
+ print("Full Parameter Fine-Tuning")
293
+
294
+ print(model)
295
+ model.print_trainable_parameters()
296
+ # import time
297
+ # print("Program starts")
298
+ # time.sleep(300)
299
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
300
+ script_args.model_name_or_path,
301
+ model_max_length=script_args.model_max_length,
302
+ padding_side="right",
303
+ use_fast=True,
304
+ )
305
+ if tokenizer.pad_token is None:
306
+ smart_tokenizer_and_embedding_resize(
307
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
308
+ tokenizer=tokenizer,
309
+ model=model,
310
+ )
311
+
312
+ if "llama" in script_args.model_name_or_path:
313
+ tokenizer.add_special_tokens(
314
+ {
315
+ "eos_token": DEFAULT_EOS_TOKEN,
316
+ "bos_token": DEFAULT_BOS_TOKEN,
317
+ "unk_token": DEFAULT_UNK_TOKEN,
318
+ }
319
+ )
320
+ # if tokenizer.pad_token is None:
321
+ # if tokenizer.unk_token_id is not None:
322
+ # tokenizer.pad_token_id = tokenizer.unk_token_id
323
+ # tokenizer.pad_token = tokenizer.unk_token
324
+ # print("Set PAD token to UNK token.")
325
+ # elif tokenizer.eos_token_id is not None:
326
+ # tokenizer.pad_token_id = tokenizer.eos_token_id
327
+ # tokenizer.pad_token = tokenizer.eos_token
328
+ # print("Set PAD token to EOS token.")
329
+
330
+ # if model is not None:
331
+ # model.config.pad_token_id = tokenizer.pad_token_id
332
+ # if model.config.pad_token_id != tokenizer.pad_token_id:
333
+ # raise ValueError("Failed to sync pad_token_id between tokenizer and model config")
334
+
335
+ raw_train_datasets = load_dataset("json", data_files=script_args.data_path, split=script_args.dataset_split, name="metamath_qa",)
336
+ train_dataset = raw_train_datasets.map(
337
+ train_tokenize_function,
338
+ batched=True,
339
+ batch_size=3000,
340
+ num_proc=32,
341
+ remove_columns=raw_train_datasets.column_names,
342
+ load_from_cache_file=True,
343
+ desc="Running tokenizer on train dataset",
344
+ fn_kwargs={"tokenizer": tokenizer, "query": script_args.dataset_field[0],
345
+ "response": script_args.dataset_field[1]}
346
+ )
347
+
348
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
349
+ data_module = dict(train_dataset=train_dataset, data_collator=data_collator)
350
+ trainer = MyTrainer(model=model, tokenizer=tokenizer, lamda=script_args.lamda, args=script_args, **data_module)
351
+ model.config.use_cache = False
352
+
353
+ start_time = datetime.now()
354
+ print('start time: ', start_time.strftime("%Y-%m-%d %H:%M:%S"))
355
+
356
+ trainer.train()
357
+
358
+ end_time = datetime.now()
359
+ print('end time: ', end_time.strftime("%Y-%m-%d %H:%M:%S"), '| duration: ', end_time - start_time)
360
+
361
+ # trainer.save_state()
362
+ tokenizer.save_pretrained(os.path.join(script_args.output_dir, 'ft'))
363
+ model.save_pretrained(os.path.join(script_args.output_dir, 'ft'))
364
+
365
+
366
+ if __name__ == "__main__":
367
+
368
+ train()
llama/inference/MATH_inference.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import pdb
4
+ import jsonlines
5
+
6
+ import util
7
+ from vllm import LLM, SamplingParams
8
+ import sys
9
+ MAX_INT = sys.maxsize
10
+ INVALID_ANS = "[invalid]"
11
+
12
+ invalid_outputs = []
13
+ def remove_boxed(s):
14
+ left = "\\boxed{"
15
+ try:
16
+ assert s[:len(left)] == left
17
+ assert s[-1] == "}"
18
+ return s[len(left):-1]
19
+ except:
20
+ return None
21
+
22
+ def process_results(doc, completion, answer):
23
+ split_ans = completion.split('The answer is: ')
24
+ if len(split_ans) > 1:
25
+ ans = split_ans[-1]
26
+ extract_ans_temp = ans.split('.\n')[0]
27
+ extract_ans_temp = extract_ans_temp.strip()
28
+ if len(extract_ans_temp)>0 and extract_ans_temp[-1] == '.':
29
+ extract_ans = extract_ans_temp[0:-1]
30
+ else:
31
+ extract_ans = extract_ans_temp
32
+ extract_ans = extract_ans.strip()
33
+ if util.is_equiv(extract_ans, answer):
34
+ return True
35
+ else:
36
+ return False
37
+ else:
38
+ temp = {'question': doc, 'output': completion, 'answer': answer}
39
+ invalid_outputs.append(temp)
40
+ return False
41
+ def batch_data(data_list, batch_size=1):
42
+ n = len(data_list) // batch_size
43
+ batch_data = []
44
+ for i in range(n-1):
45
+ start = i * batch_size
46
+ end = (i+1)*batch_size
47
+ batch_data.append(data_list[start:end])
48
+
49
+ last_start = (n-1) * batch_size
50
+ last_end = MAX_INT
51
+ batch_data.append(data_list[last_start:last_end])
52
+ return batch_data
53
+
54
+ def test_hendrycks_math(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1):
55
+ hendrycks_math_ins = []
56
+ hendrycks_math_answers = []
57
+ problem_prompt = (
58
+ "Below is an instruction that describes a task. "
59
+ "Write a response that appropriately completes the request.\n\n"
60
+ "### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
61
+ )
62
+ print('promt =====', problem_prompt)
63
+ with open(data_path, "r+", encoding="utf8") as f:
64
+ for idx, item in enumerate(jsonlines.Reader(f)):
65
+ temp_instr = problem_prompt.format(instruction=item["instruction"])
66
+ hendrycks_math_ins.append(temp_instr)
67
+ solution = item['output']
68
+ temp_ans = remove_boxed(util.last_boxed_only_string(solution))
69
+ hendrycks_math_answers.append(temp_ans)
70
+
71
+ print('total length ===', len(hendrycks_math_ins))
72
+ hendrycks_math_ins = hendrycks_math_ins[start:end]
73
+ hendrycks_math_answers = hendrycks_math_answers[start:end]
74
+ print('lenght ====', len(hendrycks_math_ins))
75
+ batch_hendrycks_math_ins = batch_data(hendrycks_math_ins, batch_size=batch_size)
76
+
77
+ stop_tokens = ["Instruction:", "Instruction", "Response:", "Response"]
78
+ sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=2048, stop=stop_tokens)
79
+ print('sampleing =====', sampling_params)
80
+ llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
81
+ outputs = llm.generate(hendrycks_math_ins, sampling_params)
82
+ res_completions = [output.outputs[0].text for output in outputs]
83
+
84
+ results = []
85
+ for idx, (prompt, completion, prompt_answer) in enumerate(zip(hendrycks_math_ins, res_completions, hendrycks_math_answers)):
86
+ res = process_results(prompt, completion, prompt_answer)
87
+ results.append(res)
88
+
89
+ acc = sum(results) / len(results)
90
+ print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', len(invalid_outputs))
91
+ # print('start===', start, ', end====',end)
92
+ print('length====', len(results), ', acc====', acc)
93
+
94
+ def parse_args():
95
+ parser = argparse.ArgumentParser()
96
+ parser.add_argument("--model", type=str, default=0) # model path
97
+ parser.add_argument("--data_file", type=str, default='data/MATH_test.jsonl') # data path
98
+ parser.add_argument("--start", type=int, default=0) #start index
99
+ parser.add_argument("--end", type=int, default=MAX_INT) # end index
100
+ parser.add_argument("--batch_size", type=int, default=50) # batch_size
101
+ parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size
102
+ return parser.parse_args()
103
+
104
+ if __name__ == "__main__":
105
+ args = parse_args()
106
+ test_hendrycks_math(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size)
107
+
108
+
llama/inference/grader.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
3
+ - https://github.com/microsoft/ProphetNet/tree/master/CRITIC
4
+ """
5
+ import multiprocessing
6
+ from math import isclose
7
+ from typing import Union
8
+
9
+ from sympy import simplify, N
10
+ from sympy.parsing.sympy_parser import parse_expr
11
+ from sympy.parsing.latex import parse_latex
12
+
13
+
14
+ def is_digit(s):
15
+ try:
16
+ float(str(s).replace(",", ""))
17
+ return True
18
+ except ValueError:
19
+ return False
20
+
21
+ def math_equal(prediction: Union[bool, float, str],
22
+ reference: Union[float, str],
23
+ include_percentage: bool = True,
24
+ is_close: bool = True,
25
+ timeout: bool = False,
26
+ ) -> bool:
27
+ """
28
+ Exact match of math if and only if:
29
+ 1. numerical equal: both can convert to float and are equal
30
+ 2. symbolic equal: both can convert to sympy expression and are equal
31
+ """
32
+ try: # 1. numerical equal
33
+ if is_digit(prediction) and is_digit(reference):
34
+ prediction = float(str(prediction).replace(",", ""))
35
+ reference = float(str(reference).replace(",", ""))
36
+ # number questions
37
+ if include_percentage:
38
+ gt_result = [reference / 100, reference, reference * 100]
39
+ else:
40
+ gt_result = [reference]
41
+ for item in gt_result:
42
+ try:
43
+ if is_close:
44
+ if isclose(item, prediction, rel_tol=1e-4):
45
+ return True
46
+ else:
47
+ if item == prediction:
48
+ return True
49
+ except Exception:
50
+ continue
51
+ return False
52
+ except:
53
+ pass
54
+
55
+ if not prediction and prediction not in [0, False]:
56
+ return False
57
+
58
+ # 2. symbolic equal
59
+ reference = str(reference).strip()
60
+ prediction = str(prediction).strip()
61
+
62
+ ## deal with [], (), {}
63
+ pred_str, ref_str = prediction, reference
64
+ if (prediction.startswith("[") and prediction.endswith("]") and not reference.startswith("(")) or \
65
+ (prediction.startswith("(") and prediction.endswith(")") and not reference.startswith("[")):
66
+ pred_str = pred_str.strip("[]()")
67
+ ref_str = ref_str.strip("[]()")
68
+ for s in ['{', "}", "(", ")"]:
69
+ ref_str = ref_str.replace(s, "")
70
+ pred_str = pred_str.replace(s, "")
71
+ if pred_str == ref_str:
72
+ return True
73
+
74
+ ## [a, b] vs. [c, d], return a==c and b==d
75
+ if (prediction.startswith("[") and prediction.endswith("]")) and (reference.startswith("[") and reference.endswith("]")) or \
76
+ (prediction.startswith("(") and prediction.endswith(")")) and (reference.startswith("(") and reference.endswith(")")):
77
+ pred_parts = prediction[1:-1].split(",")
78
+ ref_parts = reference[1:-1].split(",")
79
+ if len(pred_parts) == len(ref_parts):
80
+ if all([math_equal(pred_parts[i], ref_parts[i], include_percentage, is_close) for i in range(len(pred_parts))]):
81
+ return True
82
+
83
+ # symbolic equal with sympy
84
+ if timeout:
85
+ if call_with_timeout(symbolic_equal_process, prediction, reference):
86
+ return True
87
+ else:
88
+ if symbolic_equal(prediction, reference):
89
+ return True
90
+
91
+ return False
92
+
93
+
94
+ def math_equal_process(param):
95
+ return math_equal(param[-2], param[-1])
96
+
97
+
98
+ def symbolic_equal(a, b):
99
+ def _parse(s):
100
+ for f in [parse_latex, parse_expr]:
101
+ try:
102
+ return f(s)
103
+ except:
104
+ pass
105
+ return s
106
+ a = _parse(a)
107
+ b = _parse(b)
108
+
109
+ try:
110
+ if simplify(a-b) == 0:
111
+ return True
112
+ except:
113
+ pass
114
+
115
+ try:
116
+ if isclose(N(a), N(b), rel_tol=1e-3):
117
+ return True
118
+ except:
119
+ pass
120
+ return False
121
+
122
+
123
+ def symbolic_equal_process(a, b, output_queue):
124
+ result = symbolic_equal(a, b)
125
+ output_queue.put(result)
126
+
127
+
128
+ def call_with_timeout(func, *args, timeout=1, **kwargs):
129
+ output_queue = multiprocessing.Queue()
130
+ process_args = args + (output_queue,)
131
+ process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
132
+ process.start()
133
+ process.join(timeout)
134
+
135
+ if process.is_alive():
136
+ process.terminate()
137
+ process.join()
138
+ return False
139
+
140
+ return output_queue.get()
141
+
llama/inference/gsm8k_inference.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import re
4
+ import jsonlines
5
+ from fraction import Fraction
6
+ from vllm import LLM, SamplingParams
7
+ import sys
8
+ from grader import math_equal
9
+ MAX_INT = sys.maxsize
10
+
11
+ def is_number(s):
12
+ try:
13
+ float(s)
14
+ return True
15
+ except ValueError:
16
+ pass
17
+ try:
18
+ import unicodedata
19
+ unicodedata.numeric(s)
20
+ return True
21
+ except (TypeError, ValueError):
22
+ pass
23
+ return False
24
+
25
+ def extract_answer_number(completion):
26
+ text = completion.split('The answer is: ')
27
+ if len(text) > 1:
28
+ extract_ans = text[-1].strip()
29
+ match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans)
30
+ if match:
31
+ if '/' in match.group():
32
+ denominator = match.group().split('/')[1]
33
+ numerator = match.group().split('/')[0]
34
+ if is_number(denominator) == True and is_number(numerator) == True:
35
+ if denominator == '0':
36
+ return round(float(numerator.replace(',', '')))
37
+ else:
38
+ frac = Fraction(match.group().replace(',', ''))
39
+ num_numerator = frac.numerator
40
+ num_denominator = frac.denominator
41
+ return round(float(num_numerator / num_denominator))
42
+ else:
43
+ return None
44
+ else:
45
+ if float(match.group().replace(',', '')) == float('inf'):
46
+ return None
47
+ return round(float(match.group().replace(',', '')))
48
+ else:
49
+ return None
50
+ else:
51
+ return None
52
+
53
+ def batch_data(data_list, batch_size=1):
54
+ n = len(data_list) // batch_size
55
+ batch_data = []
56
+ for i in range(n-1):
57
+ start = i * batch_size
58
+ end = (i+1)*batch_size
59
+ batch_data.append(data_list[start:end])
60
+
61
+ last_start = (n-1) * batch_size
62
+ last_end = MAX_INT
63
+ batch_data.append(data_list[last_start:last_end])
64
+ return batch_data
65
+
66
+
67
+ def gsm8k_test(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1):
68
+ INVALID_ANS = "[invalid]"
69
+ gsm8k_ins = []
70
+ gsm8k_answers = []
71
+ problem_prompt = (
72
+ "Below is an instruction that describes a task. "
73
+ "Write a response that appropriately completes the request.\n\n"
74
+ "### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
75
+ )
76
+ print('promt =====', problem_prompt)
77
+ with open(data_path,"r+", encoding="utf8") as f:
78
+ for idx, item in enumerate(jsonlines.Reader(f)):
79
+ temp_instr = problem_prompt.format(instruction=item["question"])
80
+ gsm8k_ins.append(temp_instr)
81
+ temp_ans = item['answer'].split('#### ')[1]
82
+ temp_ans = int(temp_ans.replace(',', ''))
83
+ gsm8k_answers.append(temp_ans)
84
+
85
+ gsm8k_ins = gsm8k_ins[start:end]
86
+ gsm8k_answers = gsm8k_answers[start:end]
87
+ print('lenght ====', len(gsm8k_ins))
88
+ batch_gsm8k_ins = batch_data(gsm8k_ins, batch_size=batch_size)
89
+
90
+ stop_tokens = ["Instruction:", "Instruction", "Response:", "Response"]
91
+ sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=1024, stop=stop_tokens)
92
+ print('sampleing =====', sampling_params)
93
+ llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
94
+ result = []
95
+
96
+ outputs = llm.generate(gsm8k_ins, sampling_params)
97
+ res_completions = [output.outputs[0].text for output in outputs]
98
+
99
+ invalid_outputs = []
100
+ for idx, (prompt, completion, prompt_answer) in enumerate(zip(gsm8k_ins, res_completions, gsm8k_answers)):
101
+ doc = {'question': prompt}
102
+ y_pred = extract_answer_number(completion)
103
+ if y_pred != None:
104
+ result.append(float(y_pred) == float(prompt_answer) or math_equal(y_pred, prompt_answer))
105
+ else:
106
+ result.append(False)
107
+ temp = {'question': prompt, 'output': completion, 'answer': prompt_answer}
108
+ invalid_outputs.append(temp)
109
+ acc = sum(result) / len(result)
110
+ print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', len(invalid_outputs))
111
+ # print('start===', start, ', end====', end)
112
+ print('gsm8k length====', len(result), ', gsm8k acc====', acc)
113
+
114
+
115
+ def parse_args():
116
+ parser = argparse.ArgumentParser()
117
+ parser.add_argument("--model", type=str) # model path
118
+ parser.add_argument("--data_file", type=str, default='data/gsm8k_test.jsonl') # data path
119
+ parser.add_argument("--start", type=int, default=0) #start index
120
+ parser.add_argument("--end", type=int, default=MAX_INT) # end index
121
+ parser.add_argument("--batch_size", type=int, default=60) # batch_size
122
+ parser.add_argument("--tensor_parallel_size", type=int, default=1) # tensor_parallel_size
123
+ return parser.parse_args()
124
+
125
+ if __name__ == "__main__":
126
+ args = parse_args()
127
+ gsm8k_test(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size)
llama/inference/util.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ from grader import math_equal
3
+
4
+ def last_boxed_only(sample):
5
+ q, a = sample
6
+ a = last_boxed_only_string(a)
7
+ if a == None:
8
+ return None
9
+ return (q, a)
10
+
11
+ def last_boxed_only_string(string):
12
+ idx = string.rfind("\\boxed")
13
+ if idx < 0:
14
+ idx = string.rfind("\\fbox")
15
+ if idx < 0:
16
+ return None
17
+
18
+ i = idx
19
+ right_brace_idx = None
20
+ num_left_braces_open = 0
21
+ while i < len(string):
22
+ if string[i] == "{":
23
+ num_left_braces_open += 1
24
+ if string[i] == "}":
25
+ num_left_braces_open -= 1
26
+ if num_left_braces_open == 0:
27
+ right_brace_idx = i
28
+ break
29
+ i += 1
30
+
31
+ if right_brace_idx == None:
32
+ retval = None
33
+ else:
34
+ retval = string[idx:right_brace_idx + 1]
35
+
36
+ return retval
37
+
38
+ def only_until_first_boxed_from_tokens(string, tokens):
39
+ idx = string.find("\\boxed")
40
+ if idx < 0:
41
+ idx = string.find("\\fbox")
42
+ if idx < 0:
43
+ return None
44
+
45
+ cum_length = 0
46
+ for i, t in enumerate(tokens):
47
+ cum_length += len(t)
48
+ if cum_length >= idx:
49
+ break
50
+
51
+ return tokens[:i]
52
+
53
+
54
+
55
+ def clean_numbers(sample):
56
+ if not sample:
57
+ return None
58
+ new_sample = list()
59
+ for s in sample:
60
+ new_sample.append(_clean_numbers(s))
61
+
62
+ return tuple(new_sample)
63
+
64
+ def _clean_numbers(string):
65
+ """
66
+ Clean Numbers in the given string
67
+
68
+ >>> _clean_numbers(None, "Hello 123")
69
+ 'Hello 123'
70
+ >>> _clean_numbers(None, "Hello 1234")
71
+ 'Hello 1,234'
72
+ >>> _clean_numbers(None, "Hello 1234324asdasd")
73
+ 'Hello 1,234,324asdasd'
74
+ """
75
+ num_prev_digits = 0
76
+ new_string = ""
77
+ for i, c in enumerate(string):
78
+ # isdigit() doesnt work here because of weird unicode chars.
79
+ if c in {'1', '2', '3', '4', '5', '6', '7', '8', '9', '0'}:
80
+ num_prev_digits += 1
81
+ else:
82
+ if num_prev_digits > 3:
83
+ # Some fixing
84
+ string_number = new_string[-num_prev_digits:]
85
+ new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number))
86
+ num_prev_digits = 0
87
+ new_string += c
88
+
89
+ if num_prev_digits > 3:
90
+ # Some fixing
91
+ string_number = new_string[-num_prev_digits:]
92
+ new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number))
93
+
94
+ return new_string
95
+
96
+ def fix_fracs(string):
97
+ substrs = string.split("\\frac")
98
+ new_str = substrs[0]
99
+ if len(substrs) > 1:
100
+ substrs = substrs[1:]
101
+ for substr in substrs:
102
+ new_str += "\\frac"
103
+ if substr[0] == "{":
104
+ new_str += substr
105
+ else:
106
+ try:
107
+ assert len(substr) >= 2
108
+ except AssertionError:
109
+ return string
110
+ a = substr[0]
111
+ b = substr[1]
112
+ if b != "{":
113
+ if len(substr) > 2:
114
+ post_substr = substr[2:]
115
+ new_str += "{" + a + "}{" + b + "}" + post_substr
116
+ else:
117
+ new_str += "{" + a + "}{" + b + "}"
118
+ else:
119
+ if len(substr) > 2:
120
+ post_substr = substr[2:]
121
+ new_str += "{" + a + "}" + b + post_substr
122
+ else:
123
+ new_str += "{" + a + "}" + b
124
+ string = new_str
125
+ return string
126
+
127
+ def fix_a_slash_b(string):
128
+ if len(string.split("/")) != 2:
129
+ return string
130
+ a = string.split("/")[0]
131
+ b = string.split("/")[1]
132
+ try:
133
+ a = int(a)
134
+ b = int(b)
135
+ assert string == "{}/{}".format(a, b)
136
+ new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
137
+ return new_string
138
+ except AssertionError:
139
+ return string
140
+
141
+ def remove_right_units(string):
142
+ # "\\text{ " only ever occurs (at least in the val set) when describing units
143
+ if "\\text{ " in string:
144
+ splits = string.split("\\text{ ")
145
+ assert len(splits) == 2
146
+ return splits[0]
147
+ else:
148
+ return string
149
+
150
+ def fix_sqrt(string):
151
+ if "\\sqrt" not in string:
152
+ return string
153
+ splits = string.split("\\sqrt")
154
+ new_string = splits[0]
155
+ for split in splits[1:]:
156
+ if split[0] != "{":
157
+ a = split[0]
158
+ new_substr = "\\sqrt{" + a + "}" + split[1:]
159
+ else:
160
+ new_substr = "\\sqrt" + split
161
+ new_string += new_substr
162
+ return new_string
163
+
164
+
165
+ def strip_string(string):
166
+ # linebreaks
167
+ string = string.replace("\n", "")
168
+
169
+ # remove inverse spaces
170
+ string = string.replace("\\!", "")
171
+
172
+ # replace \\ with \
173
+ string = string.replace("\\\\", "\\")
174
+
175
+ # replace tfrac and dfrac with frac
176
+ string = string.replace("tfrac", "frac")
177
+ string = string.replace("dfrac", "frac")
178
+
179
+ # remove \left and \right
180
+ string = string.replace("\\left", "")
181
+ string = string.replace("\\right", "")
182
+
183
+ # Remove circ (degrees)
184
+ string = string.replace("^{\\circ}", "")
185
+ string = string.replace("^\\circ", "")
186
+
187
+ # remove dollar signs
188
+ string = string.replace("\\$", "")
189
+
190
+ # remove units (on the right)
191
+ string = remove_right_units(string)
192
+
193
+ # remove percentage
194
+ string = string.replace("\\%", "")
195
+ string = string.replace("\%", "") # noqa: W605
196
+
197
+ # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
198
+ string = string.replace(" .", " 0.")
199
+ string = string.replace("{.", "{0.")
200
+ # if empty, return empty string
201
+ if len(string) == 0:
202
+ return string
203
+ if string[0] == ".":
204
+ string = "0" + string
205
+
206
+ # to consider: get rid of e.g. "k = " or "q = " at beginning
207
+ if len(string.split("=")) == 2:
208
+ if len(string.split("=")[0]) <= 2:
209
+ string = string.split("=")[1]
210
+
211
+ # fix sqrt3 --> sqrt{3}
212
+ string = fix_sqrt(string)
213
+
214
+ # remove spaces
215
+ string = string.replace(" ", "")
216
+
217
+ # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
218
+ string = fix_fracs(string)
219
+
220
+ # manually change 0.5 --> \frac{1}{2}
221
+ if string == "0.5":
222
+ string = "\\frac{1}{2}"
223
+
224
+ # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
225
+ string = fix_a_slash_b(string)
226
+
227
+ return string
228
+
229
+
230
+ def is_equiv(str1, str2, verbose=False):
231
+ if str1 is None and str2 is None:
232
+ print("WARNING: Both None")
233
+ return True
234
+ if str1 is None or str2 is None:
235
+ return False
236
+
237
+ try:
238
+ ss1 = strip_string(str1)
239
+ ss2 = strip_string(str2)
240
+ #pdb.set_trace()
241
+ if verbose:
242
+ print(ss1, ss2)
243
+ #return ss1 == ss2
244
+ res = math_equal(ss1,ss2) or ss1 == ss2
245
+ return res
246
+ except Exception:
247
+ #return str1 == str2
248
+ res = math_equal(str1,str1) or str1 == str2
249
+ return res
250
+
251
+ class NotEqual:
252
+ def __eq__(self, other):
253
+ return False
llama/merge_adapter_to_base_model.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer
2
+ from peft import PeftModel, PeftConfig
3
+ import argparse
4
+ import torch
5
+
6
+ parser = argparse.ArgumentParser(description='Merge Adapter to Base Model')
7
+ parser.add_argument('--base_mode', type=str)
8
+ parser.add_argument('--adapter', type=str)
9
+ parser.add_argument('--output_path', type=str)
10
+ args = parser.parse_args()
11
+
12
+ model = AutoModelForCausalLM.from_pretrained(args.base_mode, torch_dtype=torch.bfloat16, device_map="cpu")
13
+ tokenizer = AutoTokenizer.from_pretrained(args.base_mode, device_map='auto')
14
+ #
15
+ # tokenizer = AutoTokenizer.from_pretrained(args.adapter)
16
+ model.resize_token_embeddings(32001)
17
+ print('len', len(tokenizer))
18
+ print(f"Base model vocab size after resize: {model.get_input_embeddings().weight.shape[0]}")
19
+ #
20
+ lora_config = PeftConfig.from_pretrained(args.adapter)
21
+ lora_config.init_oft_weights=True
22
+ model = PeftModel.from_pretrained(model, args.adapter, config=lora_config)
23
+ model = model.merge_and_unload()
24
+
25
+ model.save_pretrained(args.output_path, safe_serialization=False)
26
+ tokenizer.save_pretrained(args.output_path)
27
+
llama/output/cp1e4/ft/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: meta-llama/Llama-2-7b-hf
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.10.0
llama/output/cp1e4/ft/adapter_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
5
+ "block_share": false,
6
+ "coft": false,
7
+ "eps": 0.0001,
8
+ "inference_mode": true,
9
+ "init_weights": true,
10
+ "layers_pattern": null,
11
+ "layers_to_transform": null,
12
+ "module_dropout": 0.0,
13
+ "modules_to_save": null,
14
+ "peft_type": "OFT",
15
+ "r": 32,
16
+ "rank_pattern": {},
17
+ "revision": null,
18
+ "target_modules": [
19
+ "v_proj",
20
+ "q_proj"
21
+ ],
22
+ "task_type": "CAUSAL_LM"
23
+ }
llama/output/cp1e4/ft/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "[PAD]": 32000
3
+ }
llama/output/cp1e4/ft/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "</s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "</s>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
llama/output/cp1e4/ft/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
llama/output/cp1e4/ft/tokenizer_config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "32000": {
31
+ "content": "[PAD]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ }
38
+ },
39
+ "bos_token": "</s>",
40
+ "clean_up_tokenization_spaces": false,
41
+ "eos_token": "</s>",
42
+ "extra_special_tokens": {},
43
+ "legacy": false,
44
+ "model_max_length": 512,
45
+ "pad_token": "[PAD]",
46
+ "padding_side": "right",
47
+ "sp_model_kwargs": {},
48
+ "tokenizer_class": "LlamaTokenizer",
49
+ "unk_token": "</s>",
50
+ "use_default_system_prompt": false
51
+ }
llama/output/cp1e5/ft/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: facebook/opt-125m
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.10.0
llama/output/cp1e5/ft/adapter_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "facebook/opt-125m",
5
+ "block_share": false,
6
+ "coft": false,
7
+ "eps": 0.0001,
8
+ "inference_mode": true,
9
+ "init_weights": true,
10
+ "layers_pattern": null,
11
+ "layers_to_transform": null,
12
+ "module_dropout": 0.0,
13
+ "modules_to_save": null,
14
+ "peft_type": "OFT",
15
+ "r": 8,
16
+ "rank_pattern": {},
17
+ "revision": null,
18
+ "target_modules": [
19
+ "v_proj",
20
+ "q_proj"
21
+ ],
22
+ "task_type": "CAUSAL_LM"
23
+ }
llama/output/cp1e5/trainer_state.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 2e-05,
5
+ "eval_steps": 500,
6
+ "global_step": 2,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 2e-05,
13
+ "step": 2,
14
+ "total_flos": 368078929920.0,
15
+ "train_loss": 2.170874834060669,
16
+ "train_runtime": 0.7734,
17
+ "train_samples_per_second": 2.586,
18
+ "train_steps_per_second": 2.586
19
+ }
20
+ ],
21
+ "logging_steps": 1000,
22
+ "max_steps": 2,
23
+ "num_input_tokens_seen": 0,
24
+ "num_train_epochs": 1,
25
+ "save_steps": 0,
26
+ "total_flos": 368078929920.0,
27
+ "train_batch_size": 1,
28
+ "trial_name": null,
29
+ "trial_params": null
30
+ }
llama/output/cp1e5N/ft/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: meta-llama/Llama-2-7b-hf
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.10.0
llama/output/cp1e5N/ft/adapter_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
5
+ "block_share": false,
6
+ "coft": false,
7
+ "eps": 0.0001,
8
+ "inference_mode": true,
9
+ "init_weights": true,
10
+ "layers_pattern": null,
11
+ "layers_to_transform": null,
12
+ "module_dropout": 0.0,
13
+ "modules_to_save": null,
14
+ "peft_type": "OFT",
15
+ "r": 32,
16
+ "rank_pattern": {},
17
+ "revision": null,
18
+ "target_modules": [
19
+ "q_proj",
20
+ "v_proj"
21
+ ],
22
+ "task_type": "CAUSAL_LM"
23
+ }
llama/output/cp1e5N/ft/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "[PAD]": 32000
3
+ }
llama/output/cp1e5N/ft/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "</s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "</s>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
llama/output/cp1e5N/ft/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
llama/output/cp1e5N/ft/tokenizer_config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "32000": {
31
+ "content": "[PAD]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ }
38
+ },
39
+ "bos_token": "</s>",
40
+ "clean_up_tokenization_spaces": false,
41
+ "eos_token": "</s>",
42
+ "extra_special_tokens": {},
43
+ "legacy": false,
44
+ "model_max_length": 512,
45
+ "pad_token": "[PAD]",
46
+ "padding_side": "right",
47
+ "sp_model_kwargs": {},
48
+ "tokenizer_class": "LlamaTokenizer",
49
+ "unk_token": "</s>",
50
+ "use_default_system_prompt": false
51
+ }
llama/output/cp3e5/ft/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: meta-llama/Llama-2-7b-hf
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.10.0
llama/output/cp3e5/ft/adapter_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
5
+ "block_share": false,
6
+ "coft": false,
7
+ "eps": 0.0001,
8
+ "inference_mode": true,
9
+ "init_weights": true,
10
+ "layers_pattern": null,
11
+ "layers_to_transform": null,
12
+ "module_dropout": 0.0,
13
+ "modules_to_save": null,
14
+ "peft_type": "OFT",
15
+ "r": 32,
16
+ "rank_pattern": {},
17
+ "revision": null,
18
+ "target_modules": [
19
+ "q_proj",
20
+ "v_proj"
21
+ ],
22
+ "task_type": "CAUSAL_LM"
23
+ }
llama/output/cp3e5/trainer_state.json ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 2.0,
5
+ "eval_steps": 500,
6
+ "global_step": 6250,
7
+ "is_hyper_param_search": false,
8
+ "is_local_process_zero": true,
9
+ "is_world_process_zero": true,
10
+ "log_history": [
11
+ {
12
+ "epoch": 0.32,
13
+ "grad_norm": 3.8161606788635254,
14
+ "learning_rate": 2.8241524699899885e-05,
15
+ "loss": 0.3071,
16
+ "step": 1000
17
+ },
18
+ {
19
+ "epoch": 0.64,
20
+ "grad_norm": 3.0841195583343506,
21
+ "learning_rate": 2.317615224686078e-05,
22
+ "loss": 0.2366,
23
+ "step": 2000
24
+ },
25
+ {
26
+ "epoch": 0.96,
27
+ "grad_norm": 2.72049880027771,
28
+ "learning_rate": 1.6067682497434074e-05,
29
+ "loss": 0.215,
30
+ "step": 3000
31
+ },
32
+ {
33
+ "epoch": 1.28,
34
+ "grad_norm": 2.600498914718628,
35
+ "learning_rate": 8.692414973449614e-06,
36
+ "loss": 0.186,
37
+ "step": 4000
38
+ },
39
+ {
40
+ "epoch": 1.6,
41
+ "grad_norm": 2.350414752960205,
42
+ "learning_rate": 2.893317942061826e-06,
43
+ "loss": 0.1763,
44
+ "step": 5000
45
+ },
46
+ {
47
+ "epoch": 1.92,
48
+ "grad_norm": 2.4736688137054443,
49
+ "learning_rate": 1.194984047782738e-07,
50
+ "loss": 0.1733,
51
+ "step": 6000
52
+ },
53
+ {
54
+ "epoch": 2.0,
55
+ "step": 6250,
56
+ "total_flos": 3.546740381344727e+18,
57
+ "train_loss": 0.21396501647949218,
58
+ "train_runtime": 35607.6251,
59
+ "train_samples_per_second": 5.617,
60
+ "train_steps_per_second": 0.176
61
+ }
62
+ ],
63
+ "logging_steps": 1000,
64
+ "max_steps": 6250,
65
+ "num_input_tokens_seen": 0,
66
+ "num_train_epochs": 2,
67
+ "save_steps": 0,
68
+ "total_flos": 3.546740381344727e+18,
69
+ "train_batch_size": 8,
70
+ "trial_name": null,
71
+ "trial_params": null
72
+ }
llama/output/cp3e5N/ft/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: meta-llama/Llama-2-7b-hf
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.10.0
llama/output/cp3e5N/ft/adapter_config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "alpha_pattern": {},
3
+ "auto_mapping": null,
4
+ "base_model_name_or_path": "meta-llama/Llama-2-7b-hf",
5
+ "block_share": false,
6
+ "coft": false,
7
+ "eps": 0.0001,
8
+ "inference_mode": true,
9
+ "init_weights": true,
10
+ "layers_pattern": null,
11
+ "layers_to_transform": null,
12
+ "module_dropout": 0.0,
13
+ "modules_to_save": null,
14
+ "peft_type": "OFT",
15
+ "r": 32,
16
+ "rank_pattern": {},
17
+ "revision": null,
18
+ "target_modules": [
19
+ "q_proj",
20
+ "v_proj"
21
+ ],
22
+ "task_type": "CAUSAL_LM"
23
+ }
llama/output/cp3e5N/ft/added_tokens.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {
2
+ "[PAD]": 32000
3
+ }
llama/output/cp3e5N/ft/special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "</s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "</s>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
llama/output/cp3e5N/ft/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
llama/output/cp3e5N/ft/tokenizer_config.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "add_prefix_space": null,
5
+ "added_tokens_decoder": {
6
+ "0": {
7
+ "content": "<unk>",
8
+ "lstrip": false,
9
+ "normalized": false,
10
+ "rstrip": false,
11
+ "single_word": false,
12
+ "special": true
13
+ },
14
+ "1": {
15
+ "content": "<s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false,
20
+ "special": true
21
+ },
22
+ "2": {
23
+ "content": "</s>",
24
+ "lstrip": false,
25
+ "normalized": false,
26
+ "rstrip": false,
27
+ "single_word": false,
28
+ "special": true
29
+ },
30
+ "32000": {
31
+ "content": "[PAD]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false,
36
+ "special": true
37
+ }
38
+ },
39
+ "bos_token": "</s>",
40
+ "clean_up_tokenization_spaces": false,
41
+ "eos_token": "</s>",
42
+ "extra_special_tokens": {},
43
+ "legacy": false,
44
+ "model_max_length": 512,
45
+ "pad_token": "[PAD]",
46
+ "padding_side": "right",
47
+ "sp_model_kwargs": {},
48
+ "tokenizer_class": "LlamaTokenizer",
49
+ "unk_token": "</s>",
50
+ "use_default_system_prompt": false
51
+ }
llama/output/cpr1/ft/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ base_model: meta-llama/Llama-2-7b-hf
3
+ library_name: peft
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
200
+ ### Framework versions
201
+
202
+ - PEFT 0.10.0