CBDS_Basic / app.py
nit454's picture
Update app.py
4276495 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import easyocr
from PIL import Image
import numpy as np
# Set up model and labels
MODEL_NAME = "cardiffnlp/twitter-roberta-base-hate-multiclass-latest"
LABELS = [
"sexism", # 0
"racism", # 1
"disability", # 2
"sexual_orientation", # 3
"religion", # 4
"other", # 5
"not_hate" # 6
]
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
reader = easyocr.Reader(['en'])
def classify_text(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.nn.functional.softmax(logits, dim=-1)
pred = torch.argmax(probs).item()
confidence = float(probs[0][pred])
return LABELS[pred], confidence
def ocr_extract(image):
# Convert to numpy if Image
if isinstance(image, Image.Image):
image = np.array(image)
result = reader.readtext(image, detail=0)
return ' '.join(result)
def chatbot(image=None, text=None):
# Prioritize image
if image is not None:
extracted = ocr_extract(image)
if not extracted.strip():
return "No text found in image.", None
label, confidence = classify_text(extracted)
return f"OCR Extracted: {extracted}\nPrediction: {label} (Confidence: {confidence:.2f})", label
elif text and text.strip():
label, confidence = classify_text(text)
return f"Text: {text}\nPrediction: {label} (Confidence: {confidence:.2f})", label
else:
return "Please provide an image or some text.", None
iface = gr.Interface(
fn=chatbot,
inputs=[
gr.Image(type="pil", label="Upload Screenshot (optional)"),
gr.Textbox(lines=2, placeholder="Or, type/paste text here")
],
outputs=[
gr.Textbox(label="Prediction & OCR"),
gr.Label(num_top_classes=7)
],
title="Cyberbyully Detection System (with OCR)",
description="Detects: sexism, racism, religion, other, not_hate. Enter text or upload screenshot."
)
if __name__ == "__main__":
iface.launch()