nit454 commited on
Commit
dd4a9f5
Β·
verified Β·
1 Parent(s): 50362c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -58
app.py CHANGED
@@ -5,14 +5,13 @@ import easyocr
5
  from PIL import Image
6
  import numpy as np
7
 
8
- # Sarcasm Model
9
- SARCASM_MODEL_NAME = "j-hartmann/emotion-english-distilroberta-base"
10
  sarcasm_labels = ["not sarcastic", "sarcastic"]
11
- sarcasm_tokenizer = AutoTokenizer.from_pretrained(SARCASM_MODEL_NAME)
12
- sarcasm_model = AutoModelForSequenceClassification.from_pretrained(SARCASM_MODEL_NAME)
13
 
14
- # Hate Speech Model (RoBERTa multiclass)
15
- HATE_MODEL_NAME = "cardiffnlp/twitter-roberta-base-hate-multiclass-latest"
16
  hate_labels = [
17
  "sexism",
18
  "racism",
@@ -22,17 +21,18 @@ hate_labels = [
22
  "other",
23
  "not_hate"
24
  ]
25
- hate_tokenizer = AutoTokenizer.from_pretrained(HATE_MODEL_NAME)
26
- hate_model = AutoModelForSequenceClassification.from_pretrained(HATE_MODEL_NAME)
27
 
28
- # OCR Reader
29
  reader = easyocr.Reader(['en'], gpu=False)
30
 
31
  def extract_text(image):
 
 
32
  if isinstance(image, Image.Image):
33
  image = np.array(image)
34
- result = reader.readtext(image, detail=0)
35
- return ' '.join(result)
36
 
37
  def detect_sarcasm(text):
38
  inputs = sarcasm_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
@@ -40,8 +40,8 @@ def detect_sarcasm(text):
40
  outputs = sarcasm_model(**inputs)
41
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
42
  pred = torch.argmax(probs).item()
43
- confidence = float(probs[0][pred])
44
- return sarcasm_labels[pred], confidence
45
 
46
  def classify_hate(text):
47
  inputs = hate_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
@@ -49,56 +49,54 @@ def classify_hate(text):
49
  outputs = hate_model(**inputs)
50
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
51
  pred = torch.argmax(probs).item()
52
- confidence = float(probs[0][pred])
53
- return hate_labels[pred], confidence
54
 
55
-
56
- def chatbot(conversation, user_message, image=None):
57
- # Process image first if available
58
- if image is not None:
59
- extracted_text = extract_text(image)
60
- if not extracted_text.strip():
61
- response = "No text found in the uploaded image."
62
- conversation.append(("User", user_message))
63
- conversation.append(("Cyber Bully Bot", response))
64
- return conversation, None, None
65
- text = extracted_text
66
- display_input = f"[Extracted from image] {text}"
67
  else:
68
- text = user_message.strip()
69
- display_input = text
70
 
71
- # Sarcasm detection
72
- sarcasm_label, sarcasm_conf = detect_sarcasm(text)
73
  if sarcasm_label == "sarcastic":
74
- response = f"Text detected as SARCASTIC (Confidence: {sarcasm_conf:.2f}). Hate speech classification skipped."
75
- conversation.append(("User", display_input))
76
- conversation.append(("Cyber Bully Bot", response))
77
- return conversation, None, None
78
-
79
- # Hate speech classification
80
- hate_label, hate_conf = classify_hate(text)
81
- response = (
82
- f"Hate Speech Category: {hate_label} (Confidence: {hate_conf:.2f})\n"
83
- f"Sarcasm: {sarcasm_label} (Confidence: {sarcasm_conf:.2f})"
84
- )
85
- conversation.append(("User", display_input))
86
- conversation.append(("Cyber Bully Bot", response))
87
- return conversation, None, None
88
-
89
- default_conversation = []
90
 
91
- iface = gr.ChatInterface(
92
- fn=chatbot,
93
- title="Cyber Bully Detection System",
94
- description="Upload images or enter text. Bot detects sarcasm first, then classifies hate speech categories.",
95
- )
 
 
 
 
 
 
 
96
 
97
- # Add an image upload component beside text input
98
- iface.add_component(
99
- gr.Image(source="upload", label="Upload Screenshot (optional)", interactive=True),
100
- insert_before=iface.input_components[0]
101
- )
102
 
103
  if __name__ == "__main__":
104
- iface.launch()
 
5
  from PIL import Image
6
  import numpy as np
7
 
8
+ # Models and labels (same as before)
9
+ SARCASM_MODEL = "j-hartmann/emotion-english-distilroberta-base"
10
  sarcasm_labels = ["not sarcastic", "sarcastic"]
11
+ sarcasm_tokenizer = AutoTokenizer.from_pretrained(SARCASM_MODEL)
12
+ sarcasm_model = AutoModelForSequenceClassification.from_pretrained(SARCASM_MODEL)
13
 
14
+ HATE_MODEL = "cardiffnlp/twitter-roberta-base-hate-multiclass-latest"
 
15
  hate_labels = [
16
  "sexism",
17
  "racism",
 
21
  "other",
22
  "not_hate"
23
  ]
24
+ hate_tokenizer = AutoTokenizer.from_pretrained(HATE_MODEL)
25
+ hate_model = AutoModelForSequenceClassification.from_pretrained(HATE_MODEL)
26
 
 
27
  reader = easyocr.Reader(['en'], gpu=False)
28
 
29
  def extract_text(image):
30
+ if image is None:
31
+ return ""
32
  if isinstance(image, Image.Image):
33
  image = np.array(image)
34
+ texts = reader.readtext(image, detail=0)
35
+ return ' '.join(texts)
36
 
37
  def detect_sarcasm(text):
38
  inputs = sarcasm_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
40
  outputs = sarcasm_model(**inputs)
41
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
42
  pred = torch.argmax(probs).item()
43
+ conf = float(probs[0][pred])
44
+ return sarcasm_labels[pred], conf
45
 
46
  def classify_hate(text):
47
  inputs = hate_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
49
  outputs = hate_model(**inputs)
50
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
51
  pred = torch.argmax(probs).item()
52
+ conf = float(probs[0][pred])
53
+ return hate_labels[pred], conf
54
 
55
+ def respond(chat_history, user_text, user_image):
56
+ # Combine OCR and text input
57
+ if user_image is not None:
58
+ extracted = extract_text(user_image)
59
+ if extracted.strip():
60
+ input_text = extracted
61
+ elif user_text.strip():
62
+ input_text = user_text.strip()
63
+ else:
64
+ chat_history.append(("User", ""))
65
+ chat_history.append(("Bot", "Please provide text or an image with text."))
66
+ return chat_history, None, None
67
  else:
68
+ input_text = user_text.strip()
 
69
 
70
+ sarcasm_label, sarcasm_conf = detect_sarcasm(input_text)
 
71
  if sarcasm_label == "sarcastic":
72
+ response_text = f"Sarcasm detected (Confidence: {sarcasm_conf:.2f}). Hate speech detection skipped."
73
+ hate_label = None
74
+ else:
75
+ hate_label, hate_conf = classify_hate(input_text)
76
+ response_text = (
77
+ f"Hate Speech Category: {hate_label} (Confidence: {hate_conf:.2f})\n"
78
+ f"Text analyzed: \"{input_text}\""
79
+ )
80
+ chat_history.append(("User", input_text))
81
+ chat_history.append(("Bot", response_text))
82
+ return chat_history, None, None
 
 
 
 
 
83
 
84
+ with gr.Blocks() as demo:
85
+ gr.Markdown("# Cyber Bully Detection System (Chat Interface)")
86
+
87
+ chat_history = gr.State([])
88
+
89
+ with gr.Row():
90
+ chatbot = gr.Chatbot()
91
+ with gr.Row():
92
+ txt = gr.Textbox(show_label=False, placeholder="Type your message here and press enter")
93
+ img = gr.Image(label="Upload screenshot (optional)", type="pil")
94
+ with gr.Row():
95
+ clear = gr.Button("Clear Chat")
96
 
97
+ txt.submit(respond, [chatbot, txt, img], [chatbot, txt, img])
98
+ img.submit(respond, [chatbot, txt, img], [chatbot, txt, img])
99
+ clear.click(lambda: ([], None, None), None, [chatbot, txt, img])
 
 
100
 
101
  if __name__ == "__main__":
102
+ demo.launch()