gravelcompbio commited on
Commit
f6a043a
·
verified ·
1 Parent(s): 6d6151d

Update claspp_forward.py

Browse files
Files changed (1) hide show
  1. claspp_forward.py +17 -10
claspp_forward.py CHANGED
@@ -16,8 +16,10 @@ from datasets import (
16
  from modeling_esm import EsmForSequenceClassificationCustomWidehead
17
 
18
 
 
19
  tokenizer = EsmTokenizer.from_pretrained("finalCheckpoint_25_05_11/")
20
  model = EsmForSequenceClassificationCustomWidehead.from_pretrained("finalCheckpoint_25_05_11/", num_labels=54).cuda()
 
21
 
22
 
23
  ###############################################################################
@@ -234,23 +236,25 @@ def predict(input_batches):
234
  # print(torch.tensor([tokenizer(batches)['input_ids']]).cuda().shape)
235
  # print(torch.tensor([tokenizer(batches)['attention_mask']]).cuda()["logits"][0].shape)
236
  #print(torch.tensor([tokenizer(batches)['input_ids']]).cuda().squeeze().shape)
237
- with torch.no_grad():
238
- pred=(sig(model(torch.tensor([tokenizer(batches)['input_ids']]).squeeze().cuda(),torch.tensor([tokenizer(batches)['attention_mask']]).squeeze().cuda())["logits"]).tolist())
 
239
  #print(len(pred[0]))
240
  for p in pred:
 
241
  outputpreds.append(p)
242
  return outputpreds
243
 
244
 
245
- def write_output(pred,listofpeps):
246
- hf=open("output_predictions.csv",'w+')
247
  n="\n"
248
  writethisline="pep,"
249
  for i in range(len(labsoi)):
250
- writethisline+=pos2lab[i]+','
251
- hf.write(writethisline[:-1]+n)
252
  for p,ip in zip(pred,listofpeps):
253
- writethisline=f"{ip},"
254
  r=ip[10]
255
  #print(p)
256
  easyreadlab=getlab(p,r)
@@ -384,11 +388,14 @@ def main():
384
  if i%batch_size==0 and i!=0:
385
  input_batches.append(temp)
386
  temp=[]
387
- temp.append(pep)
 
 
388
  input_batches.append(temp)
389
-
 
390
  pred=predict(input_batches=input_batches)
391
- write_output(pred,listofpeps)
392
 
393
 
394
 
 
16
  from modeling_esm import EsmForSequenceClassificationCustomWidehead
17
 
18
 
19
+ print("intilizing checkpoint --might take a few min if this is the first time--")
20
  tokenizer = EsmTokenizer.from_pretrained("finalCheckpoint_25_05_11/")
21
  model = EsmForSequenceClassificationCustomWidehead.from_pretrained("finalCheckpoint_25_05_11/", num_labels=54).cuda()
22
+ print("finished downloading")
23
 
24
 
25
  ###############################################################################
 
236
  # print(torch.tensor([tokenizer(batches)['input_ids']]).cuda().shape)
237
  # print(torch.tensor([tokenizer(batches)['attention_mask']]).cuda()["logits"][0].shape)
238
  #print(torch.tensor([tokenizer(batches)['input_ids']]).cuda().squeeze().shape)
239
+ print(tokenizer(batches)['input_ids'])
240
+ print(torch.tensor([tokenizer(batches)['input_ids']]).squeeze().cuda())
241
+ pred=(sig(model(torch.tensor([tokenizer(batches)['input_ids']]).squeeze().cuda(),torch.tensor([tokenizer(batches)['attention_mask']]).squeeze().cuda())["logits"]).tolist())
242
  #print(len(pred[0]))
243
  for p in pred:
244
+ print(p)
245
  outputpreds.append(p)
246
  return outputpreds
247
 
248
 
249
+ def write_output(pred,listofpeps,file_output):
250
+ hf=open(f"{file_output}",'w+')
251
  n="\n"
252
  writethisline="pep,"
253
  for i in range(len(labsoi)):
254
+ writethisline+=pos2lab[i]
255
+ hf.write(writethisline+n)
256
  for p,ip in zip(pred,listofpeps):
257
+ writethisline=f"{ip}"
258
  r=ip[10]
259
  #print(p)
260
  easyreadlab=getlab(p,r)
 
388
  if i%batch_size==0 and i!=0:
389
  input_batches.append(temp)
390
  temp=[]
391
+ if pep=='':
392
+ continue
393
+ temp.append(pep.replace("-", "<pad>"))
394
  input_batches.append(temp)
395
+ print(listofpeps)
396
+ print(input_batches)
397
  pred=predict(input_batches=input_batches)
398
+ write_output(pred,listofpeps,file_output)
399
 
400
 
401