Spaces:
Sleeping
Sleeping
Update src/app/predict.py
Browse files- src/app/predict.py +5 -1
src/app/predict.py
CHANGED
|
@@ -3,12 +3,16 @@ import sys
|
|
| 3 |
from typing import Dict
|
| 4 |
from src.logger import logging
|
| 5 |
from src.exception import CustomExceptionHandling
|
|
|
|
| 6 |
from transformers import pipeline
|
| 7 |
|
| 8 |
|
| 9 |
# Load the zero-shot classification model
|
|
|
|
| 10 |
classifier = pipeline(
|
| 11 |
-
"zero-shot-classification",
|
|
|
|
|
|
|
| 12 |
)
|
| 13 |
|
| 14 |
|
|
|
|
| 3 |
from typing import Dict
|
| 4 |
from src.logger import logging
|
| 5 |
from src.exception import CustomExceptionHandling
|
| 6 |
+
import torch
|
| 7 |
from transformers import pipeline
|
| 8 |
|
| 9 |
|
| 10 |
# Load the zero-shot classification model
|
| 11 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
classifier = pipeline(
|
| 13 |
+
"zero-shot-classification",
|
| 14 |
+
model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0",
|
| 15 |
+
device=device,
|
| 16 |
)
|
| 17 |
|
| 18 |
|