import gradio as gr from transformers import TFBertForSequenceClassification, BertTokenizer import tensorflow as tf import praw import os import pytesseract from PIL import Image import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from scipy.special import softmax # Setup for Tesseract (if running in cloud environment like Hugging Face Spaces) os.system("apt-get update && apt-get install -y tesseract-ocr") # Load main BERT model model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert") tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert") LABELS = {0: "Neutral", 1: "Positive", 2: "Negative"} # Load fallback RoBERTa model fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment" fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name) fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name) # Reddit API config reddit = praw.Reddit( client_id=os.getenv("REDDIT_CLIENT_ID", "ul9U7jc8BIHlTAh45jkpkw"), client_secret=os.getenv("REDDIT_CLIENT_SECRET", "TuwIBEKmlb1AptNMRYpuzuNTEabMYg"), user_agent=os.getenv("REDDIT_USER_AGENT", "myscript by u/usman_afzal") ) # Extract text from Reddit URL def fetch_reddit_text(reddit_url): try: submission = reddit.submission(url=reddit_url) return f"{submission.title}\n\n{submission.selftext}" except Exception as e: return f"Error fetching Reddit post: {str(e)}" # OCR from image def extract_text_from_image(image): try: text = pytesseract.image_to_string(image) return text.strip() except Exception as e: return f"Error reading image: {str(e)}" # Fallback model logic def fallback_classifier(text): encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True) with torch.no_grad(): output = fallback_model(**encoded_input) scores = softmax(output.logits.numpy()[0]) labels = ['Negative', 'Neutral', 'Positive'] return f"Fallback Prediction: {labels[scores.argmax()]}" # Final classifier logic def classify_sentiment(text_input, reddit_url, image): # Source detection if reddit_url.strip(): text = fetch_reddit_text(reddit_url) elif image is not None: text = extract_text_from_image(image) elif text_input.strip(): text = text_input else: return "[!] Please provide text input, Reddit URL, or image." if text.lower().startswith("error") or "Unable to extract" in text: return f"[!] Error: {text}" # Classification using main model try: inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True) outputs = model(inputs) probs = tf.nn.softmax(outputs.logits, axis=1) confidence = float(tf.reduce_max(probs).numpy()) pred_label = tf.argmax(probs, axis=1).numpy()[0] if confidence < 0.5: return fallback_classifier(text) return f"Prediction: {LABELS[pred_label]} (Confidence: {confidence:.2f})" except Exception as e: return f"[!] Prediction error: {str(e)}" # Gradio UI demo = gr.Interface( fn=classify_sentiment, inputs=[ gr.Textbox(label="Text Input", placeholder="Paste any content (tweet, comment, etc)...", lines=3), gr.Textbox(label="Reddit Post URL", placeholder="Paste Reddit post URL (optional)"), gr.Image(label="Upload Image (Optional - text image)", type="pil") ], outputs="text", title="🌍 Multilingual Sentiment Analyzer", description="📊 Paste text, Reddit URL, or upload an image (screenshot of tweet etc.) to analyze sentiment.\nSupports fallback model if confidence is low." ) demo.launch()