gravelcompbio commited on
Commit
96a687b
·
verified ·
1 Parent(s): dbdae42

Update claspp_forward.py

Browse files
Files changed (1) hide show
  1. claspp_forward.py +5 -5
claspp_forward.py CHANGED
@@ -234,8 +234,8 @@ def predict(input_batches):
234
  # print(torch.tensor([tokenizer(batches)['input_ids']]).cuda().shape)
235
  # print(torch.tensor([tokenizer(batches)['attention_mask']]).cuda()["logits"][0].shape)
236
  #print(torch.tensor([tokenizer(batches)['input_ids']]).cuda().squeeze().shape)
237
-
238
- pred=(sig(model(torch.tensor([tokenizer(batches)['input_ids']]).squeeze().cuda(),torch.tensor([tokenizer(batches)['attention_mask']]).squeeze().cuda())["logits"]).tolist())
239
  #print(len(pred[0]))
240
  for p in pred:
241
  outputpreds.append(p)
@@ -247,10 +247,10 @@ def write_output(pred,listofpeps):
247
  n="\n"
248
  writethisline="pep,"
249
  for i in range(len(labsoi)):
250
- writethisline+=pos2lab[i]
251
- hf.write(writethisline+n)
252
  for p,ip in zip(pred,listofpeps):
253
- writethisline=f"{ip}"
254
  r=ip[10]
255
  #print(p)
256
  easyreadlab=getlab(p,r)
 
234
  # print(torch.tensor([tokenizer(batches)['input_ids']]).cuda().shape)
235
  # print(torch.tensor([tokenizer(batches)['attention_mask']]).cuda()["logits"][0].shape)
236
  #print(torch.tensor([tokenizer(batches)['input_ids']]).cuda().squeeze().shape)
237
+ with torch.no_grad():
238
+ pred=(sig(model(torch.tensor([tokenizer(batches)['input_ids']]).squeeze().cuda(),torch.tensor([tokenizer(batches)['attention_mask']]).squeeze().cuda())["logits"]).tolist())
239
  #print(len(pred[0]))
240
  for p in pred:
241
  outputpreds.append(p)
 
247
  n="\n"
248
  writethisline="pep,"
249
  for i in range(len(labsoi)):
250
+ writethisline+=pos2lab[i]+','
251
+ hf.write(writethisline[:-1]+n)
252
  for p,ip in zip(pred,listofpeps):
253
+ writethisline=f"{ip},"
254
  r=ip[10]
255
  #print(p)
256
  easyreadlab=getlab(p,r)