nit454 commited on
Commit
fb4632b
Β·
verified Β·
1 Parent(s): 275a605

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -19
app.py CHANGED
@@ -5,20 +5,29 @@ import easyocr
5
  from PIL import Image
6
  import numpy as np
7
 
8
- # Set up model and labels
9
  MODEL_NAME = "cardiffnlp/twitter-roberta-base-hate-multiclass-latest"
10
  LABELS = [
11
- "sexism", # 0
12
- "racism", # 1
13
- "disability", # 2
14
- "sexual_orientation", # 3
15
- "religion", # 4
16
- "other", # 5
17
- "not_hate" # 6
 
 
 
18
  ]
19
 
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
22
  reader = easyocr.Reader(['en'])
23
 
24
  def classify_text(text):
@@ -31,39 +40,63 @@ def classify_text(text):
31
  confidence = float(probs[0][pred])
32
  return LABELS[pred], confidence
33
 
 
 
 
 
 
 
 
 
 
34
  def ocr_extract(image):
35
- # Convert to numpy if Image
36
  if isinstance(image, Image.Image):
37
  image = np.array(image)
38
  result = reader.readtext(image, detail=0)
39
  return ' '.join(result)
40
 
41
  def chatbot(image=None, text=None):
42
- # Prioritize image
43
  if image is not None:
44
  extracted = ocr_extract(image)
45
  if not extracted.strip():
46
- return "No text found in image.", None
47
  label, confidence = classify_text(extracted)
48
- return f"OCR Extracted: {extracted}\nPrediction: {label} (Confidence: {confidence:.2f})", label
 
 
 
 
 
 
49
  elif text and text.strip():
50
  label, confidence = classify_text(text)
51
- return f"Text: {text}\nPrediction: {label} (Confidence: {confidence:.2f})", label
 
 
 
 
 
 
52
  else:
53
- return "Please provide an image or some text.", None
54
 
55
  iface = gr.Interface(
56
  fn=chatbot,
57
  inputs=[
58
  gr.Image(type="pil", label="Upload Screenshot (optional)"),
59
- gr.Textbox(lines=2, placeholder="Or, type/paste text here")
60
  ],
61
  outputs=[
62
- gr.Textbox(label="Prediction & OCR"),
63
- gr.Label(num_top_classes=7)
 
64
  ],
65
- title="Multiclass Hate Speech Classifier (with OCR)",
66
- description="Detects: sexism, racism, disability, sexual_orientation, religion, other, not_hate. Enter text or upload screenshot."
 
 
 
67
  )
68
 
69
  if __name__ == "__main__":
 
5
  from PIL import Image
6
  import numpy as np
7
 
8
+ # Hate Speech model (example uses base CardiffNLP + extended labels for demonstration)
9
  MODEL_NAME = "cardiffnlp/twitter-roberta-base-hate-multiclass-latest"
10
  LABELS = [
11
+ "sexism",
12
+ "racism",
13
+ "disability",
14
+ "sexual_orientation",
15
+ "religion",
16
+ "abusive_words", # added label - simulation only
17
+ "threat", # added label - simulation only
18
+ "harassment", # added label - simulation only
19
+ "sarcastic", # added label - simulation only; we'll do actual sarcasm detection via separate model
20
+ "not_hate"
21
  ]
22
 
23
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
24
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
25
+
26
+ # Sarcasm Detection model (example pretrained; replace with your actual sarcasm model)
27
+ SARCASM_MODEL_NAME = "microsoft/deberta-base-sarcasm" # example, replace if unavailable
28
+ sarcasm_tokenizer = AutoTokenizer.from_pretrained(SARCASM_MODEL_NAME)
29
+ sarcasm_model = AutoModelForSequenceClassification.from_pretrained(SARCASM_MODEL_NAME)
30
+
31
  reader = easyocr.Reader(['en'])
32
 
33
  def classify_text(text):
 
40
  confidence = float(probs[0][pred])
41
  return LABELS[pred], confidence
42
 
43
+ def is_sarcastic(text):
44
+ inputs = sarcasm_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
45
+ with torch.no_grad():
46
+ outputs = sarcasm_model(**inputs)
47
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
48
+ # assuming label 1 means sarcastic; adjust if needed
49
+ sarcasm_prob = probs[0][1].item()
50
+ return sarcasm_prob > 0.5, sarcasm_prob
51
+
52
  def ocr_extract(image):
 
53
  if isinstance(image, Image.Image):
54
  image = np.array(image)
55
  result = reader.readtext(image, detail=0)
56
  return ' '.join(result)
57
 
58
  def chatbot(image=None, text=None):
59
+ # Priority: image with OCR, else text box
60
  if image is not None:
61
  extracted = ocr_extract(image)
62
  if not extracted.strip():
63
+ return "No text found in image.", None, None
64
  label, confidence = classify_text(extracted)
65
+ sarcastic, sarcasm_prob = is_sarcastic(extracted)
66
+ sarcasm_text = "Yes" if sarcastic else "No"
67
+ return (
68
+ f"OCR Extracted: {extracted}\nPrediction: {label} (Confidence: {confidence:.2f})\nSarcasm: {sarcasm_text} (Prob: {sarcasm_prob:.2f})",
69
+ label,
70
+ sarcasm_text
71
+ )
72
  elif text and text.strip():
73
  label, confidence = classify_text(text)
74
+ sarcastic, sarcasm_prob = is_sarcastic(text)
75
+ sarcasm_text = "Yes" if sarcastic else "No"
76
+ return (
77
+ f"Text: {text}\nPrediction: {label} (Confidence: {confidence:.2f})\nSarcasm: {sarcasm_text} (Prob: {sarcasm_prob:.2f})",
78
+ label,
79
+ sarcasm_text
80
+ )
81
  else:
82
+ return "Please provide an image or some text.", None, None
83
 
84
  iface = gr.Interface(
85
  fn=chatbot,
86
  inputs=[
87
  gr.Image(type="pil", label="Upload Screenshot (optional)"),
88
+ gr.Textbox(lines=3, placeholder="Or, type/paste text here")
89
  ],
90
  outputs=[
91
+ gr.Textbox(label="Prediction & Sarcasm Detection"),
92
+ gr.Label(num_top_classes=len(LABELS), label="Hate Speech Class"),
93
+ gr.Label(num_top_classes=2, label="Sarcasm")
94
  ],
95
+ title="Multiclass Hate Speech + Sarcasm Detection Chatbot",
96
+ description="""
97
+ Classifies text (or text extracted from image) into hate speech categories including abusive words,
98
+ threat, harassment, and detects sarcasm separately. Enter text or upload an image screenshot.
99
+ """
100
  )
101
 
102
  if __name__ == "__main__":