|
|
import gradio as gr |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
import easyocr |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
SARCASM_MODEL_NAME = "j-hartmann/emotion-english-distilroberta-base" |
|
|
sarcasm_labels = ["not sarcastic", "sarcastic"] |
|
|
sarcasm_tokenizer = AutoTokenizer.from_pretrained(SARCASM_MODEL_NAME) |
|
|
sarcasm_model = AutoModelForSequenceClassification.from_pretrained(SARCASM_MODEL_NAME) |
|
|
|
|
|
|
|
|
HATE_MODEL_NAME = "cardiffnlp/twitter-roberta-base-hate-multiclass-latest" |
|
|
hate_labels = [ |
|
|
"sexism", |
|
|
"racism", |
|
|
"disability", |
|
|
"sexual_orientation", |
|
|
"religion", |
|
|
"other", |
|
|
"not_hate" |
|
|
] |
|
|
hate_tokenizer = AutoTokenizer.from_pretrained(HATE_MODEL_NAME) |
|
|
hate_model = AutoModelForSequenceClassification.from_pretrained(HATE_MODEL_NAME) |
|
|
|
|
|
|
|
|
reader = easyocr.Reader(['en'], gpu=False) |
|
|
|
|
|
def extract_text(image): |
|
|
if isinstance(image, Image.Image): |
|
|
image = np.array(image) |
|
|
texts = reader.readtext(image, detail=0) |
|
|
return ' '.join(texts) |
|
|
|
|
|
def detect_sarcasm(text): |
|
|
inputs = sarcasm_tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
with torch.no_grad(): |
|
|
outputs = sarcasm_model(**inputs) |
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
pred = torch.argmax(probs).item() |
|
|
conf = float(probs[0][pred]) |
|
|
return sarcasm_labels[pred], conf |
|
|
|
|
|
def classify_hate(text): |
|
|
inputs = hate_tokenizer(text, return_tensors="pt", truncation=True, padding=True) |
|
|
with torch.no_grad(): |
|
|
outputs = hate_model(**inputs) |
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
pred = torch.argmax(probs).item() |
|
|
conf = float(probs[0][pred]) |
|
|
return hate_labels[pred], conf |
|
|
|
|
|
def respond(chat_history, user_text, user_image): |
|
|
if user_image is not None: |
|
|
extracted_text = extract_text(user_image) |
|
|
if extracted_text.strip(): |
|
|
text_to_analyze = extracted_text |
|
|
elif user_text and user_text.strip(): |
|
|
text_to_analyze = user_text.strip() |
|
|
else: |
|
|
chat_history.append(("User", "")) |
|
|
chat_history.append(("Bot", "Please provide text or an image with readable text.")) |
|
|
return chat_history, None, None |
|
|
else: |
|
|
text_to_analyze = user_text.strip() |
|
|
|
|
|
sarcasm_label, sarcasm_conf = detect_sarcasm(text_to_analyze) |
|
|
if sarcasm_label == "sarcastic": |
|
|
bot_response = f"Sarcasm detected (Confidence: {sarcasm_conf:.2f}). Hate speech detection skipped." |
|
|
else: |
|
|
hate_label, hate_conf = classify_hate(text_to_analyze) |
|
|
bot_response = ( |
|
|
f"Hate Speech Category: {hate_label} (Confidence: {hate_conf:.2f})\n" |
|
|
f"Message: \"{text_to_analyze}\"" |
|
|
) |
|
|
chat_history.append(("User", text_to_analyze)) |
|
|
chat_history.append(("Bot", bot_response)) |
|
|
return chat_history, None, None |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Cyber Bully Detection System") |
|
|
|
|
|
chat_history = gr.State([]) |
|
|
|
|
|
chatbot = gr.Chatbot() |
|
|
txt = gr.Textbox(show_label=False, placeholder="Type your message here and press Enter") |
|
|
img = gr.Image(source="upload", type="pil", label="Upload Screenshot (optional)") |
|
|
clear_btn = gr.Button("Clear Chat") |
|
|
|
|
|
txt.submit(respond, [chatbot, txt, img], [chatbot, txt, img]) |
|
|
|
|
|
submit_img_btn = gr.Button("Analyze Image") |
|
|
submit_img_btn.click(respond, [chatbot, txt, img], [chatbot, txt, img]) |
|
|
clear_btn.click(lambda: ([], None, None), None, [chatbot, txt, img]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|