memo / core /scene_planner.py
likhonsheikh's picture
Upload Memo: Production-grade Transformers + Safetensors implementation
a8fc815 verified
"""
Scene Planner - Uses Transformer Model for Intelligent Scene Generation
Replaces toy logic with proper ML-based scene planning
"""
import math
import logging
from typing import List, Dict, Tuple
from models.text.bangla_parser import extract_scenes, BanglaSceneParser
logger = logging.getLogger(__name__)
class ScenePlanner:
"""
Production-grade scene planner using transformer models.
Handles timing, pacing, and visual coherence.
"""
def __init__(self, model_id: str = "google/mt5-small"):
"""
Initialize the scene planner.
Args:
model_id: Model for Bangla text processing
"""
self.parser = BanglaSceneParser(model_id)
logger.info("ScenePlanner initialized with transformer model")
def plan_scenes(self, text_bn: str, duration: int = 15) -> List[Dict]:
"""
Generate intelligent scene plan from Bangla text.
Args:
text_bn: Input Bangla text
duration: Total video duration in seconds
Returns:
List of scene dictionaries with timing and descriptions
"""
if not text_bn.strip():
logger.warning("Empty text provided to scene planner")
return self._fallback_scenes(duration)
try:
# Determine optimal scene count based on duration and content
scene_count = self._calculate_scene_count(text_bn, duration)
logger.info(f"Planning {scene_count} scenes for {duration}s video")
# Extract scenes using transformer model
raw_scenes = self.parser.extract_scenes(text_bn, scene_count)
# Generate scene plan with proper timing
scenes = self._generate_scene_timing(raw_scenes, duration, scene_count)
logger.info(f"Generated {len(scenes)} scenes successfully")
return scenes
except Exception as e:
logger.error(f"Scene planning failed: {e}")
return self._fallback_scenes(duration)
def _calculate_scene_count(self, text_bn: str, duration: int) -> int:
"""
Calculate optimal number of scenes based on content and duration.
Args:
text_bn: Input Bangla text
duration: Video duration in seconds
Returns:
Optimal scene count (3-12)
"""
text_length = len(text_bn)
# Base scene count from duration
if duration <= 10:
base_scenes = 3
elif duration <= 20:
base_scenes = 5
elif duration <= 30:
base_scenes = 7
else:
base_scenes = min(12, max(5, duration // 3))
# Adjust based on text complexity
sentences = text_bn.count('।') + text_bn.count('.') + text_bn.count('!')
if sentences > 0:
content_based = min(10, sentences + 2)
scene_count = min(base_scenes, content_based)
else:
scene_count = base_scenes
# Ensure reasonable bounds
return max(3, min(scene_count, 12))
def _generate_scene_timing(self, scenes: List[str], duration: int, scene_count: int) -> List[Dict]:
"""
Generate scene timing with proper pacing.
Args:
scenes: List of scene descriptions
duration: Total video duration
scene_count: Number of scenes
Returns:
List of scene dictionaries with timing
"""
if not scenes:
return self._fallback_scenes(duration)
# Calculate base timing per scene
base_duration = duration / len(scenes)
# Apply pacing rules for visual coherence
scenes_with_timing = []
for i, scene_desc in enumerate(scenes):
# Apply pacing adjustments
scene_duration = self._calculate_scene_duration(
scene_desc, base_duration, i, len(scenes)
)
# Calculate start time
start_time = sum(s.get('duration', 0) for s in scenes_with_timing)
scene = {
"id": i + 1,
"description": scene_desc,
"duration": scene_duration,
"start_time": start_time,
"end_time": start_time + scene_duration,
"visual_style": self._determine_visual_style(scene_desc),
"transition_type": self._determine_transition(i, len(scenes))
}
scenes_with_timing.append(scene)
# Ensure total duration matches target
self._adjust_timing_for_total_duration(scenes_with_timing, duration)
return scenes_with_timing
def _calculate_scene_duration(self, scene_desc: str, base_duration: float,
scene_index: int, total_scenes: int) -> float:
"""
Calculate optimal duration for individual scene.
Args:
scene_desc: Scene description
base_duration: Base duration per scene
scene_index: Index of current scene
total_scenes: Total number of scenes
Returns:
Duration for this scene
"""
# Base duration with some variation
duration = base_duration * (0.9 + 0.2 * (scene_index % 3) / 2)
# Adjust for scene complexity
complexity_indicators = ['চলাচল', 'কথোপকথন', 'অনেক', 'জটিল']
complexity = sum(1 for indicator in complexity_indicators if indicator in scene_desc)
if complexity > 0:
duration *= (1 + 0.3 * complexity)
# Ensure reasonable bounds
return max(1.5, min(duration, 8.0))
def _determine_visual_style(self, scene_desc: str) -> str:
"""Determine appropriate visual style for scene."""
if any(word in scene_desc.lower() for word in ['প্রকৃতি', 'বন', 'নদী']):
return "nature_landscape"
elif any(word in scene_desc.lower() for word in ['শহর', 'রাস্তা', 'গাড়ি']):
return "urban_environment"
elif any(word in scene_desc.lower() for word in ['বাড়ি', 'ঘর', 'আসবাব']):
return "indoor_scene"
elif any(word in scene_desc.lower() for word in ['মানুষ', 'ব্যক্তি', 'দল']):
return "character_focused"
else:
return "general_visual"
def _determine_transition(self, scene_index: int, total_scenes: int) -> str:
"""Determine transition type between scenes."""
if scene_index == 0:
return "fade_in"
elif scene_index == total_scenes - 1:
return "fade_out"
else:
return "cross_fade"
def _adjust_timing_for_total_duration(self, scenes: List[Dict], target_duration: float):
"""
Adjust scene timings to match target duration exactly.
Args:
scenes: List of scenes with timing
target_duration: Target total duration
"""
current_total = sum(scene['duration'] for scene in scenes)
if abs(current_total - target_duration) < 0.1:
return # Already close enough
# Calculate adjustment factor
adjustment_factor = target_duration / current_total
# Apply adjustment
for scene in scenes:
original_duration = scene['duration']
scene['duration'] = original_duration * adjustment_factor
# Update start/end times
scene_index = scene['id'] - 1
if scene_index == 0:
scene['start_time'] = 0
else:
scene['start_time'] = sum(s['duration'] for s in scenes[:scene_index])
scene['end_time'] = scene['start_time'] + scene['duration']
def _fallback_scenes(self, duration: int) -> List[Dict]:
"""
Generate fallback scenes when main planning fails.
Args:
duration: Video duration
Returns:
Basic scene plan
"""
scene_count = 3
scene_duration = duration / scene_count
scenes = []
for i in range(scene_count):
scene = {
"id": i + 1,
"description": f"Fallback Scene {i+1}: Visual content for segment {i+1}",
"duration": scene_duration,
"start_time": i * scene_duration,
"end_time": (i + 1) * scene_duration,
"visual_style": "general_visual",
"transition_type": "cross_fade" if i < scene_count - 1 else "fade_out"
}
scenes.append(scene)
return scenes
def get_scene_statistics(self, scenes: List[Dict]) -> Dict:
"""
Get statistics about the generated scene plan.
Args:
scenes: List of scenes
Returns:
Dictionary with scene statistics
"""
if not scenes:
return {"total_scenes": 0, "total_duration": 0}
durations = [scene['duration'] for scene in scenes]
styles = [scene['visual_style'] for scene in scenes]
return {
"total_scenes": len(scenes),
"total_duration": sum(durations),
"avg_scene_duration": sum(durations) / len(durations),
"min_scene_duration": min(durations),
"max_scene_duration": max(durations),
"visual_styles": list(set(styles)),
"scene_distribution": {style: styles.count(style) for style in set(styles)}
}
# Global planner instance
_planner_instance = None
def get_planner(model_id: str = "google/mt5-small") -> ScenePlanner:
"""Get or create a global scene planner instance."""
global _planner_instance
if _planner_instance is None or _planner_instance.parser.model_id != model_id:
_planner_instance = ScenePlanner(model_id)
return _planner_instance
def plan_scenes(text_bn: str, duration: int = 15) -> List[Dict]:
"""Convenience function for scene planning."""
planner = get_planner()
return planner.plan_scenes(text_bn, duration)