MissingBreath commited on
Commit
e237f80
·
verified ·
1 Parent(s): 70a8b66

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +19 -1
api.py CHANGED
@@ -5,7 +5,25 @@ import io
5
  import tensorflow as tf
6
 
7
 
8
- model = tf.keras.models.load_model('_9217')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  app = FastAPI()
11
  @app.post("/classify")
 
5
  import tensorflow as tf
6
 
7
 
8
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
+
10
+ tokenizer = AutoTokenizer.from_pretrained("chillies/distilbert-course-review-classification")
11
+ model = AutoModelForSequenceClassification.from_pretrained("chillies/distilbert-course-review-classification")
12
+
13
+ def inference(review):
14
+ inputs = tokenizer(review, return_tensors="pt", padding=True, truncation=True)
15
+ outputs = model(**inputs)
16
+
17
+ # Assuming the model outputs logits
18
+ predicted_class = outputs.logits.argmax(dim=-1).item()
19
+
20
+ class_labels = [
21
+ 'Improvement Suggestions', 'Questions', 'Confusion', 'Support Request',
22
+ 'Discussion', 'Course Comparison', 'Related Course Suggestions',
23
+ 'Negative', 'Positive'
24
+ ]
25
+ return class_labels[predicted_class]
26
+
27
 
28
  app = FastAPI()
29
  @app.post("/classify")