sitammeur commited on
Commit
e3101e8
·
verified ·
1 Parent(s): 3e6b74e

Update src/app/predict.py

Browse files
Files changed (1) hide show
  1. src/app/predict.py +49 -49
src/app/predict.py CHANGED
@@ -1,49 +1,49 @@
1
- # Necessary imports
2
- import sys
3
- from typing import Dict
4
- from src.logger import logging
5
- from src.exception import CustomExceptionHandling
6
- from transformers import pipeline
7
-
8
-
9
- # Load the zero-shot classification model
10
- classifier = pipeline(
11
- "zero-shot-classification", model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0"
12
- )
13
-
14
-
15
- def ZeroShotTextClassification(
16
- text_input: str, candidate_labels: str
17
- ) -> Dict[str, float]:
18
- """
19
- Performs zero-shot classification on the given text input.
20
-
21
- Args:
22
- - text_input: The input text to classify.
23
- - candidate_labels: A comma-separated string of candidate labels.
24
-
25
- Returns:
26
- Dictionary containing label-score pairs.
27
- """
28
- try:
29
- # Split and clean the candidate labels
30
- labels = [label.strip() for label in candidate_labels.split(",")]
31
-
32
- # Log the classification attempt
33
- logging.info(f"Attempting classification with {len(labels)} labels")
34
-
35
- # Perform zero-shot classification
36
- classifier = pipeline("zero-shot-classification")
37
- prediction = classifier(text_input, labels)
38
-
39
- # Return the classification results
40
- logging.info("Classification completed successfully")
41
- return {
42
- prediction["labels"][i]: prediction["scores"][i]
43
- for i in range(len(prediction["labels"]))
44
- }
45
-
46
- # Handle exceptions that may occur during the process
47
- except Exception as e:
48
- # Custom exception handling
49
- raise CustomExceptionHandling(e, sys) from e
 
1
+ # Necessary imports
2
+ import sys
3
+ from typing import Dict
4
+ from src.logger import logging
5
+ from src.exception import CustomExceptionHandling
6
+ from transformers import pipeline
7
+
8
+
9
+ # Load the zero-shot classification model
10
+ classifier = pipeline(
11
+ "zero-shot-classification", model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0"
12
+ )
13
+
14
+
15
+ def ZeroShotTextClassification(
16
+ text_input: str, candidate_labels: str
17
+ ) -> Dict[str, float]:
18
+ """
19
+ Performs zero-shot classification on the given text input.
20
+
21
+ Args:
22
+ - text_input: The input text to classify.
23
+ - candidate_labels: A comma-separated string of candidate labels.
24
+
25
+ Returns:
26
+ Dictionary containing label-score pairs.
27
+ """
28
+ try:
29
+ # Split and clean the candidate labels
30
+ labels = [label.strip() for label in candidate_labels.split(",")]
31
+
32
+ # Log the classification attempt
33
+ logging.info(f"Attempting classification with {len(labels)} labels")
34
+
35
+ # Perform zero-shot classification
36
+ classifier = pipeline("zero-shot-classification")
37
+ prediction = classifier(text_input, labels, multi_label=True)
38
+
39
+ # Return the classification results
40
+ logging.info("Classification completed successfully")
41
+ return {
42
+ prediction["labels"][i]: prediction["scores"][i]
43
+ for i in range(len(prediction["labels"]))
44
+ }
45
+
46
+ # Handle exceptions that may occur during the process
47
+ except Exception as e:
48
+ # Custom exception handling
49
+ raise CustomExceptionHandling(e, sys) from e