nit454 commited on
Commit
f2f0fd9
Β·
verified Β·
1 Parent(s): abb68e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -32
app.py CHANGED
@@ -5,49 +5,47 @@ import easyocr
5
  from PIL import Image
6
  import numpy as np
7
 
8
- # Hate Speech model (example uses base CardiffNLP + extended labels for demonstration)
9
- MODEL_NAME = "cardiffnlp/twitter-roberta-base-hate-multiclass-latest"
10
- LABELS = [
11
  "sexism",
12
  "racism",
13
  "disability",
14
  "sexual_orientation",
15
  "religion",
16
- "abusive_words", # added label - simulation only
17
- "threat", # added label - simulation only
18
- "harassment", # added label - simulation only
19
- "sarcastic", # added label - simulation only; we'll do actual sarcasm detection via separate model
20
  "not_hate"
21
  ]
22
 
23
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
25
 
26
- # Sarcasm Detection model (example pretrained; replace with your actual sarcasm model)
27
- SARCASM_MODEL_NAME = "microsoft/deberta-base-sarcasm" # example, replace if unavailable
 
28
  sarcasm_tokenizer = AutoTokenizer.from_pretrained(SARCASM_MODEL_NAME)
29
  sarcasm_model = AutoModelForSequenceClassification.from_pretrained(SARCASM_MODEL_NAME)
30
 
31
- reader = easyocr.Reader(['en'])
32
 
33
  def classify_text(text):
34
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
35
  with torch.no_grad():
36
- outputs = model(**inputs)
37
  logits = outputs.logits
38
  probs = torch.nn.functional.softmax(logits, dim=-1)
39
  pred = torch.argmax(probs).item()
40
  confidence = float(probs[0][pred])
41
- return LABELS[pred], confidence
42
 
43
  def is_sarcastic(text):
44
  inputs = sarcasm_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
45
  with torch.no_grad():
46
  outputs = sarcasm_model(**inputs)
47
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
48
- # assuming label 1 means sarcastic; adjust if needed
49
- sarcasm_prob = probs[0][1].item()
50
- return sarcasm_prob > 0.5, sarcasm_prob
51
 
52
  def ocr_extract(image):
53
  if isinstance(image, Image.Image):
@@ -56,27 +54,24 @@ def ocr_extract(image):
56
  return ' '.join(result)
57
 
58
  def chatbot(image=None, text=None):
59
- # Priority: image with OCR, else text box
60
  if image is not None:
61
  extracted = ocr_extract(image)
62
  if not extracted.strip():
63
  return "No text found in image.", None, None
64
  label, confidence = classify_text(extracted)
65
- sarcastic, sarcasm_prob = is_sarcastic(extracted)
66
- sarcasm_text = "Yes" if sarcastic else "No"
67
  return (
68
- f"OCR Extracted: {extracted}\nPrediction: {label} (Confidence: {confidence:.2f})\nSarcasm: {sarcasm_text} (Prob: {sarcasm_prob:.2f})",
69
  label,
70
- sarcasm_text
71
  )
72
  elif text and text.strip():
73
  label, confidence = classify_text(text)
74
- sarcastic, sarcasm_prob = is_sarcastic(text)
75
- sarcasm_text = "Yes" if sarcastic else "No"
76
  return (
77
- f"Text: {text}\nPrediction: {label} (Confidence: {confidence:.2f})\nSarcasm: {sarcasm_text} (Prob: {sarcasm_prob:.2f})",
78
  label,
79
- sarcasm_text
80
  )
81
  else:
82
  return "Please provide an image or some text.", None, None
@@ -89,14 +84,13 @@ iface = gr.Interface(
89
  ],
90
  outputs=[
91
  gr.Textbox(label="Prediction & Sarcasm Detection"),
92
- gr.Label(num_top_classes=len(LABELS), label="Hate Speech Class"),
93
  gr.Label(num_top_classes=2, label="Sarcasm")
94
  ],
95
- title="Cyberbully detection system Chatbot",
96
  description="""
97
- Classifies text (or text extracted from image) into hate speech categories including abusive words,
98
- threat, harassment, and detects sarcasm separately. Enter text or upload an image screenshot.
99
- """
100
  )
101
 
102
  if __name__ == "__main__":
 
5
  from PIL import Image
6
  import numpy as np
7
 
8
+ # Hate Speech Model: Only 7 valid categories!
9
+ HATE_MODEL_NAME = "cardiffnlp/twitter-roberta-base-hate-multiclass-latest"
10
+ HATE_LABELS = [
11
  "sexism",
12
  "racism",
13
  "disability",
14
  "sexual_orientation",
15
  "religion",
16
+ "other",
 
 
 
17
  "not_hate"
18
  ]
19
 
20
+ hate_tokenizer = AutoTokenizer.from_pretrained(HATE_MODEL_NAME)
21
+ hate_model = AutoModelForSequenceClassification.from_pretrained(HATE_MODEL_NAME)
22
 
23
+ # Sarcasm Model: REAL public model
24
+ SARCASM_MODEL_NAME = "mrm8488/bert-tiny-finetuned-sarcasm-detection"
25
+ SARCASM_LABELS = ["Not Sarcastic", "Sarcastic"]
26
  sarcasm_tokenizer = AutoTokenizer.from_pretrained(SARCASM_MODEL_NAME)
27
  sarcasm_model = AutoModelForSequenceClassification.from_pretrained(SARCASM_MODEL_NAME)
28
 
29
+ reader = easyocr.Reader(['en'], gpu=False)
30
 
31
  def classify_text(text):
32
+ inputs = hate_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
33
  with torch.no_grad():
34
+ outputs = hate_model(**inputs)
35
  logits = outputs.logits
36
  probs = torch.nn.functional.softmax(logits, dim=-1)
37
  pred = torch.argmax(probs).item()
38
  confidence = float(probs[0][pred])
39
+ return HATE_LABELS[pred], confidence
40
 
41
  def is_sarcastic(text):
42
  inputs = sarcasm_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
43
  with torch.no_grad():
44
  outputs = sarcasm_model(**inputs)
45
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
46
+ pred = torch.argmax(probs).item()
47
+ sarcasm_prob = float(probs[0][pred])
48
+ return SARCASM_LABELS[pred], sarcasm_prob
49
 
50
  def ocr_extract(image):
51
  if isinstance(image, Image.Image):
 
54
  return ' '.join(result)
55
 
56
  def chatbot(image=None, text=None):
 
57
  if image is not None:
58
  extracted = ocr_extract(image)
59
  if not extracted.strip():
60
  return "No text found in image.", None, None
61
  label, confidence = classify_text(extracted)
62
+ sarcasm, sarcasm_prob = is_sarcastic(extracted)
 
63
  return (
64
+ f"OCR Extracted: {extracted}\nPrediction: {label} (Confidence: {confidence:.2f})\nSarcasm: {sarcasm} (Prob: {sarcasm_prob:.2f})",
65
  label,
66
+ sarcasm
67
  )
68
  elif text and text.strip():
69
  label, confidence = classify_text(text)
70
+ sarcasm, sarcasm_prob = is_sarcastic(text)
 
71
  return (
72
+ f"Text: {text}\nPrediction: {label} (Confidence: {confidence:.2f})\nSarcasm: {sarcasm} (Prob: {sarcasm_prob:.2f})",
73
  label,
74
+ sarcasm
75
  )
76
  else:
77
  return "Please provide an image or some text.", None, None
 
84
  ],
85
  outputs=[
86
  gr.Textbox(label="Prediction & Sarcasm Detection"),
87
+ gr.Label(num_top_classes=len(HATE_LABELS), label="Hate Speech Class"),
88
  gr.Label(num_top_classes=2, label="Sarcasm")
89
  ],
90
+ title="Cyberbully Detection System Chatbot",
91
  description="""
92
+ Classifies text (or text extracted from image) into hate speech categories (sexism, racism, disability, sexual_orientation, religion, other, not_hate) and detects sarcasm (separately). Enter text or upload an image screenshot.
93
+ """
 
94
  )
95
 
96
  if __name__ == "__main__":