Update claspp_forward.py
Browse files- claspp_forward.py +17 -10
claspp_forward.py
CHANGED
|
@@ -16,8 +16,10 @@ from datasets import (
|
|
| 16 |
from modeling_esm import EsmForSequenceClassificationCustomWidehead
|
| 17 |
|
| 18 |
|
|
|
|
| 19 |
tokenizer = EsmTokenizer.from_pretrained("finalCheckpoint_25_05_11/")
|
| 20 |
model = EsmForSequenceClassificationCustomWidehead.from_pretrained("finalCheckpoint_25_05_11/", num_labels=54).cuda()
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
###############################################################################
|
|
@@ -234,23 +236,25 @@ def predict(input_batches):
|
|
| 234 |
# print(torch.tensor([tokenizer(batches)['input_ids']]).cuda().shape)
|
| 235 |
# print(torch.tensor([tokenizer(batches)['attention_mask']]).cuda()["logits"][0].shape)
|
| 236 |
#print(torch.tensor([tokenizer(batches)['input_ids']]).cuda().squeeze().shape)
|
| 237 |
-
|
| 238 |
-
|
|
|
|
| 239 |
#print(len(pred[0]))
|
| 240 |
for p in pred:
|
|
|
|
| 241 |
outputpreds.append(p)
|
| 242 |
return outputpreds
|
| 243 |
|
| 244 |
|
| 245 |
-
def write_output(pred,listofpeps):
|
| 246 |
-
hf=open("
|
| 247 |
n="\n"
|
| 248 |
writethisline="pep,"
|
| 249 |
for i in range(len(labsoi)):
|
| 250 |
-
writethisline+=pos2lab[i]
|
| 251 |
-
hf.write(writethisline
|
| 252 |
for p,ip in zip(pred,listofpeps):
|
| 253 |
-
writethisline=f"{ip}
|
| 254 |
r=ip[10]
|
| 255 |
#print(p)
|
| 256 |
easyreadlab=getlab(p,r)
|
|
@@ -384,11 +388,14 @@ def main():
|
|
| 384 |
if i%batch_size==0 and i!=0:
|
| 385 |
input_batches.append(temp)
|
| 386 |
temp=[]
|
| 387 |
-
|
|
|
|
|
|
|
| 388 |
input_batches.append(temp)
|
| 389 |
-
|
|
|
|
| 390 |
pred=predict(input_batches=input_batches)
|
| 391 |
-
write_output(pred,listofpeps)
|
| 392 |
|
| 393 |
|
| 394 |
|
|
|
|
| 16 |
from modeling_esm import EsmForSequenceClassificationCustomWidehead
|
| 17 |
|
| 18 |
|
| 19 |
+
print("intilizing checkpoint --might take a few min if this is the first time--")
|
| 20 |
tokenizer = EsmTokenizer.from_pretrained("finalCheckpoint_25_05_11/")
|
| 21 |
model = EsmForSequenceClassificationCustomWidehead.from_pretrained("finalCheckpoint_25_05_11/", num_labels=54).cuda()
|
| 22 |
+
print("finished downloading")
|
| 23 |
|
| 24 |
|
| 25 |
###############################################################################
|
|
|
|
| 236 |
# print(torch.tensor([tokenizer(batches)['input_ids']]).cuda().shape)
|
| 237 |
# print(torch.tensor([tokenizer(batches)['attention_mask']]).cuda()["logits"][0].shape)
|
| 238 |
#print(torch.tensor([tokenizer(batches)['input_ids']]).cuda().squeeze().shape)
|
| 239 |
+
print(tokenizer(batches)['input_ids'])
|
| 240 |
+
print(torch.tensor([tokenizer(batches)['input_ids']]).squeeze().cuda())
|
| 241 |
+
pred=(sig(model(torch.tensor([tokenizer(batches)['input_ids']]).squeeze().cuda(),torch.tensor([tokenizer(batches)['attention_mask']]).squeeze().cuda())["logits"]).tolist())
|
| 242 |
#print(len(pred[0]))
|
| 243 |
for p in pred:
|
| 244 |
+
print(p)
|
| 245 |
outputpreds.append(p)
|
| 246 |
return outputpreds
|
| 247 |
|
| 248 |
|
| 249 |
+
def write_output(pred,listofpeps,file_output):
|
| 250 |
+
hf=open(f"{file_output}",'w+')
|
| 251 |
n="\n"
|
| 252 |
writethisline="pep,"
|
| 253 |
for i in range(len(labsoi)):
|
| 254 |
+
writethisline+=pos2lab[i]
|
| 255 |
+
hf.write(writethisline+n)
|
| 256 |
for p,ip in zip(pred,listofpeps):
|
| 257 |
+
writethisline=f"{ip}"
|
| 258 |
r=ip[10]
|
| 259 |
#print(p)
|
| 260 |
easyreadlab=getlab(p,r)
|
|
|
|
| 388 |
if i%batch_size==0 and i!=0:
|
| 389 |
input_batches.append(temp)
|
| 390 |
temp=[]
|
| 391 |
+
if pep=='':
|
| 392 |
+
continue
|
| 393 |
+
temp.append(pep.replace("-", "<pad>"))
|
| 394 |
input_batches.append(temp)
|
| 395 |
+
print(listofpeps)
|
| 396 |
+
print(input_batches)
|
| 397 |
pred=predict(input_batches=input_batches)
|
| 398 |
+
write_output(pred,listofpeps,file_output)
|
| 399 |
|
| 400 |
|
| 401 |
|