nit454 commited on
Commit
453afb5
Β·
verified Β·
1 Parent(s): 9da1727

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -16
app.py CHANGED
@@ -5,14 +5,16 @@ import easyocr
5
  from PIL import Image
6
  import numpy as np
7
 
8
- # RoBERTa Multiclass Model
9
  MODEL_NAME = "cardiffnlp/twitter-roberta-base-hate-multiclass-latest"
10
  LABELS = [
11
- "no hate", # 0
12
- "racist", # 1
13
- "religious hate", # 2
14
- "sexual abuse", # 3
15
- "sarcastic" # 4
 
 
16
  ]
17
 
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
@@ -23,28 +25,32 @@ def classify_text(text):
23
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
 
27
  pred = torch.argmax(probs).item()
28
- return LABELS[pred], float(probs[0][pred])
 
29
 
30
  def ocr_extract(image):
 
31
  if isinstance(image, Image.Image):
32
  image = np.array(image)
33
  result = reader.readtext(image, detail=0)
34
  return ' '.join(result)
35
 
36
  def chatbot(image=None, text=None):
 
37
  if image is not None:
38
  extracted = ocr_extract(image)
39
  if not extracted.strip():
40
  return "No text found in image.", None
41
- hate_class, prob = classify_text(extracted)
42
- return f"OCR: {extracted}\n\nClass: {hate_class} (Prob: {prob:.2f})", hate_class
43
  elif text and text.strip():
44
- hate_class, prob = classify_text(text)
45
- return f"Text: {text}\nClass: {hate_class} (Prob: {prob:.2f})", hate_class
46
  else:
47
- return "Please provide a screenshot or text input.", None
48
 
49
  iface = gr.Interface(
50
  fn=chatbot,
@@ -54,10 +60,10 @@ iface = gr.Interface(
54
  ],
55
  outputs=[
56
  gr.Textbox(label="Prediction & OCR"),
57
- gr.Label(num_top_classes=5)
58
  ],
59
- title="Multiclass Hate Speech Detector Chatbot (RoBERTa, with OCR)",
60
- description="Detects religious hate, sexual abuse, racism, sarcasm or no hate. Upload a screenshot or enter text."
61
  )
62
 
63
  if __name__ == "__main__":
 
5
  from PIL import Image
6
  import numpy as np
7
 
8
+ # Set up model and labels
9
  MODEL_NAME = "cardiffnlp/twitter-roberta-base-hate-multiclass-latest"
10
  LABELS = [
11
+ "sexism", # 0
12
+ "racism", # 1
13
+ "disability", # 2
14
+ "sexual_orientation", # 3
15
+ "religion", # 4
16
+ "other", # 5
17
+ "not_hate" # 6
18
  ]
19
 
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
 
25
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
26
  with torch.no_grad():
27
  outputs = model(**inputs)
28
+ logits = outputs.logits
29
+ probs = torch.nn.functional.softmax(logits, dim=-1)
30
  pred = torch.argmax(probs).item()
31
+ confidence = float(probs[0][pred])
32
+ return LABELS[pred], confidence
33
 
34
  def ocr_extract(image):
35
+ # Convert to numpy if Image
36
  if isinstance(image, Image.Image):
37
  image = np.array(image)
38
  result = reader.readtext(image, detail=0)
39
  return ' '.join(result)
40
 
41
  def chatbot(image=None, text=None):
42
+ # Prioritize image
43
  if image is not None:
44
  extracted = ocr_extract(image)
45
  if not extracted.strip():
46
  return "No text found in image.", None
47
+ label, confidence = classify_text(extracted)
48
+ return f"OCR Extracted: {extracted}\nPrediction: {label} (Confidence: {confidence:.2f})", label
49
  elif text and text.strip():
50
+ label, confidence = classify_text(text)
51
+ return f"Text: {text}\nPrediction: {label} (Confidence: {confidence:.2f})", label
52
  else:
53
+ return "Please provide an image or some text.", None
54
 
55
  iface = gr.Interface(
56
  fn=chatbot,
 
60
  ],
61
  outputs=[
62
  gr.Textbox(label="Prediction & OCR"),
63
+ gr.Label(num_top_classes=7)
64
  ],
65
+ title="RoBERTa Multiclass Hate Speech Classifier (with OCR)",
66
+ description="Detects: sexism, racism, disability, sexual_orientation, religion, other, not_hate. Enter text or upload screenshot."
67
  )
68
 
69
  if __name__ == "__main__":