sms-spam-classifier / app_original.py
AndriusKei's picture
Rename app.py to app_original.py
a78b8b4 verified
# app.py
## Quick Start w UI input + OCR image function as fifth iteration
import gradio as gr
from transformers import pipeline
from PIL import Image
import pytesseract
# Your model pipeline
classifier = pipeline("text-classification", model="didulantha/sms-spam-detector")
custom_label_map = {
'LABEL_1': 'Possible Scam',
'LABEL_0': 'Harmless Perhaps',
}
def classify_image_or_text(input_text, input_image=None):
try:
# If image is provided, extract text using pytesseract
if input_image is not None:
import pytesseract
from PIL import Image
text = pytesseract.image_to_string(Image.open(input_image))
if not text.strip():
return "No text detected in image. Please provide a clearer image."
else:
text = input_text
if not text.strip():
return "No text provided. Please enter text or upload an image with text."
# Pass extracted text to classifier
output = classifier(text)
label = output[0]['label']
# Custom label mapping
if label == "LABEL_1":
return "Possible Scam"
elif label == "LABEL_0":
return "Harmless Perhaps"
else:
return f"Unexpected model label: {label}"
except Exception:
# Return more informative error
return ("Processing error: Something went wrong. "
"Please check if the image quality is good and the text is readable.")
iface = gr.Interface(
fn=classify_image_or_text,
inputs=[
gr.Textbox(label="SMS Message (or leave blank if uploading image)"),
gr.Image(label="Image with Text (optional)", type="pil"),
],
outputs=gr.Label(label="Prediction"),
title="SMS Spam Detector",
description="Paste SMS or upload image of SMS. Prioritizes text if both."
)
iface.launch()
## Detailed Usage
## from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
## import torch
## tokenizer = DistilBertTokenizer.from_pretrained("didulantha/sms-spam-detector")
## model = DistilBertForSequenceClassification.from_pretrained("didulantha/sms-spam-detector")
## def predict(text):
## inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128)
## outputs = model(**inputs)
## probs = torch.softmax(outputs.logits, dim=1)
## return "SPAM" if probs[0][1] > 0.5 else "HAM"
## print(predict("Free prize!"))