HemanM commited on
Commit
e0e3bb1
·
verified ·
1 Parent(s): cb92224

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +5 -2
inference.py CHANGED
@@ -45,6 +45,11 @@ def generate_response(goal, option1, option2):
45
  enc1 = tokenizer(text1, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
46
  enc2 = tokenizer(text2, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
47
 
 
 
 
 
 
48
  for k in enc1:
49
  enc1[k] = enc1[k].to(device)
50
  enc2[k] = enc2[k].to(device)
@@ -56,13 +61,11 @@ def generate_response(goal, option1, option2):
56
  logits1 = get_logits(out1)
57
  logits2 = get_logits(out2)
58
 
59
- # Sanity check shape
60
  if logits1.shape[-1] < 2 or logits2.shape[-1] < 2:
61
  raise ValueError("Model did not return 2-class logits.")
62
 
63
  score1 = logits1[0][1].item()
64
  score2 = logits2[0][1].item()
65
-
66
  evo_result = option1 if score1 > score2 else option2
67
 
68
  except Exception as e:
 
45
  enc1 = tokenizer(text1, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
46
  enc2 = tokenizer(text2, return_tensors="pt", truncation=True, padding="max_length", max_length=128)
47
 
48
+ # Remove token_type_ids to avoid crash in EvoTransformer
49
+ enc1.pop("token_type_ids", None)
50
+ enc2.pop("token_type_ids", None)
51
+
52
+ # Move tensors to device
53
  for k in enc1:
54
  enc1[k] = enc1[k].to(device)
55
  enc2[k] = enc2[k].to(device)
 
61
  logits1 = get_logits(out1)
62
  logits2 = get_logits(out2)
63
 
 
64
  if logits1.shape[-1] < 2 or logits2.shape[-1] < 2:
65
  raise ValueError("Model did not return 2-class logits.")
66
 
67
  score1 = logits1[0][1].item()
68
  score2 = logits2[0][1].item()
 
69
  evo_result = option1 if score1 > score2 else option2
70
 
71
  except Exception as e: