VSDatta commited on
Commit
d7064e7
·
verified ·
1 Parent(s): f9f151f

Update src/ai_classifier.py

Browse files
Files changed (1) hide show
  1. src/ai_classifier.py +4 -4
src/ai_classifier.py CHANGED
@@ -1,13 +1,13 @@
1
  from transformers import pipeline
2
 
3
- # Load the pipeline once and tell it where to cache the model
4
  classifier = pipeline(
5
  "zero-shot-classification",
6
  model="joeddav/xlm-roberta-large-xnli",
7
- cache_dir="./hf_cache" # This line fixes the permission error
8
  )
9
 
10
- # Define the candidate labels
11
  CATEGORIES = {
12
  "Family": "కుటుంబం",
13
  "Friendship": "స్నేహం",
@@ -35,4 +35,4 @@ def classify_proverb(text):
35
  """Classifies the proverb and returns the Telugu label."""
36
  result = classifier(text, list(CATEGORIES.keys()))
37
  top_label = result["labels"][0]
38
- return CATEGORIES[top_label]
 
1
  from transformers import pipeline
2
 
3
+ # Load the zero-shot-classification pipeline and specify a cache directory.
4
  classifier = pipeline(
5
  "zero-shot-classification",
6
  model="joeddav/xlm-roberta-large-xnli",
7
+ cache_dir="./hf_cache"
8
  )
9
 
10
+ # Mapping of English category labels to Telugu.
11
  CATEGORIES = {
12
  "Family": "కుటుంబం",
13
  "Friendship": "స్నేహం",
 
35
  """Classifies the proverb and returns the Telugu label."""
36
  result = classifier(text, list(CATEGORIES.keys()))
37
  top_label = result["labels"][0]
38
+ return CATEGORIES[top_label]