nit454 commited on
Commit
9da1727
Β·
verified Β·
1 Parent(s): d3388d4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+ import easyocr
5
+ from PIL import Image
6
+ import numpy as np
7
+
8
+ # RoBERTa Multiclass Model
9
+ MODEL_NAME = "cardiffnlp/twitter-roberta-base-hate-multiclass-latest"
10
+ LABELS = [
11
+ "no hate", # 0
12
+ "racist", # 1
13
+ "religious hate", # 2
14
+ "sexual abuse", # 3
15
+ "sarcastic" # 4
16
+ ]
17
+
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
19
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
20
+ reader = easyocr.Reader(['en'])
21
+
22
+ def classify_text(text):
23
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
24
+ with torch.no_grad():
25
+ outputs = model(**inputs)
26
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
27
+ pred = torch.argmax(probs).item()
28
+ return LABELS[pred], float(probs[0][pred])
29
+
30
+ def ocr_extract(image):
31
+ if isinstance(image, Image.Image):
32
+ image = np.array(image)
33
+ result = reader.readtext(image, detail=0)
34
+ return ' '.join(result)
35
+
36
+ def chatbot(image=None, text=None):
37
+ if image is not None:
38
+ extracted = ocr_extract(image)
39
+ if not extracted.strip():
40
+ return "No text found in image.", None
41
+ hate_class, prob = classify_text(extracted)
42
+ return f"OCR: {extracted}\n\nClass: {hate_class} (Prob: {prob:.2f})", hate_class
43
+ elif text and text.strip():
44
+ hate_class, prob = classify_text(text)
45
+ return f"Text: {text}\nClass: {hate_class} (Prob: {prob:.2f})", hate_class
46
+ else:
47
+ return "Please provide a screenshot or text input.", None
48
+
49
+ iface = gr.Interface(
50
+ fn=chatbot,
51
+ inputs=[
52
+ gr.Image(type="pil", label="Upload Screenshot (optional)"),
53
+ gr.Textbox(lines=2, placeholder="Or, type/paste text here")
54
+ ],
55
+ outputs=[
56
+ gr.Textbox(label="Prediction & OCR"),
57
+ gr.Label(num_top_classes=5)
58
+ ],
59
+ title="Multiclass Hate Speech Detector Chatbot (RoBERTa, with OCR)",
60
+ description="Detects religious hate, sexual abuse, racism, sarcasm or no hate. Upload a screenshot or enter text."
61
+ )
62
+
63
+ if __name__ == "__main__":
64
+ iface.launch()