meetmendapara commited on
Commit
6e547ba
Β·
1 Parent(s): 2d6c34c

Refactor sentiment prediction function to improve clarity and confidence extraction

Browse files
Files changed (1) hide show
  1. app.py +17 -2
app.py CHANGED
@@ -11,11 +11,26 @@ tokenizer = BertTokenizer.from_pretrained("./imdb_bert_model")
11
 
12
  # Prediction function
13
  def predict_sentiment(text):
 
 
 
 
 
 
 
 
 
14
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
15
  outputs = model(**inputs)
16
  logits = outputs.logits
17
- confidence = torch.max(logits, dim=1).item()
18
- prediction = torch.argmax(logits, dim=1).item()
 
 
 
 
 
 
19
  sentiment = "Positive 😊" if prediction == 1 else "Negative 😠"
20
  return f"{sentiment} with confidence {confidence * 100:.2f}% confidence"
21
 
 
11
 
12
  # Prediction function
13
  def predict_sentiment(text):
14
+ """
15
+ Predicts the sentiment of the given text using the fine-tuned BERT model.
16
+
17
+ Args:
18
+ text (str): The input movie review text.
19
+
20
+ Returns:
21
+ str: The predicted sentiment with confidence.
22
+ """
23
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
24
  outputs = model(**inputs)
25
  logits = outputs.logits
26
+
27
+ # Extract the maximum value (confidence) and its index (prediction)
28
+ confidence, prediction = torch.max(logits, dim=1)
29
+ confidence = confidence.item() # Convert tensor to Python float
30
+ prediction = prediction.item() # Convert tensor to Python int
31
+
32
+ # confidence = torch.max(logits, dim=1).item()
33
+ # prediction = torch.argmax(logits, dim=1).item()
34
  sentiment = "Positive 😊" if prediction == 1 else "Negative 😠"
35
  return f"{sentiment} with confidence {confidence * 100:.2f}% confidence"
36