SFM2001 commited on
Commit
6a923ef
·
1 Parent(s): ed8bf25
Files changed (1) hide show
  1. inference/infer_single.py +2 -2
inference/infer_single.py CHANGED
@@ -46,7 +46,7 @@ def generate_and_score_essay(topic, essay):
46
  truncation=True,
47
  padding_side='left'
48
  ).to(device)
49
- with torch.no_grad(), autocast(device_type='cuda', dtype=torch.float16):
50
  outputs = QWEN_MODEL.generate(
51
  **inputs,
52
  generation_config=gen_config,
@@ -86,7 +86,7 @@ def generate_and_score_essay(topic, essay):
86
  padding=True
87
  ).to(device)
88
  LONGFORMER_MODEL.eval()
89
- with torch.no_grad(), autocast(device_type='cuda', dtype=torch.float16):
90
  outputs = LONGFORMER_MODEL(**score_inputs) # Get full outputs dictionary
91
  scores = outputs['logits'].cpu().numpy()
92
  scores = [round(x) for x in scores[0]]
 
46
  truncation=True,
47
  padding_side='left'
48
  ).to(device)
49
+ with torch.no_grad():
50
  outputs = QWEN_MODEL.generate(
51
  **inputs,
52
  generation_config=gen_config,
 
86
  padding=True
87
  ).to(device)
88
  LONGFORMER_MODEL.eval()
89
+ with torch.no_grad():
90
  outputs = LONGFORMER_MODEL(**score_inputs) # Get full outputs dictionary
91
  scores = outputs['logits'].cpu().numpy()
92
  scores = [round(x) for x in scores[0]]