Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import torch | |
| import easyocr | |
| from PIL import Image | |
| import numpy as np | |
| # Set up model and labels | |
| MODEL_NAME = "cardiffnlp/twitter-roberta-base-hate-multiclass-latest" | |
| LABELS = [ | |
| "sexism", # 0 | |
| "racism", # 1 | |
| "disability", # 2 | |
| "sexual_orientation", # 3 | |
| "religion", # 4 | |
| "other", # 5 | |
| "not_hate" # 6 | |
| ] | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
| reader = easyocr.Reader(['en']) | |
| def classify_text(text): | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| probs = torch.nn.functional.softmax(logits, dim=-1) | |
| pred = torch.argmax(probs).item() | |
| confidence = float(probs[0][pred]) | |
| return LABELS[pred], confidence | |
| def ocr_extract(image): | |
| # Convert to numpy if Image | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| result = reader.readtext(image, detail=0) | |
| return ' '.join(result) | |
| def chatbot(image=None, text=None): | |
| # Prioritize image | |
| if image is not None: | |
| extracted = ocr_extract(image) | |
| if not extracted.strip(): | |
| return "No text found in image.", None | |
| label, confidence = classify_text(extracted) | |
| return f"OCR Extracted: {extracted}\nPrediction: {label} (Confidence: {confidence:.2f})", label | |
| elif text and text.strip(): | |
| label, confidence = classify_text(text) | |
| return f"Text: {text}\nPrediction: {label} (Confidence: {confidence:.2f})", label | |
| else: | |
| return "Please provide an image or some text.", None | |
| iface = gr.Interface( | |
| fn=chatbot, | |
| inputs=[ | |
| gr.Image(type="pil", label="Upload Screenshot (optional)"), | |
| gr.Textbox(lines=2, placeholder="Or, type/paste text here") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Prediction & OCR"), | |
| gr.Label(num_top_classes=7) | |
| ], | |
| title="Cyberbyully Detection System (with OCR)", | |
| description="Detects: sexism, racism, religion, other, not_hate. Enter text or upload screenshot." | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch() | |