Keshavp08 commited on
Commit
0d352fa
·
verified ·
1 Parent(s): b011208

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -20
app.py CHANGED
@@ -3,8 +3,8 @@ import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import matplotlib.pyplot as plt
5
 
6
- # Initialize tokenizer and model once to avoid reloading them on every interaction
7
- @st.cache(allow_output_mutation=True)
8
  def load_model():
9
  tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
10
  model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
@@ -16,24 +16,29 @@ st.title("Sentiment Analysis App")
16
 
17
  text = st.text_input("Enter text to analyze:")
18
  if st.button("Analyze"):
19
- encoding = tokenizer.encode_plus(text, return_tensors="pt", padding=True, truncation=True)
20
- input_ids = encoding["input_ids"]
21
- attention_mask = encoding["attention_mask"]
 
22
 
23
- with torch.no_grad():
24
- output = model(input_ids, attention_mask)
25
- prediction = int(torch.argmax(output.logits))
26
 
27
- # Detailed sentiment output
28
- sentiment = ["Negative", "Neutral", "Positive"][prediction]
29
- st.write(f"Sentiment: {sentiment}")
 
30
 
31
- values = output.logits.squeeze().tolist() # Flatten the logits tensor to a list
32
- labels = ["Negative", "Neutral", "Positive"]
33
-
34
- # Plotting
35
- fig, ax = plt.subplots()
36
- ax.bar(labels, values, color=['red', 'blue', 'green'])
37
- ax.set_title("Sentiment Analysis Scores")
38
- ax.set_ylabel("Score")
39
- st.pyplot(fig)
 
 
 
 
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import matplotlib.pyplot as plt
5
 
6
+ # Cache the loading of the tokenizer and model to speed up the app
7
+ @st.cache_resource
8
  def load_model():
9
  tokenizer = AutoTokenizer.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
10
  model = AutoModelForSequenceClassification.from_pretrained("nlptown/bert-base-multilingual-uncased-sentiment")
 
16
 
17
  text = st.text_input("Enter text to analyze:")
18
  if st.button("Analyze"):
19
+ if text: # Ensure that the text is not empty
20
+ encoding = tokenizer.encode_plus(text, return_tensors="pt", padding=True, truncation=True)
21
+ input_ids = encoding["input_ids"]
22
+ attention_mask = encoding["attention_mask"]
23
 
24
+ with torch.no_grad():
25
+ output = model(input_ids, attention_mask)
26
+ prediction = int(torch.argmax(output.logits))
27
 
28
+ # Define sentiments
29
+ sentiments = ["Negative", "Neutral", "Positive"]
30
+ sentiment = sentiments[prediction]
31
+ st.write(f"Sentiment: {sentiment}")
32
 
33
+ # Flatten the logits tensor to a list and check dimensions
34
+ values = output.logits.squeeze().tolist()
35
+ if len(values) != len(sentiments):
36
+ st.error(f"Mismatch in the number of sentiments and values. Expected {len(sentiments)}, got {len(values)}")
37
+ else:
38
+ fig, ax = plt.subplots()
39
+ ax.bar(sentiments, values, color=['red', 'blue', 'green'])
40
+ ax.set_title("Sentiment Analysis Scores")
41
+ ax.set_ylabel("Score")
42
+ st.pyplot(fig)
43
+ else:
44
+ st.error("Please enter some text to analyze.")