Keshavp08 commited on
Commit
61a5fad
·
verified ·
1 Parent(s): 0d352fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -27
app.py CHANGED
@@ -3,7 +3,6 @@ import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import matplotlib.pyplot as plt
5
 
6
- # Cache the loading of the tokenizer and model to speed up the app
7
  @st.cache_resource
8
  def load_model():
9
  tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
@@ -15,30 +14,33 @@ tokenizer, model = load_model()
15
  st.title("Sentiment Analysis App")
16
 
17
  text = st.text_input("Enter text to analyze:")
18
- if st.button("Analyze"):
19
- if text: # Ensure that the text is not empty
20
- encoding = tokenizer.encode_plus(text, return_tensors="pt", padding=True, truncation=True)
21
- input_ids = encoding["input_ids"]
22
- attention_mask = encoding["attention_mask"]
23
-
24
- with torch.no_grad():
25
- output = model(input_ids, attention_mask)
26
- prediction = int(torch.argmax(output.logits))
27
-
28
- # Define sentiments
29
- sentiments = ["Negative", "Neutral", "Positive"]
30
- sentiment = sentiments[prediction]
31
- st.write(f"Sentiment: {sentiment}")
32
-
33
- # Flatten the logits tensor to a list and check dimensions
34
- values = output.logits.squeeze().tolist()
35
- if len(values) != len(sentiments):
36
- st.error(f"Mismatch in the number of sentiments and values. Expected {len(sentiments)}, got {len(values)}")
37
- else:
38
- fig, ax = plt.subplots()
39
- ax.bar(sentiments, values, color=['red', 'blue', 'green'])
40
- ax.set_title("Sentiment Analysis Scores")
41
- ax.set_ylabel("Score")
42
- st.pyplot(fig)
43
  else:
44
- st.error("Please enter some text to analyze.")
 
 
 
 
 
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import matplotlib.pyplot as plt
5
 
 
6
  @st.cache_resource
7
  def load_model():
8
  tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
 
14
  st.title("Sentiment Analysis App")
15
 
16
  text = st.text_input("Enter text to analyze:")
17
+ if st.button("Analyze") and text:
18
+ encoding = tokenizer.encode_plus(text, return_tensors="pt", padding=True, truncation=True)
19
+ input_ids = encoding["input_ids"]
20
+ attention_mask = encoding["attention_mask"]
21
+
22
+ with torch.no_grad():
23
+ output = model(input_ids, attention_mask)
24
+ logits = output.logits.squeeze()
25
+
26
+ print("Logits Shape:", logits.shape)
27
+ print("Logits Contents:", logits)
28
+
29
+ if logits.shape[0] != 3:
30
+ st.error(f"Unexpected number of output values: {logits.shape[0]}")
31
+ st.stop()
32
+
33
+ prediction = int(torch.argmax(logits))
34
+
35
+ sentiments = ["Negative", "Neutral", "Positive"]
36
+ sentiment = sentiments[prediction]
37
+ st.write(f"Sentiment: {sentiment}")
38
+
39
+ if len(logits) != len(sentiments):
40
+ st.error(f"Mismatch in the number of sentiments and values. Expected {len(sentiments)}, got {len(logits)}")
 
41
  else:
42
+ fig, ax = plt.subplots()
43
+ ax.bar(sentiments, logits.tolist(), color=['red', 'blue', 'green'])
44
+ ax.set_title("Sentiment Analysis Scores")
45
+ ax.set_ylabel("Score")
46
+ st.pyplot(fig)