Tools / src /pipelines /factory.py
jebin2's picture
feat: Implement VoiceOver AI Pipeline and Refactor Config
5f00d5a
"""
Pipeline Factory - Registry-based pipeline selection.
"""
from typing import Callable, List, Tuple
from src.config import get_config_value
from src.pipelines.base import ContentAutomationBase
# Registry: List of (condition_fn, pipeline_class) tuples in priority order
_PIPELINE_REGISTRY: List[Tuple[Callable[[], bool], type]] = []
def register_pipeline(condition: Callable[[], bool], priority: int = 0):
"""Decorator to register a pipeline with its activation condition."""
def decorator(cls):
_PIPELINE_REGISTRY.append((priority, condition, cls))
_PIPELINE_REGISTRY.sort(key=lambda x: x[0], reverse=True) # Higher priority first
return cls
return decorator
def _init_registry():
"""Initialize the pipeline registry with default pipelines."""
from src.pipelines.standard_ai_pipeline import StandardAIPipeline
from src.pipelines.avatar_ai_pipeline import AvatarAIPipeline
from src.pipelines.hard_cut_pipeline import HardCutPipeline
from src.pipelines.beats_cut_pipeline import BeatsCutPipeline
from src.pipelines.voiceover_ai_pipeline import VoiceOverAIPipeline
global _PIPELINE_REGISTRY
_PIPELINE_REGISTRY = [
# (priority, condition, pipeline_class)
(100, lambda: get_config_value("SETUP_TYPE", "") == "vo_video_gen", VoiceOverAIPipeline),
(50, lambda: get_config_value("video_merge_type") == "hard_cut", HardCutPipeline),
(50, lambda: get_config_value("video_merge_type") == "beats_cut", BeatsCutPipeline),
(10, lambda: get_config_value("is_a2e_lip_sync", False), AvatarAIPipeline),
(0, lambda: True, StandardAIPipeline), # Default fallback
]
def get_automation_pipeline() -> ContentAutomationBase:
"""
Factory function to return the appropriate automation pipeline.
Uses priority-ordered registry to find the first matching pipeline.
"""
# Lazy init registry
if not _PIPELINE_REGISTRY:
_init_registry()
for priority, condition, pipeline_class in _PIPELINE_REGISTRY:
if condition():
return pipeline_class()
# Should never reach here due to default fallback
from src.pipelines.standard_ai_pipeline import StandardAIPipeline
return StandardAIPipeline()