nit454 commited on
Commit
7ba1745
Β·
verified Β·
1 Parent(s): 141ba65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -40
app.py CHANGED
@@ -20,13 +20,6 @@ HATE_LABELS = [
20
  hate_tokenizer = AutoTokenizer.from_pretrained(HATE_MODEL_NAME)
21
  hate_model = AutoModelForSequenceClassification.from_pretrained(HATE_MODEL_NAME)
22
 
23
- # Sarcasm Detection Model and Labels
24
- SARCASM_MODEL_NAME = "abhishek/sarcasm-detector-distilbert-base-uncased"
25
- SARCASM_LABELS = ["Not Sarcastic", "Sarcastic"]
26
-
27
- sarcasm_tokenizer = AutoTokenizer.from_pretrained(SARCASM_MODEL_NAME)
28
- sarcasm_model = AutoModelForSequenceClassification.from_pretrained(SARCASM_MODEL_NAME)
29
-
30
  reader = easyocr.Reader(['en'], gpu=False)
31
 
32
  def classify_text(text):
@@ -38,15 +31,6 @@ def classify_text(text):
38
  confidence = float(probs[0][pred])
39
  return HATE_LABELS[pred], confidence
40
 
41
- def detect_sarcasm(text):
42
- inputs = sarcasm_tokenizer(text, return_tensors="pt", truncation=True, padding=True)
43
- with torch.no_grad():
44
- outputs = sarcasm_model(**inputs)
45
- probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
46
- pred = torch.argmax(probs).item()
47
- confidence = float(probs[0][pred])
48
- return SARCASM_LABELS[pred], confidence
49
-
50
  def ocr_extract(image):
51
  if isinstance(image, Image.Image):
52
  image = np.array(image)
@@ -57,24 +41,14 @@ def chatbot(image=None, text=None):
57
  if image is not None:
58
  extracted = ocr_extract(image)
59
  if not extracted.strip():
60
- return "No text found in image.", None, None
61
- hate_label, hate_conf = classify_text(extracted)
62
- sarcasm_label, sarcasm_conf = detect_sarcasm(extracted)
63
- return (
64
- f"OCR Extracted: {extracted}\nHate Speech: {hate_label} (Confidence: {hate_conf:.2f})\nSarcasm: {sarcasm_label} (Confidence: {sarcasm_conf:.2f})",
65
- hate_label,
66
- sarcasm_label
67
- )
68
  elif text and text.strip():
69
- hate_label, hate_conf = classify_text(text)
70
- sarcasm_label, sarcasm_conf = detect_sarcasm(text)
71
- return (
72
- f"Text: {text}\nHate Speech: {hate_label} (Confidence: {hate_conf:.2f})\nSarcasm: {sarcasm_label} (Confidence: {sarcasm_conf:.2f})",
73
- hate_label,
74
- sarcasm_label
75
- )
76
  else:
77
- return "Please provide an image or some text.", None, None
78
 
79
  iface = gr.Interface(
80
  fn=chatbot,
@@ -83,16 +57,12 @@ iface = gr.Interface(
83
  gr.Textbox(lines=3, placeholder="Or, type/paste text here")
84
  ],
85
  outputs=[
86
- gr.Textbox(label="Prediction & Sarcasm Detection"),
87
  gr.Label(num_top_classes=len(HATE_LABELS), label="Hate Speech Class"),
88
- gr.Label(num_top_classes=2, label="Sarcasm")
89
  ],
90
- title="Hate Speech & Sarcasm Detection Chatbot",
91
- description="""
92
- Classifies text (or extracted text from image) into hate speech categories and detects sarcasm independently.
93
- Upload an image or enter text below.
94
- """
95
  )
96
 
97
  if __name__ == "__main__":
98
- iface.launch()
 
20
  hate_tokenizer = AutoTokenizer.from_pretrained(HATE_MODEL_NAME)
21
  hate_model = AutoModelForSequenceClassification.from_pretrained(HATE_MODEL_NAME)
22
 
 
 
 
 
 
 
 
23
  reader = easyocr.Reader(['en'], gpu=False)
24
 
25
  def classify_text(text):
 
31
  confidence = float(probs[0][pred])
32
  return HATE_LABELS[pred], confidence
33
 
 
 
 
 
 
 
 
 
 
34
  def ocr_extract(image):
35
  if isinstance(image, Image.Image):
36
  image = np.array(image)
 
41
  if image is not None:
42
  extracted = ocr_extract(image)
43
  if not extracted.strip():
44
+ return "No text found in image.", None
45
+ label, confidence = classify_text(extracted)
46
+ return f"OCR Extracted: {extracted}\nHate Speech: {label} (Confidence: {confidence:.2f})", label
 
 
 
 
 
47
  elif text and text.strip():
48
+ label, confidence = classify_text(text)
49
+ return f"Text: {text}\nHate Speech: {label} (Confidence: {confidence:.2f})", label
 
 
 
 
 
50
  else:
51
+ return "Please provide an image or some text.", None
52
 
53
  iface = gr.Interface(
54
  fn=chatbot,
 
57
  gr.Textbox(lines=3, placeholder="Or, type/paste text here")
58
  ],
59
  outputs=[
60
+ gr.Textbox(label="Prediction"),
61
  gr.Label(num_top_classes=len(HATE_LABELS), label="Hate Speech Class"),
 
62
  ],
63
+ title="Hate Speech Detection Chatbot",
64
+ description="Detects hate speech categories from text or screenshots."
 
 
 
65
  )
66
 
67
  if __name__ == "__main__":
68
+ iface.launch()