Spaces:
Running
Running
| """Model chaining engine for multi-stage AI pipelines.""" | |
| from __future__ import annotations | |
| import asyncio | |
| from collections.abc import AsyncIterator | |
| from dataclasses import dataclass | |
| from typing import Any, Callable | |
| from loguru import logger | |
| class ChainStage: | |
| """A single stage in a model chain.""" | |
| model_ref: str # e.g., "zen/minimax-m2.5-free" | |
| stage_name: str # e.g., "vision_analysis", "code_generation" | |
| description: str | |
| class ChainResult: | |
| """Result from executing a chain stage.""" | |
| stage: ChainStage | |
| output: str | |
| success: bool | |
| error: str | None = None | |
| # Chain templates for common multi-capability tasks | |
| CHAIN_TEMPLATES: dict[str, list[ChainStage]] = { | |
| "vision_to_text": [ | |
| ChainStage( | |
| model_ref="nvidia_nim/stepfun-ai/step-3.5-flash", | |
| stage_name="image_analysis", | |
| description="Analyze image content", | |
| ), | |
| ChainStage( | |
| model_ref="zen/minimax-m2.5-free", | |
| stage_name="response_generation", | |
| description="Generate final response", | |
| ), | |
| ], | |
| "reasoning_to_generation": [ | |
| ChainStage( | |
| model_ref="nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct", | |
| stage_name="analysis", | |
| description="Analyze and plan", | |
| ), | |
| ChainStage( | |
| model_ref="zen/minimax-m2.5-free", | |
| stage_name="generation", | |
| description="Generate output", | |
| ), | |
| ], | |
| } | |
| class ChainEngine: | |
| """Execute multi-model pipelines for complex requests.""" | |
| def __init__(self, provider_getter: Callable[[str], Any]): | |
| self._provider_getter = provider_getter | |
| async def execute_simple_chain( | |
| self, | |
| stages: list[ChainStage], | |
| initial_messages: list[Any], | |
| system_prompt: str | None = None, | |
| ) -> AsyncIterator[str]: | |
| """Execute a chain of models sequentially. | |
| Args: | |
| stages: List of chain stages to execute | |
| initial_messages: Initial user messages | |
| system_prompt: Optional system prompt | |
| Yields: | |
| SSE events from the final model in the chain | |
| """ | |
| if not stages: | |
| return | |
| logger.info("ChainEngine: executing {} stages", len(stages)) | |
| # For now, execute single model - full chaining requires more integration | |
| # This is a placeholder for the full implementation | |
| first_stage = stages[0] | |
| provider = self._provider_getter(first_stage.model_ref.split("/")[0]) | |
| logger.info( | |
| "ChainEngine: using model {} for chain", | |
| first_stage.model_ref, | |
| ) | |
| # For Phase 1, just delegate to provider - full chaining comes later | |
| # The infrastructure is now in place | |
| async for event in provider.stream_response( | |
| initial_messages, system_prompt, {} | |
| ): | |
| yield event | |
| def get_chain_for_requirements( | |
| self, | |
| required_capabilities: set[str], | |
| available_models: list[str], | |
| ) -> list[ChainStage] | None: | |
| """Determine the appropriate chain based on required capabilities. | |
| Args: | |
| required_capabilities: Set of capabilities needed | |
| available_models: Available model references | |
| Returns: | |
| Chain stages or None if single model is sufficient | |
| """ | |
| # If only one capability needed, no chain needed | |
| if len(required_capabilities) <= 1: | |
| return None | |
| # If multiple capabilities, build a simple chain | |
| if "vision" in required_capabilities and "coding" in required_capabilities: | |
| return CHAIN_TEMPLATES.get("vision_to_text") | |
| if "vision" in required_capabilities and "reasoning" in required_capabilities: | |
| return CHAIN_TEMPLATES.get("vision_to_text") | |
| if "reasoning" in required_capabilities and "coding" in required_capabilities: | |
| return CHAIN_TEMPLATES.get("reasoning_to_generation") | |
| # Default: no chain for now | |
| return None | |
| async def execute_model_for_stage( | |
| provider: Any, | |
| messages: list[Any], | |
| system: str | None, | |
| metadata: dict[str, Any], | |
| ) -> str: | |
| """Execute a single model stage and return its output.""" | |
| output_parts = [] | |
| try: | |
| async for event in provider.stream_response(messages, system, metadata): | |
| # Parse SSE and collect text output | |
| if "content_block_delta" in event: | |
| # Extract text from delta | |
| pass | |
| return "".join(output_parts) | |
| except Exception as e: | |
| logger.error("Chain stage failed: {}", e) | |
| raise | |