Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -59,13 +59,13 @@ cls_model_name = "saiteki-kai/QA-DeBERTa-v3-large-binary-3"
|
|
| 59 |
model = transformers.pipeline(
|
| 60 |
model=chat_model_name,
|
| 61 |
model_kwargs={"dtype": torch.bfloat16} | ({"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {}),
|
| 62 |
-
|
| 63 |
)
|
| 64 |
|
| 65 |
classifier = transformers.pipeline(
|
| 66 |
model=cls_model_name,
|
| 67 |
model_kwargs={"dtype": torch.bfloat16},
|
| 68 |
-
|
| 69 |
)
|
| 70 |
|
| 71 |
unsafe_idx = classifier.model.config.label2id["unsafe"]
|
|
|
|
| 59 |
model = transformers.pipeline(
|
| 60 |
model=chat_model_name,
|
| 61 |
model_kwargs={"dtype": torch.bfloat16} | ({"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {}),
|
| 62 |
+
device=device,
|
| 63 |
)
|
| 64 |
|
| 65 |
classifier = transformers.pipeline(
|
| 66 |
model=cls_model_name,
|
| 67 |
model_kwargs={"dtype": torch.bfloat16},
|
| 68 |
+
device=device
|
| 69 |
)
|
| 70 |
|
| 71 |
unsafe_idx = classifier.model.config.label2id["unsafe"]
|