"""Model chaining engine for multi-stage AI pipelines.""" from __future__ import annotations import asyncio from collections.abc import AsyncIterator from dataclasses import dataclass from typing import Any, Callable from loguru import logger @dataclass(frozen=True, slots=True) class ChainStage: """A single stage in a model chain.""" model_ref: str # e.g., "zen/minimax-m2.5-free" stage_name: str # e.g., "vision_analysis", "code_generation" description: str @dataclass(frozen=True, slots=True) class ChainResult: """Result from executing a chain stage.""" stage: ChainStage output: str success: bool error: str | None = None # Chain templates for common multi-capability tasks CHAIN_TEMPLATES: dict[str, list[ChainStage]] = { "vision_to_text": [ ChainStage( model_ref="nvidia_nim/stepfun-ai/step-3.5-flash", stage_name="image_analysis", description="Analyze image content", ), ChainStage( model_ref="zen/minimax-m2.5-free", stage_name="response_generation", description="Generate final response", ), ], "reasoning_to_generation": [ ChainStage( model_ref="nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct", stage_name="analysis", description="Analyze and plan", ), ChainStage( model_ref="zen/minimax-m2.5-free", stage_name="generation", description="Generate output", ), ], } class ChainEngine: """Execute multi-model pipelines for complex requests.""" def __init__(self, provider_getter: Callable[[str], Any]): self._provider_getter = provider_getter async def execute_simple_chain( self, stages: list[ChainStage], initial_messages: list[Any], system_prompt: str | None = None, ) -> AsyncIterator[str]: """Execute a chain of models sequentially. Args: stages: List of chain stages to execute initial_messages: Initial user messages system_prompt: Optional system prompt Yields: SSE events from the final model in the chain """ if not stages: return logger.info("ChainEngine: executing {} stages", len(stages)) # For now, execute single model - full chaining requires more integration # This is a placeholder for the full implementation first_stage = stages[0] provider = self._provider_getter(first_stage.model_ref.split("/")[0]) logger.info( "ChainEngine: using model {} for chain", first_stage.model_ref, ) # For Phase 1, just delegate to provider - full chaining comes later # The infrastructure is now in place async for event in provider.stream_response( initial_messages, system_prompt, {} ): yield event def get_chain_for_requirements( self, required_capabilities: set[str], available_models: list[str], ) -> list[ChainStage] | None: """Determine the appropriate chain based on required capabilities. Args: required_capabilities: Set of capabilities needed available_models: Available model references Returns: Chain stages or None if single model is sufficient """ # If only one capability needed, no chain needed if len(required_capabilities) <= 1: return None # If multiple capabilities, build a simple chain if "vision" in required_capabilities and "coding" in required_capabilities: return CHAIN_TEMPLATES.get("vision_to_text") if "vision" in required_capabilities and "reasoning" in required_capabilities: return CHAIN_TEMPLATES.get("vision_to_text") if "reasoning" in required_capabilities and "coding" in required_capabilities: return CHAIN_TEMPLATES.get("reasoning_to_generation") # Default: no chain for now return None async def execute_model_for_stage( provider: Any, messages: list[Any], system: str | None, metadata: dict[str, Any], ) -> str: """Execute a single model stage and return its output.""" output_parts = [] try: async for event in provider.stream_response(messages, system, metadata): # Parse SSE and collect text output if "content_block_delta" in event: # Extract text from delta pass return "".join(output_parts) except Exception as e: logger.error("Chain stage failed: {}", e) raise