entropy25's picture
Update app.py
6a0213a verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
from transformers import CLIPProcessor, CLIPModel, AutoTokenizer, AutoModel
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2
class GatedAttentionFusion(nn.Module):
def __init__(self, img_dim=512, text_dim=768, hidden_dim=256):
super().__init__()
self.img_proj = nn.Linear(img_dim, hidden_dim)
self.text_proj = nn.Linear(text_dim, hidden_dim)
self.gate = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.Sigmoid()
)
self.cross_attention = nn.MultiheadAttention(hidden_dim, 8, dropout=0.1)
self.layer_norm = nn.LayerNorm(hidden_dim)
def forward(self, img_feat, text_feat):
img_proj = self.img_proj(img_feat)
text_proj = self.text_proj(text_feat)
concat_feat = torch.cat([img_proj, text_proj], dim=-1)
gate_weight = self.gate(concat_feat)
gated_img = img_proj * gate_weight
gated_text = text_proj * (1 - gate_weight)
fused_feat = gated_img + gated_text
fused_feat = fused_feat.unsqueeze(0)
attended_feat, attention_weights = self.cross_attention(
fused_feat, fused_feat, fused_feat
)
attended_feat = attended_feat.squeeze(0)
attended_feat = self.layer_norm(attended_feat + fused_feat.squeeze(0))
return attended_feat, attention_weights
class SentimentClassifier(nn.Module):
def __init__(self, input_dim=256, num_classes=3):
super().__init__()
self.feature_enhancer = nn.Sequential(
nn.Linear(input_dim, 512),
nn.LayerNorm(512),
nn.ReLU(),
nn.Dropout(0.3)
)
self.self_attention = nn.MultiheadAttention(512, 8, dropout=0.1)
self.classifier = nn.Sequential(
nn.Linear(512, 256),
nn.LayerNorm(256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, num_classes)
)
def forward(self, x):
enhanced = self.feature_enhancer(x)
enhanced = enhanced.unsqueeze(0)
attended, attn_weights = self.self_attention(enhanced, enhanced, enhanced)
attended = attended.squeeze(0)
final_feat = enhanced.squeeze(0) + attended
logits = self.classifier(final_feat)
return logits, attn_weights
class MovieSentimentAnalyzer:
def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load models
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.text_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
self.text_model = AutoModel.from_pretrained("bert-base-uncased")
# Custom layers
self.fusion_module = GatedAttentionFusion()
self.classifier = SentimentClassifier()
# Move to device
self.clip_model.to(self.device)
self.text_model.to(self.device)
self.fusion_module.to(self.device)
self.classifier.to(self.device)
# Set eval mode
self.clip_model.eval()
self.text_model.eval()
self.fusion_module.eval()
self.classifier.eval()
# Movie sentiment labels
self.labels = ['Not Recommended', 'Average', 'Highly Recommended']
def extract_image_features(self, image):
inputs = self.clip_processor(images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
image_features = self.clip_model.get_image_features(**inputs)
return image_features
def extract_text_features(self, text):
inputs = self.text_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.text_model(**inputs)
text_features = outputs.last_hidden_state.mean(dim=1)
return text_features
def predict_sentiment(self, image, text):
img_features = self.extract_image_features(image)
text_features = self.extract_text_features(text)
with torch.no_grad():
fused_features, fusion_attention = self.fusion_module(img_features, text_features)
logits, classification_attention = self.classifier(fused_features)
probabilities = F.softmax(logits, dim=-1)
return probabilities, fusion_attention, classification_attention
def generate_gradcam(self, image, text):
img_array = np.array(image.resize((224, 224)))
height, width = img_array.shape[:2]
# Mock attention map
attention_map = np.random.random((height, width))
attention_map = cv2.GaussianBlur(attention_map, (21, 21), 0)
attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
overlay = 0.6 * img_array + 0.4 * heatmap
overlay = np.uint8(overlay)
return Image.fromarray(overlay)
def create_attention_visualization(self, text, attention_weights):
words = text.split()
if len(words) == 0:
return "No text provided"
mock_weights = np.random.random(len(words))
mock_weights = mock_weights / mock_weights.sum()
highlighted_text = []
for word, weight in zip(words, mock_weights):
intensity = min(1.0, weight * 3)
highlighted_text.append((word, intensity))
return highlighted_text
# Initialize analyzer
analyzer = MovieSentimentAnalyzer()
def analyze_movie_sentiment(image, text):
if image is None or not text.strip():
return (
{"Error": 1.0},
None,
None,
"Please upload a movie poster and enter your review.",
None
)
try:
probabilities, fusion_attn, class_attn = analyzer.predict_sentiment(image, text)
prob_dict = {
analyzer.labels[i]: float(probabilities[0][i])
for i in range(len(analyzer.labels))
}
gradcam_image = analyzer.generate_gradcam(image, text)
text_attention = analyzer.create_attention_visualization(text, class_attn)
# Create plot
fig, ax = plt.subplots(figsize=(8, 5))
labels = list(prob_dict.keys())
values = list(prob_dict.values())
colors = ['#ff6b6b', '#feca57', '#48dbfb']
bars = ax.bar(labels, values, color=colors, alpha=0.8)
ax.set_ylabel('Recommendation Score')
ax.set_title('Movie Sentiment Analysis')
ax.set_ylim(0, 1)
for bar, value in zip(bars, values):
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
f'{value:.3f}', ha='center', va='bottom')
plt.tight_layout()
# Generate explanation
predicted_label = max(prob_dict, key=prob_dict.get)
confidence = prob_dict[predicted_label]
explanation = f"""
**Movie Analysis Results:**
🎯 **Recommendation**: {predicted_label}
⭐ **Confidence**: {confidence:.1%}
**Analysis Summary:**
The model analyzed the movie poster/image and your review text to determine the overall sentiment.
Visual elements like color scheme, composition, and textual sentiment patterns were considered.
**Score Breakdown:**
β€’ Highly Recommended: {prob_dict['Highly Recommended']:.1%}
β€’ Average: {prob_dict['Average']:.1%}
β€’ Not Recommended: {prob_dict['Not Recommended']:.1%}
"""
return prob_dict, fig, gradcam_image, explanation, text_attention
except Exception as e:
return (
{"Error": 1.0},
None,
None,
f"Analysis error: {str(e)}",
None
)
def create_interface():
with gr.Blocks(
theme=gr.themes.Soft(),
title="Movie Sentiment Analysis",
css="""
.gradio-container {
max-width: 1200px !important;
}
.main-header {
text-align: center;
margin-bottom: 30px;
}
"""
) as interface:
gr.HTML("""
<div class="main-header">
<h1>🎬 Movie Sentiment Analysis</h1>
<p>AI-powered analysis of movie posters and reviews for recommendation insights</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“₯ Input")
image_input = gr.Image(
type="pil",
label="🎬 Upload Movie Poster",
height=300
)
text_input = gr.Textbox(
label="πŸ“ Movie Review",
placeholder="Enter your movie review or thoughts here...",
lines=4
)
analyze_btn = gr.Button(
"πŸ” Analyze Movie",
variant="primary",
size="lg"
)
gr.Markdown("### πŸ“Š Results")
sentiment_output = gr.Label(
label="🎯 Recommendation",
num_top_classes=3
)
with gr.Column(scale=1):
gr.Markdown("### πŸ“ˆ Confidence Scores")
confidence_plot = gr.Plot(label="Analysis Results")
gr.Markdown("### πŸ’­ Analysis Summary")
explanation_output = gr.Textbox(
label="Detailed Results",
lines=8,
max_lines=15
)
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ”₯ Visual Attention")
gradcam_output = gr.Image(
label="Poster Analysis Heatmap",
height=300
)
with gr.Column():
gr.Markdown("### πŸ“ Text Attention")
text_attention_output = gr.HighlightedText(
label="Key Words",
combine_adjacent=True
)
# Example text suggestions
gr.Markdown("### 🎯 Example Reviews")
gr.Markdown("""
**Positive Example:**
"This movie exceeded all my expectations! The visual effects were breathtaking and the storyline was incredibly engaging. Definitely worth watching!"
**Negative Example:**
"I found this film quite disappointing. The pacing was slow and the plot felt predictable. Not what I was hoping for."
""")
analyze_btn.click(
fn=analyze_movie_sentiment,
inputs=[image_input, text_input],
outputs=[sentiment_output, confidence_plot, gradcam_output, explanation_output, text_attention_output]
)
gr.HTML("""
<div style="text-align: center; margin-top: 40px; padding: 20px; border-top: 1px solid #ddd;">
<p><strong>🎬 Movie Industry AI Analysis</strong></p>
<p>Powered by CLIP + BERT with cross-modal attention for movie recommendation</p>
</div>
""")
return interface
if __name__ == "__main__":
demo = create_interface()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True
)