sitammeur commited on
Commit
9b22410
·
verified ·
1 Parent(s): 66a505e

Update src/app/predict.py

Browse files
Files changed (1) hide show
  1. src/app/predict.py +5 -1
src/app/predict.py CHANGED
@@ -3,12 +3,16 @@ import sys
3
  from typing import Dict
4
  from src.logger import logging
5
  from src.exception import CustomExceptionHandling
 
6
  from transformers import pipeline
7
 
8
 
9
  # Load the zero-shot classification model
 
10
  classifier = pipeline(
11
- "zero-shot-classification", model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0"
 
 
12
  )
13
 
14
 
 
3
  from typing import Dict
4
  from src.logger import logging
5
  from src.exception import CustomExceptionHandling
6
+ import torch
7
  from transformers import pipeline
8
 
9
 
10
  # Load the zero-shot classification model
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
  classifier = pipeline(
13
+ "zero-shot-classification",
14
+ model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0",
15
+ device=device,
16
  )
17
 
18