sitammeur commited on
Commit
5ba7182
·
verified ·
1 Parent(s): e3101e8

Update src/app/predict.py

Browse files
Files changed (1) hide show
  1. src/app/predict.py +7 -2
src/app/predict.py CHANGED
@@ -33,8 +33,13 @@ def ZeroShotTextClassification(
33
  logging.info(f"Attempting classification with {len(labels)} labels")
34
 
35
  # Perform zero-shot classification
36
- classifier = pipeline("zero-shot-classification")
37
- prediction = classifier(text_input, labels, multi_label=True)
 
 
 
 
 
38
 
39
  # Return the classification results
40
  logging.info("Classification completed successfully")
 
33
  logging.info(f"Attempting classification with {len(labels)} labels")
34
 
35
  # Perform zero-shot classification
36
+ hypothesis_template = "This text is about {}"
37
+ prediction = classifier(
38
+ text_input,
39
+ labels,
40
+ hypothesis_template=hypothesis_template,
41
+ multi_label=True,
42
+ )
43
 
44
  # Return the classification results
45
  logging.info("Classification completed successfully")