Spaces:
Running
Running
Commit ·
132da4a
1
Parent(s): 251ecbb
improve performance
Browse files
app.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
import spacy
|
| 2 |
import nltk
|
| 3 |
nltk.download('wordnet', quiet=True)
|
| 4 |
-
spacy.
|
|
|
|
| 5 |
from compute_lng import compute_lng
|
| 6 |
|
| 7 |
import torch
|
|
@@ -111,7 +112,7 @@ def impute_targets():
|
|
| 111 |
shared_state.target = round_ling(interp_raw).tolist()
|
| 112 |
return shared_state.target
|
| 113 |
|
| 114 |
-
def generate_with_feedback(sent1, approx):
|
| 115 |
if sent1 == '':
|
| 116 |
raise gr.Error('Please input a source text.')
|
| 117 |
|
|
@@ -122,24 +123,25 @@ def generate_with_feedback(sent1, approx):
|
|
| 122 |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
|
| 123 |
ling2 = torch.tensor(scaler.transform([shared_state.target])).float().to(device)
|
| 124 |
inputs = {
|
| 125 |
-
'
|
| 126 |
'sentence2_ling': ling2,
|
| 127 |
-
'
|
| 128 |
}
|
| 129 |
|
| 130 |
-
|
|
|
|
| 131 |
|
| 132 |
interpolation = '-- ' + '\n-- '.join(interpolations)
|
| 133 |
# Return both the generation results and the updated slider values
|
| 134 |
return [pred_text, interpolation] + [gr.update(value=val) for val in shared_state.target]
|
| 135 |
|
| 136 |
-
def generate_random(sent1, count, approx):
|
| 137 |
if sent1 == '':
|
| 138 |
raise gr.Error('Please input a source text.')
|
| 139 |
preds, interpolations = [], []
|
| 140 |
orig_active_indices = shared_state.active_indices
|
| 141 |
shared_state.active_indices = set(range(len(lng_names)))
|
| 142 |
-
for c in range(count):
|
| 143 |
idx = np.random.randint(0, len(ling_collection))
|
| 144 |
ling_ex = ling_collection[idx]
|
| 145 |
shared_state.target = ling_ex.copy()
|
|
@@ -167,7 +169,7 @@ def generate_random(sent1, count, approx):
|
|
| 167 |
shared_state.active_indices = orig_active_indices
|
| 168 |
return '\n***\n'.join(preds), '\n***\n'.join(interpolations)
|
| 169 |
|
| 170 |
-
def estimate_gen(sent1, sent2, approx):
|
| 171 |
if 'approximate' in approx:
|
| 172 |
input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device)
|
| 173 |
with torch.no_grad():
|
|
@@ -183,7 +185,7 @@ def estimate_gen(sent1, sent2, approx):
|
|
| 183 |
|
| 184 |
orig_active_indices = shared_state.active_indices
|
| 185 |
shared_state.active_indices = set(range(len(lng_names)))
|
| 186 |
-
gen = generate_with_feedback(sent1, approx)[:2]
|
| 187 |
shared_state.active_indices = orig_active_indices
|
| 188 |
return gen + [gr.update(value=val) for val in shared_state.target]
|
| 189 |
|
|
|
|
| 1 |
import spacy
|
| 2 |
import nltk
|
| 3 |
nltk.download('wordnet', quiet=True)
|
| 4 |
+
if not spacy.util.is_package('en_core_web_sm'):
|
| 5 |
+
spacy.cli.download('en_core_web_sm')
|
| 6 |
from compute_lng import compute_lng
|
| 7 |
|
| 8 |
import torch
|
|
|
|
| 112 |
shared_state.target = round_ling(interp_raw).tolist()
|
| 113 |
return shared_state.target
|
| 114 |
|
| 115 |
+
def generate_with_feedback(sent1, approx, progress=gr.Progress()):
|
| 116 |
if sent1 == '':
|
| 117 |
raise gr.Error('Please input a source text.')
|
| 118 |
|
|
|
|
| 123 |
input_ids = tokenizer.encode(sent1, return_tensors='pt').to(device)
|
| 124 |
ling2 = torch.tensor(scaler.transform([shared_state.target])).float().to(device)
|
| 125 |
inputs = {
|
| 126 |
+
'input_ids': input_ids,
|
| 127 |
'sentence2_ling': ling2,
|
| 128 |
+
'attention_mask': torch.ones_like(input_ids)
|
| 129 |
}
|
| 130 |
|
| 131 |
+
progress((0, None), unit='intermediate paraphrase generated.')
|
| 132 |
+
pred, (pred_text, interpolations) = model.infer_with_feedback_BP(ling_disc, sem_emb, inputs, tokenizer, progress)
|
| 133 |
|
| 134 |
interpolation = '-- ' + '\n-- '.join(interpolations)
|
| 135 |
# Return both the generation results and the updated slider values
|
| 136 |
return [pred_text, interpolation] + [gr.update(value=val) for val in shared_state.target]
|
| 137 |
|
| 138 |
+
def generate_random(sent1, count, approx, progress=gr.Progress()):
|
| 139 |
if sent1 == '':
|
| 140 |
raise gr.Error('Please input a source text.')
|
| 141 |
preds, interpolations = [], []
|
| 142 |
orig_active_indices = shared_state.active_indices
|
| 143 |
shared_state.active_indices = set(range(len(lng_names)))
|
| 144 |
+
for c in progress.tqdm(range(count), desc='Generating random sentences', unit='paraphrases'):
|
| 145 |
idx = np.random.randint(0, len(ling_collection))
|
| 146 |
ling_ex = ling_collection[idx]
|
| 147 |
shared_state.target = ling_ex.copy()
|
|
|
|
| 169 |
shared_state.active_indices = orig_active_indices
|
| 170 |
return '\n***\n'.join(preds), '\n***\n'.join(interpolations)
|
| 171 |
|
| 172 |
+
def estimate_gen(sent1, sent2, approx, progress=gr.Progress()):
|
| 173 |
if 'approximate' in approx:
|
| 174 |
input_ids = tokenizer.encode(sent2, return_tensors='pt').to(device)
|
| 175 |
with torch.no_grad():
|
|
|
|
| 185 |
|
| 186 |
orig_active_indices = shared_state.active_indices
|
| 187 |
shared_state.active_indices = set(range(len(lng_names)))
|
| 188 |
+
gen = generate_with_feedback(sent1, approx, progress)[:2]
|
| 189 |
shared_state.active_indices = orig_active_indices
|
| 190 |
return gen + [gr.update(value=val) for val in shared_state.target]
|
| 191 |
|
const.py
CHANGED
|
@@ -1030,7 +1030,6 @@ used_indices = [
|
|
| 1030 |
63, 64, 65, 66, 67, 68, 73, 121, 124, 129, 134, 136, 254,
|
| 1031 |
257, 258, 261, 263, 272, 274
|
| 1032 |
]
|
| 1033 |
-
lftk_used_indices = [1, 7, 8, 9, 10, 11, 12, 17, 65, 68, 73, 78, 80, 198, 201, 202, 205, 207, 216, 218]
|
| 1034 |
|
| 1035 |
eval_indices = [4,5,6,18,257,272]
|
| 1036 |
eval_indices = [used_indices.index(idx) for idx in eval_indices]
|
|
|
|
| 1030 |
63, 64, 65, 66, 67, 68, 73, 121, 124, 129, 134, 136, 254,
|
| 1031 |
257, 258, 261, 263, 272, 274
|
| 1032 |
]
|
|
|
|
| 1033 |
|
| 1034 |
eval_indices = [4,5,6,18,257,272]
|
| 1035 |
eval_indices = [used_indices.index(idx) for idx in eval_indices]
|
model.py
CHANGED
|
@@ -10,6 +10,10 @@ from types import MethodType
|
|
| 10 |
from utils import *
|
| 11 |
from ling_disc import DebertaReplacedTokenizer
|
| 12 |
from const import *
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
|
|
@@ -77,9 +81,9 @@ class LingGenerator(nn.Module):
|
|
| 77 |
bs = inputs_embeds.shape[0]
|
| 78 |
|
| 79 |
if self.gen_input == 's+l':
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
inputs_embeds = inputs_embeds +
|
| 83 |
|
| 84 |
gen = self.gen(inputs_embeds=inputs_embeds,
|
| 85 |
attention_mask=inputs_att_mask).last_hidden_state.mean(1)
|
|
@@ -185,13 +189,13 @@ class SemEmb(T5EncoderModel):
|
|
| 185 |
nn.Linear(hidden_dim, 1))
|
| 186 |
|
| 187 |
def compare_sem(self, **batch):
|
| 188 |
-
bs = batch['
|
| 189 |
-
ones = torch.ones((bs, 1), device=batch['
|
| 190 |
sep = torch.ones((bs, 1), dtype=torch.long,
|
| 191 |
-
device=batch['
|
| 192 |
-
att_mask = torch.cat([batch['
|
| 193 |
if 'logits' in batch:
|
| 194 |
-
input_ids = torch.cat([batch['
|
| 195 |
embeds1 = self.shared(input_ids)
|
| 196 |
|
| 197 |
logits = batch['logits']
|
|
@@ -201,11 +205,11 @@ class SemEmb(T5EncoderModel):
|
|
| 201 |
|
| 202 |
embeds2 = onehot_ @ self.shared.weight
|
| 203 |
embeds1_2 = torch.cat([embeds1, embeds2], dim=1)
|
| 204 |
-
hidden_units =
|
| 205 |
attention_mask=att_mask).last_hidden_state.mean(1)
|
| 206 |
elif 'sentence2_input_ids' in batch:
|
| 207 |
-
input_ids = torch.cat([batch['
|
| 208 |
-
hidden_units =
|
| 209 |
attention_mask=att_mask).last_hidden_state.mean(1)
|
| 210 |
probs = self.projection(hidden_units)
|
| 211 |
return probs
|
|
@@ -222,31 +226,36 @@ def prepare_inputs_for_generation(
|
|
| 222 |
cross_attn_head_mask=None,
|
| 223 |
use_cache=None,
|
| 224 |
encoder_outputs=None,
|
| 225 |
-
|
| 226 |
-
|
| 227 |
**kwargs
|
| 228 |
):
|
| 229 |
-
|
| 230 |
# cut decoder_input_ids if past is used
|
| 231 |
if past_key_values is not None:
|
| 232 |
input_ids = input_ids[:, -1:]
|
| 233 |
|
|
|
|
|
|
|
| 234 |
input_ids = input_ids.clone()
|
| 235 |
decoder_inputs_embeds = self.shared(input_ids)
|
| 236 |
|
| 237 |
-
if combine_method == '
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
| 241 |
if ling2_only:
|
| 242 |
-
decoder_inputs_embeds = torch.cat([
|
| 243 |
else:
|
| 244 |
-
decoder_inputs_embeds = torch.cat([
|
| 245 |
-
|
|
|
|
| 246 |
if ling2_only:
|
| 247 |
-
decoder_inputs_embeds = decoder_inputs_embeds +
|
| 248 |
else:
|
| 249 |
-
decoder_inputs_embeds = decoder_inputs_embeds +
|
| 250 |
|
| 251 |
return {
|
| 252 |
"decoder_inputs_embeds": decoder_inputs_embeds,
|
|
@@ -257,19 +266,27 @@ def prepare_inputs_for_generation(
|
|
| 257 |
"decoder_head_mask": decoder_head_mask,
|
| 258 |
"cross_attn_head_mask": cross_attn_head_mask,
|
| 259 |
"use_cache": use_cache,
|
|
|
|
| 260 |
}
|
| 261 |
|
| 262 |
class LogitsAdd(LogitsProcessor):
|
| 263 |
-
def __init__(self,
|
| 264 |
super().__init__()
|
| 265 |
-
self.
|
| 266 |
|
| 267 |
def __call__(self, input_ids, scores):
|
| 268 |
-
return scores + self.
|
| 269 |
|
| 270 |
-
class EncoderDecoderVAE(
|
| 271 |
def __init__(self, config, args, pad_token_id, sepeos_token_id, vocab_size = 32128):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
super().__init__(config)
|
|
|
|
| 273 |
self.prepare_inputs_for_generation = types.MethodType(
|
| 274 |
partial(prepare_inputs_for_generation, args.combine_method, args.ling2_only),
|
| 275 |
self)
|
|
@@ -287,7 +304,7 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
| 287 |
nn.ReLU(),
|
| 288 |
nn.Linear(hidden_dim, hidden_dim),
|
| 289 |
)
|
| 290 |
-
elif 'concat' in args.combine_method or 'add' in args.combine_method:
|
| 291 |
if args.ling_embed_type == 'two-layer':
|
| 292 |
self.ling_embed = nn.Sequential(
|
| 293 |
nn.Linear(args.lng_dim, args.lng_dim),
|
|
@@ -297,6 +314,7 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
| 297 |
else:
|
| 298 |
self.ling_embed = nn.Linear(args.lng_dim, hidden_dim)
|
| 299 |
self.ling_dropout = nn.Dropout(args.ling_dropout)
|
|
|
|
| 300 |
|
| 301 |
if args.ling_vae:
|
| 302 |
self.ling_mu = nn.Linear(hidden_dim, hidden_dim)
|
|
@@ -306,8 +324,20 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
| 306 |
nn.init.xavier_uniform_(self.ling_logvar.weight)
|
| 307 |
|
| 308 |
|
| 309 |
-
generate_with_grad = unwrap(
|
| 310 |
self.generate_with_grad = MethodType(generate_with_grad, self)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
def get_fusion_layer(self):
|
| 313 |
if 'fusion' in self.args.combine_method:
|
|
@@ -321,122 +351,143 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
| 321 |
std = torch.exp(0.5 * logvar)
|
| 322 |
return mu + std * torch.randn_like(std)
|
| 323 |
|
| 324 |
-
def
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
else:
|
| 328 |
-
|
| 329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
bs = inputs_embeds.shape[0]
|
| 331 |
-
|
| 332 |
if self.args.combine_method in ('input_concat', 'input_add'):
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
sent2_ling = batch['sent2_ling_embed']
|
| 339 |
-
else:
|
| 340 |
-
sent2_ling = self.ling_embed(self.ling_dropout(batch['sentence2_ling']))
|
| 341 |
-
if self.args.ling_vae:
|
| 342 |
-
sent1_ling = F.leaky_relu(sent1_ling)
|
| 343 |
-
sent1_mu, sent1_logvar = self.ling_mu(sent1_ling), self.ling_logvar(sent1_ling)
|
| 344 |
-
sent1_ling = self.sample(sent1_mu, sent1_logvar)
|
| 345 |
-
|
| 346 |
-
sent2_ling = F.leaky_relu(sent2_ling)
|
| 347 |
-
sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling)
|
| 348 |
-
sent2_ling = self.sample(sent2_mu, sent2_logvar)
|
| 349 |
-
cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
|
| 350 |
-
'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
|
| 351 |
-
'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
|
| 352 |
-
else:
|
| 353 |
-
cache.update({'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
|
| 354 |
-
sent1_ling = sent1_ling.view(bs, 1, -1)
|
| 355 |
-
sent2_ling = sent2_ling.view(bs, 1, -1)
|
| 356 |
if self.args.combine_method == 'input_concat':
|
| 357 |
if self.args.ling2_only:
|
| 358 |
-
inputs_embeds = torch.cat([inputs_embeds,
|
| 359 |
inputs_att_mask = torch.cat([inputs_att_mask,
|
| 360 |
torch.ones((bs, 1)).to(inputs_embeds.device)], dim=1)
|
| 361 |
else:
|
| 362 |
-
inputs_embeds = torch.cat([inputs_embeds,
|
| 363 |
inputs_att_mask = torch.cat([inputs_att_mask,
|
| 364 |
torch.ones((bs, 2)).to(inputs_embeds.device)], dim=1)
|
| 365 |
elif self.args.combine_method == 'input_add':
|
| 366 |
if self.args.ling2_only:
|
| 367 |
-
inputs_embeds = inputs_embeds +
|
| 368 |
else:
|
| 369 |
-
inputs_embeds = inputs_embeds +
|
|
|
|
|
|
|
|
|
|
| 370 |
return self.encoder(inputs_embeds=inputs_embeds,
|
| 371 |
attention_mask=inputs_att_mask), inputs_att_mask, cache
|
| 372 |
|
| 373 |
-
def decode(self,
|
| 374 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
cache = {}
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
sent2_ling = F.leaky_relu(sent2_ling)
|
| 393 |
-
sent2_mu, sent2_logvar = self.ling_mu(sent2_ling), self.ling_logvar(sent2_ling)
|
| 394 |
-
sent2_ling = self.sample(sent2_mu, sent2_logvar)
|
| 395 |
-
cache.update({'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
|
| 396 |
-
'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
|
| 397 |
-
'sent1_ling': sent1_ling, 'sent2_ling': sent2_ling})
|
| 398 |
-
else:
|
| 399 |
-
cache.update({'sent2_ling': sent2_ling})
|
| 400 |
-
if sent1_ling is not None:
|
| 401 |
-
cache.update({'sent1_ling': sent1_ling})
|
| 402 |
-
if sent1_ling is not None:
|
| 403 |
-
sent1_ling = sent1_ling.view(bs, 1, -1)
|
| 404 |
-
sent2_ling = sent2_ling.view(bs, 1, -1)
|
| 405 |
-
if self.args.combine_method == 'decoder_add_first' and not generate:
|
| 406 |
-
sent2_ling = torch.cat([sent2_ling,
|
| 407 |
-
torch.repeat_interleave(torch.zeros_like(sent2_ling), batch['sentence2_input_ids'].shape[1] - 1, dim=1)], dim = 1)
|
| 408 |
-
else:
|
| 409 |
-
sent1_ling, sent2_ling = None, None
|
| 410 |
-
|
| 411 |
-
if self.args.combine_method == 'embed_concat':
|
| 412 |
-
enc_output.last_hidden_state = torch.cat([enc_output.last_hidden_state,
|
| 413 |
-
sent1_ling, sent2_ling], dim=1)
|
| 414 |
-
inputs_att_mask = torch.cat([inputs_att_mask,
|
| 415 |
-
torch.ones((bs, 2)).to(inputs_att_mask.device)], dim=1)
|
| 416 |
-
elif 'fusion' in self.args.combine_method:
|
| 417 |
-
sent1_ling = batch['sentence1_ling'].unsqueeze(1)\
|
| 418 |
-
.expand(-1, enc_output.last_hidden_state.shape[1], -1)
|
| 419 |
-
sent2_ling = batch['sentence2_ling'].unsqueeze(1)\
|
| 420 |
-
.expand(-1, enc_output.last_hidden_state.shape[1], -1)
|
| 421 |
-
if self.args.ling2_only:
|
| 422 |
-
combined_embedding = torch.cat([enc_output.last_hidden_state, sent2_ling], dim=2)
|
| 423 |
else:
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
if generate:
|
| 428 |
if self.args.combine_method == 'logits_add':
|
| 429 |
-
logits_processor = LogitsProcessorList([LogitsAdd(
|
| 430 |
else:
|
| 431 |
logits_processor = LogitsProcessorList()
|
| 432 |
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
output_scores=True,
|
| 440 |
logits_processor = logits_processor,
|
| 441 |
# renormalize_logits=True,
|
| 442 |
# do_sample=True,
|
|
@@ -445,68 +496,135 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
| 445 |
# min_new_tokens=3,
|
| 446 |
# repetition_penalty=1.2,
|
| 447 |
max_length=self.args.max_length,
|
|
|
|
|
|
|
| 448 |
)
|
| 449 |
-
|
| 450 |
-
cache.update({'scores': scores})
|
| 451 |
-
return dec_output.sequences, cache
|
| 452 |
-
|
| 453 |
-
decoder_input_ids = self._shift_right(batch['sentence2_input_ids'])
|
| 454 |
-
decoder_inputs_embeds = self.shared(decoder_input_ids)
|
| 455 |
-
decoder_att_mask = batch['sentence2_attention_mask']
|
| 456 |
-
labels = batch['sentence2_input_ids'].clone()
|
| 457 |
-
labels[labels == self.pad_token_id] = -100
|
| 458 |
-
|
| 459 |
-
if self.args.combine_method == 'decoder_concat':
|
| 460 |
-
if self.args.ling2_only:
|
| 461 |
-
decoder_inputs_embeds = torch.cat([sent2_ling, decoder_inputs_embeds], dim=1)
|
| 462 |
-
decoder_att_mask = torch.cat([torch.ones((bs, 1)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1)
|
| 463 |
-
labels = torch.cat([torch.ones((bs, 1), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
|
| 464 |
-
labels], dim=1)
|
| 465 |
-
else:
|
| 466 |
-
decoder_inputs_embeds = torch.cat([sent1_ling, sent2_ling, decoder_inputs_embeds], dim=1)
|
| 467 |
-
decoder_att_mask = torch.cat([torch.ones((bs, 2)).to(decoder_inputs_embeds.device), decoder_att_mask], dim=1)
|
| 468 |
-
labels = torch.cat([torch.ones((bs, 2), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
|
| 469 |
-
labels], dim=1)
|
| 470 |
-
elif self.args.combine_method == 'decoder_add' or self.args.combine_method == 'decoder_add_first' :
|
| 471 |
-
if self.args.ling2_only:
|
| 472 |
-
decoder_inputs_embeds = decoder_inputs_embeds + self.args.combine_weight * sent2_ling
|
| 473 |
-
else:
|
| 474 |
-
decoder_inputs_embeds = decoder_inputs_embeds + sent1_ling + sent2_ling
|
| 475 |
|
| 476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 478 |
-
decoder_attention_mask=
|
| 479 |
-
encoder_outputs=
|
| 480 |
-
attention_mask=
|
| 481 |
labels=labels,
|
|
|
|
|
|
|
| 482 |
)
|
| 483 |
if self.args.combine_method == 'logits_add':
|
| 484 |
-
dec_output.logits = dec_output.logits + self.args.combine_weight *
|
| 485 |
vocab_size = dec_output.logits.size(-1)
|
| 486 |
dec_output.loss = F.cross_entropy(dec_output.logits.view(-1, vocab_size), labels.view(-1))
|
| 487 |
return dec_output, cache
|
| 488 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
cache.update(cache2)
|
| 494 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
|
| 496 |
def infer_with_cache(self, batch):
|
| 497 |
-
dec_output, _, cache = self
|
| 498 |
return dec_output, cache
|
| 499 |
|
| 500 |
def infer(self, batch):
|
| 501 |
dec_output, _ = self.infer_with_cache(batch)
|
| 502 |
return dec_output
|
| 503 |
|
| 504 |
-
def infer_with_feedback_BP(self, ling_disc, sem_emb, batch, tokenizer):
|
| 505 |
from torch.autograd import grad
|
| 506 |
interpolations = []
|
| 507 |
def line_search():
|
| 508 |
-
best_val = None
|
| 509 |
-
best_loss = None
|
| 510 |
eta = 1e3
|
| 511 |
sem_prob = 1
|
| 512 |
patience = 4
|
|
@@ -516,13 +634,11 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
| 516 |
new_loss, pred = get_loss(param_)
|
| 517 |
max_len = pred.shape[1]
|
| 518 |
lens = torch.where(pred == self.eos_token_id, 1, 0).argmax(-1) + 1
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
sem_prob = torch.sigmoid(sem_emb.compare_sem(**
|
| 524 |
-
# if sem_prob <= 0.1:
|
| 525 |
-
# patience -= 1
|
| 526 |
if new_loss < loss and sem_prob >= 0.90 and lens.item() > 1:
|
| 527 |
return param_
|
| 528 |
eta *= 2.25
|
|
@@ -531,7 +647,7 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
| 531 |
|
| 532 |
def get_loss(param):
|
| 533 |
if self.args.feedback_param == 'l':
|
| 534 |
-
batch.update({'
|
| 535 |
elif self.args.feedback_param == 's':
|
| 536 |
batch.update({'inputs_embeds': param})
|
| 537 |
|
|
@@ -539,8 +655,9 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
| 539 |
logits = param
|
| 540 |
pred = param.argmax(-1)
|
| 541 |
else:
|
| 542 |
-
|
| 543 |
-
|
|
|
|
| 544 |
out = ling_disc(logits = logits)
|
| 545 |
probs = F.softmax(out, 1)
|
| 546 |
if ling_disc.quant:
|
|
@@ -553,13 +670,13 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
| 553 |
ling2_embed = self.ling_embed(batch['sentence2_ling'])
|
| 554 |
param = torch.nn.Parameter(ling2_embed, requires_grad = True)
|
| 555 |
elif self.args.feedback_param == 's':
|
| 556 |
-
inputs_embeds = self.shared(batch['
|
| 557 |
param = torch.nn.Parameter(inputs_embeds, requires_grad = True)
|
| 558 |
elif self.args.feedback_param == 'logits':
|
| 559 |
logits = self.infer_with_cache(batch)[1]['scores']
|
| 560 |
param = torch.nn.Parameter(logits, requires_grad = True)
|
| 561 |
-
|
| 562 |
-
while
|
| 563 |
loss, pred = get_loss(param)
|
| 564 |
pred_text = tokenizer.batch_decode(pred.cpu().numpy(),
|
| 565 |
skip_special_tokens=True)[0]
|
|
@@ -571,6 +688,9 @@ class EncoderDecoderVAE(T5ForConditionalGeneration):
|
|
| 571 |
param = line_search()
|
| 572 |
if param is False:
|
| 573 |
break
|
|
|
|
|
|
|
|
|
|
| 574 |
return pred, [pred_text, interpolations]
|
| 575 |
|
| 576 |
def set_grad(module, state):
|
|
@@ -609,7 +729,7 @@ class LingDiscPipeline():
|
|
| 609 |
def __init__(self,
|
| 610 |
model_name="google/flan-t5-base",
|
| 611 |
disc_type='deberta',
|
| 612 |
-
disc_ckpt='/
|
| 613 |
# disc_type='t5',
|
| 614 |
# disc_ckpt='/data/mohamed/checkpoints/ling_conversion_ling_disc.pt',
|
| 615 |
):
|
|
@@ -629,15 +749,13 @@ def get_model(args, tokenizer, device):
|
|
| 629 |
ling_disc = LingDisc(args.model_name, args.disc_type, args.disc_model_path).to(device)
|
| 630 |
else:
|
| 631 |
ling_disc = None
|
| 632 |
-
if args.linggen_type != 'none':
|
| 633 |
-
ling_gen = LingGenerator(args).to(device)
|
| 634 |
|
| 635 |
-
if
|
| 636 |
model = EncoderDecoderVAE.from_pretrained(args.model_path, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device)
|
| 637 |
else:
|
| 638 |
-
model =
|
| 639 |
|
| 640 |
-
if args.sem_loss or args.
|
| 641 |
if args.sem_loss_type == 'shared':
|
| 642 |
sem_emb = model.encoder
|
| 643 |
elif args.sem_loss_type == 'dedicated':
|
|
@@ -649,3 +767,14 @@ def get_model(args, tokenizer, device):
|
|
| 649 |
|
| 650 |
return model, ling_disc, sem_emb
|
| 651 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
from utils import *
|
| 11 |
from ling_disc import DebertaReplacedTokenizer
|
| 12 |
from const import *
|
| 13 |
+
from lingconv_t5 import LingConvT5ForConditionalGeneration
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from transformers.modeling_outputs import Seq2SeqLMOutput
|
| 16 |
+
from typing import Optional, Dict, Any
|
| 17 |
|
| 18 |
|
| 19 |
|
|
|
|
| 81 |
bs = inputs_embeds.shape[0]
|
| 82 |
|
| 83 |
if self.gen_input == 's+l':
|
| 84 |
+
sentence1_ling = self.ling_embed(batch['sentence1_ling'])
|
| 85 |
+
sentence1_ling = sentence1_ling.view(bs, 1, -1)
|
| 86 |
+
inputs_embeds = inputs_embeds + sentence1_ling
|
| 87 |
|
| 88 |
gen = self.gen(inputs_embeds=inputs_embeds,
|
| 89 |
attention_mask=inputs_att_mask).last_hidden_state.mean(1)
|
|
|
|
| 189 |
nn.Linear(hidden_dim, 1))
|
| 190 |
|
| 191 |
def compare_sem(self, **batch):
|
| 192 |
+
bs = batch['attention_mask'].shape[0]
|
| 193 |
+
ones = torch.ones((bs, 1), device=batch['attention_mask'].device)
|
| 194 |
sep = torch.ones((bs, 1), dtype=torch.long,
|
| 195 |
+
device=batch['attention_mask'].device) * self.sep_token_id
|
| 196 |
+
att_mask = torch.cat([batch['attention_mask'], ones, batch['sentence2_attention_mask']], dim=1)
|
| 197 |
if 'logits' in batch:
|
| 198 |
+
input_ids = torch.cat([batch['input_ids'], sep], dim=1)
|
| 199 |
embeds1 = self.shared(input_ids)
|
| 200 |
|
| 201 |
logits = batch['logits']
|
|
|
|
| 205 |
|
| 206 |
embeds2 = onehot_ @ self.shared.weight
|
| 207 |
embeds1_2 = torch.cat([embeds1, embeds2], dim=1)
|
| 208 |
+
hidden_units = super().forward(inputs_embeds=embeds1_2,
|
| 209 |
attention_mask=att_mask).last_hidden_state.mean(1)
|
| 210 |
elif 'sentence2_input_ids' in batch:
|
| 211 |
+
input_ids = torch.cat([batch['input_ids'], sep, batch['sentence2_input_ids']], dim=1)
|
| 212 |
+
hidden_units = super().forward(input_ids=input_ids,
|
| 213 |
attention_mask=att_mask).last_hidden_state.mean(1)
|
| 214 |
probs = self.projection(hidden_units)
|
| 215 |
return probs
|
|
|
|
| 226 |
cross_attn_head_mask=None,
|
| 227 |
use_cache=None,
|
| 228 |
encoder_outputs=None,
|
| 229 |
+
sentence1_ling=None,
|
| 230 |
+
sentence2_ling=None,
|
| 231 |
**kwargs
|
| 232 |
):
|
|
|
|
| 233 |
# cut decoder_input_ids if past is used
|
| 234 |
if past_key_values is not None:
|
| 235 |
input_ids = input_ids[:, -1:]
|
| 236 |
|
| 237 |
+
cached = use_cache and len(past_key_values) > 0
|
| 238 |
+
|
| 239 |
input_ids = input_ids.clone()
|
| 240 |
decoder_inputs_embeds = self.shared(input_ids)
|
| 241 |
|
| 242 |
+
if combine_method == 'layer_injection':
|
| 243 |
+
# For layer injection, we'll pass the ling embeddings separately
|
| 244 |
+
ling_embed = sentence2_ling if ling2_only else (sentence1_ling + sentence2_ling)
|
| 245 |
+
elif combine_method == 'decoder_add_first' and not cached:
|
| 246 |
+
sentence2_ling = torch.cat([sentence2_ling,
|
| 247 |
+
torch.repeat_interleave(torch.zeros_like(sentence2_ling), input_ids.shape[1] - 1, dim=1)], dim = 1)
|
| 248 |
+
elif combine_method == 'decoder_concat':
|
| 249 |
if ling2_only:
|
| 250 |
+
decoder_inputs_embeds = torch.cat([sentence2_ling, decoder_inputs_embeds], dim=1)
|
| 251 |
else:
|
| 252 |
+
decoder_inputs_embeds = torch.cat([sentence1_ling, sentence2_ling, decoder_inputs_embeds], dim=1)
|
| 253 |
+
|
| 254 |
+
if combine_method == 'decoder_add' or (not cached and combine_method == 'decoder_add_first'):
|
| 255 |
if ling2_only:
|
| 256 |
+
decoder_inputs_embeds = decoder_inputs_embeds + sentence2_ling
|
| 257 |
else:
|
| 258 |
+
decoder_inputs_embeds = decoder_inputs_embeds + sentence1_ling + sentence2_ling
|
| 259 |
|
| 260 |
return {
|
| 261 |
"decoder_inputs_embeds": decoder_inputs_embeds,
|
|
|
|
| 266 |
"decoder_head_mask": decoder_head_mask,
|
| 267 |
"cross_attn_head_mask": cross_attn_head_mask,
|
| 268 |
"use_cache": use_cache,
|
| 269 |
+
"ling_embed": ling_embed if combine_method == 'layer_injection' else None,
|
| 270 |
}
|
| 271 |
|
| 272 |
class LogitsAdd(LogitsProcessor):
|
| 273 |
+
def __init__(self, sentence2_ling):
|
| 274 |
super().__init__()
|
| 275 |
+
self.sentence2_ling = sentence2_ling
|
| 276 |
|
| 277 |
def __call__(self, input_ids, scores):
|
| 278 |
+
return scores + self.sentence2_ling
|
| 279 |
|
| 280 |
+
class EncoderDecoderVAE(LingConvT5ForConditionalGeneration):
|
| 281 |
def __init__(self, config, args, pad_token_id, sepeos_token_id, vocab_size = 32128):
|
| 282 |
+
if args.combine_method == 'layer_injection':
|
| 283 |
+
if args.injection_layer < 0 or args.injection_layer >= config.num_decoder_layers:
|
| 284 |
+
raise ValueError(f"Invalid injection layer: {args.injection_layer}. Must be between 0 and {config.num_decoder_layers - 1}.")
|
| 285 |
+
config.ling_injection_layer = args.injection_layer
|
| 286 |
+
config.ling_injection_type = args.injection_type # 'first' or 'all'
|
| 287 |
+
|
| 288 |
super().__init__(config)
|
| 289 |
+
|
| 290 |
self.prepare_inputs_for_generation = types.MethodType(
|
| 291 |
partial(prepare_inputs_for_generation, args.combine_method, args.ling2_only),
|
| 292 |
self)
|
|
|
|
| 304 |
nn.ReLU(),
|
| 305 |
nn.Linear(hidden_dim, hidden_dim),
|
| 306 |
)
|
| 307 |
+
elif 'concat' in args.combine_method or 'add' in args.combine_method or 'layer_injection' in args.combine_method:
|
| 308 |
if args.ling_embed_type == 'two-layer':
|
| 309 |
self.ling_embed = nn.Sequential(
|
| 310 |
nn.Linear(args.lng_dim, args.lng_dim),
|
|
|
|
| 314 |
else:
|
| 315 |
self.ling_embed = nn.Linear(args.lng_dim, hidden_dim)
|
| 316 |
self.ling_dropout = nn.Dropout(args.ling_dropout)
|
| 317 |
+
self.ling_embed.apply(self._init_weights)
|
| 318 |
|
| 319 |
if args.ling_vae:
|
| 320 |
self.ling_mu = nn.Linear(hidden_dim, hidden_dim)
|
|
|
|
| 324 |
nn.init.xavier_uniform_(self.ling_logvar.weight)
|
| 325 |
|
| 326 |
|
| 327 |
+
generate_with_grad = unwrap(super().generate)
|
| 328 |
self.generate_with_grad = MethodType(generate_with_grad, self)
|
| 329 |
+
self.generate_original = super().generate
|
| 330 |
+
|
| 331 |
+
def _init_weights(self, module):
|
| 332 |
+
std = self.args.initializer_range
|
| 333 |
+
if isinstance(module, nn.Linear):
|
| 334 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 335 |
+
if module.bias is not None:
|
| 336 |
+
module.bias.data.zero_()
|
| 337 |
+
elif isinstance(module, nn.Embedding):
|
| 338 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 339 |
+
if module.padding_idx is not None:
|
| 340 |
+
module.weight.data[module.padding_idx].zero_()
|
| 341 |
|
| 342 |
def get_fusion_layer(self):
|
| 343 |
if 'fusion' in self.args.combine_method:
|
|
|
|
| 351 |
std = torch.exp(0.5 * logvar)
|
| 352 |
return mu + std * torch.randn_like(std)
|
| 353 |
|
| 354 |
+
def _process_ling_embeddings(self, sentence1_ling, sentence2_ling,
|
| 355 |
+
sentence1_ling_embed, sentence2_ling_embed, bs):
|
| 356 |
+
"""Helper method to process linguistic embeddings"""
|
| 357 |
+
cache = {}
|
| 358 |
+
|
| 359 |
+
# Process sentence1 embedding
|
| 360 |
+
if sentence1_ling_embed is not None:
|
| 361 |
+
sentence1_ling = sentence1_ling_embed
|
| 362 |
+
elif sentence1_ling is not None:
|
| 363 |
+
sentence1_ling = self.ling_embed(self.ling_dropout(sentence1_ling))
|
| 364 |
else:
|
| 365 |
+
sentence1_ling = None
|
| 366 |
+
|
| 367 |
+
# Process sentence2 embedding
|
| 368 |
+
if sentence2_ling_embed is not None:
|
| 369 |
+
sentence2_ling = sentence2_ling_embed
|
| 370 |
+
elif sentence2_ling is not None:
|
| 371 |
+
sentence2_ling = self.ling_embed(self.ling_dropout(sentence2_ling))
|
| 372 |
+
else:
|
| 373 |
+
sentence2_ling = None
|
| 374 |
+
|
| 375 |
+
# Apply VAE if configured
|
| 376 |
+
if self.args.ling_vae and sentence1_ling is not None and sentence2_ling is not None:
|
| 377 |
+
sentence1_ling = F.leaky_relu(sentence1_ling)
|
| 378 |
+
sent1_mu, sent1_logvar = self.ling_mu(sentence1_ling), self.ling_logvar(sentence1_ling)
|
| 379 |
+
sentence1_ling = self.sample(sent1_mu, sent1_logvar)
|
| 380 |
+
|
| 381 |
+
sentence2_ling = F.leaky_relu(sentence2_ling)
|
| 382 |
+
sent2_mu, sent2_logvar = self.ling_mu(sentence2_ling), self.ling_logvar(sentence2_ling)
|
| 383 |
+
sentence2_ling = self.sample(sent2_mu, sent2_logvar)
|
| 384 |
+
|
| 385 |
+
cache.update({
|
| 386 |
+
'sent1_mu': sent1_mu, 'sent1_logvar': sent1_logvar,
|
| 387 |
+
'sent2_mu': sent2_mu, 'sent2_logvar': sent2_logvar,
|
| 388 |
+
'sentence1_ling': sentence1_ling, 'sentence2_ling': sentence2_ling
|
| 389 |
+
})
|
| 390 |
+
else:
|
| 391 |
+
if sentence2_ling is not None:
|
| 392 |
+
cache['sentence2_ling'] = sentence2_ling
|
| 393 |
+
if sentence1_ling is not None:
|
| 394 |
+
cache['sentence1_ling'] = sentence1_ling
|
| 395 |
+
|
| 396 |
+
# Reshape embeddings
|
| 397 |
+
if sentence1_ling is not None:
|
| 398 |
+
sentence1_ling = sentence1_ling.view(bs, 1, -1)
|
| 399 |
+
if sentence2_ling is not None:
|
| 400 |
+
sentence2_ling = sentence2_ling.view(bs, 1, -1)
|
| 401 |
+
|
| 402 |
+
return sentence1_ling, sentence2_ling, cache
|
| 403 |
+
|
| 404 |
+
def encode(self,
|
| 405 |
+
input_ids=None,
|
| 406 |
+
attention_mask=None,
|
| 407 |
+
sentence1_ling=None,
|
| 408 |
+
sentence2_ling=None,
|
| 409 |
+
sentence1_ling_embed=None,
|
| 410 |
+
sentence2_ling_embed=None,
|
| 411 |
+
inputs_embeds=None,
|
| 412 |
+
):
|
| 413 |
+
if inputs_embeds is None:
|
| 414 |
+
inputs_embeds = self.shared(input_ids)
|
| 415 |
+
inputs_att_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids)
|
| 416 |
bs = inputs_embeds.shape[0]
|
| 417 |
+
|
| 418 |
if self.args.combine_method in ('input_concat', 'input_add'):
|
| 419 |
+
sentence1_ling, sentence2_ling, cache = self._process_ling_embeddings(
|
| 420 |
+
sentence1_ling, sentence2_ling,
|
| 421 |
+
sentence1_ling_embed, sentence2_ling_embed, bs
|
| 422 |
+
)
|
| 423 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
if self.args.combine_method == 'input_concat':
|
| 425 |
if self.args.ling2_only:
|
| 426 |
+
inputs_embeds = torch.cat([inputs_embeds, sentence2_ling], dim=1)
|
| 427 |
inputs_att_mask = torch.cat([inputs_att_mask,
|
| 428 |
torch.ones((bs, 1)).to(inputs_embeds.device)], dim=1)
|
| 429 |
else:
|
| 430 |
+
inputs_embeds = torch.cat([inputs_embeds, sentence1_ling, sentence2_ling], dim=1)
|
| 431 |
inputs_att_mask = torch.cat([inputs_att_mask,
|
| 432 |
torch.ones((bs, 2)).to(inputs_embeds.device)], dim=1)
|
| 433 |
elif self.args.combine_method == 'input_add':
|
| 434 |
if self.args.ling2_only:
|
| 435 |
+
inputs_embeds = inputs_embeds + sentence2_ling
|
| 436 |
else:
|
| 437 |
+
inputs_embeds = inputs_embeds + sentence1_ling + sentence2_ling
|
| 438 |
+
else:
|
| 439 |
+
cache = {}
|
| 440 |
+
|
| 441 |
return self.encoder(inputs_embeds=inputs_embeds,
|
| 442 |
attention_mask=inputs_att_mask), inputs_att_mask, cache
|
| 443 |
|
| 444 |
+
def decode(self,
|
| 445 |
+
sentence2_input_ids=None,
|
| 446 |
+
sentence1_ling=None,
|
| 447 |
+
sentence2_ling=None,
|
| 448 |
+
encoder_outputs=None,
|
| 449 |
+
encoder_attention_mask=None,
|
| 450 |
+
decoder_inputs_embeds=None,
|
| 451 |
+
decoder_attention_mask=None,
|
| 452 |
+
generate=False,
|
| 453 |
+
sentence1_ling_embed=None,
|
| 454 |
+
sentence2_ling_embed=None,
|
| 455 |
+
ling_embed=None,
|
| 456 |
+
generate_with_grad=False,
|
| 457 |
+
**kwargs
|
| 458 |
+
):
|
| 459 |
+
bs = encoder_outputs[0].shape[0]
|
| 460 |
cache = {}
|
| 461 |
+
|
| 462 |
+
if decoder_inputs_embeds is None:
|
| 463 |
+
if self.args.combine_method in ('embed_concat', 'decoder_concat', 'decoder_add',
|
| 464 |
+
'logits_add', 'decoder_add_first', 'layer_injection'):
|
| 465 |
+
sentence1_ling, sentence2_ling, cache = self._process_ling_embeddings(
|
| 466 |
+
sentence1_ling, sentence2_ling,
|
| 467 |
+
sentence1_ling_embed, sentence2_ling_embed, bs
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
if (self.args.combine_method == 'decoder_add_first' or
|
| 471 |
+
(self.args.combine_method == 'layer_injection' and
|
| 472 |
+
self.args.injection_type == 'first')) and not generate:
|
| 473 |
+
sentence2_ling = torch.cat([sentence2_ling,
|
| 474 |
+
torch.repeat_interleave(torch.zeros_like(sentence2_ling),
|
| 475 |
+
sentence2_input_ids.shape[1] - 1, dim=1)], dim = 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
else:
|
| 477 |
+
sentence1_ling, sentence2_ling = None, None
|
| 478 |
+
|
|
|
|
| 479 |
if generate:
|
| 480 |
if self.args.combine_method == 'logits_add':
|
| 481 |
+
logits_processor = LogitsProcessorList([LogitsAdd(sentence2_ling.view(bs, -1))])
|
| 482 |
else:
|
| 483 |
logits_processor = LogitsProcessorList()
|
| 484 |
|
| 485 |
+
generate_fn = self.generate_with_grad if generate_with_grad else self.generate_original
|
| 486 |
+
dec_output = generate_fn(
|
| 487 |
+
attention_mask=encoder_attention_mask,
|
| 488 |
+
encoder_outputs=encoder_outputs,
|
| 489 |
+
sentence1_ling=sentence1_ling,
|
| 490 |
+
sentence2_ling=sentence2_ling,
|
|
|
|
| 491 |
logits_processor = logits_processor,
|
| 492 |
# renormalize_logits=True,
|
| 493 |
# do_sample=True,
|
|
|
|
| 496 |
# min_new_tokens=3,
|
| 497 |
# repetition_penalty=1.2,
|
| 498 |
max_length=self.args.max_length,
|
| 499 |
+
use_cache=True,
|
| 500 |
+
**kwargs
|
| 501 |
)
|
| 502 |
+
return dec_output, cache
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
|
| 504 |
+
if sentence2_input_ids is not None:
|
| 505 |
+
labels = sentence2_input_ids.clone()
|
| 506 |
+
labels[labels == self.pad_token_id] = -100
|
| 507 |
+
else:
|
| 508 |
+
labels = None
|
| 509 |
+
|
| 510 |
+
if decoder_inputs_embeds is None:
|
| 511 |
+
decoder_input_ids = self._shift_right(sentence2_input_ids)
|
| 512 |
+
decoder_inputs_embeds = self.shared(decoder_input_ids)
|
| 513 |
+
|
| 514 |
+
if self.args.combine_method == 'decoder_concat':
|
| 515 |
+
if self.args.ling2_only:
|
| 516 |
+
decoder_inputs_embeds = torch.cat([sentence2_ling, decoder_inputs_embeds], dim=1)
|
| 517 |
+
decoder_attention_mask = torch.cat([torch.ones((bs, 1)).to(decoder_inputs_embeds.device), decoder_attention_mask], dim=1)
|
| 518 |
+
labels = torch.cat([torch.ones((bs, 1), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
|
| 519 |
+
labels], dim=1)
|
| 520 |
+
else:
|
| 521 |
+
decoder_inputs_embeds = torch.cat([sentence1_ling, sentence2_ling, decoder_inputs_embeds], dim=1)
|
| 522 |
+
decoder_attention_mask = torch.cat([torch.ones((bs, 2)).to(decoder_inputs_embeds.device), decoder_attention_mask], dim=1)
|
| 523 |
+
labels = torch.cat([torch.ones((bs, 2), dtype=torch.int64).to(decoder_inputs_embeds.device) * self.pad_token_id,
|
| 524 |
+
labels], dim=1)
|
| 525 |
+
elif self.args.combine_method == 'decoder_add' or self.args.combine_method == 'decoder_add_first' :
|
| 526 |
+
if self.args.ling2_only:
|
| 527 |
+
decoder_inputs_embeds = decoder_inputs_embeds + self.args.combine_weight * sentence2_ling
|
| 528 |
+
else:
|
| 529 |
+
decoder_inputs_embeds = decoder_inputs_embeds + sentence1_ling + sentence2_ling
|
| 530 |
+
|
| 531 |
+
if ling_embed is None:
|
| 532 |
+
ling_embed = sentence2_ling
|
| 533 |
+
|
| 534 |
+
dec_output = super().forward(
|
| 535 |
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 536 |
+
decoder_attention_mask=decoder_attention_mask,
|
| 537 |
+
encoder_outputs=encoder_outputs,
|
| 538 |
+
attention_mask=encoder_attention_mask,
|
| 539 |
labels=labels,
|
| 540 |
+
ling_embed=ling_embed,
|
| 541 |
+
**kwargs
|
| 542 |
)
|
| 543 |
if self.args.combine_method == 'logits_add':
|
| 544 |
+
dec_output.logits = dec_output.logits + self.args.combine_weight * sentence2_ling
|
| 545 |
vocab_size = dec_output.logits.size(-1)
|
| 546 |
dec_output.loss = F.cross_entropy(dec_output.logits.view(-1, vocab_size), labels.view(-1))
|
| 547 |
return dec_output, cache
|
| 548 |
|
| 549 |
+
def generate(self, *args, **kwargs):
|
| 550 |
+
return self.forward(*args, **kwargs, generate=True)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def forward(self,
|
| 554 |
+
input_ids=None,
|
| 555 |
+
attention_mask=None,
|
| 556 |
+
labels=None,
|
| 557 |
+
decoder_attention_mask=None,
|
| 558 |
+
decoder_inputs_embeds=None,
|
| 559 |
+
sentence1_ling=None,
|
| 560 |
+
sentence2_ling=None,
|
| 561 |
+
sentence1_ling_embed=None,
|
| 562 |
+
sentence2_ling_embed=None,
|
| 563 |
+
inputs_embeds=None,
|
| 564 |
+
generate=False,
|
| 565 |
+
encoder_outputs=None,
|
| 566 |
+
encoder_attention_mask=None,
|
| 567 |
+
ling_embed=None,
|
| 568 |
+
generate_with_grad=False,
|
| 569 |
+
**kwargs):
|
| 570 |
|
| 571 |
+
cache = {}
|
| 572 |
+
if encoder_outputs is None:
|
| 573 |
+
encoder_outputs, encoder_attention_mask, cache = self.encode(
|
| 574 |
+
input_ids=input_ids,
|
| 575 |
+
attention_mask=attention_mask,
|
| 576 |
+
sentence1_ling=sentence1_ling,
|
| 577 |
+
sentence2_ling=sentence2_ling,
|
| 578 |
+
sentence1_ling_embed=sentence1_ling_embed,
|
| 579 |
+
sentence2_ling_embed=sentence2_ling_embed,
|
| 580 |
+
inputs_embeds=inputs_embeds
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
dec_output, cache2 = self.decode(
|
| 584 |
+
sentence2_input_ids=labels,
|
| 585 |
+
sentence1_ling=sentence1_ling,
|
| 586 |
+
sentence2_ling=sentence2_ling,
|
| 587 |
+
decoder_inputs_embeds=decoder_inputs_embeds,
|
| 588 |
+
decoder_attention_mask=decoder_attention_mask,
|
| 589 |
+
encoder_outputs=encoder_outputs,
|
| 590 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 591 |
+
generate=generate,
|
| 592 |
+
sentence1_ling_embed=sentence1_ling_embed,
|
| 593 |
+
sentence2_ling_embed=sentence2_ling_embed,
|
| 594 |
+
ling_embed=ling_embed,
|
| 595 |
+
generate_with_grad=generate_with_grad,
|
| 596 |
+
**kwargs
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
cache.update(cache2)
|
| 600 |
+
if generate:
|
| 601 |
+
return dec_output
|
| 602 |
+
else:
|
| 603 |
+
return MySeq2SeqLMOutput(
|
| 604 |
+
loss=dec_output.loss,
|
| 605 |
+
logits=dec_output.logits,
|
| 606 |
+
past_key_values=dec_output.past_key_values,
|
| 607 |
+
decoder_hidden_states=dec_output.decoder_hidden_states,
|
| 608 |
+
decoder_attentions=dec_output.decoder_attentions,
|
| 609 |
+
cross_attentions=dec_output.cross_attentions,
|
| 610 |
+
encoder_last_hidden_state=encoder_outputs[0],
|
| 611 |
+
encoder_hidden_states=getattr(encoder_outputs, 'hidden_states', None),
|
| 612 |
+
encoder_attentions=getattr(encoder_outputs, 'attentions', None),
|
| 613 |
+
cache=cache
|
| 614 |
+
)
|
| 615 |
|
| 616 |
def infer_with_cache(self, batch):
|
| 617 |
+
dec_output, _, cache = self(batch, generate = True)
|
| 618 |
return dec_output, cache
|
| 619 |
|
| 620 |
def infer(self, batch):
|
| 621 |
dec_output, _ = self.infer_with_cache(batch)
|
| 622 |
return dec_output
|
| 623 |
|
| 624 |
+
def infer_with_feedback_BP(self, ling_disc, sem_emb, batch, tokenizer, progress=None):
|
| 625 |
from torch.autograd import grad
|
| 626 |
interpolations = []
|
| 627 |
def line_search():
|
|
|
|
|
|
|
| 628 |
eta = 1e3
|
| 629 |
sem_prob = 1
|
| 630 |
patience = 4
|
|
|
|
| 634 |
new_loss, pred = get_loss(param_)
|
| 635 |
max_len = pred.shape[1]
|
| 636 |
lens = torch.where(pred == self.eos_token_id, 1, 0).argmax(-1) + 1
|
| 637 |
+
sem_batch = {**batch,
|
| 638 |
+
'sentence2_input_ids': pred,
|
| 639 |
+
'sentence2_attention_mask': sequence_mask(lens, max_len = max_len)
|
| 640 |
+
}
|
| 641 |
+
sem_prob = torch.sigmoid(sem_emb.compare_sem(**sem_batch)).item()
|
|
|
|
|
|
|
| 642 |
if new_loss < loss and sem_prob >= 0.90 and lens.item() > 1:
|
| 643 |
return param_
|
| 644 |
eta *= 2.25
|
|
|
|
| 647 |
|
| 648 |
def get_loss(param):
|
| 649 |
if self.args.feedback_param == 'l':
|
| 650 |
+
batch.update({'sentence2_ling_embed': param})
|
| 651 |
elif self.args.feedback_param == 's':
|
| 652 |
batch.update({'inputs_embeds': param})
|
| 653 |
|
|
|
|
| 655 |
logits = param
|
| 656 |
pred = param.argmax(-1)
|
| 657 |
else:
|
| 658 |
+
outputs = self.generate(**batch, output_scores=True, return_dict_in_generate=True, generate_with_grad=True)
|
| 659 |
+
pred = outputs.sequences
|
| 660 |
+
logits = torch.stack(outputs.scores, dim=1)
|
| 661 |
out = ling_disc(logits = logits)
|
| 662 |
probs = F.softmax(out, 1)
|
| 663 |
if ling_disc.quant:
|
|
|
|
| 670 |
ling2_embed = self.ling_embed(batch['sentence2_ling'])
|
| 671 |
param = torch.nn.Parameter(ling2_embed, requires_grad = True)
|
| 672 |
elif self.args.feedback_param == 's':
|
| 673 |
+
inputs_embeds = self.shared(batch['input_ids'])
|
| 674 |
param = torch.nn.Parameter(inputs_embeds, requires_grad = True)
|
| 675 |
elif self.args.feedback_param == 'logits':
|
| 676 |
logits = self.infer_with_cache(batch)[1]['scores']
|
| 677 |
param = torch.nn.Parameter(logits, requires_grad = True)
|
| 678 |
+
num_iter = 0
|
| 679 |
+
while num_iter < 3:
|
| 680 |
loss, pred = get_loss(param)
|
| 681 |
pred_text = tokenizer.batch_decode(pred.cpu().numpy(),
|
| 682 |
skip_special_tokens=True)[0]
|
|
|
|
| 688 |
param = line_search()
|
| 689 |
if param is False:
|
| 690 |
break
|
| 691 |
+
num_iter += 1
|
| 692 |
+
if progress is not None:
|
| 693 |
+
progress((num_iter, None), unit='intermediate paraphrase generated.')
|
| 694 |
return pred, [pred_text, interpolations]
|
| 695 |
|
| 696 |
def set_grad(module, state):
|
|
|
|
| 729 |
def __init__(self,
|
| 730 |
model_name="google/flan-t5-base",
|
| 731 |
disc_type='deberta',
|
| 732 |
+
disc_ckpt='mohdelgaar/lingconv-discriminator',
|
| 733 |
# disc_type='t5',
|
| 734 |
# disc_ckpt='/data/mohamed/checkpoints/ling_conversion_ling_disc.pt',
|
| 735 |
):
|
|
|
|
| 749 |
ling_disc = LingDisc(args.model_name, args.disc_type, args.disc_model_path).to(device)
|
| 750 |
else:
|
| 751 |
ling_disc = None
|
|
|
|
|
|
|
| 752 |
|
| 753 |
+
if args.model_path:
|
| 754 |
model = EncoderDecoderVAE.from_pretrained(args.model_path, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device)
|
| 755 |
else:
|
| 756 |
+
model = EncoderDecoderVAE.from_pretrained(args.model_name, args, tokenizer.pad_token_id, tokenizer.eos_token_id).to(device)
|
| 757 |
|
| 758 |
+
if args.sem_loss or args.model_path:
|
| 759 |
if args.sem_loss_type == 'shared':
|
| 760 |
sem_emb = model.encoder
|
| 761 |
elif args.sem_loss_type == 'dedicated':
|
|
|
|
| 767 |
|
| 768 |
return model, ling_disc, sem_emb
|
| 769 |
|
| 770 |
+
@dataclass
|
| 771 |
+
class MySeq2SeqLMOutput(Seq2SeqLMOutput):
|
| 772 |
+
"""
|
| 773 |
+
Extends Seq2SeqLMOutput to include a cache dictionary for additional model outputs.
|
| 774 |
+
|
| 775 |
+
Args:
|
| 776 |
+
cache (`Dict[str, Any]`):
|
| 777 |
+
Dictionary containing additional model outputs like linguistic features,
|
| 778 |
+
VAE parameters, scores, etc.
|
| 779 |
+
"""
|
| 780 |
+
cache: Optional[Dict[str, Any]] = None
|
options.py
CHANGED
|
@@ -1,16 +1,28 @@
|
|
| 1 |
-
import os, json
|
| 2 |
import argparse
|
| 3 |
-
import numpy as np
|
| 4 |
from datetime import datetime
|
| 5 |
from const import lftkplus_names
|
|
|
|
| 6 |
from copy import deepcopy
|
|
|
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
def parse_args(ckpt=None):
|
| 10 |
parser = argparse.ArgumentParser()
|
|
|
|
|
|
|
|
|
|
| 11 |
parser.add_argument('--data_dir', default='/data/mohamed/data')
|
| 12 |
parser.add_argument('--data', default='ling_conversion')
|
| 13 |
-
parser.add_argument('--data_sources')
|
| 14 |
parser.add_argument('--data_type', default='text')
|
| 15 |
parser.add_argument('--aim_repo', default='/data/mohamed/')
|
| 16 |
parser.add_argument('--ckpt_dir', default='/data/mohamed/checkpoints')
|
|
@@ -25,7 +37,7 @@ def parse_args(ckpt=None):
|
|
| 25 |
parser.add_argument('--sem_loss_tao', default=0.5, type=float)
|
| 26 |
parser.add_argument('--sem_loss_eps', default=1, type=float)
|
| 27 |
parser.add_argument('--ckpt')
|
| 28 |
-
parser.add_argument('--disc_ckpt')
|
| 29 |
parser.add_argument('--sem_ckpt')
|
| 30 |
parser.add_argument('--lng_ids')
|
| 31 |
parser.add_argument('--lng_ids_idx', type=int)
|
|
@@ -36,30 +48,34 @@ def parse_args(ckpt=None):
|
|
| 36 |
parser.add_argument('--sem_path', default="mohdelgaar/lingconv-semantic-classifier")
|
| 37 |
parser.add_argument('--sem_model_path', default="mohdelgaar/lingconv-semantic-classifier")
|
| 38 |
parser.add_argument('--disc_model_path', default="mohdelgaar/lingconv-discriminator")
|
| 39 |
-
parser.add_argument('--disc_type', default="
|
| 40 |
-
parser.add_argument('--aim_exp', default='
|
| 41 |
parser.add_argument('--sem_loss_type', default='dedicated')
|
| 42 |
-
parser.add_argument('--combine_method', default='
|
|
|
|
|
|
|
| 43 |
parser.add_argument('--train_log', type=int, default=200)
|
| 44 |
-
parser.add_argument('--val_log', type=int, default=
|
| 45 |
-
parser.add_argument('--
|
| 46 |
-
parser.add_argument('--
|
| 47 |
-
parser.add_argument('--
|
| 48 |
-
parser.add_argument('--
|
|
|
|
| 49 |
parser.add_argument('--hidden_dim', type=int, default=500)
|
| 50 |
parser.add_argument('--latent_dim', type=int, default=150)
|
| 51 |
parser.add_argument('--lng_dim', type=int, default=40)
|
| 52 |
-
parser.add_argument('--disc_lng_dim', type=int)
|
| 53 |
parser.add_argument('--use_lora', action='store_true')
|
| 54 |
parser.add_argument('--lora_r', type=int, default=64)
|
| 55 |
parser.add_argument('--gpu', type=str, default='0')
|
| 56 |
-
parser.add_argument('--epochs', type=int, default=
|
| 57 |
parser.add_argument('--grad_accumulation', type=int, default=1)
|
| 58 |
parser.add_argument('--n_ica', type=int, default=10)
|
| 59 |
parser.add_argument('--max_length', type=int, default=200)
|
| 60 |
parser.add_argument('--total_steps', type=int)
|
| 61 |
parser.add_argument('--kld_const', type=float, default=1)
|
| 62 |
-
parser.add_argument('--lr', type=float, default=1e-
|
|
|
|
| 63 |
parser.add_argument('--kl_weight', type=float, default=1e-1)
|
| 64 |
parser.add_argument('--weight_decay', type=float, default=1e-2)
|
| 65 |
parser.add_argument('--ling_dropout', type=float, default=0.1)
|
|
@@ -71,12 +87,12 @@ def parse_args(ckpt=None):
|
|
| 71 |
parser.add_argument('--pretrain_disc', action='store_true')
|
| 72 |
parser.add_argument('--linggen_type', default='none')
|
| 73 |
parser.add_argument('--linggen_input', default='s+l')
|
| 74 |
-
parser.add_argument(
|
| 75 |
parser.add_argument('--ling_vae', action='store_true')
|
| 76 |
parser.add_argument('--process_lingpred', action='store_true')
|
| 77 |
parser.add_argument('--fudge_lambda', type=float, default=1.0)
|
| 78 |
parser.add_argument('--use_lingpred', action='store_true')
|
| 79 |
-
parser.add_argument('--ling2_only', action='store_true')
|
| 80 |
parser.add_argument('--cycle_loss', action='store_true')
|
| 81 |
parser.add_argument('--disc_loss', action='store_true')
|
| 82 |
parser.add_argument('--sem_loss', action='store_true')
|
|
@@ -96,19 +112,36 @@ def parse_args(ckpt=None):
|
|
| 96 |
parser.add_argument('--quant_nbins', type=int, default=20)
|
| 97 |
parser.add_argument('--src_lng', default = 'ling')
|
| 98 |
parser.add_argument('--to_restore', nargs='+', default=[])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
# args = parser.parse_args()
|
| 100 |
args, unknown = parser.parse_known_args()
|
| 101 |
args.name = f'{datetime.now().strftime("%m%d_%H-%M-%S")}-{args.data}-{args.combine_method}'
|
| 102 |
|
| 103 |
major_arg = args.major_arg
|
| 104 |
to_restore = [
|
|
|
|
|
|
|
|
|
|
| 105 |
] + args.to_restore
|
| 106 |
to_restore = {k: args.__dict__[k] for k in to_restore}
|
| 107 |
|
| 108 |
if not args.disc_loss or args.disc_ckpt:
|
| 109 |
args.disc_steps = 0
|
| 110 |
|
| 111 |
-
if args.data_sources
|
|
|
|
|
|
|
| 112 |
args.data_sources = args.data_sources.split(',')
|
| 113 |
|
| 114 |
if ckpt is not None:
|
|
@@ -120,13 +153,17 @@ def parse_args(ckpt=None):
|
|
| 120 |
ckpts = args.ckpt.split(',')
|
| 121 |
args_list = [deepcopy(args) for _ in range(len(ckpts))]
|
| 122 |
for i in range(len(ckpts)):
|
| 123 |
-
args_path = ckpts[i].replace('_best', '').replace('.pt', '.json'
|
| 124 |
with open(args_path) as f:
|
| 125 |
args_list[i].__dict__.update(json.load(f))
|
| 126 |
args_list[i].__dict__.update(to_restore)
|
| 127 |
args_list[i].ckpt = ckpts[i]
|
| 128 |
else:
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
ckpt = args.ckpt
|
| 131 |
with open(args_path) as f:
|
| 132 |
args.__dict__.update(json.load(f))
|
|
|
|
|
|
|
| 1 |
import argparse
|
|
|
|
| 2 |
from datetime import datetime
|
| 3 |
from const import lftkplus_names
|
| 4 |
+
import os, json
|
| 5 |
from copy import deepcopy
|
| 6 |
+
import numpy as np
|
| 7 |
|
| 8 |
+
def str2bool(v):
|
| 9 |
+
if isinstance(v, bool):
|
| 10 |
+
return v
|
| 11 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 12 |
+
return True
|
| 13 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 14 |
+
return False
|
| 15 |
+
else:
|
| 16 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 17 |
|
| 18 |
def parse_args(ckpt=None):
|
| 19 |
parser = argparse.ArgumentParser()
|
| 20 |
+
parser.add_argument('--do_train', action='store_true')
|
| 21 |
+
parser.add_argument('--do_eval', action='store_true')
|
| 22 |
+
parser.add_argument('--do_predict', action='store_true')
|
| 23 |
parser.add_argument('--data_dir', default='/data/mohamed/data')
|
| 24 |
parser.add_argument('--data', default='ling_conversion')
|
| 25 |
+
parser.add_argument('--data_sources', default='qqp,mrpc,stsb')
|
| 26 |
parser.add_argument('--data_type', default='text')
|
| 27 |
parser.add_argument('--aim_repo', default='/data/mohamed/')
|
| 28 |
parser.add_argument('--ckpt_dir', default='/data/mohamed/checkpoints')
|
|
|
|
| 37 |
parser.add_argument('--sem_loss_tao', default=0.5, type=float)
|
| 38 |
parser.add_argument('--sem_loss_eps', default=1, type=float)
|
| 39 |
parser.add_argument('--ckpt')
|
| 40 |
+
parser.add_argument('--disc_ckpt', default='mohdelgaar/lingconv-discriminator')
|
| 41 |
parser.add_argument('--sem_ckpt')
|
| 42 |
parser.add_argument('--lng_ids')
|
| 43 |
parser.add_argument('--lng_ids_idx', type=int)
|
|
|
|
| 48 |
parser.add_argument('--sem_path', default="mohdelgaar/lingconv-semantic-classifier")
|
| 49 |
parser.add_argument('--sem_model_path', default="mohdelgaar/lingconv-semantic-classifier")
|
| 50 |
parser.add_argument('--disc_model_path', default="mohdelgaar/lingconv-discriminator")
|
| 51 |
+
parser.add_argument('--disc_type', default="deberta")
|
| 52 |
+
parser.add_argument('--aim_exp', default='lingconv-1201')
|
| 53 |
parser.add_argument('--sem_loss_type', default='dedicated')
|
| 54 |
+
parser.add_argument('--combine_method', default='decoder_add_first')
|
| 55 |
+
parser.add_argument('--injection_type', default='first')
|
| 56 |
+
parser.add_argument('--injection_layer', type=int, default=1)
|
| 57 |
parser.add_argument('--train_log', type=int, default=200)
|
| 58 |
+
parser.add_argument('--val_log', type=int, default=1000)
|
| 59 |
+
parser.add_argument('--warmup_steps', type=int, default=1000)
|
| 60 |
+
parser.add_argument('--batch_size', type=int, default=16)
|
| 61 |
+
parser.add_argument('--eval_batch_size', type=int, default=256)
|
| 62 |
+
parser.add_argument('--max_eval_samples', type=int, default=3000)
|
| 63 |
+
parser.add_argument('--test_batch_size', type=int, default=256)
|
| 64 |
parser.add_argument('--hidden_dim', type=int, default=500)
|
| 65 |
parser.add_argument('--latent_dim', type=int, default=150)
|
| 66 |
parser.add_argument('--lng_dim', type=int, default=40)
|
| 67 |
+
parser.add_argument('--disc_lng_dim', type=int, default=40)
|
| 68 |
parser.add_argument('--use_lora', action='store_true')
|
| 69 |
parser.add_argument('--lora_r', type=int, default=64)
|
| 70 |
parser.add_argument('--gpu', type=str, default='0')
|
| 71 |
+
parser.add_argument('--epochs', type=int, default=2)
|
| 72 |
parser.add_argument('--grad_accumulation', type=int, default=1)
|
| 73 |
parser.add_argument('--n_ica', type=int, default=10)
|
| 74 |
parser.add_argument('--max_length', type=int, default=200)
|
| 75 |
parser.add_argument('--total_steps', type=int)
|
| 76 |
parser.add_argument('--kld_const', type=float, default=1)
|
| 77 |
+
parser.add_argument('--lr', type=float, default=1e-3)
|
| 78 |
+
parser.add_argument('--initializer_range', type=float, default=0.02)
|
| 79 |
parser.add_argument('--kl_weight', type=float, default=1e-1)
|
| 80 |
parser.add_argument('--weight_decay', type=float, default=1e-2)
|
| 81 |
parser.add_argument('--ling_dropout', type=float, default=0.1)
|
|
|
|
| 87 |
parser.add_argument('--pretrain_disc', action='store_true')
|
| 88 |
parser.add_argument('--linggen_type', default='none')
|
| 89 |
parser.add_argument('--linggen_input', default='s+l')
|
| 90 |
+
parser.add_argument("--aug_same", type=str2bool, nargs='?', const=True, default=False)
|
| 91 |
parser.add_argument('--ling_vae', action='store_true')
|
| 92 |
parser.add_argument('--process_lingpred', action='store_true')
|
| 93 |
parser.add_argument('--fudge_lambda', type=float, default=1.0)
|
| 94 |
parser.add_argument('--use_lingpred', action='store_true')
|
| 95 |
+
parser.add_argument('--ling2_only', action='store_true', default=True)
|
| 96 |
parser.add_argument('--cycle_loss', action='store_true')
|
| 97 |
parser.add_argument('--disc_loss', action='store_true')
|
| 98 |
parser.add_argument('--sem_loss', action='store_true')
|
|
|
|
| 112 |
parser.add_argument('--quant_nbins', type=int, default=20)
|
| 113 |
parser.add_argument('--src_lng', default = 'ling')
|
| 114 |
parser.add_argument('--to_restore', nargs='+', default=[])
|
| 115 |
+
parser.add_argument('--freeze_lm', action='store_true',
|
| 116 |
+
help='Freeze the language model and only train the linguistic embedding')
|
| 117 |
+
parser.add_argument('--prepend_prompt', action='store_true',
|
| 118 |
+
help='Prepend "generate a paraphrase: " to input text')
|
| 119 |
+
parser.add_argument('--prompt_text', type=str, default="generate a paraphrase: ",
|
| 120 |
+
help='Text to prepend to input if prepend_prompt is True')
|
| 121 |
+
parser.add_argument('--do_imputation', action='store_true',
|
| 122 |
+
help='Whether to perform imputation on linguistic features')
|
| 123 |
+
parser.add_argument('--imputation_percentage', type=int, default=20,
|
| 124 |
+
help='Percentage of features to impute (20, 40, 60, 80)')
|
| 125 |
+
parser.add_argument('--imputation_seed', type=int, default=0,
|
| 126 |
+
help='Seed for imputation set selection (0, 1, 2)')
|
| 127 |
# args = parser.parse_args()
|
| 128 |
args, unknown = parser.parse_known_args()
|
| 129 |
args.name = f'{datetime.now().strftime("%m%d_%H-%M-%S")}-{args.data}-{args.combine_method}'
|
| 130 |
|
| 131 |
major_arg = args.major_arg
|
| 132 |
to_restore = [
|
| 133 |
+
'total_steps','major_arg','gpu','demo', 'eval_only', 'save_predict', 'predict_fn', 'fudge', 'predict_with_feedback',
|
| 134 |
+
'feedback_param', 'fb_log', 'data_dir', 'data', 'disc_ckpt', 'disc_type', 'sem_ckpt', 'fudge_lambda', 'eval_batch_size', 'test_batch_size', 'max_eval_samples',
|
| 135 |
+
'do_train', 'do_eval', 'do_predict',
|
| 136 |
] + args.to_restore
|
| 137 |
to_restore = {k: args.__dict__[k] for k in to_restore}
|
| 138 |
|
| 139 |
if not args.disc_loss or args.disc_ckpt:
|
| 140 |
args.disc_steps = 0
|
| 141 |
|
| 142 |
+
if args.data_sources == 'all':
|
| 143 |
+
args.data_sources = None
|
| 144 |
+
elif args.data_sources is not None:
|
| 145 |
args.data_sources = args.data_sources.split(',')
|
| 146 |
|
| 147 |
if ckpt is not None:
|
|
|
|
| 153 |
ckpts = args.ckpt.split(',')
|
| 154 |
args_list = [deepcopy(args) for _ in range(len(ckpts))]
|
| 155 |
for i in range(len(ckpts)):
|
| 156 |
+
args_path = ckpts[i].replace('_best', '').replace('.pt', '') + '.json'
|
| 157 |
with open(args_path) as f:
|
| 158 |
args_list[i].__dict__.update(json.load(f))
|
| 159 |
args_list[i].__dict__.update(to_restore)
|
| 160 |
args_list[i].ckpt = ckpts[i]
|
| 161 |
else:
|
| 162 |
+
args.ckpt = args.ckpt.rstrip('/')
|
| 163 |
+
if 'checkpoint-' in args.ckpt:
|
| 164 |
+
args_path = os.path.dirname(args.ckpt) + '.json'
|
| 165 |
+
else:
|
| 166 |
+
args_path = args.ckpt.replace('.pt', '') + '.json'
|
| 167 |
ckpt = args.ckpt
|
| 168 |
with open(args_path) as f:
|
| 169 |
args.__dict__.update(json.load(f))
|