Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModel | |
| from PIL import Image | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| class GatedAttentionFusion(nn.Module): | |
| def __init__(self, img_dim=512, text_dim=768, hidden_dim=256): | |
| super().__init__() | |
| self.img_proj = nn.Linear(img_dim, hidden_dim) | |
| self.text_proj = nn.Linear(text_dim, hidden_dim) | |
| self.gate = nn.Sequential( | |
| nn.Linear(hidden_dim * 2, hidden_dim), | |
| nn.Sigmoid() | |
| ) | |
| self.cross_attention = nn.MultiheadAttention(hidden_dim, 8, dropout=0.1) | |
| self.layer_norm = nn.LayerNorm(hidden_dim) | |
| def forward(self, img_feat, text_feat): | |
| img_proj = self.img_proj(img_feat) | |
| text_proj = self.text_proj(text_feat) | |
| concat_feat = torch.cat([img_proj, text_proj], dim=-1) | |
| gate_weight = self.gate(concat_feat) | |
| gated_img = img_proj * gate_weight | |
| gated_text = text_proj * (1 - gate_weight) | |
| fused_feat = gated_img + gated_text | |
| fused_feat = fused_feat.unsqueeze(0) | |
| attended_feat, attention_weights = self.cross_attention( | |
| fused_feat, fused_feat, fused_feat | |
| ) | |
| attended_feat = attended_feat.squeeze(0) | |
| attended_feat = self.layer_norm(attended_feat + fused_feat.squeeze(0)) | |
| return attended_feat, attention_weights | |
| class SentimentClassifier(nn.Module): | |
| def __init__(self, input_dim=256, num_classes=3): | |
| super().__init__() | |
| self.feature_enhancer = nn.Sequential( | |
| nn.Linear(input_dim, 512), | |
| nn.LayerNorm(512), | |
| nn.ReLU(), | |
| nn.Dropout(0.3) | |
| ) | |
| self.self_attention = nn.MultiheadAttention(512, 8, dropout=0.1) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(512, 256), | |
| nn.LayerNorm(256), | |
| nn.ReLU(), | |
| nn.Dropout(0.2), | |
| nn.Linear(256, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, num_classes) | |
| ) | |
| def forward(self, x): | |
| enhanced = self.feature_enhancer(x) | |
| enhanced = enhanced.unsqueeze(0) | |
| attended, attn_weights = self.self_attention(enhanced, enhanced, enhanced) | |
| attended = attended.squeeze(0) | |
| final_feat = enhanced.squeeze(0) + attended | |
| logits = self.classifier(final_feat) | |
| return logits, attn_weights | |
| class MovieSentimentAnalyzer: | |
| def __init__(self): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load models | |
| self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| self.text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| self.text_model = AutoModel.from_pretrained("bert-base-uncased") | |
| # Custom layers | |
| self.fusion_module = GatedAttentionFusion() | |
| self.classifier = SentimentClassifier() | |
| # Move to device | |
| self.clip_model.to(self.device) | |
| self.text_model.to(self.device) | |
| self.fusion_module.to(self.device) | |
| self.classifier.to(self.device) | |
| # Set eval mode | |
| self.clip_model.eval() | |
| self.text_model.eval() | |
| self.fusion_module.eval() | |
| self.classifier.eval() | |
| # Movie sentiment labels | |
| self.labels = ['Not Recommended', 'Average', 'Highly Recommended'] | |
| def extract_image_features(self, image): | |
| inputs = self.clip_processor(images=image, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| image_features = self.clip_model.get_image_features(**inputs) | |
| return image_features | |
| def extract_text_features(self, text): | |
| inputs = self.text_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.text_model(**inputs) | |
| text_features = outputs.last_hidden_state.mean(dim=1) | |
| return text_features | |
| def predict_sentiment(self, image, text): | |
| img_features = self.extract_image_features(image) | |
| text_features = self.extract_text_features(text) | |
| with torch.no_grad(): | |
| fused_features, fusion_attention = self.fusion_module(img_features, text_features) | |
| logits, classification_attention = self.classifier(fused_features) | |
| probabilities = F.softmax(logits, dim=-1) | |
| return probabilities, fusion_attention, classification_attention | |
| def generate_gradcam(self, image, text): | |
| img_array = np.array(image.resize((224, 224))) | |
| height, width = img_array.shape[:2] | |
| # Mock attention map | |
| attention_map = np.random.random((height, width)) | |
| attention_map = cv2.GaussianBlur(attention_map, (21, 21), 0) | |
| attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min()) | |
| heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET) | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| overlay = 0.6 * img_array + 0.4 * heatmap | |
| overlay = np.uint8(overlay) | |
| return Image.fromarray(overlay) | |
| def create_attention_visualization(self, text, attention_weights): | |
| words = text.split() | |
| if len(words) == 0: | |
| return "No text provided" | |
| mock_weights = np.random.random(len(words)) | |
| mock_weights = mock_weights / mock_weights.sum() | |
| highlighted_text = [] | |
| for word, weight in zip(words, mock_weights): | |
| intensity = min(1.0, weight * 3) | |
| highlighted_text.append((word, intensity)) | |
| return highlighted_text | |
| # Initialize analyzer | |
| analyzer = MovieSentimentAnalyzer() | |
| def analyze_movie_sentiment(image, text): | |
| if image is None or not text.strip(): | |
| return ( | |
| {"Error": 1.0}, | |
| None, | |
| None, | |
| "Please upload a movie poster and enter your review.", | |
| None | |
| ) | |
| try: | |
| probabilities, fusion_attn, class_attn = analyzer.predict_sentiment(image, text) | |
| prob_dict = { | |
| analyzer.labels[i]: float(probabilities[0][i]) | |
| for i in range(len(analyzer.labels)) | |
| } | |
| gradcam_image = analyzer.generate_gradcam(image, text) | |
| text_attention = analyzer.create_attention_visualization(text, class_attn) | |
| # Create plot | |
| fig, ax = plt.subplots(figsize=(8, 5)) | |
| labels = list(prob_dict.keys()) | |
| values = list(prob_dict.values()) | |
| colors = ['#ff6b6b', '#feca57', '#48dbfb'] | |
| bars = ax.bar(labels, values, color=colors, alpha=0.8) | |
| ax.set_ylabel('Recommendation Score') | |
| ax.set_title('Movie Sentiment Analysis') | |
| ax.set_ylim(0, 1) | |
| for bar, value in zip(bars, values): | |
| height = bar.get_height() | |
| ax.text(bar.get_x() + bar.get_width()/2., height + 0.01, | |
| f'{value:.3f}', ha='center', va='bottom') | |
| plt.tight_layout() | |
| # Generate explanation | |
| predicted_label = max(prob_dict, key=prob_dict.get) | |
| confidence = prob_dict[predicted_label] | |
| explanation = f""" | |
| **Movie Analysis Results:** | |
| π― **Recommendation**: {predicted_label} | |
| β **Confidence**: {confidence:.1%} | |
| **Analysis Summary:** | |
| The model analyzed the movie poster/image and your review text to determine the overall sentiment. | |
| Visual elements like color scheme, composition, and textual sentiment patterns were considered. | |
| **Score Breakdown:** | |
| β’ Highly Recommended: {prob_dict['Highly Recommended']:.1%} | |
| β’ Average: {prob_dict['Average']:.1%} | |
| β’ Not Recommended: {prob_dict['Not Recommended']:.1%} | |
| """ | |
| return prob_dict, fig, gradcam_image, explanation, text_attention | |
| except Exception as e: | |
| return ( | |
| {"Error": 1.0}, | |
| None, | |
| None, | |
| f"Analysis error: {str(e)}", | |
| None | |
| ) | |
| def create_interface(): | |
| with gr.Blocks( | |
| theme=gr.themes.Soft(), | |
| title="Movie Sentiment Analysis", | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| .main-header { | |
| text-align: center; | |
| margin-bottom: 30px; | |
| } | |
| """ | |
| ) as interface: | |
| gr.HTML(""" | |
| <div class="main-header"> | |
| <h1>π¬ Movie Sentiment Analysis</h1> | |
| <p>AI-powered analysis of movie posters and reviews for recommendation insights</p> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π₯ Input") | |
| image_input = gr.Image( | |
| type="pil", | |
| label="π¬ Upload Movie Poster", | |
| height=300 | |
| ) | |
| text_input = gr.Textbox( | |
| label="π Movie Review", | |
| placeholder="Enter your movie review or thoughts here...", | |
| lines=4 | |
| ) | |
| analyze_btn = gr.Button( | |
| "π Analyze Movie", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.Markdown("### π Results") | |
| sentiment_output = gr.Label( | |
| label="π― Recommendation", | |
| num_top_classes=3 | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### π Confidence Scores") | |
| confidence_plot = gr.Plot(label="Analysis Results") | |
| gr.Markdown("### π Analysis Summary") | |
| explanation_output = gr.Textbox( | |
| label="Detailed Results", | |
| lines=8, | |
| max_lines=15 | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### π₯ Visual Attention") | |
| gradcam_output = gr.Image( | |
| label="Poster Analysis Heatmap", | |
| height=300 | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### π Text Attention") | |
| text_attention_output = gr.HighlightedText( | |
| label="Key Words", | |
| combine_adjacent=True | |
| ) | |
| # Example text suggestions | |
| gr.Markdown("### π― Example Reviews") | |
| gr.Markdown(""" | |
| **Positive Example:** | |
| "This movie exceeded all my expectations! The visual effects were breathtaking and the storyline was incredibly engaging. Definitely worth watching!" | |
| **Negative Example:** | |
| "I found this film quite disappointing. The pacing was slow and the plot felt predictable. Not what I was hoping for." | |
| """) | |
| analyze_btn.click( | |
| fn=analyze_movie_sentiment, | |
| inputs=[image_input, text_input], | |
| outputs=[sentiment_output, confidence_plot, gradcam_output, explanation_output, text_attention_output] | |
| ) | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-top: 40px; padding: 20px; border-top: 1px solid #ddd;"> | |
| <p><strong>π¬ Movie Industry AI Analysis</strong></p> | |
| <p>Powered by CLIP + BERT with cross-modal attention for movie recommendation</p> | |
| </div> | |
| """) | |
| return interface | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| show_error=True | |
| ) |