Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import torch | |
| import emoji | |
| import re | |
| import numpy as np | |
| from collections import Counter | |
| from instagrapi import Client | |
| from transformers import ( | |
| pipeline, | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| Trainer, | |
| TrainingArguments, | |
| RobertaForSequenceClassification, | |
| AlbertForSequenceClassification | |
| ) | |
| from datasets import Dataset, Features, Value | |
| from sklearn.metrics import accuracy_score, f1_score | |
| # Configuration | |
| CONFIG = { | |
| "max_length": 128, | |
| "batch_size": 16, | |
| "learning_rate": 2e-5, | |
| "num_train_epochs": 3, | |
| "few_shot_examples": 5, | |
| "confidence_threshold": 0.7, | |
| "neutral_reanalysis_threshold": 0.33 | |
| } | |
| # Global state | |
| cl = None | |
| explore_reels_list = [] | |
| sentiment_analyzer = None | |
| content_classifier = None | |
| # Content categories | |
| CONTENT_CATEGORIES = [ | |
| "news", "meme", "sports", "science", "music", "movie", | |
| "gym", "comedy", "food", "technology", "travel", "fashion", "art", "business" | |
| ] | |
| CATEGORY_KEYWORDS = { | |
| "news": {"news", "update", "breaking", "reported", "headlines"}, | |
| "meme": {"meme", "funny", "lol", "haha", "relatable"}, | |
| "sports": {"sports", "cricket", "football", "match", "game", "team", "score"}, | |
| "science": {"science", "research", "discovery", "experiment", "facts", "theory"}, | |
| "music": {"music", "song", "album", "release", "artist", "beats"}, | |
| "movie": {"movie", "film", "bollywood", "trailer", "series", "actor"}, | |
| "gym": {"gym", "workout", "fitness", "exercise", "training", "bodybuilding"}, | |
| "comedy": {"comedy", "joke", "humor", "standup", "skit", "laugh"}, | |
| "food": {"food", "recipe", "cooking", "eat", "delicious", "restaurant", "kitchen"}, | |
| "technology": {"tech", "phone", "computer", "ai", "gadget", "software", "innovation"}, | |
| "travel": {"travel", "trip", "vacation", "explore", "destination", "adventure"}, | |
| "fashion": {"fashion", "style", "ootd", "outfit", "trends", "clothing"}, | |
| "art": {"art", "artist", "painting", "drawing", "creative", "design"}, | |
| "business": {"business", "startup", "marketing", "money", "finance", "entrepreneur"} | |
| } | |
| class ReelSentimentAnalyzer: | |
| def __init__(self): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self._initialize_models() | |
| self._setup_emotion_mappings() | |
| def _initialize_models(self): | |
| print("Loading sentiment analysis models...") | |
| # English models | |
| self.emotion_tokenizer = AutoTokenizer.from_pretrained("finiteautomata/bertweet-base-emotion-analysis") | |
| self.emotion_model = AutoModelForSequenceClassification.from_pretrained( | |
| "finiteautomata/bertweet-base-emotion-analysis" | |
| ).to(self.device) | |
| self.sentiment_tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest") | |
| self.sentiment_model = RobertaForSequenceClassification.from_pretrained( | |
| "cardiffnlp/twitter-roberta-base-sentiment-latest", | |
| ignore_mismatched_sizes=True | |
| ).to(self.device) | |
| # Hindi/English model | |
| self.hindi_tokenizer = AutoTokenizer.from_pretrained("ai4bharat/indic-bert") | |
| self.hindi_model = AlbertForSequenceClassification.from_pretrained( | |
| "ai4bharat/indic-bert", | |
| num_labels=3, | |
| id2label={0: "negative", 1: "neutral", 2: "positive"}, | |
| label2id={"negative": 0, "neutral": 1, "positive": 2} | |
| ).to(self.device) | |
| self.hindi_label2id = self.hindi_model.config.label2id | |
| def _setup_emotion_mappings(self): | |
| self.emotion_map = { | |
| "joy": "positive", "love": "positive", "happy": "positive", | |
| "anger": "negative", "sadness": "negative", "fear": "negative", | |
| "surprise": "neutral", "neutral": "neutral", "disgust": "negative", "shame": "negative" | |
| } | |
| self.neutral_keywords = { | |
| "ad", "sponsored", "promo", "sale", "discount", "offer", "giveaway", | |
| "buy", "shop", "link in bio", | |
| "विज्ञापन", "प्रचार", "ऑफर", "डिस्काउंट", "बिक्री", "लिंक बायो में" | |
| } | |
| def train_hindi_model(self, train_data, eval_data=None): | |
| print("Fine-tuning Hindi sentiment model...") | |
| train_dataset = Dataset.from_pandas(pd.DataFrame(train_data)) | |
| def map_labels_to_ids(examples): | |
| labels = [] | |
| for label_str in examples["label"]: | |
| if label_str in self.hindi_label2id: | |
| labels.append(self.hindi_label2id[label_str]) | |
| else: | |
| print(f"Warning: Unexpected label '{label_str}'. Mapping to neutral.") | |
| labels.append(self.hindi_label2id["neutral"]) | |
| examples["label"] = labels | |
| return examples | |
| train_dataset = train_dataset.map(map_labels_to_ids, batched=True) | |
| train_dataset = train_dataset.cast_column("label", Value("int64")) | |
| def tokenize_function(examples): | |
| return self.hindi_tokenizer( | |
| examples["text"], | |
| padding="max_length", | |
| truncation=True, | |
| max_length=CONFIG["max_length"] | |
| ) | |
| tokenized_train = train_dataset.map(tokenize_function, batched=True) | |
| training_args = TrainingArguments( | |
| output_dir="./results", | |
| eval_strategy="epoch" if eval_data else "no", | |
| per_device_train_batch_size=CONFIG["batch_size"], | |
| per_device_eval_batch_size=CONFIG["batch_size"], | |
| learning_rate=CONFIG["learning_rate"], | |
| num_train_epochs=CONFIG["num_train_epochs"], | |
| weight_decay=0.01, | |
| save_strategy="no", | |
| logging_dir='./logs', | |
| logging_steps=10, | |
| report_to="none" | |
| ) | |
| def compute_metrics(p): | |
| predictions, labels = p | |
| predictions = np.argmax(predictions, axis=1) | |
| return { | |
| "accuracy": accuracy_score(labels, predictions), | |
| "f1": f1_score(labels, predictions, average="weighted") | |
| } | |
| eval_dataset_processed = None | |
| if eval_data: | |
| eval_dataset = Dataset.from_pandas(pd.DataFrame(eval_data)) | |
| eval_dataset = eval_dataset.map(map_labels_to_ids, batched=True) | |
| eval_dataset_processed = eval_dataset.cast_column("label", Value("int64")).map(tokenize_function, batched=True) | |
| trainer = Trainer( | |
| model=self.hindi_model, | |
| args=training_args, | |
| train_dataset=tokenized_train, | |
| eval_dataset=eval_dataset_processed, | |
| compute_metrics=compute_metrics if eval_data else None, | |
| ) | |
| trainer.train() | |
| self.hindi_model.save_pretrained("./fine_tuned_hindi_sentiment") | |
| self.hindi_tokenizer.save_pretrained("./fine_tuned_hindi_sentiment") | |
| def preprocess_text(self, text): | |
| if not text: | |
| return "" | |
| text = emoji.demojize(text, delimiters=(" ", " ")) | |
| text = re.sub(r"http\S+|@\w+", "", text) | |
| abbrevs = { | |
| r"\bomg\b": "oh my god", | |
| r"\btbh\b": "to be honest", | |
| r"\bky\b": "kyun", | |
| r"\bkb\b": "kab", | |
| r"\bkya\b": "kya", | |
| r"\bkahan\b": "kahan", | |
| r"\bkaisa\b": "kaisa" | |
| } | |
| for pattern, replacement in abbrevs.items(): | |
| text = re.sub(pattern, replacement, text, flags=re.IGNORECASE) | |
| return re.sub(r"\s+", " ", text).strip() | |
| def detect_language(self, text): | |
| if re.search(r"[\u0900-\u097F]", text): | |
| return "hi" | |
| hinglish_keywords = ["hai", "kyun", "nahi", "kya", "acha", "bas", "yaar", "main"] | |
| if any(re.search(rf"\b{kw}\b", text.lower()) for kw in hinglish_keywords): | |
| return "hi-latin" | |
| return "en" | |
| def analyze_content(self, text): | |
| processed = self.preprocess_text(text) | |
| if not processed: | |
| return "neutral", 0.5, {"reason": "empty_text"} | |
| lang = self.detect_language(processed) | |
| if any(re.search(rf"\b{re.escape(kw)}\b", processed.lower()) for kw in self.neutral_keywords): | |
| return "neutral", 0.9, {"reason": "neutral_keyword"} | |
| try: | |
| if lang in ("hi", "hi-latin"): | |
| return self._analyze_hindi_content(processed) | |
| return self._analyze_english_content(processed) | |
| except Exception as e: | |
| print(f"Analysis error: {e}") | |
| return "neutral", 0.5, {"error": str(e), "original_text": text[:50]} | |
| def _analyze_hindi_content(self, text): | |
| inputs = self.hindi_tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=CONFIG["max_length"] | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.hindi_model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| pred_idx = torch.argmax(probs).item() | |
| confidence = probs[0][pred_idx].item() | |
| label = self.hindi_model.config.id2label[pred_idx] | |
| return label, confidence, {"model": "fine-tuned-indic-bert", "lang": "hi"} | |
| def _analyze_english_content(self, text): | |
| # Emotion analysis | |
| emotion_inputs = self.emotion_tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=CONFIG["max_length"] | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| emotion_outputs = self.emotion_model(**emotion_inputs) | |
| emotion_probs = torch.nn.functional.softmax(emotion_outputs.logits, dim=-1) | |
| emotion_pred = torch.argmax(emotion_probs).item() | |
| emotion_label = self.emotion_model.config.id2label[emotion_pred] | |
| emotion_score = emotion_probs[0][emotion_pred].item() | |
| # Sentiment analysis | |
| sentiment_inputs = self.sentiment_tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=CONFIG["max_length"] | |
| ).to(self.device) | |
| with torch.no_grad(): | |
| sentiment_outputs = self.sentiment_model(**sentiment_inputs) | |
| sentiment_probs = torch.nn.functional.softmax(sentiment_outputs.logits, dim=-1) | |
| sentiment_pred = torch.argmax(sentiment_probs).item() | |
| sentiment_label_mapping = {0: 'negative', 1: 'neutral', 2: 'positive'} | |
| sentiment_label = sentiment_label_mapping.get(sentiment_pred, 'neutral') | |
| sentiment_score = sentiment_probs[0][sentiment_pred].item() | |
| # Combine results | |
| mapped_emotion = self.emotion_map.get(emotion_label, "neutral") | |
| if sentiment_score > CONFIG["confidence_threshold"]: | |
| final_label = sentiment_label | |
| final_confidence = sentiment_score | |
| reason = "high_sentiment_confidence" | |
| elif emotion_score > CONFIG["confidence_threshold"] and mapped_emotion != "neutral": | |
| final_label = mapped_emotion | |
| final_confidence = emotion_score | |
| reason = "high_emotion_confidence" | |
| else: | |
| if sentiment_label == mapped_emotion and sentiment_label != "neutral": | |
| final_label = sentiment_label | |
| final_confidence = (sentiment_score + emotion_score) / 2 | |
| reason = "emotion_sentiment_agreement" | |
| elif sentiment_label != "neutral" and sentiment_score > emotion_score and sentiment_score > 0.4: | |
| final_label = sentiment_label | |
| final_confidence = sentiment_score * 0.9 | |
| reason = "sentiment_slightly_higher" | |
| elif mapped_emotion != "neutral" and emotion_score > sentiment_score and emotion_score > 0.4: | |
| final_label = mapped_emotion | |
| final_confidence = emotion_score * 0.9 | |
| reason = "emotion_slightly_higher" | |
| else: | |
| final_label = "neutral" | |
| final_confidence = 0.6 | |
| reason = "fallback_to_neutral" | |
| return final_label, final_confidence, { | |
| "emotion_label": emotion_label, | |
| "emotion_score": emotion_score, | |
| "sentiment_label": sentiment_label, | |
| "sentiment_score": sentiment_score, | |
| "mapped_emotion": mapped_emotion, | |
| "model": "ensemble", | |
| "lang": "en", | |
| "reason": reason | |
| } | |
| def analyze_reels(self, reels, max_to_analyze=100): | |
| print(f"Analyzing {max_to_analyze} reels...") | |
| results = Counter() | |
| detailed_results = [] | |
| for i, reel in enumerate(reels[:max_to_analyze], 1): | |
| caption = getattr(reel, 'caption_text', '') or getattr(reel, 'caption', '') or '' | |
| label, confidence, details = self.analyze_content(caption) | |
| results[label] += 1 | |
| detailed_results.append({ | |
| "reel_id": reel.id, | |
| "text": caption, | |
| "label": label, | |
| "confidence": confidence, | |
| "details": details | |
| }) | |
| if sum(results.values()) > 0 and results["neutral"] / sum(results.values()) > CONFIG["neutral_reanalysis_threshold"]: | |
| self._reduce_neutrals(results, detailed_results) | |
| return results, detailed_results | |
| def _reduce_neutrals(self, results, detailed_results): | |
| neutrals_to_recheck = [item for item in detailed_results if item["label"] == "neutral" and item["confidence"] < 0.8] | |
| for item in neutrals_to_recheck: | |
| text_lower = self.preprocess_text(item["text"]).lower() | |
| pos_keywords = {"amazing", "love", "best", "fantastic", "awesome", "superb", "great"} | |
| neg_keywords = {"hate", "worst", "bad", "terrible", "awful", "disappointed", "horrible", "cringe"} | |
| is_strong_pos = any(re.search(rf"\b{re.escape(kw)}\b", text_lower) for kw in pos_keywords) | |
| is_strong_neg = any(re.search(rf"\b{re.escape(kw)}\b", text_lower) for kw in neg_keywords) | |
| if is_strong_pos and not is_strong_neg: | |
| results["neutral"] -= 1 | |
| results["positive"] += 1 | |
| item.update({ | |
| "label": "positive", | |
| "confidence": min(0.95, item["confidence"] + 0.3), | |
| "reanalyzed": True, | |
| "reanalysis_reason": "strong_pos_keywords" | |
| }) | |
| elif is_strong_neg and not is_strong_pos: | |
| results["neutral"] -= 1 | |
| results["negative"] += 1 | |
| item.update({ | |
| "label": "negative", | |
| "confidence": min(0.95, item["confidence"] + 0.3), | |
| "reanalyzed": True, | |
| "reanalysis_reason": "strong_neg_keywords" | |
| }) | |
| def plot_sentiment_pie(results, title="Reels Sentiment Analysis"): | |
| sizes = [results.get('positive', 0), results.get('neutral', 0), results.get('negative', 0)] | |
| if sum(sizes) == 0: | |
| return None | |
| labels = ['Positive', 'Neutral', 'Negative'] | |
| colors = ['#4CAF50', '#FFC107', '#F44336'] | |
| explode = (0.05, 0, 0.05) | |
| fig, ax = plt.subplots(figsize=(8, 6)) | |
| ax.pie(sizes, explode=explode, labels=labels, colors=colors, | |
| autopct='%1.1f%%', shadow=True, startangle=140, | |
| textprops={'fontsize': 12, 'color': 'black'}) | |
| ax.axis('equal') | |
| plt.title(title, fontsize=16, pad=20) | |
| plt.tight_layout() | |
| return fig | |
| def plot_category_distribution(counter, title="Reels Content Distribution"): | |
| total = sum(counter.values()) | |
| if total == 0: | |
| return None | |
| threshold = total * 0.02 | |
| other_count = 0 | |
| labels = [] | |
| sizes = [] | |
| for category, count in counter.most_common(): | |
| if count >= threshold and category != "other": | |
| labels.append(category.replace('_', ' ').title()) | |
| sizes.append(count) | |
| else: | |
| other_count += count | |
| if other_count > 0: | |
| labels.append("Other") | |
| sizes.append(other_count) | |
| if not sizes: | |
| return None | |
| fig, ax = plt.subplots(figsize=(10, 8)) | |
| colors = plt.cm.viridis(np.linspace(0, 1, len(sizes))) | |
| ax.pie(sizes, labels=labels, autopct='%1.1f%%', startangle=140, colors=colors, | |
| wedgeprops={'edgecolor': 'white', 'linewidth': 1}, textprops={'fontsize': 11}) | |
| plt.title(title, pad=20, fontsize=15) | |
| plt.axis('equal') | |
| plt.tight_layout() | |
| return fig | |
| def preprocess_text_cat(text): | |
| if not text: | |
| return "" | |
| text = re.sub(r"http\S+|@\w+|#\w+", "", text).lower() | |
| return re.sub(r"\s+", " ", text).strip() | |
| def classify_reel_content(text): | |
| global content_classifier | |
| processed = preprocess_text_cat(text) | |
| if not processed or len(processed.split()) < 2: | |
| return "other", {"reason": "short_text"} | |
| for category, keywords in CATEGORY_KEYWORDS.items(): | |
| if any(re.search(rf"\b{re.escape(keyword)}\b", processed) for keyword in keywords): | |
| return category, {"reason": "keyword_match"} | |
| if content_classifier is None: | |
| return "other", {"reason": "classifier_not_initialized"} | |
| try: | |
| result = content_classifier(processed[:256], CONTENT_CATEGORIES, multi_label=False) | |
| top_label = result['labels'][0] | |
| top_score = result['scores'][0] | |
| return top_label if top_score > 0.5 else "other", {"reason": "model_prediction", "score": top_score} | |
| except Exception as e: | |
| print(f"Classification error: {e}") | |
| return "other", {"reason": "classification_error"} | |
| # Gradio Interface Functions | |
| def login_gradio_auto(): | |
| global cl | |
| try: | |
| PASSWORD = "qwerty@desk" # Replace with your actual password | |
| except Exception as e: | |
| return f"Error accessing password: {e}", gr.update(visible=False) | |
| if not PASSWORD: | |
| return "Error: Instagram password not found.", gr.update(visible=False) | |
| cl = Client() | |
| try: | |
| cl.login("jattman1993", PASSWORD) | |
| return f"Successfully logged in as jattman1993", gr.update(visible=False) | |
| except Exception as e: | |
| cl = None | |
| error_message = str(e) | |
| if "Two factor challenged" in error_message or "challenge_required" in error_message: | |
| return f"Login failed: Two-factor authentication required.", gr.update(visible=True) | |
| return f"Error during login: {error_message}", gr.update(visible=False) | |
| def submit_otp_gradio(otp_code): | |
| global cl | |
| if cl is None: | |
| return "Error: Not logged in.", "", gr.update(visible=False) | |
| try: | |
| cl.two_factor_login(otp_code) | |
| return f"OTP successful. Logged in as jattman1993.", "", gr.update(visible=False) | |
| except Exception as e: | |
| return f"OTP failed: {e}", "", gr.update(visible=True) | |
| def fetch_reels_gradio(): | |
| global cl, explore_reels_list | |
| if cl is None: | |
| explore_reels_list = [] | |
| return "Error: Not logged in." | |
| try: | |
| explore_reels_list = cl.explore_reels()[:100] | |
| return f"Fetched {len(explore_reels_list)} reels." | |
| except Exception as e: | |
| explore_reels_list = [] | |
| return f"Error fetching reels: {e}" | |
| def analyze_reels_gradio(max_to_analyze): | |
| global explore_reels_list, sentiment_analyzer, content_classifier | |
| if not explore_reels_list: | |
| return "Error: No reels fetched.", None, None | |
| num_reels = min(max_to_analyze, len(explore_reels_list)) | |
| reels_to_analyze = explore_reels_list[:num_reels] | |
| if sentiment_analyzer is None: | |
| sentiment_analyzer = ReelSentimentAnalyzer() | |
| if content_classifier is None: | |
| content_classifier = pipeline( | |
| "zero-shot-classification", | |
| model="facebook/bart-large-mnli", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| status_messages = [] | |
| sentiment_plot = None | |
| content_plot = None | |
| # Sentiment Analysis | |
| try: | |
| sentiment_results, _ = sentiment_analyzer.analyze_reels(reels_to_analyze) | |
| sentiment_plot = plot_sentiment_pie(sentiment_results) | |
| status_messages.append("Sentiment analysis complete.") | |
| except Exception as e: | |
| status_messages.append(f"Sentiment error: {e}") | |
| # Content Analysis | |
| try: | |
| category_counts = Counter() | |
| for reel in reels_to_analyze: | |
| caption = getattr(reel, 'caption_text', '') or getattr(reel, 'caption', '') or '' | |
| category, _ = classify_reel_content(caption) | |
| category_counts[category] += 1 | |
| content_plot = plot_category_distribution(category_counts) | |
| status_messages.append("Content analysis complete.") | |
| except Exception as e: | |
| status_messages.append(f"Content error: {e}") | |
| return "\n".join(status_messages), sentiment_plot, content_plot | |
| # Gradio Interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Instagram Reels Analysis") | |
| # Login Section | |
| with gr.Row(): | |
| connect_btn = gr.Button("Connect Instagram") | |
| login_status = gr.Label(label="Login Status") | |
| # OTP Input (hidden initially) | |
| with gr.Row(visible=False) as otp_row: | |
| otp_input = gr.Textbox(label="Enter OTP Code") | |
| otp_submit_btn = gr.Button("Submit OTP") | |
| # Fetch Section | |
| with gr.Row(): | |
| fetch_btn = gr.Button("Fetch Reels") | |
| fetch_status = gr.Label(label="Fetch Status") | |
| # Analysis Section | |
| with gr.Row(): | |
| max_reels = gr.Slider(1, 100, value=10, step=1, label="Number of Reels to Analyze") | |
| analyze_btn = gr.Button("Analyze Reels") | |
| analyze_status = gr.Label(label="Analysis Status") | |
| # Results Section | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Sentiment Analysis") | |
| sentiment_output = gr.Plot(label="Sentiment Distribution") | |
| with gr.Column(): | |
| gr.Markdown("## Content Analysis") | |
| content_output = gr.Plot(label="Content Distribution") | |
| # Event handlers | |
| connect_btn.click( | |
| login_gradio_auto, | |
| inputs=None, | |
| outputs=[login_status, otp_row] | |
| ) | |
| otp_submit_btn.click( | |
| submit_otp_gradio, | |
| inputs=otp_input, | |
| outputs=[login_status, otp_input, otp_row] | |
| ) | |
| fetch_btn.click( | |
| fetch_reels_gradio, | |
| inputs=None, | |
| outputs=fetch_status | |
| ) | |
| analyze_btn.click( | |
| analyze_reels_gradio, | |
| inputs=max_reels, | |
| outputs=[analyze_status, sentiment_output, content_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |