StoryLens / app.py
Marek4321's picture
Update app.py
75752d4 verified
import streamlit as st
import os
from PIL import Image
from config import INDUSTRIES, CAMPAIGN_GOALS, CATEGORY_COLORS, MAX_VIDEO_LENGTH_SECONDS
from video_loader import VideoLoader
from frame_extractor import FrameExtractor
from audio_extractor import AudioExtractor
from vision_analyzer import VisionAnalyzer
from segment_synchronizer import SegmentSynchronizer
from narrative_classifier import NarrativeClassifier
from report_generator import ReportGenerator
# Page config
st.set_page_config(
page_title="StoryLens - Ad Narrative Analyzer",
page_icon="🎬",
layout="wide"
)
# Initialize session state
if 'analysis_result' not in st.session_state:
st.session_state.analysis_result = None
if 'transcript' not in st.session_state:
st.session_state.transcript = None
# Sidebar
with st.sidebar:
st.header("Configuration")
# API Settings
with st.expander("API Settings", expanded=True):
st.subheader("MiniMax (Vision & LLM)")
api_key = st.text_input(
"MiniMax API Key",
type="password",
value=os.getenv("MINIMAX_API_KEY", ""),
help="Get your API key from MiniMax platform"
)
group_id = st.text_input(
"MiniMax Group ID",
value=os.getenv("MINIMAX_GROUP_ID", "")
)
if api_key and group_id:
st.session_state.api_key = api_key
st.session_state.group_id = group_id
st.success("MiniMax configured")
st.divider()
st.subheader("OpenAI (Whisper)")
openai_key = st.text_input(
"OpenAI API Key",
type="password",
value=os.getenv("OPENAI_API_KEY", ""),
help="For audio transcription (Whisper)"
)
if openai_key:
st.session_state.openai_key = openai_key
st.success("OpenAI configured")
st.divider()
# Campaign Settings
st.subheader("Campaign Settings")
industry = st.selectbox("Industry", INDUSTRIES)
campaign_goal = st.selectbox("Campaign Goal", CAMPAIGN_GOALS)
# Main content
st.title("StoryLens")
st.markdown("*Diagnose your video ad's narrative structure*")
# Video Input
st.header("Video Input")
col1, col2 = st.columns(2)
with col1:
st.subheader("Upload File")
uploaded_file = st.file_uploader(
"Choose video file",
type=["mp4", "mov", "avi", "webm"],
help="Max 120 seconds"
)
with col2:
st.subheader("YouTube URL")
youtube_url = st.text_input(
"Paste URL",
placeholder="https://youtube.com/watch?v=..."
)
# Analyze button
video_source = uploaded_file or youtube_url
minimax_ready = hasattr(st.session_state, 'api_key') and st.session_state.api_key
openai_ready = hasattr(st.session_state, 'openai_key') and st.session_state.openai_key
api_ready = minimax_ready and openai_ready
if video_source and api_ready:
if st.button("Analyze", type="primary", use_container_width=True):
# Progress container
progress_container = st.container()
with progress_container:
progress_bar = st.progress(0)
status_text = st.empty()
try:
# Initialize components
api_key = st.session_state.api_key
group_id = st.session_state.group_id
openai_key = st.session_state.openai_key
video_loader = VideoLoader()
frame_extractor = FrameExtractor()
audio_extractor = AudioExtractor(openai_api_key=openai_key)
vision_analyzer = VisionAnalyzer(api_key, group_id)
synchronizer = SegmentSynchronizer()
classifier = NarrativeClassifier(api_key, group_id)
report_generator = ReportGenerator()
# Step 1: Load video
status_text.text("Loading video...")
progress_bar.progress(10)
if uploaded_file:
video_path = video_loader.load_local(uploaded_file)
else:
video_path = video_loader.load_youtube(youtube_url)
if not video_path:
st.error("Failed to load video")
st.stop()
# Check duration
duration = video_loader.get_video_duration(video_path)
if duration > MAX_VIDEO_LENGTH_SECONDS:
st.error(f"Video too long ({duration:.0f}s). Max allowed: {MAX_VIDEO_LENGTH_SECONDS}s")
st.stop()
# Step 2: Extract frames
status_text.text("Extracting frames...")
progress_bar.progress(20)
frames = frame_extractor.extract_frames(video_path)
# Step 3: Extract & transcribe audio
status_text.text("Transcribing audio...")
progress_bar.progress(35)
audio_path = audio_extractor.extract_audio(video_path)
transcript = audio_extractor.transcribe(audio_path)
# Step 4: Analyze frames visually
status_text.text("Analyzing frames...")
progress_bar.progress(50)
frame_descriptions = vision_analyzer.describe_frames_batch(frames)
# Step 5: Synchronize
status_text.text("Synchronizing segments...")
progress_bar.progress(70)
segments = synchronizer.synchronize(frame_descriptions, transcript)
# Step 6: Classify narrative
status_text.text("Classifying narrative structure...")
progress_bar.progress(85)
analysis = classifier.classify(segments)
# Step 7: Generate report
status_text.text("Generating report...")
progress_bar.progress(95)
report = report_generator.generate(analysis, industry, campaign_goal)
progress_bar.progress(100)
status_text.text("Analysis complete!")
# Store result
st.session_state.analysis_result = report
st.session_state.transcript = transcript
except Exception as e:
st.error(f"Analysis failed: {str(e)}")
import traceback
st.code(traceback.format_exc())
elif not api_ready:
missing = []
if not minimax_ready:
missing.append("MiniMax API Key + Group ID")
if not openai_ready:
missing.append("OpenAI API Key")
st.warning(f"Please configure API settings in the sidebar: {', '.join(missing)}")
elif not video_source:
st.info("Upload a video file or paste a YouTube URL to begin")
# Display results
if st.session_state.analysis_result:
result = st.session_state.analysis_result
st.divider()
# Summary metrics
st.header("Analysis Results")
col1, col2, col3, col4 = st.columns(4)
with col1:
story_status = "YES" if result['summary']['has_story'] else "NO"
st.metric("Story Detected", story_status)
with col2:
st.metric("Detected Arc", result['summary']['detected_arc'])
with col3:
st.metric("Optimal Arc", result['summary']['optimal_arc_for_goal'])
with col4:
st.metric("Potential Uplift", result['summary']['potential_uplift'])
# Story explanation
if result['summary']['story_explanation']:
st.info(f"**Story Analysis:** {result['summary']['story_explanation']}")
st.divider()
# Timeline visualization
st.subheader("Narrative Timeline")
for seg in result['segments']:
col1, col2, col3, col4 = st.columns([1, 1, 2, 3])
with col1:
# Frame thumbnail
if seg.get('frame_path') and os.path.exists(seg['frame_path']):
img = Image.open(seg['frame_path'])
st.image(img, width=120)
else:
st.write("[Frame]")
with col2:
st.caption(f"**{seg['start']:.1f}s - {seg['end']:.1f}s**")
# Role badge with color
category = seg.get('role_category', 'OTHER')
color = CATEGORY_COLORS.get(category, '#9E9E9E')
role = seg.get('functional_role', 'Unknown')
st.markdown(
f'<span style="background-color: {color}; color: white; '
f'padding: 4px 8px; border-radius: 4px; font-size: 12px;">'
f'{role}</span>',
unsafe_allow_html=True
)
with col3:
visual_text = seg.get('visual', 'N/A')
st.write(f"**Visual:** {visual_text}")
with col4:
if seg.get('speech'):
st.write(f"**Speech:** \"{seg['speech']}\"")
if seg.get('reasoning'):
st.caption(f"*{seg['reasoning']}*")
st.divider()
# Detected sequence
if result.get('detected_sequence'):
st.subheader("Story Arc Flow")
arc_flow = " -> ".join(result['detected_sequence'])
st.markdown(f"**{arc_flow}**")
# Missing elements
if result.get('missing_elements'):
st.subheader("Missing Elements")
for element in result['missing_elements']:
st.warning(f"- {element}")
st.divider()
# Recommendations
st.subheader("Recommendations")
for rec in result.get('recommendations', []):
priority = rec.get('priority', 'LOW')
icon = "[HIGH]" if priority == "HIGH" else "[MEDIUM]" if priority == "MEDIUM" else "[LOW]"
with st.expander(f"{icon} {rec['action']}", expanded=(priority == "HIGH")):
col1, col2 = st.columns(2)
with col1:
st.metric("Expected Impact", rec.get('expected_impact', 'N/A'))
with col2:
st.metric("Priority", priority)
st.write(f"**Reasoning:** {rec.get('reasoning', '')}")
# Benchmark info
with st.expander("Benchmark Details"):
benchmark = result.get('benchmark', {})
st.write(f"**Best Arc for {campaign_goal}:** {benchmark.get('best_arc', 'N/A')}")
st.write(f"**Average Uplift:** +{benchmark.get('uplift_percent', '?')}%")
st.write(f"**Recommendation:** {benchmark.get('recommendation', 'N/A')}")
# Full Transcript
if hasattr(st.session_state, 'transcript') and st.session_state.transcript:
st.divider()
st.subheader("Full Transcript")
transcript = st.session_state.transcript
# Display with timestamps
for seg in transcript:
start = seg.get('start', 0)
end = seg.get('end', 0)
text = seg.get('text', '')
if text:
if start > 0 or end > 0:
st.markdown(f"**[{start:.1f}s - {end:.1f}s]** {text}")
else:
st.markdown(text)
# Also show as plain text block
with st.expander("Plain Text"):
full_text = " ".join([seg.get('text', '') for seg in transcript if seg.get('text')])
st.text_area("Full transcript", full_text, height=150, disabled=True)