VSDatta commited on
Commit
f9f151f
·
verified ·
1 Parent(s): ac41f0e

Update src/ai_classifier.py

Browse files
Files changed (1) hide show
  1. src/ai_classifier.py +9 -4
src/ai_classifier.py CHANGED
@@ -1,9 +1,13 @@
1
- # ai_classifier.py
2
  from transformers import pipeline
3
 
4
- # Load once and reuse
5
- classifier = pipeline("zero-shot-classification", model="joeddav/xlm-roberta-large-xnli", cache_dir="./hf_cache")
6
- # Define the candidate labels (in English internally)
 
 
 
 
 
7
  CATEGORIES = {
8
  "Family": "కుటుంబం",
9
  "Friendship": "స్నేహం",
@@ -28,6 +32,7 @@ CATEGORIES = {
28
  }
29
 
30
  def classify_proverb(text):
 
31
  result = classifier(text, list(CATEGORIES.keys()))
32
  top_label = result["labels"][0]
33
  return CATEGORIES[top_label]
 
 
1
  from transformers import pipeline
2
 
3
+ # Load the pipeline once and tell it where to cache the model
4
+ classifier = pipeline(
5
+ "zero-shot-classification",
6
+ model="joeddav/xlm-roberta-large-xnli",
7
+ cache_dir="./hf_cache" # This line fixes the permission error
8
+ )
9
+
10
+ # Define the candidate labels
11
  CATEGORIES = {
12
  "Family": "కుటుంబం",
13
  "Friendship": "స్నేహం",
 
32
  }
33
 
34
  def classify_proverb(text):
35
+ """Classifies the proverb and returns the Telugu label."""
36
  result = classifier(text, list(CATEGORIES.keys()))
37
  top_label = result["labels"][0]
38
  return CATEGORIES[top_label]