Alexvatti commited on
Commit
1a953b4
·
verified ·
1 Parent(s): 5c4ac70

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +7 -1
main.py CHANGED
@@ -7,7 +7,13 @@ from transformers import pipeline, AutoModelForSequenceClassification, AutoToken
7
 
8
  app = FastAPI()
9
 
10
- classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli", from_flax=True)
 
 
 
 
 
 
11
 
12
  categories = ["Spam", "Not Spam"]
13
 
 
7
 
8
  app = FastAPI()
9
 
10
+ model_name = "facebook/bart-large-mnli"
11
+
12
+ # Force PyTorch model instead of Flax
13
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, from_pt=True)
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+
16
+ classifier = pipeline("zero-shot-classification", model=model, tokenizer=tokenizer)
17
 
18
  categories = ["Spam", "Not Spam"]
19