BERTNN commited on
Commit ·
6218f6c
1
Parent(s): b4537c6
Upload predefined_bertnn.py
Browse files- predefined_bertnn.py +11 -9
predefined_bertnn.py
CHANGED
|
@@ -471,9 +471,16 @@ def gen_new(Identity,Behavior,Modifier,n_df,word_type):
|
|
| 471 |
ys= torch.tensor(values)
|
| 472 |
inputs, masks = preprocessing_for_bert([sents])
|
| 473 |
yield inputs, masks, ys,indexx #torch.tensor(sents),
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 477 |
|
| 478 |
def gen_new(Identity,Behavior,Modifier,n_df,word_type):
|
| 479 |
|
|
@@ -526,12 +533,7 @@ def gen_alt(Identity,Behavior,Modifier,n_df,word_type):
|
|
| 526 |
|
| 527 |
yield inputs, masks, ys,indexx
|
| 528 |
|
| 529 |
-
|
| 530 |
-
if alt:
|
| 531 |
-
dt_ldr= [x for x in DataLoader([next(gen_alt(I,B,M,N_df,WT)) for x in range(batch_size)], batch_size=batch_size)][0]
|
| 532 |
-
else:
|
| 533 |
-
dt_ldr= [x for x in DataLoader([next(gen_new(I,B,M,N_df,WT)) for x in range(batch_size)], batch_size=batch_size)][0]
|
| 534 |
-
return(dt_ldr)
|
| 535 |
|
| 536 |
cols=['EEMA', 'EPMA', 'EAMA', 'EEA', 'EPA', 'EAA', 'EEB', 'EPB', 'EAB',
|
| 537 |
'EEMO', 'EPMO', 'EAMO', 'EEO', 'EPO', 'EAO', 'ModA', 'Actor', 'Behavior', 'ModO', 'Object']
|
|
|
|
| 471 |
ys= torch.tensor(values)
|
| 472 |
inputs, masks = preprocessing_for_bert([sents])
|
| 473 |
yield inputs, masks, ys,indexx #torch.tensor(sents),
|
| 474 |
+
|
| 475 |
+
def ldr_new(I,B,M,N_df,WT,batch_size=32,alt=0):
|
| 476 |
+
if alt:
|
| 477 |
+
dt_ldr= [x for x in DataLoader([next(gen_alt(I,B,M,N_df,WT)) for x in range(batch_size)], batch_size=batch_size)][0]
|
| 478 |
+
else:
|
| 479 |
+
dt_ldr= [x for x in DataLoader([next(gen_new(I,B,M,N_df,WT)) for x in range(batch_size)], batch_size=batch_size)][0]
|
| 480 |
+
return(dt_ldr)
|
| 481 |
+
# def ldr_new(I,B,M,N_df,WT,batch_size=32):
|
| 482 |
+
# dt_ldr= [x for x in DataLoader([next(gen_new(I,B,M,N_df,WT)) for x in range(batch_size)], batch_size=batch_size)][0]
|
| 483 |
+
# return(dt_ldr)
|
| 484 |
|
| 485 |
def gen_new(Identity,Behavior,Modifier,n_df,word_type):
|
| 486 |
|
|
|
|
| 533 |
|
| 534 |
yield inputs, masks, ys,indexx
|
| 535 |
|
| 536 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
|
| 538 |
cols=['EEMA', 'EPMA', 'EAMA', 'EEA', 'EPA', 'EAA', 'EEB', 'EPB', 'EAB',
|
| 539 |
'EEMO', 'EPMO', 'EAMO', 'EEO', 'EPO', 'EAO', 'ModA', 'Actor', 'Behavior', 'ModO', 'Object']
|