File size: 1,576 Bytes
8606cde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Necessary imports
import sys
from typing import Dict
from src.logger import logging
from src.exception import CustomExceptionHandling
from transformers import pipeline


# Load the zero-shot classification model
classifier = pipeline(
    "zero-shot-classification", model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0"
)


def ZeroShotTextClassification(

    text_input: str, candidate_labels: str

) -> Dict[str, float]:
    """

    Performs zero-shot classification on the given text input.



    Args:

        - text_input: The input text to classify.

        - candidate_labels: A comma-separated string of candidate labels.



    Returns:

        Dictionary containing label-score pairs.

    """
    try:
        # Split and clean the candidate labels
        labels = [label.strip() for label in candidate_labels.split(",")]

        # Log the classification attempt
        logging.info(f"Attempting classification with {len(labels)} labels")

        # Perform zero-shot classification
        classifier = pipeline("zero-shot-classification")
        prediction = classifier(text_input, labels)

        # Return the classification results
        logging.info("Classification completed successfully")
        return {
            prediction["labels"][i]: prediction["scores"][i]
            for i in range(len(prediction["labels"]))
        }

    # Handle exceptions that may occur during the process
    except Exception as e:
        # Custom exception handling
        raise CustomExceptionHandling(e, sys) from e