sitammeur's picture
Upload 6 files
8606cde verified
raw
history blame
1.58 kB
# Necessary imports
import sys
from typing import Dict
from src.logger import logging
from src.exception import CustomExceptionHandling
from transformers import pipeline
# Load the zero-shot classification model
classifier = pipeline(
"zero-shot-classification", model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0"
)
def ZeroShotTextClassification(
text_input: str, candidate_labels: str
) -> Dict[str, float]:
"""
Performs zero-shot classification on the given text input.
Args:
- text_input: The input text to classify.
- candidate_labels: A comma-separated string of candidate labels.
Returns:
Dictionary containing label-score pairs.
"""
try:
# Split and clean the candidate labels
labels = [label.strip() for label in candidate_labels.split(",")]
# Log the classification attempt
logging.info(f"Attempting classification with {len(labels)} labels")
# Perform zero-shot classification
classifier = pipeline("zero-shot-classification")
prediction = classifier(text_input, labels)
# Return the classification results
logging.info("Classification completed successfully")
return {
prediction["labels"][i]: prediction["scores"][i]
for i in range(len(prediction["labels"]))
}
# Handle exceptions that may occur during the process
except Exception as e:
# Custom exception handling
raise CustomExceptionHandling(e, sys) from e