mjpsm commited on
Commit
1001133
·
verified ·
1 Parent(s): 1aa09aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -3
app.py CHANGED
@@ -38,8 +38,6 @@ def predict(req: Request):
38
  # 🔥 FIX HERE
39
  inputs.pop("token_type_ids", None)
40
 
41
- inputs = {k: v.to(device) for k, v in inputs.items()}
42
-
43
  with torch.no_grad():
44
  outputs = model(**inputs)
45
 
@@ -49,7 +47,7 @@ def predict(req: Request):
49
  pred = torch.argmax(probs, dim=1).item()
50
  confidence = probs[0][pred].item()
51
 
52
- label = "ROADBLOCK" if pred == 1 else "NOT_ROADBLOCK"
53
 
54
  return {
55
  "input": req.text,
 
38
  # 🔥 FIX HERE
39
  inputs.pop("token_type_ids", None)
40
 
 
 
41
  with torch.no_grad():
42
  outputs = model(**inputs)
43
 
 
47
  pred = torch.argmax(probs, dim=1).item()
48
  confidence = probs[0][pred].item()
49
 
50
+ label = model.config.id2label[pred]
51
 
52
  return {
53
  "input": req.text,