Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import TFBertForSequenceClassification, BertTokenizer | |
| import tensorflow as tf | |
| import praw | |
| import os | |
| import pytesseract | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from scipy.special import softmax | |
| # Setup for Tesseract (if running in cloud environment like Hugging Face Spaces) | |
| os.system("apt-get update && apt-get install -y tesseract-ocr") | |
| # Load main BERT model | |
| model = TFBertForSequenceClassification.from_pretrained("shrish191/sentiment-bert") | |
| tokenizer = BertTokenizer.from_pretrained("shrish191/sentiment-bert") | |
| LABELS = {0: "Neutral", 1: "Positive", 2: "Negative"} | |
| # Load fallback RoBERTa model | |
| fallback_model_name = "cardiffnlp/twitter-roberta-base-sentiment" | |
| fallback_tokenizer = AutoTokenizer.from_pretrained(fallback_model_name) | |
| fallback_model = AutoModelForSequenceClassification.from_pretrained(fallback_model_name) | |
| # Reddit API config | |
| reddit = praw.Reddit( | |
| client_id=os.getenv("REDDIT_CLIENT_ID", "ul9U7jc8BIHlTAh45jkpkw"), | |
| client_secret=os.getenv("REDDIT_CLIENT_SECRET", "TuwIBEKmlb1AptNMRYpuzuNTEabMYg"), | |
| user_agent=os.getenv("REDDIT_USER_AGENT", "myscript by u/usman_afzal") | |
| ) | |
| # Extract text from Reddit URL | |
| def fetch_reddit_text(reddit_url): | |
| try: | |
| submission = reddit.submission(url=reddit_url) | |
| return f"{submission.title}\n\n{submission.selftext}" | |
| except Exception as e: | |
| return f"Error fetching Reddit post: {str(e)}" | |
| # OCR from image | |
| def extract_text_from_image(image): | |
| try: | |
| text = pytesseract.image_to_string(image) | |
| return text.strip() | |
| except Exception as e: | |
| return f"Error reading image: {str(e)}" | |
| # Fallback model logic | |
| def fallback_classifier(text): | |
| encoded_input = fallback_tokenizer(text, return_tensors='pt', truncation=True, padding=True) | |
| with torch.no_grad(): | |
| output = fallback_model(**encoded_input) | |
| scores = softmax(output.logits.numpy()[0]) | |
| labels = ['Negative', 'Neutral', 'Positive'] | |
| return f"Fallback Prediction: {labels[scores.argmax()]}" | |
| # Final classifier logic | |
| def classify_sentiment(text_input, reddit_url, image): | |
| # Source detection | |
| if reddit_url.strip(): | |
| text = fetch_reddit_text(reddit_url) | |
| elif image is not None: | |
| text = extract_text_from_image(image) | |
| elif text_input.strip(): | |
| text = text_input | |
| else: | |
| return "[!] Please provide text input, Reddit URL, or image." | |
| if text.lower().startswith("error") or "Unable to extract" in text: | |
| return f"[!] Error: {text}" | |
| # Classification using main model | |
| try: | |
| inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True) | |
| outputs = model(inputs) | |
| probs = tf.nn.softmax(outputs.logits, axis=1) | |
| confidence = float(tf.reduce_max(probs).numpy()) | |
| pred_label = tf.argmax(probs, axis=1).numpy()[0] | |
| if confidence < 0.5: | |
| return fallback_classifier(text) | |
| return f"Prediction: {LABELS[pred_label]} (Confidence: {confidence:.2f})" | |
| except Exception as e: | |
| return f"[!] Prediction error: {str(e)}" | |
| # Gradio UI | |
| demo = gr.Interface( | |
| fn=classify_sentiment, | |
| inputs=[ | |
| gr.Textbox(label="Text Input", placeholder="Paste any content (tweet, comment, etc)...", lines=3), | |
| gr.Textbox(label="Reddit Post URL", placeholder="Paste Reddit post URL (optional)"), | |
| gr.Image(label="Upload Image (Optional - text image)", type="pil") | |
| ], | |
| outputs="text", | |
| title="π Multilingual Sentiment Analyzer", | |
| description="π Paste text, Reddit URL, or upload an image (screenshot of tweet etc.) to analyze sentiment.\nSupports fallback model if confidence is low." | |
| ) | |
| demo.launch() | |