MindLabUnimib commited on
Commit
d5b1c96
·
verified ·
1 Parent(s): fe84da2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -59,13 +59,13 @@ cls_model_name = "saiteki-kai/QA-DeBERTa-v3-large-binary-3"
59
  model = transformers.pipeline(
60
  model=chat_model_name,
61
  model_kwargs={"dtype": torch.bfloat16} | ({"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {}),
62
- device_map="cuda",
63
  )
64
 
65
  classifier = transformers.pipeline(
66
  model=cls_model_name,
67
  model_kwargs={"dtype": torch.bfloat16},
68
- device_map="cuda"
69
  )
70
 
71
  unsafe_idx = classifier.model.config.label2id["unsafe"]
 
59
  model = transformers.pipeline(
60
  model=chat_model_name,
61
  model_kwargs={"dtype": torch.bfloat16} | ({"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {}),
62
+ device=device,
63
  )
64
 
65
  classifier = transformers.pipeline(
66
  model=cls_model_name,
67
  model_kwargs={"dtype": torch.bfloat16},
68
+ device=device
69
  )
70
 
71
  unsafe_idx = classifier.model.config.label2id["unsafe"]