nit454 commited on
Commit
563d018
·
verified ·
1 Parent(s): beff589

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -19
app.py CHANGED
@@ -1,21 +1,25 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
 
 
 
4
 
5
- # Using Microsoft DeBERTa v3 base model (general-purpose, fine-tune recommended)
6
- MODEL_NAME = "microsoft/deberta-v3-base"
7
  LABELS = [
8
- "sexism",
9
- "racism",
10
- "disability",
11
- "sexual_orientation",
12
- "religion",
13
- "other",
14
- "not_hate"
15
  ]
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
18
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=len(LABELS))
 
19
 
20
  def classify_text(text):
21
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
@@ -27,18 +31,39 @@ def classify_text(text):
27
  confidence = float(probs[0][pred])
28
  return LABELS[pred], confidence
29
 
30
- def chatbot(text):
31
- if not text or not text.strip():
32
- return "Please enter some text."
33
- label, confidence = classify_text(text)
34
- return f"Prediction: {label} (Confidence: {confidence:.2f})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  iface = gr.Interface(
37
  fn=chatbot,
38
- inputs=gr.Textbox(lines=3, placeholder="Enter text for hate speech classification"),
39
- outputs="text",
40
- title="DeBERTa Hate Speech Classifier",
41
- description="Classifies text into hate speech categories with DeBERTa v3-base model."
 
 
 
 
 
 
42
  )
43
 
44
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
4
+ import easyocr
5
+ from PIL import Image
6
+ import numpy as np
7
 
8
+ # Set up model and labels
9
+ MODEL_NAME = "cardiffnlp/twitter-roberta-base-hate-multiclass-latest"
10
  LABELS = [
11
+ "sexism", # 0
12
+ "racism", # 1
13
+ "disability", # 2
14
+ "sexual_orientation", # 3
15
+ "religion", # 4
16
+ "other", # 5
17
+ "not_hate" # 6
18
  ]
19
 
20
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
22
+ reader = easyocr.Reader(['en'])
23
 
24
  def classify_text(text):
25
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
 
31
  confidence = float(probs[0][pred])
32
  return LABELS[pred], confidence
33
 
34
+ def ocr_extract(image):
35
+ # Convert to numpy if Image
36
+ if isinstance(image, Image.Image):
37
+ image = np.array(image)
38
+ result = reader.readtext(image, detail=0)
39
+ return ' '.join(result)
40
+
41
+ def chatbot(image=None, text=None):
42
+ # Prioritize image
43
+ if image is not None:
44
+ extracted = ocr_extract(image)
45
+ if not extracted.strip():
46
+ return "No text found in image.", None
47
+ label, confidence = classify_text(extracted)
48
+ return f"OCR Extracted: {extracted}\nPrediction: {label} (Confidence: {confidence:.2f})", label
49
+ elif text and text.strip():
50
+ label, confidence = classify_text(text)
51
+ return f"Text: {text}\nPrediction: {label} (Confidence: {confidence:.2f})", label
52
+ else:
53
+ return "Please provide an image or some text.", None
54
 
55
  iface = gr.Interface(
56
  fn=chatbot,
57
+ inputs=[
58
+ gr.Image(type="pil", label="Upload Screenshot (optional)"),
59
+ gr.Textbox(lines=2, placeholder="Or, type/paste text here")
60
+ ],
61
+ outputs=[
62
+ gr.Textbox(label="Prediction & OCR"),
63
+ gr.Label(num_top_classes=7)
64
+ ],
65
+ title="RoBERTa Multiclass Hate Speech Classifier (with OCR)",
66
+ description="Detects: sexism, racism, disability, sexual_orientation, religion, other, not_hate. Enter text or upload screenshot."
67
  )
68
 
69
  if __name__ == "__main__":