shayankhan7 commited on
Commit
8c3561d
·
verified ·
1 Parent(s): 231fea0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
3
  from scipy.special import softmax
4
  import torch
 
5
 
6
  # Load sentiment model
7
  @st.cache_resource
@@ -17,7 +18,7 @@ def load_emotion_model():
17
  tokenizer = AutoTokenizer.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
18
  return model, tokenizer
19
 
20
- # Load T5 paraphrasing model
21
  @st.cache_resource
22
  def load_paraphrase_model():
23
  model = AutoModelForSeq2SeqLM.from_pretrained("Vamsi/T5_Paraphrase_Paws")
@@ -42,7 +43,7 @@ def get_emotion(text, model, tokenizer):
42
  labels = ['anger', 'disgust', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
43
  return labels[probs.argmax()], float(probs.max()) * 100
44
 
45
- # Generate feedback
46
  def generate_feedback(sentiment, emotion):
47
  if sentiment == "Negative":
48
  if emotion in ["anger", "disgust", "sadness"]:
@@ -59,8 +60,16 @@ def generate_feedback(sentiment, emotion):
59
  else:
60
  return "🙂 Your message is positive, but think about whether it’s being fully understood."
61
 
62
- # Paraphrase / rewrite message
63
- def rewrite_message(text, model, tokenizer):
 
 
 
 
 
 
 
 
64
  text = "paraphrase: " + text + " </s>"
65
  encoding = tokenizer.encode_plus(text, return_tensors="pt", max_length=128, truncation=True)
66
  with torch.no_grad():
@@ -73,9 +82,9 @@ def rewrite_message(text, model, tokenizer):
73
  temperature=1.5
74
  )
75
  rewrites = [tokenizer.decode(o, skip_special_tokens=True) for o in output]
76
- return list(set(rewrites)) # remove duplicates
77
 
78
- # UI
79
  st.title("🗣️ Message Tone & Rewrite Checker (Phase 2)")
80
  st.write("Before you send that message, check how it might be received — and improve it if needed.")
81
 
@@ -101,7 +110,7 @@ if st.button("Analyze"):
101
  st.markdown("---")
102
  st.markdown("### ✨ Try Rewriting Your Message")
103
  para_model, para_token = load_paraphrase_model()
104
- rewrites = rewrite_message(text, para_model, para_token)
105
 
106
  for i, r in enumerate(rewrites, 1):
107
  st.write(f"**Version {i}:** {r}")
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
3
  from scipy.special import softmax
4
  import torch
5
+ import re
6
 
7
  # Load sentiment model
8
  @st.cache_resource
 
18
  tokenizer = AutoTokenizer.from_pretrained("j-hartmann/emotion-english-distilroberta-base")
19
  return model, tokenizer
20
 
21
+ # Load paraphrasing model
22
  @st.cache_resource
23
  def load_paraphrase_model():
24
  model = AutoModelForSeq2SeqLM.from_pretrained("Vamsi/T5_Paraphrase_Paws")
 
43
  labels = ['anger', 'disgust', 'fear', 'joy', 'neutral', 'sadness', 'surprise']
44
  return labels[probs.argmax()], float(probs.max()) * 100
45
 
46
+ # Feedback generation
47
  def generate_feedback(sentiment, emotion):
48
  if sentiment == "Negative":
49
  if emotion in ["anger", "disgust", "sadness"]:
 
60
  else:
61
  return "🙂 Your message is positive, but think about whether it’s being fully understood."
62
 
63
+ # Profanity detection
64
+ def contains_profanity(text):
65
+ profane_words = ['fuck', 'shit', 'bitch', 'stupid', 'idiot', 'dumb', 'asshole']
66
+ return any(re.search(rf"\b{word}\b", text.lower()) for word in profane_words)
67
+
68
+ # Smart rewrite logic
69
+ def smart_rewrite_message(text, model, tokenizer):
70
+ if contains_profanity(text):
71
+ return ["⚠️ Your message may contain harmful language. Please rephrase it with respect and calm."]
72
+
73
  text = "paraphrase: " + text + " </s>"
74
  encoding = tokenizer.encode_plus(text, return_tensors="pt", max_length=128, truncation=True)
75
  with torch.no_grad():
 
82
  temperature=1.5
83
  )
84
  rewrites = [tokenizer.decode(o, skip_special_tokens=True) for o in output]
85
+ return list(set(rewrites))
86
 
87
+ # Streamlit App UI
88
  st.title("🗣️ Message Tone & Rewrite Checker (Phase 2)")
89
  st.write("Before you send that message, check how it might be received — and improve it if needed.")
90
 
 
110
  st.markdown("---")
111
  st.markdown("### ✨ Try Rewriting Your Message")
112
  para_model, para_token = load_paraphrase_model()
113
+ rewrites = smart_rewrite_message(text, para_model, para_token)
114
 
115
  for i, r in enumerate(rewrites, 1):
116
  st.write(f"**Version {i}:** {r}")