Usman06's picture
Update app.py
66f83bf verified
import gradio as gr
from transformers import TFBertForSequenceClassification, BertTokenizer
import tensorflow as tf
import praw
import os
import pytesseract
from PIL import Image
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from scipy.special import softmax
# Setup for Tesseract (if running in cloud environment like Hugging Face Spaces)
os.system("apt-get update && apt-get install -y tesseract-ocr")
# Load main BERT model
model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert")
tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert")
LABELS = {0: "Neutral", 1: "Positive", 2: "Negative"}
# Load fallback RoBERTa model
fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment"
fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name)
fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name)
# Reddit API config
reddit = praw.Reddit(
client_id=os.getenv("REDDIT_CLIENT_ID", "ul9U7jc8BIHlTAh45jkpkw"),
client_secret=os.getenv("REDDIT_CLIENT_SECRET", "TuwIBEKmlb1AptNMRYpuzuNTEabMYg"),
user_agent=os.getenv("REDDIT_USER_AGENT", "myscript by u/usman_afzal")
)
# Extract text from Reddit URL
def fetch_reddit_text(reddit_url):
try:
submission = reddit.submission(url=reddit_url)
return f"{submission.title}\n\n{submission.selftext}"
except Exception as e:
return f"Error fetching Reddit post: {str(e)}"
# OCR from image
def extract_text_from_image(image):
try:
text = pytesseract.image_to_string(image)
return text.strip()
except Exception as e:
return f"Error reading image: {str(e)}"
# Fallback model logic
def fallback_classifier(text):
encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True)
with torch.no_grad():
output = fallback_model(**encoded_input)
scores = softmax(output.logits.numpy()[0])
labels = ['Negative', 'Neutral', 'Positive']
return f"Fallback Prediction: {labels[scores.argmax()]}"
# Final classifier logic
def classify_sentiment(text_input, reddit_url, image):
# Source detection
if reddit_url.strip():
text = fetch_reddit_text(reddit_url)
elif image is not None:
text = extract_text_from_image(image)
elif text_input.strip():
text = text_input
else:
return "[!] Please provide text input, Reddit URL, or image."
if text.lower().startswith("error") or "Unable to extract" in text:
return f"[!] Error: {text}"
# Classification using main model
try:
inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True)
outputs = model(inputs)
probs = tf.nn.softmax(outputs.logits, axis=1)
confidence = float(tf.reduce_max(probs).numpy())
pred_label = tf.argmax(probs, axis=1).numpy()[0]
if confidence < 0.5:
return fallback_classifier(text)
return f"Prediction: {LABELS[pred_label]} (Confidence: {confidence:.2f})"
except Exception as e:
return f"[!] Prediction error: {str(e)}"
# Gradio UI
demo = gr.Interface(
fn=classify_sentiment,
inputs=[
gr.Textbox(label="Text Input", placeholder="Paste any content (tweet, comment, etc)...", lines=3),
gr.Textbox(label="Reddit Post URL", placeholder="Paste Reddit post URL (optional)"),
gr.Image(label="Upload Image (Optional - text image)", type="pil")
],
outputs="text",
title="🌍 Multilingual Sentiment Analyzer",
description="πŸ“Š Paste text, Reddit URL, or upload an image (screenshot of tweet etc.) to analyze sentiment.\nSupports fallback model if confidence is low."
)
demo.launch()